diff --git a/beets/ui/pipeline.py b/beets/ui/pipeline.py index 3b3cd8379..f1d38cb87 100644 --- a/beets/ui/pipeline.py +++ b/beets/ui/pipeline.py @@ -25,7 +25,7 @@ shutdown when the processing is complete and when a stage raises an exception. """ from __future__ import with_statement # for Python 2.5 -from Queue import Queue +import Queue from threading import Thread, Lock BUBBLE = '__PIPELINE_BUBBLE__' @@ -58,6 +58,15 @@ class FirstPipelineThread(Thread): # Time to abort? with self.abort_lock: if self.abort_flag: + # We may have accidentally added one more object + # to the queue *after* it was cleared by the + # abort() method. Remove it if present. + try: + self.out_queue.get_nowait() + except Queue.Empty: + pass + + # Stop generating and poison. break # Get the value from the generator. @@ -82,6 +91,20 @@ class FirstPipelineThread(Thread): poisoning out_channel. """ with self.abort_lock: + # Empty the channel before poisoning it. + # This very hacky approach to clearing the queue is + # compliments of Tim Peters: + # http://www.mail-archive.com/python-list@python.org/msg95322.html + self.out_queue.mutex.acquire() + try: + self.out_queue.queue.clear() + self.out_queue.unfinished_tasks = 0 + self.out_queue.not_full.notify() + self.out_queue.all_tasks_done.notifyAll() + finally: + self.out_queue.mutex.release() + + # Notify the generator thread. self.abort_flag = True class MiddlePipelineThread(Thread): @@ -188,7 +211,7 @@ class Pipeline(object): messages between the stages are stored in queues of the given size. """ - queues = [Queue(queue_size) for i in range(len(self.stages)-1)] + queues = [Queue.Queue(queue_size) for i in range(len(self.stages)-1)] threads = [FirstPipelineThread(self.stages[0], queues[0])] for i in range(1, len(self.stages)-1): threads.append(MiddlePipelineThread( @@ -202,15 +225,22 @@ class Pipeline(object): # Wait for termination. try: - for thread in threads: - thread.join() + # The final thread lasts the longest. + threads[-1].join() except: # Shut down the pipeline by telling the first thread to # poison its channel. threads[0].abort() raise - # Was there an exception? + # Halt the pipeline in case there was an exception. + threads[0].abort() + + # Make completely sure that all the threads have finished + # before we return. + for thread in threads[:-1]: + thread.join() + exc = threads[-1].exc if exc: raise exc @@ -219,6 +249,8 @@ class Pipeline(object): if __name__ == '__main__': import time + # Test a normally-terminating pipeline both in sequence and + # in parallel. def produce(): for i in range(5): print 'generating', i @@ -235,12 +267,32 @@ if __name__ == '__main__': num = yield time.sleep(1) print 'received', num - ts_start = time.time() - Pipeline([produce(), work(), consume()]).run_sequential() + # Pipeline([produce(), work(), consume()]).run_sequential() ts_middle = time.time() - Pipeline([produce(), work(), consume()]).run_parallel() + # Pipeline([produce(), work(), consume()]).run_parallel() ts_end = time.time() - print 'Sequential time:', ts_middle - ts_start print 'Parallel time:', ts_end - ts_middle + print + + # Test a pipeline that raises an exception. + def exc_produce(): + for i in range(10): + print 'generating', i + time.sleep(1) + yield i + def exc_work(): + num = yield + while True: + print 'processing', num + time.sleep(3) + if num == 3: + raise Exception() + num = yield num * 2 + def exc_consume(): + while True: + num = yield + print 'received', num + Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1) +