Merge branch 'master' of git://github.com/sampsyo/beets

This commit is contained in:
mdecker 2011-04-17 23:17:10 +02:00
commit 9c2b7a9653
4 changed files with 325 additions and 72 deletions

View file

@ -141,7 +141,7 @@ class ImportConfig(object):
class ImportTask(object):
"""Represents a single set of items to be imported along with its
intermediate state. May represent an album or just a set of items.
intermediate state. May represent an album or a single item.
"""
def __init__(self, toppath=None, path=None, items=None):
self.toppath = toppath
@ -162,7 +162,7 @@ class ImportTask(object):
def item_task(cls, item):
"""Creates an ImportTask for a single item."""
obj = cls()
obj.items = [item]
obj.item = item
obj.is_album = False
return obj
@ -182,26 +182,18 @@ class ImportTask(object):
"""
self.set_match(None, None, None, None)
def set_item_matches(self, item_matches):
"""Sets the candidates for this set of items after an initial
match. `item_matches` should be a list of match tuples,
one for each item.
"""
assert len(self.items) == len(item_matches)
self.item_matches = item_matches
self.is_album = False
def set_item_match(self, candidates, rec):
"""Set the match for a single-item task."""
assert len(self.items) == 1
self.item_matches = [(candidates, rec)]
assert not self.is_album
assert self.item is not None
self.item_match = (candidates, rec)
def set_null_item_match(self):
"""For single-item tasks, mark the item as having no matches.
"""
assert len(self.items) == 1
assert not self.is_album
self.item_matches = [None]
assert self.item is not None
self.item_match = None
def set_choice(self, choice):
"""Given either an (info, items) tuple or an action constant,
@ -215,7 +207,11 @@ class ImportTask(object):
self.choice_flag = choice
self.info = None
if choice == action.SKIP:
self.items = None # Items no longer needed.
# Items are no longer needed.
if self.is_album:
self.items = None
else:
self.item = None
else:
assert not isinstance(choice, action)
if self.is_album:
@ -356,15 +352,31 @@ def user_query(config):
choice = config.choose_match_func(task, config)
task.set_choice(choice)
# As-tracks: transition to singleton workflow.
if choice is action.TRACKS:
# Set up a little pipeline for dealing with the singletons.
item_tasks = []
def emitter():
for item in task.items:
yield ImportTask.item_task(item)
def collector():
while True:
item_task = yield
item_tasks.append(item_task)
ipl = pipeline.Pipeline((emitter(), item_lookup(config),
item_query(config), collector()))
ipl.run_sequential()
task = pipeline.multiple(item_tasks)
# Log certain choices.
if choice is action.ASIS:
tag_log(config.logfile, 'asis', task.path)
elif choice is action.SKIP:
tag_log(config.logfile, 'skip', task.path)
# Check for duplicates if we have a match.
if choice == action.ASIS or isinstance(choice, tuple):
if choice == action.ASIS:
# Check for duplicates if we have a match (or ASIS).
if choice is action.ASIS or isinstance(choice, tuple):
if choice is action.ASIS:
artist = task.cur_artist
album = task.cur_album
else:
@ -412,12 +424,11 @@ def apply_choices(config):
if task.is_album:
autotag.apply_metadata(task.items, task.info)
else:
for item, info in zip(task.items, task.info):
autotag.apply_item_metadata(item, info)
autotag.apply_item_metadata(task.item, task.info)
items = task.items if task.is_album else [task.item]
if config.copy and config.delete:
old_paths = [os.path.realpath(item.path)
for item in task.items]
for item in task.items:
old_paths = [os.path.realpath(syspath(item.path)) for item in items]
for item in items:
if config.copy:
item.move(lib, True, task.should_create_album())
if config.write and task.should_write_tags():
@ -431,7 +442,7 @@ def apply_choices(config):
infer_aa = task.should_infer_aa())
else:
# Add tracks.
for item in task.items:
for item in items:
lib.add(item)
lib.save()
@ -446,12 +457,11 @@ def apply_choices(config):
if task.should_create_album():
plugins.send('album_imported', lib=lib, album=albuminfo)
else:
for item in task.items:
plugins.send('item_imported', lib=lib, item=item)
plugins.send('item_imported', lib=lib, item=task.item)
# Finally, delete old files.
if config.copy and config.delete:
new_paths = [os.path.realpath(item.path) for item in task.items]
new_paths = [os.path.realpath(item.path) for item in items]
for old_path in old_paths:
# Only delete files that were actually moved.
if old_path not in new_paths:
@ -481,7 +491,7 @@ def item_lookup(config):
task = None
while True:
task = yield task
task.set_item_match(*autotag.tag_item(task.items[0]))
task.set_item_match(*autotag.tag_item(task.item))
def item_query(config):
"""A coroutine that queries the user for input on single-item
@ -491,7 +501,7 @@ def item_query(config):
while True:
task = yield task
choice = config.choose_item_func(task, config)
task.set_choice([choice])
task.set_choice(choice)
def item_progress(config):
"""Skips the lookup and query stages in a non-autotagged singleton
@ -501,7 +511,7 @@ def item_progress(config):
log.info('Importing items:')
while True:
task = yield task
log.info(task.items[0].path)
log.info(task.item.path)
task.set_null_item_match()
task.set_choice(action.ASIS)

