pipeline revamp: parallel stages, immediate exception abort

The pipeline module now support stages that have multiple threads working in
parallel; this can bring ordinary task parallelism to parts of a pipelined
workflow. This change also involves making the pipeline terminate immediately
when an exception is raised in a coroutine.
This commit is contained in:
Adrian Sampson 2011-04-11 23:52:51 -07:00
parent 5d22405e63
commit 7600a8c0ad

View file

@ -1,5 +1,5 @@
# This file is part of beets.
# Copyright 2010, Adrian Sampson.
# Copyright 2011, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
@ -28,25 +28,30 @@ from __future__ import with_statement # for Python 2.5
import Queue
from threading import Thread, Lock
import sys
import types
BUBBLE = '__PIPELINE_BUBBLE__'
POISON = '__PIPELINE_POISON__'
DEFAULT_QUEUE_SIZE = 16
def clear_queue(q):
"""Safely empty a queue."""
# This very hacky approach to clearing the queue
# compliments of Tim Peters:
# http://www.mail-archive.com/python-list@python.org/msg95322.html
q.mutex.acquire()
try:
q.queue.clear()
q.unfinished_tasks = 0
q.not_full.notify()
q.all_tasks_done.notifyAll()
finally:
q.mutex.release()
def invalidate_queue(q):
"""Breaks a Queue such that it never blocks, always has size 1,
and has no maximum size.
"""
def _qsize(len=len):
return 1
def _put(item):
pass
def _get():
return None
with q.mutex:
q.maxsize = 0
q._qsize = _qsize
q._put = _put
q._get = _get
q.not_empty.notify()
q.not_full.notify()
class PipelineError(object):
"""An indication that an exception occurred in the pipeline. The
@ -58,24 +63,38 @@ class PipelineError(object):
class PipelineThread(Thread):
"""Abstract base class for pipeline-stage threads."""
def __init__(self):
def __init__(self, all_threads):
super(PipelineThread, self).__init__()
self.abort_lock = Lock()
self.abort_flag = False
self.all_threads = all_threads
self.exc_info = None
def abort(self):
"""Shut down the thread at the next chance possible.
"""
with self.abort_lock:
self.abort_flag = True
# Empty the channel before poisoning it.
# Ensure that we are not blocking on a queue read or write.
if hasattr(self, 'in_queue'):
invalidate_queue(self.in_queue)
if hasattr(self, 'out_queue'):
invalidate_queue(self.out_queue)
def abort_all(self, exc_info):
"""Abort all other threads in the system for an exception.
"""
self.exc_info = exc_info
for thread in self.all_threads:
thread.abort()
class FirstPipelineThread(PipelineThread):
"""The thread running the first stage in a parallel pipeline setup.
The coroutine should just be a generator.
"""
def __init__(self, coro, out_queue):
super(FirstPipelineThread, self).__init__()
def __init__(self, coro, out_queue, all_threads):
super(FirstPipelineThread, self).__init__(all_threads)
self.coro = coro
self.out_queue = out_queue
@ -83,25 +102,27 @@ class FirstPipelineThread(PipelineThread):
self.abort_flag = False
def run(self):
while True:
# Time to abort?
with self.abort_lock:
if self.abort_flag:
return
# Get the value from the generator.
try:
msg = self.coro.next()
except StopIteration:
break
except:
self.out_queue.put(PipelineError(sys.exc_info()))
return
# Send it to the next stage.
if msg is BUBBLE:
continue
self.out_queue.put(msg)
try:
while True:
# Time to abort?
with self.abort_lock:
if self.abort_flag:
return
# Get the value from the generator.
try:
msg = self.coro.next()
except StopIteration:
break
# Send it to the next stage.
if msg is BUBBLE:
continue
self.out_queue.put(msg)
except:
self.abort_all(sys.exc_info())
return
# Generator finished; shut down the pipeline.
self.out_queue.put(POISON)
@ -110,50 +131,49 @@ class MiddlePipelineThread(PipelineThread):
"""A thread running any stage in the pipeline except the first or
last.
"""
def __init__(self, coro, in_queue, out_queue):
super(MiddlePipelineThread, self).__init__()
def __init__(self, coro, in_queue, out_queue, all_threads):
super(MiddlePipelineThread, self).__init__(all_threads)
self.coro = coro
self.in_queue = in_queue
self.out_queue = out_queue
def run(self):
# Prime the coroutine.
self.coro.next()
while True:
with self.abort_lock:
if self.abort_flag:
return
try:
# Prime the coroutine.
self.coro.next()
while True:
with self.abort_lock:
if self.abort_flag:
return
# Get the message from the previous stage.
msg = self.in_queue.get()
if msg is POISON:
break
elif isinstance(msg, PipelineError):
self.out_queue.put(msg)
return
# Invoke the current stage.
try:
# Get the message from the previous stage.
msg = self.in_queue.get()
if msg is POISON:
break
# Invoke the current stage.
out = self.coro.send(msg)
except:
self.out_queue.put(PipelineError(sys.exc_info()))
return
# Send message to next stage.
if out is BUBBLE:
continue
self.out_queue.put(out)
# Send message to next stage.
if out is BUBBLE:
continue
self.out_queue.put(out)
except:
self.abort_all(sys.exc_info())
return
# Pipeline is shutting down normally.
self.in_queue.put(POISON)
self.out_queue.put(POISON)
class LastPipelineThread(PipelineThread):
"""A thread running the last stage in a pipeline. The coroutine
should yield nothing.
"""
def __init__(self, coro, in_queue):
super(LastPipelineThread, self).__init__()
def __init__(self, coro, in_queue, all_threads):
super(LastPipelineThread, self).__init__(all_threads)
self.coro = coro
self.in_queue = in_queue
@ -161,28 +181,26 @@ class LastPipelineThread(PipelineThread):
# Prime the coroutine.
self.coro.next()
while True:
with self.abort_lock:
if self.abort_flag:
return
try:
while True:
with self.abort_lock:
if self.abort_flag:
return
# Get the message from the previous stage.
msg = self.in_queue.get()
if msg is POISON:
break
# Get the message from the previous stage.
msg = self.in_queue.get()
if msg is POISON:
break
elif isinstance(msg, PipelineError):
self.exc_info = msg.exc_info
return
# Send to consumer.
try:
# Send to consumer.
self.coro.send(msg)
except:
self.exc_info = sys.exc_info()
return
except:
self.abort_all(sys.exc_info())
return
# No exception raised in pipeline.
self.exc_info = None
# Pipeline is shutting down normally.
self.in_queue.put(POISON)
class Pipeline(object):
"""Represents a staged pattern of work. Each stage in the pipeline
@ -195,20 +213,29 @@ class Pipeline(object):
"""
if len(stages) < 2:
raise ValueError('pipeline must have at least two stages')
self.stages = stages
self.stages = []
for stage in stages:
if isinstance(stage, types.GeneratorType):
# Default to one thread per stage.
self.stages.append((stage,))
else:
self.stages.append(stage)
def run_sequential(self):
"""Run the pipeline sequentially in the current thread. The
stages are run one after the other.
stages are run one after the other. Only the first coroutine
in each stage is used.
"""
coros = [stage[0] for stage in self.stages]
# "Prime" the coroutines.
for coro in self.stages[1:]:
for coro in coros:
coro.next()
# Begin the pipeline.
for msg in self.stages[0]:
for stage in self.stages[1:]:
msg = stage.send(msg)
for msg in coros[0]:
for coro in coros[1:]:
msg = coro.send(msg)
if msg is BUBBLE:
# Don't continue to the next stage.
break
@ -219,37 +246,44 @@ class Pipeline(object):
size.
"""
queues = [Queue.Queue(queue_size) for i in range(len(self.stages)-1)]
threads = [FirstPipelineThread(self.stages[0], queues[0])]
threads = []
# Set up first stage.
for coro in self.stages[0]:
threads.append(FirstPipelineThread(coro, queues[0], threads))
# Middle stages.
for i in range(1, len(self.stages)-1):
threads.append(MiddlePipelineThread(
self.stages[i], queues[i-1], queues[i]
))
threads.append(LastPipelineThread(self.stages[-1], queues[-1]))
for coro in self.stages[i]:
threads.append(MiddlePipelineThread(
coro, queues[i-1], queues[i], threads
))
# Last stage.
for coro in self.stages[-1]:
threads.append(
LastPipelineThread(coro, queues[-1], threads)
)
# Start threads.
for thread in threads:
thread.start()
# Wait for termination.
try:
# The final thread lasts the longest.
threads[-1].join()
finally:
# Halt the pipeline in case there was an exception.
for thread in threads:
thread.abort()
for queue in queues:
clear_queue(queue)
# Wait for termination. The final thread lasts the longest.
threads[-1].join()
# Make completely sure that all the threads have finished
# before we return.
# before we return. They should already be either finished,
# in normal operation, or aborted, in case of an exception.
for thread in threads[:-1]:
thread.join()
exc_info = threads[-1].exc_info
if exc_info:
# Make the exception appear as it was raised originally.
raise exc_info[0], exc_info[1], exc_info[2]
for thread in threads:
exc_info = thread.exc_info
if exc_info:
# Make the exception appear as it was raised originally.
raise exc_info[0], exc_info[1], exc_info[2]
# Smoke test.
if __name__ == '__main__':
@ -259,39 +293,42 @@ if __name__ == '__main__':
# in parallel.
def produce():
for i in range(5):
print 'generating', i
print 'generating %i' % i
time.sleep(1)
yield i
def work():
num = yield
while True:
print 'processing', num
print 'processing %i' % num
time.sleep(2)
num = yield num*2
def consume():
while True:
num = yield
time.sleep(1)
print 'received', num
print 'received %i' % num
ts_start = time.time()
Pipeline([produce(), work(), consume()]).run_sequential()
ts_middle = time.time()
ts_seq = time.time()
Pipeline([produce(), work(), consume()]).run_parallel()
ts_par = time.time()
Pipeline([produce(), (work(), work()), consume()]).run_parallel()
ts_end = time.time()
print 'Sequential time:', ts_middle - ts_start
print 'Parallel time:', ts_end - ts_middle
print 'Sequential time:', ts_seq - ts_start
print 'Parallel time:', ts_par - ts_seq
print 'Multiply-parallel time:', ts_end - ts_par
print
# Test a pipeline that raises an exception.
def exc_produce():
for i in range(10):
print 'generating', i
print 'generating %i' % i
time.sleep(1)
yield i
def exc_work():
num = yield
while True:
print 'processing', num
print 'processing %i' % num
time.sleep(3)
if num == 3:
raise Exception()
@ -301,5 +338,5 @@ if __name__ == '__main__':
num = yield
#if num == 4:
# raise Exception()
print 'received', num
print 'received %i' % num
Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1)