Add pull() generator to Pipeline

This commit is contained in:
Thomas Scholtes 2014-02-01 13:43:05 +01:00
parent 736835ce72
commit 565a284c03
2 changed files with 53 additions and 15 deletions

View file

@ -315,21 +315,7 @@ class Pipeline(object):
stages are run one after the other. Only the first coroutine
in each stage is used.
"""
coros = [stage[0] for stage in self.stages]
# "Prime" the coroutines.
for coro in coros[1:]:
coro.next()
# Begin the pipeline.
for out in coros[0]:
msgs = _allmsgs(out)
for coro in coros[1:]:
next_msgs = []
for msg in msgs:
out = coro.send(msg)
next_msgs.extend(_allmsgs(out))
msgs = next_msgs
list(self.pull())
def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE):
"""Run the pipeline in parallel using one thread per stage. The
@ -386,6 +372,30 @@ class Pipeline(object):
# Make the exception appear as it was raised originally.
raise exc_info[0], exc_info[1], exc_info[2]
def pull(self):
"""Yield elements from the end of the pipeline. Runs the stages
sequentially until the last yields some messages. Each of the messages
is then yielded by ``pulled.next()``. Only the first coroutine in each
stage is used.
"""
coros = [stage[0] for stage in self.stages]
# "Prime" the coroutines.
for coro in coros[1:]:
coro.next()
# Begin the pipeline.
for out in coros[0]:
msgs = _allmsgs(out)
for coro in coros[1:]:
next_msgs = []
for msg in msgs:
out = coro.send(msg)
next_msgs.extend(_allmsgs(out))
msgs = next_msgs
for msg in msgs:
yield msg
# Smoke test.
if __name__ == '__main__':
import time

View file

@ -71,6 +71,15 @@ class SimplePipelineTest(unittest.TestCase):
self.pl.run_parallel()
self.assertEqual(self.l, [0,2,4,6,8])
def test_pull(self):
pl = pipeline.Pipeline((_produce(), _work()))
self.assertEqual(list(pl.pull()), [0,2,4,6,8])
def test_pull_chain(self):
pl = pipeline.Pipeline((_produce(), _work()))
pl2 = pipeline.Pipeline((pl.pull(), _work()))
self.assertEqual(list(pl2.pull()), [0,4,8,12,16])
class ParallelStageTest(unittest.TestCase):
def setUp(self):
self.l = []
@ -87,6 +96,10 @@ class ParallelStageTest(unittest.TestCase):
# Order possibly not preserved; use set equality.
self.assertEqual(set(self.l), set([0,2,4,6,8]))
def test_pull(self):
pl = pipeline.Pipeline((_produce(), (_work(),_work())))
self.assertEqual(list(pl.pull()), [0,2,4,6,8])
class ExceptionTest(unittest.TestCase):
def setUp(self):
self.l = []
@ -98,6 +111,12 @@ class ExceptionTest(unittest.TestCase):
def test_run_parallel(self):
self.assertRaises(TestException, self.pl.run_parallel)
def test_pull(self):
pl = pipeline.Pipeline((_produce(), _exc_work()))
pull = pl.pull()
for i in range(3): pull.next()
self.assertRaises(TestException, pull.next)
class ParallelExceptionTest(unittest.TestCase):
def setUp(self):
self.l = []
@ -144,6 +163,10 @@ class BubbleTest(unittest.TestCase):
self.pl.run_parallel()
self.assertEqual(self.l, [0,2,4,8])
def test_pull(self):
pl = pipeline.Pipeline((_produce(), _bub_work()))
self.assertEqual(list(pl.pull()), [0,2,4,8])
class MultiMessageTest(unittest.TestCase):
def setUp(self):
self.l = []
@ -159,6 +182,11 @@ class MultiMessageTest(unittest.TestCase):
self.pl.run_parallel()
self.assertEqual(self.l, [0,0,1,-1,2,-2,3,-3,4,-4])
def test_pull(self):
pl = pipeline.Pipeline((_produce(), _multi_work()))
self.assertEqual(list(pl.pull()), [0,0,1,-1,2,-2,3,-3,4,-4])
def suite():
return unittest.TestLoader().loadTestsFromName(__name__)