diff --git a/beets/importer.py b/beets/importer.py index 308270577..063e8f326 100644 --- a/beets/importer.py +++ b/beets/importer.py @@ -929,24 +929,24 @@ def query_tasks(session): yield ImportTask(None, [album.item_dir()], items) -def lookup_candidates(session): +@pipeline.stage +def lookup_candidates(session, task): """A coroutine for performing the initial MusicBrainz lookup for an album. It accepts lists of Items and yields (items, cur_artist, cur_album, candidates, rec) tuples. If no match is found, all of the yielded parameters (except items) are None. """ - task = None - while True: - task = yield task - if task.skip: - continue + if task.skip: + return task - plugins.send('import_task_start', session=session, task=task) - log.debug('Looking up: %s' % displayable_path(task.paths)) - task.lookup_candidates() + plugins.send('import_task_start', session=session, task=task) + log.debug('Looking up: %s' % displayable_path(task.paths)) + task.lookup_candidates() + return task -def user_query(session): +@pipeline.stage +def user_query(session, task): """A coroutine for interfacing with the user about the tagging process. @@ -959,40 +959,38 @@ def user_query(session): acces to the choice via the ``taks.choice_flag`` property and may choose to change it. """ - task = None - while True: - task = yield task - if task.skip: - continue + if task.skip: + return task - # Ask the user for a choice. - task.choose_match(session) - plugins.send('import_task_choice', session=session, task=task) + # Ask the user for a choice. + task.choose_match(session) + plugins.send('import_task_choice', session=session, task=task) - # As-tracks: transition to singleton workflow. - if task.choice_flag is action.TRACKS: - # Set up a little pipeline for dealing with the singletons. - def emitter(task): - for item in task.items: - yield SingletonImportTask(item) - yield SentinelImportTask(task.toppath, task.paths) + # As-tracks: transition to singleton workflow. + if task.choice_flag is action.TRACKS: + # Set up a little pipeline for dealing with the singletons. + def emitter(task): + for item in task.items: + yield SingletonImportTask(item) + yield SentinelImportTask(task.toppath, task.paths) - ipl = pipeline.Pipeline([ - emitter(task), - lookup_candidates(session), - user_query(session), - ]) - task = pipeline.multiple(ipl.pull()) + ipl = pipeline.Pipeline([ + emitter(task), + lookup_candidates(session), + user_query(session), + ]) + task = pipeline.multiple(ipl.pull()) - # As albums: group items by albums and create task for each album - elif task.choice_flag is action.ALBUMS: - ipl = pipeline.Pipeline([ - iter([task]), - group_albums(session), - lookup_candidates(session), - user_query(session) - ]) - task = pipeline.multiple(ipl.pull()) + # As albums: group items by albums and create task for each album + elif task.choice_flag is action.ALBUMS: + ipl = pipeline.Pipeline([ + iter([task]), + group_albums(session), + lookup_candidates(session), + user_query(session) + ]) + task = pipeline.multiple(ipl.pull()) + return task def resolve_duplicates(session): @@ -1015,90 +1013,85 @@ def resolve_duplicates(session): recent.add(ident) -def import_asis(session): +@pipeline.stage +def import_asis(session, task): """Select the `action.ASIS` choice for all tasks. This stage replaces the initial_lookup and user_query stages when the importer is run without autotagging. """ - task = None - while True: - task = yield task - if task.skip: - continue + if task.skip: + return task - log.info(displayable_path(task.paths)) + log.info(displayable_path(task.paths)) - # Behave as if ASIS were selected. - task.set_null_candidates() - task.set_choice(action.ASIS) + # Behave as if ASIS were selected. + task.set_null_candidates() + task.set_choice(action.ASIS) + return task -def apply_choices(session): +@pipeline.stage +def apply_choices(session, task): """A coroutine for applying changes to albums and singletons during the autotag process. """ - task = None - while True: - task = yield task - if task.skip: - continue + if task.skip: + return task - # Change metadata. - if task.apply: - task.apply_metadata() - plugins.send('import_task_apply', session=session, task=task) + # Change metadata. + if task.apply: + task.apply_metadata() + plugins.send('import_task_apply', session=session, task=task) - # Infer album-level fields. - if task.is_album: - task.infer_album_fields() + # Infer album-level fields. + if task.is_album: + task.infer_album_fields() - task.add(session.lib) + task.add(session.lib) + return task -def plugin_stage(session, func): +@pipeline.stage +def plugin_stage(session, func, task): """A coroutine (pipeline stage) that calls the given function with each non-skipped import task. These stages occur between applying metadata changes and moving/copying/writing files. """ - task = None - while True: - task = yield task - if task.skip: - continue - func(session, task) + if task.skip: + return task + func(session, task) - # Stage may modify DB, so re-load cached item data. - for item in task.imported_items(): - item.load() + # Stage may modify DB, so re-load cached item data. + for item in task.imported_items(): + item.load() + return task -def manipulate_files(session): +@pipeline.stage +def manipulate_files(session, task): """A coroutine (pipeline stage) that performs necessary file manipulations *after* items have been added to the library. """ - task = None - while True: - task = yield task - if task.skip: - continue + if task.skip: + return task - if task.remove_duplicates: - task.do_remove_duplicates(session.lib) + if task.remove_duplicates: + task.do_remove_duplicates(session.lib) - task.manipulate_files( - move=config['import']['move'], - copy=config['import']['copy'], - write=config['import']['write'], - session=session, - ) + task.manipulate_files( + move=config['import']['move'], + copy=config['import']['copy'], + write=config['import']['write'], + session=session, + ) + return task # TODO Get rid of this. -def finalize(session): - while True: - task = yield - task.finalize(session) +@pipeline.stage +def finalize(session, task): + task.finalize(session) def group_albums(session): diff --git a/beets/util/pipeline.py b/beets/util/pipeline.py index 95b77b4da..1591122e6 100644 --- a/beets/util/pipeline.py +++ b/beets/util/pipeline.py @@ -137,6 +137,29 @@ def multiple(messages): return MultiMessage(messages) +def stage(func): + """Decorate a function to become a simple stage. + + >>> @stage + ... def add(n, i): + ... return i + n + >>> pipe = Pipeline([ + ... iter([1, 2, 3]), + ... add(2), + ... ]) + >>> list(pipe.pull()) + [3, 4, 5] + """ + + def coro(*args): + task = None + while True: + task = yield task + task = func(*(args + (task,))) + return coro + + + def _allmsgs(obj): """Returns a list of all the messages encapsulated in obj. If obj is a MultiMessage, returns its enclosed messages. If obj is BUBBLE, diff --git a/test/test_pipeline.py b/test/test_pipeline.py index cd371af12..62a087eb2 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -208,6 +208,20 @@ class MultiMessageTest(unittest.TestCase): self.assertEqual(list(pl.pull()), [0, 0, 1, -1, 2, -2, 3, -3, 4, -4]) +class StageDecoratorTest(unittest.TestCase): + + def test_decorator(self): + @pipeline.stage + def add(n, i): + return i + n + + pl = pipeline.Pipeline([ + iter([1, 2, 3]), + add(2) + ]) + self.assertEqual(list(pl.pull()), [3, 4, 5]) + + def suite(): return unittest.TestLoader().loadTestsFromName(__name__)