mirror of
https://github.com/beetbox/beets.git
synced 2026-01-30 20:13:37 +01:00
Add pull() generator to Pipeline
This commit is contained in:
parent
736835ce72
commit
565a284c03
2 changed files with 53 additions and 15 deletions
|
|
@ -315,21 +315,7 @@ class Pipeline(object):
|
|||
stages are run one after the other. Only the first coroutine
|
||||
in each stage is used.
|
||||
"""
|
||||
coros = [stage[0] for stage in self.stages]
|
||||
|
||||
# "Prime" the coroutines.
|
||||
for coro in coros[1:]:
|
||||
coro.next()
|
||||
|
||||
# Begin the pipeline.
|
||||
for out in coros[0]:
|
||||
msgs = _allmsgs(out)
|
||||
for coro in coros[1:]:
|
||||
next_msgs = []
|
||||
for msg in msgs:
|
||||
out = coro.send(msg)
|
||||
next_msgs.extend(_allmsgs(out))
|
||||
msgs = next_msgs
|
||||
list(self.pull())
|
||||
|
||||
def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE):
|
||||
"""Run the pipeline in parallel using one thread per stage. The
|
||||
|
|
@ -386,6 +372,30 @@ class Pipeline(object):
|
|||
# Make the exception appear as it was raised originally.
|
||||
raise exc_info[0], exc_info[1], exc_info[2]
|
||||
|
||||
def pull(self):
|
||||
"""Yield elements from the end of the pipeline. Runs the stages
|
||||
sequentially until the last yields some messages. Each of the messages
|
||||
is then yielded by ``pulled.next()``. Only the first coroutine in each
|
||||
stage is used.
|
||||
"""
|
||||
coros = [stage[0] for stage in self.stages]
|
||||
|
||||
# "Prime" the coroutines.
|
||||
for coro in coros[1:]:
|
||||
coro.next()
|
||||
|
||||
# Begin the pipeline.
|
||||
for out in coros[0]:
|
||||
msgs = _allmsgs(out)
|
||||
for coro in coros[1:]:
|
||||
next_msgs = []
|
||||
for msg in msgs:
|
||||
out = coro.send(msg)
|
||||
next_msgs.extend(_allmsgs(out))
|
||||
msgs = next_msgs
|
||||
for msg in msgs:
|
||||
yield msg
|
||||
|
||||
# Smoke test.
|
||||
if __name__ == '__main__':
|
||||
import time
|
||||
|
|
|
|||
|
|
@ -71,6 +71,15 @@ class SimplePipelineTest(unittest.TestCase):
|
|||
self.pl.run_parallel()
|
||||
self.assertEqual(self.l, [0,2,4,6,8])
|
||||
|
||||
def test_pull(self):
|
||||
pl = pipeline.Pipeline((_produce(), _work()))
|
||||
self.assertEqual(list(pl.pull()), [0,2,4,6,8])
|
||||
|
||||
def test_pull_chain(self):
|
||||
pl = pipeline.Pipeline((_produce(), _work()))
|
||||
pl2 = pipeline.Pipeline((pl.pull(), _work()))
|
||||
self.assertEqual(list(pl2.pull()), [0,4,8,12,16])
|
||||
|
||||
class ParallelStageTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.l = []
|
||||
|
|
@ -87,6 +96,10 @@ class ParallelStageTest(unittest.TestCase):
|
|||
# Order possibly not preserved; use set equality.
|
||||
self.assertEqual(set(self.l), set([0,2,4,6,8]))
|
||||
|
||||
def test_pull(self):
|
||||
pl = pipeline.Pipeline((_produce(), (_work(),_work())))
|
||||
self.assertEqual(list(pl.pull()), [0,2,4,6,8])
|
||||
|
||||
class ExceptionTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.l = []
|
||||
|
|
@ -98,6 +111,12 @@ class ExceptionTest(unittest.TestCase):
|
|||
def test_run_parallel(self):
|
||||
self.assertRaises(TestException, self.pl.run_parallel)
|
||||
|
||||
def test_pull(self):
|
||||
pl = pipeline.Pipeline((_produce(), _exc_work()))
|
||||
pull = pl.pull()
|
||||
for i in range(3): pull.next()
|
||||
self.assertRaises(TestException, pull.next)
|
||||
|
||||
class ParallelExceptionTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.l = []
|
||||
|
|
@ -144,6 +163,10 @@ class BubbleTest(unittest.TestCase):
|
|||
self.pl.run_parallel()
|
||||
self.assertEqual(self.l, [0,2,4,8])
|
||||
|
||||
def test_pull(self):
|
||||
pl = pipeline.Pipeline((_produce(), _bub_work()))
|
||||
self.assertEqual(list(pl.pull()), [0,2,4,8])
|
||||
|
||||
class MultiMessageTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.l = []
|
||||
|
|
@ -159,6 +182,11 @@ class MultiMessageTest(unittest.TestCase):
|
|||
self.pl.run_parallel()
|
||||
self.assertEqual(self.l, [0,0,1,-1,2,-2,3,-3,4,-4])
|
||||
|
||||
def test_pull(self):
|
||||
pl = pipeline.Pipeline((_produce(), _multi_work()))
|
||||
self.assertEqual(list(pl.pull()), [0,0,1,-1,2,-2,3,-3,4,-4])
|
||||
|
||||
|
||||
def suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue