diff --git a/beets/importer.py b/beets/importer.py index 063e8f326..8be2c2c8e 100644 --- a/beets/importer.py +++ b/beets/importer.py @@ -929,7 +929,7 @@ def query_tasks(session): yield ImportTask(None, [album.item_dir()], items) -@pipeline.stage +@pipeline.mutator_stage def lookup_candidates(session, task): """A coroutine for performing the initial MusicBrainz lookup for an album. It accepts lists of Items and yields @@ -937,12 +937,11 @@ def lookup_candidates(session, task): is found, all of the yielded parameters (except items) are None. """ if task.skip: - return task + return plugins.send('import_task_start', session=session, task=task) log.debug('Looking up: %s' % displayable_path(task.paths)) task.lookup_candidates() - return task @pipeline.stage @@ -990,6 +989,7 @@ def user_query(session, task): user_query(session) ]) task = pipeline.multiple(ipl.pull()) + return task @@ -1013,7 +1013,7 @@ def resolve_duplicates(session): recent.add(ident) -@pipeline.stage +@pipeline.mutator_stage def import_asis(session, task): """Select the `action.ASIS` choice for all tasks. @@ -1021,23 +1021,22 @@ def import_asis(session, task): when the importer is run without autotagging. """ if task.skip: - return task + return log.info(displayable_path(task.paths)) # Behave as if ASIS were selected. task.set_null_candidates() task.set_choice(action.ASIS) - return task -@pipeline.stage +@pipeline.mutator_stage def apply_choices(session, task): """A coroutine for applying changes to albums and singletons during the autotag process. """ if task.skip: - return task + return # Change metadata. if task.apply: @@ -1049,32 +1048,31 @@ def apply_choices(session, task): task.infer_album_fields() task.add(session.lib) - return task -@pipeline.stage +@pipeline.mutator_stage def plugin_stage(session, func, task): """A coroutine (pipeline stage) that calls the given function with each non-skipped import task. These stages occur between applying metadata changes and moving/copying/writing files. """ if task.skip: - return task + return + func(session, task) # Stage may modify DB, so re-load cached item data. for item in task.imported_items(): item.load() - return task -@pipeline.stage +@pipeline.mutator_stage def manipulate_files(session, task): """A coroutine (pipeline stage) that performs necessary file manipulations *after* items have been added to the library. """ if task.skip: - return task + return if task.remove_duplicates: task.do_remove_duplicates(session.lib) @@ -1085,7 +1083,6 @@ def manipulate_files(session, task): write=config['import']['write'], session=session, ) - return task # TODO Get rid of this. diff --git a/beets/util/pipeline.py b/beets/util/pipeline.py index 1591122e6..d267789c8 100644 --- a/beets/util/pipeline.py +++ b/beets/util/pipeline.py @@ -159,6 +159,28 @@ def stage(func): return coro +def mutator_stage(func): + """Decorate a function that manipulates items in a coroutine to + become a simple stage. + + >>> @mutator_stage + ... def setkey(key, item): + ... item[key] = True + >>> pipe = Pipeline([ + ... iter([{'x': False}, {'a': False}]), + ... setkey('x'), + ... ]) + >>> list(pipe.pull()) + [{'x': True}, {'a': False, 'x': True}] + """ + + def coro(*args): + task = None + while True: + task = yield task + func(*(args + (task,))) + return coro + def _allmsgs(obj): """Returns a list of all the messages encapsulated in obj. If obj diff --git a/test/test_pipeline.py b/test/test_pipeline.py index 62a087eb2..0c4de6836 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -210,7 +210,7 @@ class MultiMessageTest(unittest.TestCase): class StageDecoratorTest(unittest.TestCase): - def test_decorator(self): + def test_stage_decorator(self): @pipeline.stage def add(n, i): return i + n @@ -221,6 +221,18 @@ class StageDecoratorTest(unittest.TestCase): ]) self.assertEqual(list(pl.pull()), [3, 4, 5]) + def test_mutator_stage_decorator(self): + @pipeline.mutator_stage + def setkey(key, item): + item[key] = True + + pl = pipeline.Pipeline([ + iter([{'x': False}, {'a': False}]), + setkey('x'), + ]) + self.assertEqual(list(pl.pull()), + [{'x': True}, {'a': False, 'x': True}]) + def suite(): return unittest.TestLoader().loadTestsFromName(__name__)