diff --git a/beets/util/pipeline.py b/beets/util/pipeline.py index 61e045db0..700a018de 100644 --- a/beets/util/pipeline.py +++ b/beets/util/pipeline.py @@ -41,17 +41,23 @@ POISON = '__PIPELINE_POISON__' DEFAULT_QUEUE_SIZE = 16 -def invalidate_queue(q): +def invalidate_queue(q, val=None, sync=True): """Breaks a Queue such that it never blocks, always has size 1, - and has no maximum size. + and has no maximum size. get()ing from the queue returns `val`, + which defaults to None. `sync` controls whether a lock is + required (because it's not reentrant!). """ def _qsize(len=len): return 1 def _put(item): pass def _get(): - return None - with q.mutex: + return val + + if sync: + q.mutex.acquire() + + try: q.maxsize = 0 q._qsize = _qsize q._put = _put @@ -59,6 +65,57 @@ def invalidate_queue(q): q.not_empty.notify() q.not_full.notify() + finally: + if sync: + q.mutex.release() + +class CountedQueue(Queue.Queue): + """A queue that keeps track of the number of threads that are + still feeding into it. The queue is poisoned when all threads are + finished with the queue. + """ + def __init__(self, maxsize=0): + Queue.Queue.__init__(self, maxsize) + self.nthreads = 0 + self.poisoned = False + + def acquire(self): + """Indicate that a thread will start putting into this queue. + Should not be called after the queue is already poisoned. + """ + with self.mutex: + assert not self.poisoned + assert self.nthreads >= 0 + self.nthreads += 1 + + def release(self): + """Indicate that a thread that was putting into this queue has + exited. If this is the last thread using the queue, the queue + is poisoned. + """ + with self.mutex: + self.nthreads -= 1 + assert self.nthreads >= 0 + if self.nthreads == 0: + # All threads are done adding to this queue. Poison it + # when it becomes empty. + self.poisoned = True + + # Replacement _get invalidates when no items remain. + _old_get = self._get + def _get(): + out = _old_get() + if not self.queue: + invalidate_queue(self, POISON, False) + return out + + if self.queue: + # Items remain. + self._get = _get + else: + # No items. Invalidate immediately. + invalidate_queue(self, POISON, False) + class PipelineError(object): """An indication that an exception occurred in the pipeline. The object is passed through the pipeline to shut down all threads @@ -103,6 +160,7 @@ class FirstPipelineThread(PipelineThread): super(FirstPipelineThread, self).__init__(all_threads) self.coro = coro self.out_queue = out_queue + self.out_queue.acquire() self.abort_lock = Lock() self.abort_flag = False @@ -131,7 +189,7 @@ class FirstPipelineThread(PipelineThread): return # Generator finished; shut down the pipeline. - self.out_queue.put(POISON) + self.out_queue.release() class MiddlePipelineThread(PipelineThread): """A thread running any stage in the pipeline except the first or @@ -142,6 +200,7 @@ class MiddlePipelineThread(PipelineThread): self.coro = coro self.in_queue = in_queue self.out_queue = out_queue + self.out_queue.acquire() def run(self): try: @@ -175,8 +234,7 @@ class MiddlePipelineThread(PipelineThread): return # Pipeline is shutting down normally. - self.in_queue.put(POISON) - self.out_queue.put(POISON) + self.out_queue.release() class LastPipelineThread(PipelineThread): """A thread running the last stage in a pipeline. The coroutine @@ -212,9 +270,6 @@ class LastPipelineThread(PipelineThread): except: self.abort_all(sys.exc_info()) return - - # Pipeline is shutting down normally. - self.in_queue.put(POISON) class Pipeline(object): """Represents a staged pattern of work. Each stage in the pipeline @@ -259,7 +314,7 @@ class Pipeline(object): messages between the stages are stored in queues of the given size. """ - queues = [Queue.Queue(queue_size) for i in range(len(self.stages)-1)] + queues = [CountedQueue(queue_size) for i in range(len(self.stages)-1)] threads = [] # Set up first stage. diff --git a/test/test_pipeline.py b/test/test_pipeline.py new file mode 100644 index 000000000..276db8e66 --- /dev/null +++ b/test/test_pipeline.py @@ -0,0 +1,124 @@ +# This file is part of beets. +# Copyright 2010, Adrian Sampson. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +"""Test the "pipeline.py" restricted parallel programming library. +""" + +import unittest + +import _common +from beets.util import pipeline + +# Some simple pipeline stages for testing. +def _produce(num=5): + for i in range(num): + yield i +def _work(): + i = None + while True: + i = yield i + i *= 2 +def _consume(l): + while True: + i = yield + l.append(i) + +# A worker that raises an exception. +class TestException(Exception): pass +def _exc_work(num=3): + i = None + while True: + i = yield i + if i == num: + raise TestException() + i *= 2 + +class SimplePipelineTest(unittest.TestCase): + def setUp(self): + self.l = [] + self.pl = pipeline.Pipeline((_produce(), _work(), _consume(self.l))) + + def test_run_sequential(self): + self.pl.run_sequential() + self.assertEqual(self.l, [0,2,4,6,8]) + + def test_run_parallel(self): + self.pl.run_parallel() + self.assertEqual(self.l, [0,2,4,6,8]) + +class ParallelStageTest(unittest.TestCase): + def setUp(self): + self.l = [] + self.pl = pipeline.Pipeline(( + _produce(), (_work(), _work()), _consume(self.l) + )) + + def test_run_sequential(self): + self.pl.run_sequential() + self.assertEqual(self.l, [0,2,4,6,8]) + + def test_run_parallel(self): + self.pl.run_parallel() + # Order possibly not preserved; use set equality. + self.assertEqual(set(self.l), set([0,2,4,6,8])) + +class ExceptionTest(unittest.TestCase): + def setUp(self): + self.l = [] + self.pl = pipeline.Pipeline((_produce(), _exc_work(), _consume(self.l))) + + def test_run_sequential(self): + self.assertRaises(TestException, self.pl.run_sequential) + + def test_run_parallel(self): + self.assertRaises(TestException, self.pl.run_parallel) + +class ParallelExceptionTest(unittest.TestCase): + def setUp(self): + self.l = [] + self.pl = pipeline.Pipeline(( + _produce(), (_exc_work(), _exc_work()), _consume(self.l) + )) + + def test_run_parallel(self): + self.assertRaises(TestException, self.pl.run_parallel) + +class ConstrainedThreadedPipelineTest(unittest.TestCase): + def test_constrained(self): + l = [] + # Do a "significant" amount of work... + pl = pipeline.Pipeline((_produce(1000), _work(), _consume(l))) + # ... with only a single queue slot. + pl.run_parallel(1) + self.assertEqual(l, [i*2 for i in range(1000)]) + + def test_constrained_exception(self): + # Raise an exception in a constrained pipeline. + l = [] + pl = pipeline.Pipeline((_produce(1000), _exc_work(), _consume(l))) + self.assertRaises(TestException, pl.run_parallel, 1) + + def test_constrained_parallel(self): + l = [] + pl = pipeline.Pipeline(( + _produce(1000), (_work(), _work()), _consume(l) + )) + pl.run_parallel(1) + self.assertEqual(set(l), set(i*2 for i in range(1000))) + +def suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + +if __name__ == '__main__': + unittest.main(defaultTest='suite')