View file

@ -351,20 +351,18 @@ def choose_match(task, config):
return choice
def choose_item(task, config):
"""Ask the user for a choice about tagging a set of items. Returns
"""Ask the user for a choice about tagging a single item. Returns
either an action constant or a track info dictionary.
"""
print_()
print_(task.items[0].path)
#TODO multiple items.
candidates, rec = task.item_matches[0]
print_(task.item.path)
candidates, rec = task.item_match
if config.quiet:
# Quiet mode; make a decision.
if task.rec == autotag.RECOMMEND_STRONG:
dist, track_info = candidates[0]
show_item_change(task.items[0], track_info, dist, config.color)
show_item_change(task.item.color)
return track_info
else:
return _quiet_fall_back(config)
@ -372,7 +370,7 @@ def choose_item(task, config):
while True:
# Ask for a choice.
choice = choose_candidate(candidates, True, rec, config.color,
item=task.items[0])
item=task.item)
if choice in (importer.action.SKIP, importer.action.ASIS):
return choice
@ -381,7 +379,7 @@ def choose_item(task, config):
elif choice == importer.action.MANUAL:
# Continue in the loop with a new set of candidates.
search_artist, search_title = manual_search(False)
candidates, rec = autotag.tag_item(task.items[0], search_artist,
candidates, rec = autotag.tag_item(task.item, search_artist,
search_title)
else:
# Chose a candidate.

View file

@ -41,17 +41,23 @@ POISON = '__PIPELINE_POISON__'
DEFAULT_QUEUE_SIZE = 16
def invalidate_queue(q):
def _invalidate_queue(q, val=None, sync=True):
"""Breaks a Queue such that it never blocks, always has size 1,
and has no maximum size.
and has no maximum size. get()ing from the queue returns `val`,
which defaults to None. `sync` controls whether a lock is
required (because it's not reentrant!).
"""
def _qsize(len=len):
return 1
def _put(item):
pass
def _get():
return None
with q.mutex:
return val
if sync:
q.mutex.acquire()
try:
q.maxsize = 0
q._qsize = _qsize
q._put = _put
@ -59,13 +65,80 @@ def invalidate_queue(q):
q.not_empty.notify()
q.not_full.notify()
class PipelineError(object):
"""An indication that an exception occurred in the pipeline. The
object is passed through the pipeline to shut down all threads
before it is raised again in the main thread.
finally:
if sync:
q.mutex.release()
class CountedQueue(Queue.Queue):
"""A queue that keeps track of the number of threads that are
still feeding into it. The queue is poisoned when all threads are
finished with the queue.
"""
def __init__(self, exc_info):
self.exc_info = exc_info
def __init__(self, maxsize=0):
Queue.Queue.__init__(self, maxsize)
self.nthreads = 0
self.poisoned = False
def acquire(self):
"""Indicate that a thread will start putting into this queue.
Should not be called after the queue is already poisoned.
"""
with self.mutex:
assert not self.poisoned
assert self.nthreads >= 0
self.nthreads += 1
def release(self):
"""Indicate that a thread that was putting into this queue has
exited. If this is the last thread using the queue, the queue
is poisoned.
"""
with self.mutex:
self.nthreads -= 1
assert self.nthreads >= 0
if self.nthreads == 0:
# All threads are done adding to this queue. Poison it
# when it becomes empty.
self.poisoned = True
# Replacement _get invalidates when no items remain.
_old_get = self._get
def _get():
out = _old_get()
if not self.queue:
_invalidate_queue(self, POISON, False)
return out
if self.queue:
# Items remain.
self._get = _get
else:
# No items. Invalidate immediately.
_invalidate_queue(self, POISON, False)
class MultiMessage(object):
"""A message yielded by a pipeline stage encapsulating multiple
values to be sent to the next stage.
"""
def __init__(self, messages):
self.messages = messages
def multiple(messages):
"""Yield multiple([message, ..]) from a pipeline stage to send
multiple values to the next pipeline stage.
"""
return MultiMessage(messages)
def _allmsgs(obj):
"""Returns a list of all the messages encapsulated in obj. If obj
is a MultiMessage, returns its enclosed messages. If obj is BUBBLE,
returns an empty list. Otherwise, returns a list containing obj.
"""
if isinstance(obj, MultiMessage):
return obj.messages
elif obj == BUBBLE:
return []
else:
return [obj]
class PipelineThread(Thread):
"""Abstract base class for pipeline-stage threads."""
@ -84,9 +157,9 @@ class PipelineThread(Thread):
# Ensure that we are not blocking on a queue read or write.
if hasattr(self, 'in_queue'):
invalidate_queue(self.in_queue)
_invalidate_queue(self.in_queue)
if hasattr(self, 'out_queue'):
invalidate_queue(self.out_queue)
_invalidate_queue(self.out_queue)
def abort_all(self, exc_info):
"""Abort all other threads in the system for an exception.
@ -103,6 +176,7 @@ class FirstPipelineThread(PipelineThread):
super(FirstPipelineThread, self).__init__(all_threads)
self.coro = coro
self.out_queue = out_queue
self.out_queue.acquire()
self.abort_lock = Lock()
self.abort_flag = False
@ -110,7 +184,6 @@ class FirstPipelineThread(PipelineThread):
def run(self):
try:
while True:
# Time to abort?
with self.abort_lock:
if self.abort_flag:
return
@ -121,17 +194,19 @@ class FirstPipelineThread(PipelineThread):
except StopIteration:
break
# Send it to the next stage.
if msg is BUBBLE:
continue
self.out_queue.put(msg)
# Send messages to the next stage.
for msg in _allmsgs(msg):
with self.abort_lock:
if self.abort_flag:
return
self.out_queue.put(msg)
except:
self.abort_all(sys.exc_info())
return
# Generator finished; shut down the pipeline.
self.out_queue.put(POISON)
self.out_queue.release()
class MiddlePipelineThread(PipelineThread):
"""A thread running any stage in the pipeline except the first or
@ -142,6 +217,7 @@ class MiddlePipelineThread(PipelineThread):
self.coro = coro
self.in_queue = in_queue
self.out_queue = out_queue
self.out_queue.acquire()
def run(self):
try:
@ -165,18 +241,19 @@ class MiddlePipelineThread(PipelineThread):
# Invoke the current stage.
out = self.coro.send(msg)
# Send message to next stage.
if out is BUBBLE:
continue
self.out_queue.put(out)
# Send messages to next stage.
for msg in _allmsgs(out):
with self.abort_lock:
if self.abort_flag:
return
self.out_queue.put(msg)
except:
self.abort_all(sys.exc_info())
return
# Pipeline is shutting down normally.
self.in_queue.put(POISON)
self.out_queue.put(POISON)
self.out_queue.release()
class LastPipelineThread(PipelineThread):
"""A thread running the last stage in a pipeline. The coroutine
@ -212,9 +289,6 @@ class LastPipelineThread(PipelineThread):
except:
self.abort_all(sys.exc_info())
return
# Pipeline is shutting down normally.
self.in_queue.put(POISON)
class Pipeline(object):
"""Represents a staged pattern of work. Each stage in the pipeline
@ -247,19 +321,21 @@ class Pipeline(object):
coro.next()
# Begin the pipeline.
for msg in coros[0]:
for out in coros[0]:
msgs = _allmsgs(out)
for coro in coros[1:]:
msg = coro.send(msg)
if msg is BUBBLE:
# Don't continue to the next stage.
break
next_msgs = []
for msg in msgs:
out = coro.send(msg)
next_msgs.extend(_allmsgs(out))
msgs = next_msgs
def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE):
"""Run the pipeline in parallel using one thread per stage. The
messages between the stages are stored in queues of the given
size.
"""
queues = [Queue.Queue(queue_size) for i in range(len(self.stages)-1)]
queues = [CountedQueue(queue_size) for i in range(len(self.stages)-1)]
threads = []
# Set up first stage.

