multiple() function for sending many messages to next stage

This commit is contained in:
Adrian Sampson 2011-04-17 08:18:54 -07:00
parent 85cd3cdb84
commit 828f1aa4f1
2 changed files with 91 additions and 25 deletions

View file

@ -41,7 +41,7 @@ POISON = '__PIPELINE_POISON__'
DEFAULT_QUEUE_SIZE = 16
def invalidate_queue(q, val=None, sync=True):
def _invalidate_queue(q, val=None, sync=True):
"""Breaks a Queue such that it never blocks, always has size 1,
and has no maximum size. get()ing from the queue returns `val`,
which defaults to None. `sync` controls whether a lock is
@ -106,7 +106,7 @@ class CountedQueue(Queue.Queue):
def _get():
out = _old_get()
if not self.queue:
invalidate_queue(self, POISON, False)
_invalidate_queue(self, POISON, False)
return out
if self.queue:
@ -114,15 +114,31 @@ class CountedQueue(Queue.Queue):
self._get = _get
else:
# No items. Invalidate immediately.
invalidate_queue(self, POISON, False)
_invalidate_queue(self, POISON, False)
class PipelineError(object):
"""An indication that an exception occurred in the pipeline. The
object is passed through the pipeline to shut down all threads
before it is raised again in the main thread.
class MultiMessage(object):
"""A message yielded by a pipeline stage encapsulating multiple
values to be sent to the next stage.
"""
def __init__(self, exc_info):
self.exc_info = exc_info
def __init__(self, messages):
self.messages = messages
def multiple(messages):
"""Yield multiple([message, ..]) from a pipeline stage to send
multiple values to the next pipeline stage.
"""
return MultiMessage(messages)
def _allmsgs(obj):
"""Returns a list of all the messages encapsulated in obj. If obj
is a MultiMessage, returns its enclosed messages. If obj is BUBBLE,
returns an empty list. Otherwise, returns a list containing obj.
"""
if isinstance(obj, MultiMessage):
return obj.messages
elif obj == BUBBLE:
return []
else:
return [obj]
class PipelineThread(Thread):
"""Abstract base class for pipeline-stage threads."""
@ -141,9 +157,9 @@ class PipelineThread(Thread):
# Ensure that we are not blocking on a queue read or write.
if hasattr(self, 'in_queue'):
invalidate_queue(self.in_queue)
_invalidate_queue(self.in_queue)
if hasattr(self, 'out_queue'):
invalidate_queue(self.out_queue)
_invalidate_queue(self.out_queue)
def abort_all(self, exc_info):
"""Abort all other threads in the system for an exception.
@ -168,7 +184,6 @@ class FirstPipelineThread(PipelineThread):
def run(self):
try:
while True:
# Time to abort?
with self.abort_lock:
if self.abort_flag:
return
@ -179,10 +194,12 @@ class FirstPipelineThread(PipelineThread):
except StopIteration:
break
# Send it to the next stage.
if msg is BUBBLE:
continue
self.out_queue.put(msg)
# Send messages to the next stage.
for msg in _allmsgs(msg):
with self.abort_lock:
if self.abort_flag:
return
self.out_queue.put(msg)
except:
self.abort_all(sys.exc_info())
@ -224,10 +241,12 @@ class MiddlePipelineThread(PipelineThread):
# Invoke the current stage.
out = self.coro.send(msg)
# Send message to next stage.
if out is BUBBLE:
continue
self.out_queue.put(out)
# Send messages to next stage.
for msg in _allmsgs(out):
with self.abort_lock:
if self.abort_flag:
return
self.out_queue.put(msg)
except:
self.abort_all(sys.exc_info())
@ -302,12 +321,14 @@ class Pipeline(object):
coro.next()
# Begin the pipeline.
for msg in coros[0]:
for out in coros[0]:
msgs = _allmsgs(out)
for coro in coros[1:]:
msg = coro.send(msg)
if msg is BUBBLE:
# Don't continue to the next stage.
break
next_msgs = []
for msg in msgs:
out = coro.send(msg)
next_msgs.extend(_allmsgs(out))
msgs = next_msgs
def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE):
"""Run the pipeline in parallel using one thread per stage. The

View file

@ -44,6 +44,23 @@ def _exc_work(num=3):
raise TestException()
i *= 2
# A worker that yields a bubble.
def _bub_work(num=3):
i = None
while True:
i = yield i
if i == num:
i = pipeline.BUBBLE
else:
i *= 2
# Yet another worker that yields multiple messages.
def _multi_work():
i = None
while True:
i = yield i
i = pipeline.multiple([i, -i])
class SimplePipelineTest(unittest.TestCase):
def setUp(self):
self.l = []
@ -117,6 +134,34 @@ class ConstrainedThreadedPipelineTest(unittest.TestCase):
pl.run_parallel(1)
self.assertEqual(set(l), set(i*2 for i in range(1000)))
class BubbleTest(unittest.TestCase):
def setUp(self):
self.l = []
self.pl = pipeline.Pipeline((_produce(), _bub_work(), _consume(self.l)))
def test_run_sequential(self):
self.pl.run_sequential()
self.assertEqual(self.l, [0,2,4,8])
def test_run_parallel(self):
self.pl.run_parallel()
self.assertEqual(self.l, [0,2,4,8])
class MultiMessageTest(unittest.TestCase):
def setUp(self):
self.l = []
self.pl = pipeline.Pipeline((
_produce(), _multi_work(), _consume(self.l)
))
def test_run_sequential(self):
self.pl.run_sequential()
self.assertEqual(self.l, [0,0,1,-1,2,-2,3,-3,4,-4])
def test_run_parallel(self):
self.pl.run_parallel()
self.assertEqual(self.l, [0,0,1,-1,2,-2,3,-3,4,-4])
def suite():
return unittest.TestLoader().loadTestsFromName(__name__)