From 828f1aa4f1c4cc665dd40b1672decb508a87a4b0 Mon Sep 17 00:00:00 2001 From: Adrian Sampson Date: Sun, 17 Apr 2011 08:18:54 -0700 Subject: [PATCH] multiple() function for sending many messages to next stage --- beets/util/pipeline.py | 71 +++++++++++++++++++++++++++--------------- test/test_pipeline.py | 45 ++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 25 deletions(-) diff --git a/beets/util/pipeline.py b/beets/util/pipeline.py index 700a018de..549ad0913 100644 --- a/beets/util/pipeline.py +++ b/beets/util/pipeline.py @@ -41,7 +41,7 @@ POISON = '__PIPELINE_POISON__' DEFAULT_QUEUE_SIZE = 16 -def invalidate_queue(q, val=None, sync=True): +def _invalidate_queue(q, val=None, sync=True): """Breaks a Queue such that it never blocks, always has size 1, and has no maximum size. get()ing from the queue returns `val`, which defaults to None. `sync` controls whether a lock is @@ -106,7 +106,7 @@ class CountedQueue(Queue.Queue): def _get(): out = _old_get() if not self.queue: - invalidate_queue(self, POISON, False) + _invalidate_queue(self, POISON, False) return out if self.queue: @@ -114,15 +114,31 @@ class CountedQueue(Queue.Queue): self._get = _get else: # No items. Invalidate immediately. - invalidate_queue(self, POISON, False) + _invalidate_queue(self, POISON, False) -class PipelineError(object): - """An indication that an exception occurred in the pipeline. The - object is passed through the pipeline to shut down all threads - before it is raised again in the main thread. +class MultiMessage(object): + """A message yielded by a pipeline stage encapsulating multiple + values to be sent to the next stage. """ - def __init__(self, exc_info): - self.exc_info = exc_info + def __init__(self, messages): + self.messages = messages +def multiple(messages): + """Yield multiple([message, ..]) from a pipeline stage to send + multiple values to the next pipeline stage. + """ + return MultiMessage(messages) + +def _allmsgs(obj): + """Returns a list of all the messages encapsulated in obj. If obj + is a MultiMessage, returns its enclosed messages. If obj is BUBBLE, + returns an empty list. Otherwise, returns a list containing obj. + """ + if isinstance(obj, MultiMessage): + return obj.messages + elif obj == BUBBLE: + return [] + else: + return [obj] class PipelineThread(Thread): """Abstract base class for pipeline-stage threads.""" @@ -141,9 +157,9 @@ class PipelineThread(Thread): # Ensure that we are not blocking on a queue read or write. if hasattr(self, 'in_queue'): - invalidate_queue(self.in_queue) + _invalidate_queue(self.in_queue) if hasattr(self, 'out_queue'): - invalidate_queue(self.out_queue) + _invalidate_queue(self.out_queue) def abort_all(self, exc_info): """Abort all other threads in the system for an exception. @@ -168,7 +184,6 @@ class FirstPipelineThread(PipelineThread): def run(self): try: while True: - # Time to abort? with self.abort_lock: if self.abort_flag: return @@ -179,10 +194,12 @@ class FirstPipelineThread(PipelineThread): except StopIteration: break - # Send it to the next stage. - if msg is BUBBLE: - continue - self.out_queue.put(msg) + # Send messages to the next stage. + for msg in _allmsgs(msg): + with self.abort_lock: + if self.abort_flag: + return + self.out_queue.put(msg) except: self.abort_all(sys.exc_info()) @@ -224,10 +241,12 @@ class MiddlePipelineThread(PipelineThread): # Invoke the current stage. out = self.coro.send(msg) - # Send message to next stage. - if out is BUBBLE: - continue - self.out_queue.put(out) + # Send messages to next stage. + for msg in _allmsgs(out): + with self.abort_lock: + if self.abort_flag: + return + self.out_queue.put(msg) except: self.abort_all(sys.exc_info()) @@ -302,12 +321,14 @@ class Pipeline(object): coro.next() # Begin the pipeline. - for msg in coros[0]: + for out in coros[0]: + msgs = _allmsgs(out) for coro in coros[1:]: - msg = coro.send(msg) - if msg is BUBBLE: - # Don't continue to the next stage. - break + next_msgs = [] + for msg in msgs: + out = coro.send(msg) + next_msgs.extend(_allmsgs(out)) + msgs = next_msgs def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE): """Run the pipeline in parallel using one thread per stage. The diff --git a/test/test_pipeline.py b/test/test_pipeline.py index 276db8e66..a732b1f6a 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -44,6 +44,23 @@ def _exc_work(num=3): raise TestException() i *= 2 +# A worker that yields a bubble. +def _bub_work(num=3): + i = None + while True: + i = yield i + if i == num: + i = pipeline.BUBBLE + else: + i *= 2 + +# Yet another worker that yields multiple messages. +def _multi_work(): + i = None + while True: + i = yield i + i = pipeline.multiple([i, -i]) + class SimplePipelineTest(unittest.TestCase): def setUp(self): self.l = [] @@ -117,6 +134,34 @@ class ConstrainedThreadedPipelineTest(unittest.TestCase): pl.run_parallel(1) self.assertEqual(set(l), set(i*2 for i in range(1000))) +class BubbleTest(unittest.TestCase): + def setUp(self): + self.l = [] + self.pl = pipeline.Pipeline((_produce(), _bub_work(), _consume(self.l))) + + def test_run_sequential(self): + self.pl.run_sequential() + self.assertEqual(self.l, [0,2,4,8]) + + def test_run_parallel(self): + self.pl.run_parallel() + self.assertEqual(self.l, [0,2,4,8]) + +class MultiMessageTest(unittest.TestCase): + def setUp(self): + self.l = [] + self.pl = pipeline.Pipeline(( + _produce(), _multi_work(), _consume(self.l) + )) + + def test_run_sequential(self): + self.pl.run_sequential() + self.assertEqual(self.l, [0,0,1,-1,2,-2,3,-3,4,-4]) + + def test_run_parallel(self): + self.pl.run_parallel() + self.assertEqual(self.l, [0,0,1,-1,2,-2,3,-3,4,-4]) + def suite(): return unittest.TestLoader().loadTestsFromName(__name__)