From 7600a8c0ada948d05b97da7034f44b0eb69863b6 Mon Sep 17 00:00:00 2001 From: Adrian Sampson Date: Mon, 11 Apr 2011 23:52:51 -0700 Subject: [PATCH] pipeline revamp: parallel stages, immediate exception abort The pipeline module now support stages that have multiple threads working in parallel; this can bring ordinary task parallelism to parts of a pipelined workflow. This change also involves making the pipeline terminate immediately when an exception is raised in a coroutine. --- beets/util/pipeline.py | 277 +++++++++++++++++++++++------------------ 1 file changed, 157 insertions(+), 120 deletions(-) diff --git a/beets/util/pipeline.py b/beets/util/pipeline.py index 5aaf26f2b..e91f3e874 100644 --- a/beets/util/pipeline.py +++ b/beets/util/pipeline.py @@ -1,5 +1,5 @@ # This file is part of beets. -# Copyright 2010, Adrian Sampson. +# Copyright 2011, Adrian Sampson. # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the @@ -28,25 +28,30 @@ from __future__ import with_statement # for Python 2.5 import Queue from threading import Thread, Lock import sys +import types BUBBLE = '__PIPELINE_BUBBLE__' POISON = '__PIPELINE_POISON__' DEFAULT_QUEUE_SIZE = 16 -def clear_queue(q): - """Safely empty a queue.""" - # This very hacky approach to clearing the queue - # compliments of Tim Peters: - # http://www.mail-archive.com/python-list@python.org/msg95322.html - q.mutex.acquire() - try: - q.queue.clear() - q.unfinished_tasks = 0 - q.not_full.notify() - q.all_tasks_done.notifyAll() - finally: - q.mutex.release() +def invalidate_queue(q): + """Breaks a Queue such that it never blocks, always has size 1, + and has no maximum size. + """ + def _qsize(len=len): + return 1 + def _put(item): + pass + def _get(): + return None + with q.mutex: + q.maxsize = 0 + q._qsize = _qsize + q._put = _put + q._get = _get + q.not_empty.notify() + q.not_full.notify() class PipelineError(object): """An indication that an exception occurred in the pipeline. The @@ -58,24 +63,38 @@ class PipelineError(object): class PipelineThread(Thread): """Abstract base class for pipeline-stage threads.""" - def __init__(self): + def __init__(self, all_threads): super(PipelineThread, self).__init__() self.abort_lock = Lock() self.abort_flag = False + self.all_threads = all_threads + self.exc_info = None def abort(self): """Shut down the thread at the next chance possible. """ with self.abort_lock: self.abort_flag = True - # Empty the channel before poisoning it. + + # Ensure that we are not blocking on a queue read or write. + if hasattr(self, 'in_queue'): + invalidate_queue(self.in_queue) + if hasattr(self, 'out_queue'): + invalidate_queue(self.out_queue) + + def abort_all(self, exc_info): + """Abort all other threads in the system for an exception. + """ + self.exc_info = exc_info + for thread in self.all_threads: + thread.abort() class FirstPipelineThread(PipelineThread): """The thread running the first stage in a parallel pipeline setup. The coroutine should just be a generator. """ - def __init__(self, coro, out_queue): - super(FirstPipelineThread, self).__init__() + def __init__(self, coro, out_queue, all_threads): + super(FirstPipelineThread, self).__init__(all_threads) self.coro = coro self.out_queue = out_queue @@ -83,25 +102,27 @@ class FirstPipelineThread(PipelineThread): self.abort_flag = False def run(self): - while True: - # Time to abort? - with self.abort_lock: - if self.abort_flag: - return - - # Get the value from the generator. - try: - msg = self.coro.next() - except StopIteration: - break - except: - self.out_queue.put(PipelineError(sys.exc_info())) - return - - # Send it to the next stage. - if msg is BUBBLE: - continue - self.out_queue.put(msg) + try: + while True: + # Time to abort? + with self.abort_lock: + if self.abort_flag: + return + + # Get the value from the generator. + try: + msg = self.coro.next() + except StopIteration: + break + + # Send it to the next stage. + if msg is BUBBLE: + continue + self.out_queue.put(msg) + + except: + self.abort_all(sys.exc_info()) + return # Generator finished; shut down the pipeline. self.out_queue.put(POISON) @@ -110,50 +131,49 @@ class MiddlePipelineThread(PipelineThread): """A thread running any stage in the pipeline except the first or last. """ - def __init__(self, coro, in_queue, out_queue): - super(MiddlePipelineThread, self).__init__() + def __init__(self, coro, in_queue, out_queue, all_threads): + super(MiddlePipelineThread, self).__init__(all_threads) self.coro = coro self.in_queue = in_queue self.out_queue = out_queue def run(self): - # Prime the coroutine. - self.coro.next() - - while True: - with self.abort_lock: - if self.abort_flag: - return + try: + # Prime the coroutine. + self.coro.next() + + while True: + with self.abort_lock: + if self.abort_flag: + return - # Get the message from the previous stage. - msg = self.in_queue.get() - if msg is POISON: - break - elif isinstance(msg, PipelineError): - self.out_queue.put(msg) - return - - # Invoke the current stage. - try: + # Get the message from the previous stage. + msg = self.in_queue.get() + if msg is POISON: + break + + # Invoke the current stage. out = self.coro.send(msg) - except: - self.out_queue.put(PipelineError(sys.exc_info())) - return - - # Send message to next stage. - if out is BUBBLE: - continue - self.out_queue.put(out) + + # Send message to next stage. + if out is BUBBLE: + continue + self.out_queue.put(out) + + except: + self.abort_all(sys.exc_info()) + return # Pipeline is shutting down normally. + self.in_queue.put(POISON) self.out_queue.put(POISON) class LastPipelineThread(PipelineThread): """A thread running the last stage in a pipeline. The coroutine should yield nothing. """ - def __init__(self, coro, in_queue): - super(LastPipelineThread, self).__init__() + def __init__(self, coro, in_queue, all_threads): + super(LastPipelineThread, self).__init__(all_threads) self.coro = coro self.in_queue = in_queue @@ -161,28 +181,26 @@ class LastPipelineThread(PipelineThread): # Prime the coroutine. self.coro.next() - while True: - with self.abort_lock: - if self.abort_flag: - return + try: + while True: + with self.abort_lock: + if self.abort_flag: + return + + # Get the message from the previous stage. + msg = self.in_queue.get() + if msg is POISON: + break - # Get the message from the previous stage. - msg = self.in_queue.get() - if msg is POISON: - break - elif isinstance(msg, PipelineError): - self.exc_info = msg.exc_info - return - - # Send to consumer. - try: + # Send to consumer. self.coro.send(msg) - except: - self.exc_info = sys.exc_info() - return + + except: + self.abort_all(sys.exc_info()) + return - # No exception raised in pipeline. - self.exc_info = None + # Pipeline is shutting down normally. + self.in_queue.put(POISON) class Pipeline(object): """Represents a staged pattern of work. Each stage in the pipeline @@ -195,20 +213,29 @@ class Pipeline(object): """ if len(stages) < 2: raise ValueError('pipeline must have at least two stages') - self.stages = stages + self.stages = [] + for stage in stages: + if isinstance(stage, types.GeneratorType): + # Default to one thread per stage. + self.stages.append((stage,)) + else: + self.stages.append(stage) def run_sequential(self): """Run the pipeline sequentially in the current thread. The - stages are run one after the other. + stages are run one after the other. Only the first coroutine + in each stage is used. """ + coros = [stage[0] for stage in self.stages] + # "Prime" the coroutines. - for coro in self.stages[1:]: + for coro in coros: coro.next() # Begin the pipeline. - for msg in self.stages[0]: - for stage in self.stages[1:]: - msg = stage.send(msg) + for msg in coros[0]: + for coro in coros[1:]: + msg = coro.send(msg) if msg is BUBBLE: # Don't continue to the next stage. break @@ -219,37 +246,44 @@ class Pipeline(object): size. """ queues = [Queue.Queue(queue_size) for i in range(len(self.stages)-1)] - threads = [FirstPipelineThread(self.stages[0], queues[0])] + threads = [] + + # Set up first stage. + for coro in self.stages[0]: + threads.append(FirstPipelineThread(coro, queues[0], threads)) + + + # Middle stages. for i in range(1, len(self.stages)-1): - threads.append(MiddlePipelineThread( - self.stages[i], queues[i-1], queues[i] - )) - threads.append(LastPipelineThread(self.stages[-1], queues[-1])) + for coro in self.stages[i]: + threads.append(MiddlePipelineThread( + coro, queues[i-1], queues[i], threads + )) + + # Last stage. + for coro in self.stages[-1]: + threads.append( + LastPipelineThread(coro, queues[-1], threads) + ) # Start threads. for thread in threads: thread.start() - # Wait for termination. - try: - # The final thread lasts the longest. - threads[-1].join() - finally: - # Halt the pipeline in case there was an exception. - for thread in threads: - thread.abort() - for queue in queues: - clear_queue(queue) + # Wait for termination. The final thread lasts the longest. + threads[-1].join() # Make completely sure that all the threads have finished - # before we return. + # before we return. They should already be either finished, + # in normal operation, or aborted, in case of an exception. for thread in threads[:-1]: thread.join() - exc_info = threads[-1].exc_info - if exc_info: - # Make the exception appear as it was raised originally. - raise exc_info[0], exc_info[1], exc_info[2] + for thread in threads: + exc_info = thread.exc_info + if exc_info: + # Make the exception appear as it was raised originally. + raise exc_info[0], exc_info[1], exc_info[2] # Smoke test. if __name__ == '__main__': @@ -259,39 +293,42 @@ if __name__ == '__main__': # in parallel. def produce(): for i in range(5): - print 'generating', i + print 'generating %i' % i time.sleep(1) yield i def work(): num = yield while True: - print 'processing', num + print 'processing %i' % num time.sleep(2) num = yield num*2 def consume(): while True: num = yield time.sleep(1) - print 'received', num + print 'received %i' % num ts_start = time.time() Pipeline([produce(), work(), consume()]).run_sequential() - ts_middle = time.time() + ts_seq = time.time() Pipeline([produce(), work(), consume()]).run_parallel() + ts_par = time.time() + Pipeline([produce(), (work(), work()), consume()]).run_parallel() ts_end = time.time() - print 'Sequential time:', ts_middle - ts_start - print 'Parallel time:', ts_end - ts_middle + print 'Sequential time:', ts_seq - ts_start + print 'Parallel time:', ts_par - ts_seq + print 'Multiply-parallel time:', ts_end - ts_par print # Test a pipeline that raises an exception. def exc_produce(): for i in range(10): - print 'generating', i + print 'generating %i' % i time.sleep(1) yield i def exc_work(): num = yield while True: - print 'processing', num + print 'processing %i' % num time.sleep(3) if num == 3: raise Exception() @@ -301,5 +338,5 @@ if __name__ == '__main__': num = yield #if num == 4: # raise Exception() - print 'received', num + print 'received %i' % num Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1)