mirror of
https://github.com/beetbox/beets.git
synced 2025-12-29 03:52:51 +01:00
pipeline revamp: parallel stages, immediate exception abort
The pipeline module now support stages that have multiple threads working in parallel; this can bring ordinary task parallelism to parts of a pipelined workflow. This change also involves making the pipeline terminate immediately when an exception is raised in a coroutine.
This commit is contained in:
parent
5d22405e63
commit
7600a8c0ad
1 changed files with 157 additions and 120 deletions
|
|
@ -1,5 +1,5 @@
|
|||
# This file is part of beets.
|
||||
# Copyright 2010, Adrian Sampson.
|
||||
# Copyright 2011, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
|
|
@ -28,25 +28,30 @@ from __future__ import with_statement # for Python 2.5
|
|||
import Queue
|
||||
from threading import Thread, Lock
|
||||
import sys
|
||||
import types
|
||||
|
||||
BUBBLE = '__PIPELINE_BUBBLE__'
|
||||
POISON = '__PIPELINE_POISON__'
|
||||
|
||||
DEFAULT_QUEUE_SIZE = 16
|
||||
|
||||
def clear_queue(q):
|
||||
"""Safely empty a queue."""
|
||||
# This very hacky approach to clearing the queue
|
||||
# compliments of Tim Peters:
|
||||
# http://www.mail-archive.com/python-list@python.org/msg95322.html
|
||||
q.mutex.acquire()
|
||||
try:
|
||||
q.queue.clear()
|
||||
q.unfinished_tasks = 0
|
||||
q.not_full.notify()
|
||||
q.all_tasks_done.notifyAll()
|
||||
finally:
|
||||
q.mutex.release()
|
||||
def invalidate_queue(q):
|
||||
"""Breaks a Queue such that it never blocks, always has size 1,
|
||||
and has no maximum size.
|
||||
"""
|
||||
def _qsize(len=len):
|
||||
return 1
|
||||
def _put(item):
|
||||
pass
|
||||
def _get():
|
||||
return None
|
||||
with q.mutex:
|
||||
q.maxsize = 0
|
||||
q._qsize = _qsize
|
||||
q._put = _put
|
||||
q._get = _get
|
||||
q.not_empty.notify()
|
||||
q.not_full.notify()
|
||||
|
||||
class PipelineError(object):
|
||||
"""An indication that an exception occurred in the pipeline. The
|
||||
|
|
@ -58,24 +63,38 @@ class PipelineError(object):
|
|||
|
||||
class PipelineThread(Thread):
|
||||
"""Abstract base class for pipeline-stage threads."""
|
||||
def __init__(self):
|
||||
def __init__(self, all_threads):
|
||||
super(PipelineThread, self).__init__()
|
||||
self.abort_lock = Lock()
|
||||
self.abort_flag = False
|
||||
self.all_threads = all_threads
|
||||
self.exc_info = None
|
||||
|
||||
def abort(self):
|
||||
"""Shut down the thread at the next chance possible.
|
||||
"""
|
||||
with self.abort_lock:
|
||||
self.abort_flag = True
|
||||
# Empty the channel before poisoning it.
|
||||
|
||||
# Ensure that we are not blocking on a queue read or write.
|
||||
if hasattr(self, 'in_queue'):
|
||||
invalidate_queue(self.in_queue)
|
||||
if hasattr(self, 'out_queue'):
|
||||
invalidate_queue(self.out_queue)
|
||||
|
||||
def abort_all(self, exc_info):
|
||||
"""Abort all other threads in the system for an exception.
|
||||
"""
|
||||
self.exc_info = exc_info
|
||||
for thread in self.all_threads:
|
||||
thread.abort()
|
||||
|
||||
class FirstPipelineThread(PipelineThread):
|
||||
"""The thread running the first stage in a parallel pipeline setup.
|
||||
The coroutine should just be a generator.
|
||||
"""
|
||||
def __init__(self, coro, out_queue):
|
||||
super(FirstPipelineThread, self).__init__()
|
||||
def __init__(self, coro, out_queue, all_threads):
|
||||
super(FirstPipelineThread, self).__init__(all_threads)
|
||||
self.coro = coro
|
||||
self.out_queue = out_queue
|
||||
|
||||
|
|
@ -83,25 +102,27 @@ class FirstPipelineThread(PipelineThread):
|
|||
self.abort_flag = False
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
# Time to abort?
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
|
||||
# Get the value from the generator.
|
||||
try:
|
||||
msg = self.coro.next()
|
||||
except StopIteration:
|
||||
break
|
||||
except:
|
||||
self.out_queue.put(PipelineError(sys.exc_info()))
|
||||
return
|
||||
|
||||
# Send it to the next stage.
|
||||
if msg is BUBBLE:
|
||||
continue
|
||||
self.out_queue.put(msg)
|
||||
try:
|
||||
while True:
|
||||
# Time to abort?
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
|
||||
# Get the value from the generator.
|
||||
try:
|
||||
msg = self.coro.next()
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
# Send it to the next stage.
|
||||
if msg is BUBBLE:
|
||||
continue
|
||||
self.out_queue.put(msg)
|
||||
|
||||
except:
|
||||
self.abort_all(sys.exc_info())
|
||||
return
|
||||
|
||||
# Generator finished; shut down the pipeline.
|
||||
self.out_queue.put(POISON)
|
||||
|
|
@ -110,50 +131,49 @@ class MiddlePipelineThread(PipelineThread):
|
|||
"""A thread running any stage in the pipeline except the first or
|
||||
last.
|
||||
"""
|
||||
def __init__(self, coro, in_queue, out_queue):
|
||||
super(MiddlePipelineThread, self).__init__()
|
||||
def __init__(self, coro, in_queue, out_queue, all_threads):
|
||||
super(MiddlePipelineThread, self).__init__(all_threads)
|
||||
self.coro = coro
|
||||
self.in_queue = in_queue
|
||||
self.out_queue = out_queue
|
||||
|
||||
def run(self):
|
||||
# Prime the coroutine.
|
||||
self.coro.next()
|
||||
|
||||
while True:
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
try:
|
||||
# Prime the coroutine.
|
||||
self.coro.next()
|
||||
|
||||
while True:
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
|
||||
# Get the message from the previous stage.
|
||||
msg = self.in_queue.get()
|
||||
if msg is POISON:
|
||||
break
|
||||
elif isinstance(msg, PipelineError):
|
||||
self.out_queue.put(msg)
|
||||
return
|
||||
|
||||
# Invoke the current stage.
|
||||
try:
|
||||
# Get the message from the previous stage.
|
||||
msg = self.in_queue.get()
|
||||
if msg is POISON:
|
||||
break
|
||||
|
||||
# Invoke the current stage.
|
||||
out = self.coro.send(msg)
|
||||
except:
|
||||
self.out_queue.put(PipelineError(sys.exc_info()))
|
||||
return
|
||||
|
||||
# Send message to next stage.
|
||||
if out is BUBBLE:
|
||||
continue
|
||||
self.out_queue.put(out)
|
||||
|
||||
# Send message to next stage.
|
||||
if out is BUBBLE:
|
||||
continue
|
||||
self.out_queue.put(out)
|
||||
|
||||
except:
|
||||
self.abort_all(sys.exc_info())
|
||||
return
|
||||
|
||||
# Pipeline is shutting down normally.
|
||||
self.in_queue.put(POISON)
|
||||
self.out_queue.put(POISON)
|
||||
|
||||
class LastPipelineThread(PipelineThread):
|
||||
"""A thread running the last stage in a pipeline. The coroutine
|
||||
should yield nothing.
|
||||
"""
|
||||
def __init__(self, coro, in_queue):
|
||||
super(LastPipelineThread, self).__init__()
|
||||
def __init__(self, coro, in_queue, all_threads):
|
||||
super(LastPipelineThread, self).__init__(all_threads)
|
||||
self.coro = coro
|
||||
self.in_queue = in_queue
|
||||
|
||||
|
|
@ -161,28 +181,26 @@ class LastPipelineThread(PipelineThread):
|
|||
# Prime the coroutine.
|
||||
self.coro.next()
|
||||
|
||||
while True:
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
try:
|
||||
while True:
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
|
||||
# Get the message from the previous stage.
|
||||
msg = self.in_queue.get()
|
||||
if msg is POISON:
|
||||
break
|
||||
|
||||
# Get the message from the previous stage.
|
||||
msg = self.in_queue.get()
|
||||
if msg is POISON:
|
||||
break
|
||||
elif isinstance(msg, PipelineError):
|
||||
self.exc_info = msg.exc_info
|
||||
return
|
||||
|
||||
# Send to consumer.
|
||||
try:
|
||||
# Send to consumer.
|
||||
self.coro.send(msg)
|
||||
except:
|
||||
self.exc_info = sys.exc_info()
|
||||
return
|
||||
|
||||
except:
|
||||
self.abort_all(sys.exc_info())
|
||||
return
|
||||
|
||||
# No exception raised in pipeline.
|
||||
self.exc_info = None
|
||||
# Pipeline is shutting down normally.
|
||||
self.in_queue.put(POISON)
|
||||
|
||||
class Pipeline(object):
|
||||
"""Represents a staged pattern of work. Each stage in the pipeline
|
||||
|
|
@ -195,20 +213,29 @@ class Pipeline(object):
|
|||
"""
|
||||
if len(stages) < 2:
|
||||
raise ValueError('pipeline must have at least two stages')
|
||||
self.stages = stages
|
||||
self.stages = []
|
||||
for stage in stages:
|
||||
if isinstance(stage, types.GeneratorType):
|
||||
# Default to one thread per stage.
|
||||
self.stages.append((stage,))
|
||||
else:
|
||||
self.stages.append(stage)
|
||||
|
||||
def run_sequential(self):
|
||||
"""Run the pipeline sequentially in the current thread. The
|
||||
stages are run one after the other.
|
||||
stages are run one after the other. Only the first coroutine
|
||||
in each stage is used.
|
||||
"""
|
||||
coros = [stage[0] for stage in self.stages]
|
||||
|
||||
# "Prime" the coroutines.
|
||||
for coro in self.stages[1:]:
|
||||
for coro in coros:
|
||||
coro.next()
|
||||
|
||||
# Begin the pipeline.
|
||||
for msg in self.stages[0]:
|
||||
for stage in self.stages[1:]:
|
||||
msg = stage.send(msg)
|
||||
for msg in coros[0]:
|
||||
for coro in coros[1:]:
|
||||
msg = coro.send(msg)
|
||||
if msg is BUBBLE:
|
||||
# Don't continue to the next stage.
|
||||
break
|
||||
|
|
@ -219,37 +246,44 @@ class Pipeline(object):
|
|||
size.
|
||||
"""
|
||||
queues = [Queue.Queue(queue_size) for i in range(len(self.stages)-1)]
|
||||
threads = [FirstPipelineThread(self.stages[0], queues[0])]
|
||||
threads = []
|
||||
|
||||
# Set up first stage.
|
||||
for coro in self.stages[0]:
|
||||
threads.append(FirstPipelineThread(coro, queues[0], threads))
|
||||
|
||||
|
||||
# Middle stages.
|
||||
for i in range(1, len(self.stages)-1):
|
||||
threads.append(MiddlePipelineThread(
|
||||
self.stages[i], queues[i-1], queues[i]
|
||||
))
|
||||
threads.append(LastPipelineThread(self.stages[-1], queues[-1]))
|
||||
for coro in self.stages[i]:
|
||||
threads.append(MiddlePipelineThread(
|
||||
coro, queues[i-1], queues[i], threads
|
||||
))
|
||||
|
||||
# Last stage.
|
||||
for coro in self.stages[-1]:
|
||||
threads.append(
|
||||
LastPipelineThread(coro, queues[-1], threads)
|
||||
)
|
||||
|
||||
# Start threads.
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
# Wait for termination.
|
||||
try:
|
||||
# The final thread lasts the longest.
|
||||
threads[-1].join()
|
||||
finally:
|
||||
# Halt the pipeline in case there was an exception.
|
||||
for thread in threads:
|
||||
thread.abort()
|
||||
for queue in queues:
|
||||
clear_queue(queue)
|
||||
# Wait for termination. The final thread lasts the longest.
|
||||
threads[-1].join()
|
||||
|
||||
# Make completely sure that all the threads have finished
|
||||
# before we return.
|
||||
# before we return. They should already be either finished,
|
||||
# in normal operation, or aborted, in case of an exception.
|
||||
for thread in threads[:-1]:
|
||||
thread.join()
|
||||
|
||||
exc_info = threads[-1].exc_info
|
||||
if exc_info:
|
||||
# Make the exception appear as it was raised originally.
|
||||
raise exc_info[0], exc_info[1], exc_info[2]
|
||||
for thread in threads:
|
||||
exc_info = thread.exc_info
|
||||
if exc_info:
|
||||
# Make the exception appear as it was raised originally.
|
||||
raise exc_info[0], exc_info[1], exc_info[2]
|
||||
|
||||
# Smoke test.
|
||||
if __name__ == '__main__':
|
||||
|
|
@ -259,39 +293,42 @@ if __name__ == '__main__':
|
|||
# in parallel.
|
||||
def produce():
|
||||
for i in range(5):
|
||||
print 'generating', i
|
||||
print 'generating %i' % i
|
||||
time.sleep(1)
|
||||
yield i
|
||||
def work():
|
||||
num = yield
|
||||
while True:
|
||||
print 'processing', num
|
||||
print 'processing %i' % num
|
||||
time.sleep(2)
|
||||
num = yield num*2
|
||||
def consume():
|
||||
while True:
|
||||
num = yield
|
||||
time.sleep(1)
|
||||
print 'received', num
|
||||
print 'received %i' % num
|
||||
ts_start = time.time()
|
||||
Pipeline([produce(), work(), consume()]).run_sequential()
|
||||
ts_middle = time.time()
|
||||
ts_seq = time.time()
|
||||
Pipeline([produce(), work(), consume()]).run_parallel()
|
||||
ts_par = time.time()
|
||||
Pipeline([produce(), (work(), work()), consume()]).run_parallel()
|
||||
ts_end = time.time()
|
||||
print 'Sequential time:', ts_middle - ts_start
|
||||
print 'Parallel time:', ts_end - ts_middle
|
||||
print 'Sequential time:', ts_seq - ts_start
|
||||
print 'Parallel time:', ts_par - ts_seq
|
||||
print 'Multiply-parallel time:', ts_end - ts_par
|
||||
print
|
||||
|
||||
# Test a pipeline that raises an exception.
|
||||
def exc_produce():
|
||||
for i in range(10):
|
||||
print 'generating', i
|
||||
print 'generating %i' % i
|
||||
time.sleep(1)
|
||||
yield i
|
||||
def exc_work():
|
||||
num = yield
|
||||
while True:
|
||||
print 'processing', num
|
||||
print 'processing %i' % num
|
||||
time.sleep(3)
|
||||
if num == 3:
|
||||
raise Exception()
|
||||
|
|
@ -301,5 +338,5 @@ if __name__ == '__main__':
|
|||
num = yield
|
||||
#if num == 4:
|
||||
# raise Exception()
|
||||
print 'received', num
|
||||
print 'received %i' % num
|
||||
Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1)
|
||||
|
|
|
|||
Loading…
Reference in a new issue