reliably terminate pipeline when exception is raised

Previously, the producer thread (i.e., the first stage) would continue running
to completion even when an exception was raised! And, depending on the size of
the queue, deadlock was even possible if the next stage was no longer consuming
the produced values.
This commit is contained in:
Adrian Sampson 2010-08-02 19:07:47 -07:00
parent 4239c08127
commit df766abcb4

View file

@ -25,7 +25,7 @@ shutdown when the processing is complete and when a stage raises an
exception.
"""
from __future__ import with_statement # for Python 2.5
from Queue import Queue
import Queue
from threading import Thread, Lock
BUBBLE = '__PIPELINE_BUBBLE__'
@ -58,6 +58,15 @@ class FirstPipelineThread(Thread):
# Time to abort?
with self.abort_lock:
if self.abort_flag:
# We may have accidentally added one more object
# to the queue *after* it was cleared by the
# abort() method. Remove it if present.
try:
self.out_queue.get_nowait()
except Queue.Empty:
pass
# Stop generating and poison.
break
# Get the value from the generator.
@ -82,6 +91,20 @@ class FirstPipelineThread(Thread):
poisoning out_channel.
"""
with self.abort_lock:
# Empty the channel before poisoning it.
# This very hacky approach to clearing the queue is
# compliments of Tim Peters:
# http://www.mail-archive.com/python-list@python.org/msg95322.html
self.out_queue.mutex.acquire()
try:
self.out_queue.queue.clear()
self.out_queue.unfinished_tasks = 0
self.out_queue.not_full.notify()
self.out_queue.all_tasks_done.notifyAll()
finally:
self.out_queue.mutex.release()
# Notify the generator thread.
self.abort_flag = True
class MiddlePipelineThread(Thread):
@ -188,7 +211,7 @@ class Pipeline(object):
messages between the stages are stored in queues of the given
size.
"""
queues = [Queue(queue_size) for i in range(len(self.stages)-1)]
queues = [Queue.Queue(queue_size) for i in range(len(self.stages)-1)]
threads = [FirstPipelineThread(self.stages[0], queues[0])]
for i in range(1, len(self.stages)-1):
threads.append(MiddlePipelineThread(
@ -202,15 +225,22 @@ class Pipeline(object):
# Wait for termination.
try:
for thread in threads:
thread.join()
# The final thread lasts the longest.
threads[-1].join()
except:
# Shut down the pipeline by telling the first thread to
# poison its channel.
threads[0].abort()
raise
# Was there an exception?
# Halt the pipeline in case there was an exception.
threads[0].abort()
# Make completely sure that all the threads have finished
# before we return.
for thread in threads[:-1]:
thread.join()
exc = threads[-1].exc
if exc:
raise exc
@ -219,6 +249,8 @@ class Pipeline(object):
if __name__ == '__main__':
import time
# Test a normally-terminating pipeline both in sequence and
# in parallel.
def produce():
for i in range(5):
print 'generating', i
@ -235,12 +267,32 @@ if __name__ == '__main__':
num = yield
time.sleep(1)
print 'received', num
ts_start = time.time()
Pipeline([produce(), work(), consume()]).run_sequential()
# Pipeline([produce(), work(), consume()]).run_sequential()
ts_middle = time.time()
Pipeline([produce(), work(), consume()]).run_parallel()
# Pipeline([produce(), work(), consume()]).run_parallel()
ts_end = time.time()
print 'Sequential time:', ts_middle - ts_start
print 'Parallel time:', ts_end - ts_middle
print
# Test a pipeline that raises an exception.
def exc_produce():
for i in range(10):
print 'generating', i
time.sleep(1)
yield i
def exc_work():
num = yield
while True:
print 'processing', num
time.sleep(3)
if num == 3:
raise Exception()
num = yield num * 2
def exc_consume():
while True:
num = yield
print 'received', num
Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1)