diff --git a/beets/util/pipeline.py b/beets/util/pipeline.py index 98a1addce..cebde0f23 100644 --- a/beets/util/pipeline.py +++ b/beets/util/pipeline.py @@ -492,64 +492,3 @@ class Pipeline: msgs = next_msgs for msg in msgs: yield msg - - -# Smoke test. -if __name__ == "__main__": - import time - - # Test a normally-terminating pipeline both in sequence and - # in parallel. - def produce(): - for i in range(5): - print("generating %i" % i) - time.sleep(1) - yield i - - def work(): - num = yield - while True: - print("processing %i" % num) - time.sleep(2) - num = yield num * 2 - - def consume(): - while True: - num = yield - time.sleep(1) - print("received %i" % num) - - ts_start = time.time() - Pipeline([produce(), work(), consume()]).run_sequential() - ts_seq = time.time() - Pipeline([produce(), work(), consume()]).run_parallel() - ts_par = time.time() - Pipeline([produce(), (work(), work()), consume()]).run_parallel() - ts_end = time.time() - print("Sequential time:", ts_seq - ts_start) - print("Parallel time:", ts_par - ts_seq) - print("Multiply-parallel time:", ts_end - ts_par) - print() - - # Test a pipeline that raises an exception. - def exc_produce(): - for i in range(10): - print("generating %i" % i) - time.sleep(1) - yield i - - def exc_work(): - num = yield - while True: - print("processing %i" % num) - time.sleep(3) - if num == 3: - raise Exception() - num = yield num * 2 - - def exc_consume(): - while True: - num = yield - print("received %i" % num) - - Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1) diff --git a/test/test_pipeline.py b/test/test_pipeline.py index 83b8d744c..5007ad826 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -39,11 +39,16 @@ def _consume(result): result.append(i) -# A worker that raises an exception. +# Pipeline stages that raise an exception. class PipelineError(Exception): pass +def _exc_produce(num=5): + yield from range(num) + raise PipelineError() + + def _exc_work(num=3): i = None while True: @@ -53,6 +58,14 @@ def _exc_work(num=3): i *= 2 +def _exc_consume(result, num=4): + while True: + i = yield + if i == num: + raise PipelineError() + result.append(i) + + # A worker that yields a bubble. def _bub_work(num=3): i = None @@ -121,17 +134,32 @@ class ParallelStageTest(unittest.TestCase): class ExceptionTest(unittest.TestCase): def setUp(self): self.result = [] - self.pl = pipeline.Pipeline( - (_produce(), _exc_work(), _consume(self.result)) - ) + + def run_sequential(self, *stages): + pl = pipeline.Pipeline(stages) + with pytest.raises(PipelineError): + pl.run_sequential() + + def run_parallel(self, *stages): + pl = pipeline.Pipeline(stages) + with pytest.raises(PipelineError): + pl.run_parallel() def test_run_sequential(self): - with pytest.raises(PipelineError): - self.pl.run_sequential() + """Test that exceptions from various stages of the pipeline are + properly propagated when running sequentially. + """ + self.run_sequential(_exc_produce(), _work(), _consume(self.result)) + self.run_sequential(_produce(), _exc_work(), _consume(self.result)) + self.run_sequential(_produce(), _work(), _exc_consume(self.result)) def test_run_parallel(self): - with pytest.raises(PipelineError): - self.pl.run_parallel() + """Test that exceptions from various stages of the pipeline are + properly propagated when running in parallel. + """ + self.run_parallel(_exc_produce(), _work(), _consume(self.result)) + self.run_parallel(_produce(), _exc_work(), _consume(self.result)) + self.run_parallel(_produce(), _work(), _exc_consume(self.result)) def test_pull(self): pl = pipeline.Pipeline((_produce(), _exc_work()))