mirror of
https://github.com/beetbox/beets.git
synced 2026-01-17 05:34:23 +01:00
multiple() function for sending many messages to next stage
This commit is contained in:
parent
85cd3cdb84
commit
828f1aa4f1
2 changed files with 91 additions and 25 deletions
|
|
@ -41,7 +41,7 @@ POISON = '__PIPELINE_POISON__'
|
|||
|
||||
DEFAULT_QUEUE_SIZE = 16
|
||||
|
||||
def invalidate_queue(q, val=None, sync=True):
|
||||
def _invalidate_queue(q, val=None, sync=True):
|
||||
"""Breaks a Queue such that it never blocks, always has size 1,
|
||||
and has no maximum size. get()ing from the queue returns `val`,
|
||||
which defaults to None. `sync` controls whether a lock is
|
||||
|
|
@ -106,7 +106,7 @@ class CountedQueue(Queue.Queue):
|
|||
def _get():
|
||||
out = _old_get()
|
||||
if not self.queue:
|
||||
invalidate_queue(self, POISON, False)
|
||||
_invalidate_queue(self, POISON, False)
|
||||
return out
|
||||
|
||||
if self.queue:
|
||||
|
|
@ -114,15 +114,31 @@ class CountedQueue(Queue.Queue):
|
|||
self._get = _get
|
||||
else:
|
||||
# No items. Invalidate immediately.
|
||||
invalidate_queue(self, POISON, False)
|
||||
_invalidate_queue(self, POISON, False)
|
||||
|
||||
class PipelineError(object):
|
||||
"""An indication that an exception occurred in the pipeline. The
|
||||
object is passed through the pipeline to shut down all threads
|
||||
before it is raised again in the main thread.
|
||||
class MultiMessage(object):
|
||||
"""A message yielded by a pipeline stage encapsulating multiple
|
||||
values to be sent to the next stage.
|
||||
"""
|
||||
def __init__(self, exc_info):
|
||||
self.exc_info = exc_info
|
||||
def __init__(self, messages):
|
||||
self.messages = messages
|
||||
def multiple(messages):
|
||||
"""Yield multiple([message, ..]) from a pipeline stage to send
|
||||
multiple values to the next pipeline stage.
|
||||
"""
|
||||
return MultiMessage(messages)
|
||||
|
||||
def _allmsgs(obj):
|
||||
"""Returns a list of all the messages encapsulated in obj. If obj
|
||||
is a MultiMessage, returns its enclosed messages. If obj is BUBBLE,
|
||||
returns an empty list. Otherwise, returns a list containing obj.
|
||||
"""
|
||||
if isinstance(obj, MultiMessage):
|
||||
return obj.messages
|
||||
elif obj == BUBBLE:
|
||||
return []
|
||||
else:
|
||||
return [obj]
|
||||
|
||||
class PipelineThread(Thread):
|
||||
"""Abstract base class for pipeline-stage threads."""
|
||||
|
|
@ -141,9 +157,9 @@ class PipelineThread(Thread):
|
|||
|
||||
# Ensure that we are not blocking on a queue read or write.
|
||||
if hasattr(self, 'in_queue'):
|
||||
invalidate_queue(self.in_queue)
|
||||
_invalidate_queue(self.in_queue)
|
||||
if hasattr(self, 'out_queue'):
|
||||
invalidate_queue(self.out_queue)
|
||||
_invalidate_queue(self.out_queue)
|
||||
|
||||
def abort_all(self, exc_info):
|
||||
"""Abort all other threads in the system for an exception.
|
||||
|
|
@ -168,7 +184,6 @@ class FirstPipelineThread(PipelineThread):
|
|||
def run(self):
|
||||
try:
|
||||
while True:
|
||||
# Time to abort?
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
|
|
@ -179,10 +194,12 @@ class FirstPipelineThread(PipelineThread):
|
|||
except StopIteration:
|
||||
break
|
||||
|
||||
# Send it to the next stage.
|
||||
if msg is BUBBLE:
|
||||
continue
|
||||
self.out_queue.put(msg)
|
||||
# Send messages to the next stage.
|
||||
for msg in _allmsgs(msg):
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
self.out_queue.put(msg)
|
||||
|
||||
except:
|
||||
self.abort_all(sys.exc_info())
|
||||
|
|
@ -224,10 +241,12 @@ class MiddlePipelineThread(PipelineThread):
|
|||
# Invoke the current stage.
|
||||
out = self.coro.send(msg)
|
||||
|
||||
# Send message to next stage.
|
||||
if out is BUBBLE:
|
||||
continue
|
||||
self.out_queue.put(out)
|
||||
# Send messages to next stage.
|
||||
for msg in _allmsgs(out):
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
self.out_queue.put(msg)
|
||||
|
||||
except:
|
||||
self.abort_all(sys.exc_info())
|
||||
|
|
@ -302,12 +321,14 @@ class Pipeline(object):
|
|||
coro.next()
|
||||
|
||||
# Begin the pipeline.
|
||||
for msg in coros[0]:
|
||||
for out in coros[0]:
|
||||
msgs = _allmsgs(out)
|
||||
for coro in coros[1:]:
|
||||
msg = coro.send(msg)
|
||||
if msg is BUBBLE:
|
||||
# Don't continue to the next stage.
|
||||
break
|
||||
next_msgs = []
|
||||
for msg in msgs:
|
||||
out = coro.send(msg)
|
||||
next_msgs.extend(_allmsgs(out))
|
||||
msgs = next_msgs
|
||||
|
||||
def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE):
|
||||
"""Run the pipeline in parallel using one thread per stage. The
|
||||
|
|
|
|||
|
|
@ -44,6 +44,23 @@ def _exc_work(num=3):
|
|||
raise TestException()
|
||||
i *= 2
|
||||
|
||||
# A worker that yields a bubble.
|
||||
def _bub_work(num=3):
|
||||
i = None
|
||||
while True:
|
||||
i = yield i
|
||||
if i == num:
|
||||
i = pipeline.BUBBLE
|
||||
else:
|
||||
i *= 2
|
||||
|
||||
# Yet another worker that yields multiple messages.
|
||||
def _multi_work():
|
||||
i = None
|
||||
while True:
|
||||
i = yield i
|
||||
i = pipeline.multiple([i, -i])
|
||||
|
||||
class SimplePipelineTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.l = []
|
||||
|
|
@ -117,6 +134,34 @@ class ConstrainedThreadedPipelineTest(unittest.TestCase):
|
|||
pl.run_parallel(1)
|
||||
self.assertEqual(set(l), set(i*2 for i in range(1000)))
|
||||
|
||||
class BubbleTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.l = []
|
||||
self.pl = pipeline.Pipeline((_produce(), _bub_work(), _consume(self.l)))
|
||||
|
||||
def test_run_sequential(self):
|
||||
self.pl.run_sequential()
|
||||
self.assertEqual(self.l, [0,2,4,8])
|
||||
|
||||
def test_run_parallel(self):
|
||||
self.pl.run_parallel()
|
||||
self.assertEqual(self.l, [0,2,4,8])
|
||||
|
||||
class MultiMessageTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.l = []
|
||||
self.pl = pipeline.Pipeline((
|
||||
_produce(), _multi_work(), _consume(self.l)
|
||||
))
|
||||
|
||||
def test_run_sequential(self):
|
||||
self.pl.run_sequential()
|
||||
self.assertEqual(self.l, [0,0,1,-1,2,-2,3,-3,4,-4])
|
||||
|
||||
def test_run_parallel(self):
|
||||
self.pl.run_parallel()
|
||||
self.assertEqual(self.l, [0,0,1,-1,2,-2,3,-3,4,-4])
|
||||
|
||||
def suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue