mirror of
https://github.com/beetbox/beets.git
synced 2026-01-07 16:34:45 +01:00
add pipeline tests, fixing a bug with parallel stages in the process
This commit is contained in:
parent
bf5c569884
commit
85cd3cdb84
2 changed files with 190 additions and 11 deletions
|
|
@ -41,17 +41,23 @@ POISON = '__PIPELINE_POISON__'
|
|||
|
||||
DEFAULT_QUEUE_SIZE = 16
|
||||
|
||||
def invalidate_queue(q):
|
||||
def invalidate_queue(q, val=None, sync=True):
|
||||
"""Breaks a Queue such that it never blocks, always has size 1,
|
||||
and has no maximum size.
|
||||
and has no maximum size. get()ing from the queue returns `val`,
|
||||
which defaults to None. `sync` controls whether a lock is
|
||||
required (because it's not reentrant!).
|
||||
"""
|
||||
def _qsize(len=len):
|
||||
return 1
|
||||
def _put(item):
|
||||
pass
|
||||
def _get():
|
||||
return None
|
||||
with q.mutex:
|
||||
return val
|
||||
|
||||
if sync:
|
||||
q.mutex.acquire()
|
||||
|
||||
try:
|
||||
q.maxsize = 0
|
||||
q._qsize = _qsize
|
||||
q._put = _put
|
||||
|
|
@ -59,6 +65,57 @@ def invalidate_queue(q):
|
|||
q.not_empty.notify()
|
||||
q.not_full.notify()
|
||||
|
||||
finally:
|
||||
if sync:
|
||||
q.mutex.release()
|
||||
|
||||
class CountedQueue(Queue.Queue):
|
||||
"""A queue that keeps track of the number of threads that are
|
||||
still feeding into it. The queue is poisoned when all threads are
|
||||
finished with the queue.
|
||||
"""
|
||||
def __init__(self, maxsize=0):
|
||||
Queue.Queue.__init__(self, maxsize)
|
||||
self.nthreads = 0
|
||||
self.poisoned = False
|
||||
|
||||
def acquire(self):
|
||||
"""Indicate that a thread will start putting into this queue.
|
||||
Should not be called after the queue is already poisoned.
|
||||
"""
|
||||
with self.mutex:
|
||||
assert not self.poisoned
|
||||
assert self.nthreads >= 0
|
||||
self.nthreads += 1
|
||||
|
||||
def release(self):
|
||||
"""Indicate that a thread that was putting into this queue has
|
||||
exited. If this is the last thread using the queue, the queue
|
||||
is poisoned.
|
||||
"""
|
||||
with self.mutex:
|
||||
self.nthreads -= 1
|
||||
assert self.nthreads >= 0
|
||||
if self.nthreads == 0:
|
||||
# All threads are done adding to this queue. Poison it
|
||||
# when it becomes empty.
|
||||
self.poisoned = True
|
||||
|
||||
# Replacement _get invalidates when no items remain.
|
||||
_old_get = self._get
|
||||
def _get():
|
||||
out = _old_get()
|
||||
if not self.queue:
|
||||
invalidate_queue(self, POISON, False)
|
||||
return out
|
||||
|
||||
if self.queue:
|
||||
# Items remain.
|
||||
self._get = _get
|
||||
else:
|
||||
# No items. Invalidate immediately.
|
||||
invalidate_queue(self, POISON, False)
|
||||
|
||||
class PipelineError(object):
|
||||
"""An indication that an exception occurred in the pipeline. The
|
||||
object is passed through the pipeline to shut down all threads
|
||||
|
|
@ -103,6 +160,7 @@ class FirstPipelineThread(PipelineThread):
|
|||
super(FirstPipelineThread, self).__init__(all_threads)
|
||||
self.coro = coro
|
||||
self.out_queue = out_queue
|
||||
self.out_queue.acquire()
|
||||
|
||||
self.abort_lock = Lock()
|
||||
self.abort_flag = False
|
||||
|
|
@ -131,7 +189,7 @@ class FirstPipelineThread(PipelineThread):
|
|||
return
|
||||
|
||||
# Generator finished; shut down the pipeline.
|
||||
self.out_queue.put(POISON)
|
||||
self.out_queue.release()
|
||||
|
||||
class MiddlePipelineThread(PipelineThread):
|
||||
"""A thread running any stage in the pipeline except the first or
|
||||
|
|
@ -142,6 +200,7 @@ class MiddlePipelineThread(PipelineThread):
|
|||
self.coro = coro
|
||||
self.in_queue = in_queue
|
||||
self.out_queue = out_queue
|
||||
self.out_queue.acquire()
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
|
|
@ -175,8 +234,7 @@ class MiddlePipelineThread(PipelineThread):
|
|||
return
|
||||
|
||||
# Pipeline is shutting down normally.
|
||||
self.in_queue.put(POISON)
|
||||
self.out_queue.put(POISON)
|
||||
self.out_queue.release()
|
||||
|
||||
class LastPipelineThread(PipelineThread):
|
||||
"""A thread running the last stage in a pipeline. The coroutine
|
||||
|
|
@ -212,9 +270,6 @@ class LastPipelineThread(PipelineThread):
|
|||
except:
|
||||
self.abort_all(sys.exc_info())
|
||||
return
|
||||
|
||||
# Pipeline is shutting down normally.
|
||||
self.in_queue.put(POISON)
|
||||
|
||||
class Pipeline(object):
|
||||
"""Represents a staged pattern of work. Each stage in the pipeline
|
||||
|
|
@ -259,7 +314,7 @@ class Pipeline(object):
|
|||
messages between the stages are stored in queues of the given
|
||||
size.
|
||||
"""
|
||||
queues = [Queue.Queue(queue_size) for i in range(len(self.stages)-1)]
|
||||
queues = [CountedQueue(queue_size) for i in range(len(self.stages)-1)]
|
||||
threads = []
|
||||
|
||||
# Set up first stage.
|
||||
|
|
|
|||
124
test/test_pipeline.py
Normal file
124
test/test_pipeline.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
# This file is part of beets.
|
||||
# Copyright 2010, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Test the "pipeline.py" restricted parallel programming library.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import _common
|
||||
from beets.util import pipeline
|
||||
|
||||
# Some simple pipeline stages for testing.
|
||||
def _produce(num=5):
|
||||
for i in range(num):
|
||||
yield i
|
||||
def _work():
|
||||
i = None
|
||||
while True:
|
||||
i = yield i
|
||||
i *= 2
|
||||
def _consume(l):
|
||||
while True:
|
||||
i = yield
|
||||
l.append(i)
|
||||
|
||||
# A worker that raises an exception.
|
||||
class TestException(Exception): pass
|
||||
def _exc_work(num=3):
|
||||
i = None
|
||||
while True:
|
||||
i = yield i
|
||||
if i == num:
|
||||
raise TestException()
|
||||
i *= 2
|
||||
|
||||
class SimplePipelineTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.l = []
|
||||
self.pl = pipeline.Pipeline((_produce(), _work(), _consume(self.l)))
|
||||
|
||||
def test_run_sequential(self):
|
||||
self.pl.run_sequential()
|
||||
self.assertEqual(self.l, [0,2,4,6,8])
|
||||
|
||||
def test_run_parallel(self):
|
||||
self.pl.run_parallel()
|
||||
self.assertEqual(self.l, [0,2,4,6,8])
|
||||
|
||||
class ParallelStageTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.l = []
|
||||
self.pl = pipeline.Pipeline((
|
||||
_produce(), (_work(), _work()), _consume(self.l)
|
||||
))
|
||||
|
||||
def test_run_sequential(self):
|
||||
self.pl.run_sequential()
|
||||
self.assertEqual(self.l, [0,2,4,6,8])
|
||||
|
||||
def test_run_parallel(self):
|
||||
self.pl.run_parallel()
|
||||
# Order possibly not preserved; use set equality.
|
||||
self.assertEqual(set(self.l), set([0,2,4,6,8]))
|
||||
|
||||
class ExceptionTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.l = []
|
||||
self.pl = pipeline.Pipeline((_produce(), _exc_work(), _consume(self.l)))
|
||||
|
||||
def test_run_sequential(self):
|
||||
self.assertRaises(TestException, self.pl.run_sequential)
|
||||
|
||||
def test_run_parallel(self):
|
||||
self.assertRaises(TestException, self.pl.run_parallel)
|
||||
|
||||
class ParallelExceptionTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.l = []
|
||||
self.pl = pipeline.Pipeline((
|
||||
_produce(), (_exc_work(), _exc_work()), _consume(self.l)
|
||||
))
|
||||
|
||||
def test_run_parallel(self):
|
||||
self.assertRaises(TestException, self.pl.run_parallel)
|
||||
|
||||
class ConstrainedThreadedPipelineTest(unittest.TestCase):
|
||||
def test_constrained(self):
|
||||
l = []
|
||||
# Do a "significant" amount of work...
|
||||
pl = pipeline.Pipeline((_produce(1000), _work(), _consume(l)))
|
||||
# ... with only a single queue slot.
|
||||
pl.run_parallel(1)
|
||||
self.assertEqual(l, [i*2 for i in range(1000)])
|
||||
|
||||
def test_constrained_exception(self):
|
||||
# Raise an exception in a constrained pipeline.
|
||||
l = []
|
||||
pl = pipeline.Pipeline((_produce(1000), _exc_work(), _consume(l)))
|
||||
self.assertRaises(TestException, pl.run_parallel, 1)
|
||||
|
||||
def test_constrained_parallel(self):
|
||||
l = []
|
||||
pl = pipeline.Pipeline((
|
||||
_produce(1000), (_work(), _work()), _consume(l)
|
||||
))
|
||||
pl.run_parallel(1)
|
||||
self.assertEqual(set(l), set(i*2 for i in range(1000)))
|
||||
|
||||
def suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(defaultTest='suite')
|
||||
Loading…
Reference in a new issue