169
test/test_pipeline.py Normal file
View file

@ -0,0 +1,169 @@
# This file is part of beets.
# Copyright 2010, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Test the "pipeline.py" restricted parallel programming library.
"""
import unittest
import _common
from beets.util import pipeline
# Some simple pipeline stages for testing.
def _produce(num=5):
for i in range(num):
yield i
def _work():
i = None
while True:
i = yield i
i *= 2
def _consume(l):
while True:
i = yield
l.append(i)
# A worker that raises an exception.
class TestException(Exception): pass
def _exc_work(num=3):
i = None
while True:
i = yield i
if i == num:
raise TestException()
i *= 2
# A worker that yields a bubble.
def _bub_work(num=3):
i = None
while True:
i = yield i
if i == num:
i = pipeline.BUBBLE
else:
i *= 2
# Yet another worker that yields multiple messages.
def _multi_work():
i = None
while True:
i = yield i
i = pipeline.multiple([i, -i])
class SimplePipelineTest(unittest.TestCase):
def setUp(self):
self.l = []
self.pl = pipeline.Pipeline((_produce(), _work(), _consume(self.l)))
def test_run_sequential(self):
self.pl.run_sequential()
self.assertEqual(self.l, [0,2,4,6,8])
def test_run_parallel(self):
self.pl.run_parallel()
self.assertEqual(self.l, [0,2,4,6,8])
class ParallelStageTest(unittest.TestCase):
def setUp(self):
self.l = []
self.pl = pipeline.Pipeline((
_produce(), (_work(), _work()), _consume(self.l)
))
def test_run_sequential(self):
self.pl.run_sequential()
self.assertEqual(self.l, [0,2,4,6,8])
def test_run_parallel(self):
self.pl.run_parallel()
# Order possibly not preserved; use set equality.
self.assertEqual(set(self.l), set([0,2,4,6,8]))
class ExceptionTest(unittest.TestCase):
def setUp(self):
self.l = []
self.pl = pipeline.Pipeline((_produce(), _exc_work(), _consume(self.l)))
def test_run_sequential(self):
self.assertRaises(TestException, self.pl.run_sequential)
def test_run_parallel(self):
self.assertRaises(TestException, self.pl.run_parallel)
class ParallelExceptionTest(unittest.TestCase):
def setUp(self):
self.l = []
self.pl = pipeline.Pipeline((
_produce(), (_exc_work(), _exc_work()), _consume(self.l)
))
def test_run_parallel(self):
self.assertRaises(TestException, self.pl.run_parallel)
class ConstrainedThreadedPipelineTest(unittest.TestCase):
def test_constrained(self):
l = []
# Do a "significant" amount of work...
pl = pipeline.Pipeline((_produce(1000), _work(), _consume(l)))
# ... with only a single queue slot.
pl.run_parallel(1)
self.assertEqual(l, [i*2 for i in range(1000)])
def test_constrained_exception(self):
# Raise an exception in a constrained pipeline.
l = []
pl = pipeline.Pipeline((_produce(1000), _exc_work(), _consume(l)))
self.assertRaises(TestException, pl.run_parallel, 1)
def test_constrained_parallel(self):
l = []
pl = pipeline.Pipeline((
_produce(1000), (_work(), _work()), _consume(l)
))
pl.run_parallel(1)
self.assertEqual(set(l), set(i*2 for i in range(1000)))
class BubbleTest(unittest.TestCase):
def setUp(self):
self.l = []
self.pl = pipeline.Pipeline((_produce(), _bub_work(), _consume(self.l)))
def test_run_sequential(self):
self.pl.run_sequential()
self.assertEqual(self.l, [0,2,4,8])
def test_run_parallel(self):
self.pl.run_parallel()
self.assertEqual(self.l, [0,2,4,8])
class MultiMessageTest(unittest.TestCase):
def setUp(self):
self.l = []
self.pl = pipeline.Pipeline((
_produce(), _multi_work(), _consume(self.l)
))
def test_run_sequential(self):
self.pl.run_sequential()
self.assertEqual(self.l, [0,0,1,-1,2,-2,3,-3,4,-4])
def test_run_parallel(self):
self.pl.run_parallel()
self.assertEqual(self.l, [0,0,1,-1,2,-2,3,-3,4,-4])
def suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == '__main__':
unittest.main(defaultTest='suite')