From 565a284c033130ee352129f520da341847e35db9 Mon Sep 17 00:00:00 2001 From: Thomas Scholtes Date: Sat, 1 Feb 2014 13:43:05 +0100 Subject: [PATCH] Add pull() generator to Pipeline --- beets/util/pipeline.py | 40 +++++++++++++++++++++++++--------------- test/test_pipeline.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/beets/util/pipeline.py b/beets/util/pipeline.py index f2923dc78..22531bc7d 100644 --- a/beets/util/pipeline.py +++ b/beets/util/pipeline.py @@ -315,21 +315,7 @@ class Pipeline(object): stages are run one after the other. Only the first coroutine in each stage is used. """ - coros = [stage[0] for stage in self.stages] - - # "Prime" the coroutines. - for coro in coros[1:]: - coro.next() - - # Begin the pipeline. - for out in coros[0]: - msgs = _allmsgs(out) - for coro in coros[1:]: - next_msgs = [] - for msg in msgs: - out = coro.send(msg) - next_msgs.extend(_allmsgs(out)) - msgs = next_msgs + list(self.pull()) def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE): """Run the pipeline in parallel using one thread per stage. The @@ -386,6 +372,30 @@ class Pipeline(object): # Make the exception appear as it was raised originally. raise exc_info[0], exc_info[1], exc_info[2] + def pull(self): + """Yield elements from the end of the pipeline. Runs the stages + sequentially until the last yields some messages. Each of the messages + is then yielded by ``pulled.next()``. Only the first coroutine in each + stage is used. + """ + coros = [stage[0] for stage in self.stages] + + # "Prime" the coroutines. + for coro in coros[1:]: + coro.next() + + # Begin the pipeline. + for out in coros[0]: + msgs = _allmsgs(out) + for coro in coros[1:]: + next_msgs = [] + for msg in msgs: + out = coro.send(msg) + next_msgs.extend(_allmsgs(out)) + msgs = next_msgs + for msg in msgs: + yield msg + # Smoke test. if __name__ == '__main__': import time diff --git a/test/test_pipeline.py b/test/test_pipeline.py index 1e5bb72d3..4163a3aba 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -71,6 +71,15 @@ class SimplePipelineTest(unittest.TestCase): self.pl.run_parallel() self.assertEqual(self.l, [0,2,4,6,8]) + def test_pull(self): + pl = pipeline.Pipeline((_produce(), _work())) + self.assertEqual(list(pl.pull()), [0,2,4,6,8]) + + def test_pull_chain(self): + pl = pipeline.Pipeline((_produce(), _work())) + pl2 = pipeline.Pipeline((pl.pull(), _work())) + self.assertEqual(list(pl2.pull()), [0,4,8,12,16]) + class ParallelStageTest(unittest.TestCase): def setUp(self): self.l = [] @@ -87,6 +96,10 @@ class ParallelStageTest(unittest.TestCase): # Order possibly not preserved; use set equality. self.assertEqual(set(self.l), set([0,2,4,6,8])) + def test_pull(self): + pl = pipeline.Pipeline((_produce(), (_work(),_work()))) + self.assertEqual(list(pl.pull()), [0,2,4,6,8]) + class ExceptionTest(unittest.TestCase): def setUp(self): self.l = [] @@ -98,6 +111,12 @@ class ExceptionTest(unittest.TestCase): def test_run_parallel(self): self.assertRaises(TestException, self.pl.run_parallel) + def test_pull(self): + pl = pipeline.Pipeline((_produce(), _exc_work())) + pull = pl.pull() + for i in range(3): pull.next() + self.assertRaises(TestException, pull.next) + class ParallelExceptionTest(unittest.TestCase): def setUp(self): self.l = [] @@ -144,6 +163,10 @@ class BubbleTest(unittest.TestCase): self.pl.run_parallel() self.assertEqual(self.l, [0,2,4,8]) + def test_pull(self): + pl = pipeline.Pipeline((_produce(), _bub_work())) + self.assertEqual(list(pl.pull()), [0,2,4,8]) + class MultiMessageTest(unittest.TestCase): def setUp(self): self.l = [] @@ -159,6 +182,11 @@ class MultiMessageTest(unittest.TestCase): self.pl.run_parallel() self.assertEqual(self.l, [0,0,1,-1,2,-2,3,-3,4,-4]) + def test_pull(self): + pl = pipeline.Pipeline((_produce(), _multi_work())) + self.assertEqual(list(pl.pull()), [0,0,1,-1,2,-2,3,-3,4,-4]) + + def suite(): return unittest.TestLoader().loadTestsFromName(__name__)