diff --git a/beets/autotag/match.py b/beets/autotag/match.py index 0279e27f0..ae5be88a4 100644 --- a/beets/autotag/match.py +++ b/beets/autotag/match.py @@ -173,42 +173,33 @@ def current_metadata(items): consensus[key] = (freq == len(values)) return likelies['artist'], likelies['album'], consensus['artist'] -def order_items(items, trackinfo): - """Orders the items based on how they match some canonical track - information. Returns a list of Items whose length is equal to the - length of ``trackinfo``. This always produces a result if the - numbers of items is at most the number of TrackInfo objects - (otherwise, returns None). In the case of a partial match, the - returned list may contain None in some positions. +def assign_items(items, tracks): + """Given a list of Items and a list of TrackInfo objects, find the + best mapping between them. Returns a mapping from Items to TrackInfo + objects, a set of extra Items, and a set of extra TrackInfo + objects. These "extra" objects occur when there is an unequal number + of objects of the two types. """ - # Make sure lengths match: If there is less items, it might just be that - # there is some tracks missing. - if len(items) > len(trackinfo): - return None - # Construct the cost matrix. costs = [] - for cur_item in items: + for item in items: row = [] - for i, canon_item in enumerate(trackinfo): - row.append(track_distance(cur_item, canon_item, i+1)) + for i, track in enumerate(tracks): + row.append(track_distance(item, track)) costs.append(row) # Find a minimum-cost bipartite matching. matching = Munkres().compute(costs) - # Order items based on the matching. - ordered_items = [None]*len(trackinfo) - for cur_idx, canon_idx in matching: - ordered_items[canon_idx] = items[cur_idx] - return ordered_items + # Produce the output matching. + mapping = dict((items[i], tracks[j]) for (i, j) in matching) + extra_items = set(items) - set(mapping.keys()) + extra_tracks = set(tracks) - set(mapping.values()) + return mapping, extra_items, extra_tracks -def track_distance(item, track_info, track_index=None, incl_artist=False): - """Determines the significance of a track metadata change. Returns - a float in [0.0,1.0]. `track_index` is the track number of the - `track_info` metadata set. If `track_index` is provided and - item.track is set, then these indices are used as a component of - the distance calculation. `incl_artist` indicates that a distance +def track_distance(item, track_info, incl_artist=False): + """Determines the significance of a track metadata change. Returns a + float in [0.0,1.0]. `incl_artist` indicates that a distance component should be included for the track artist (i.e., for various-artist releases). """ @@ -239,8 +230,8 @@ def track_distance(item, track_info, track_index=None, incl_artist=False): dist_max += TRACK_ARTIST_WEIGHT # Track index. - if track_index and item.track: - if item.track not in (track_index, track_info.medium_index): + if track_info.index and item.track: + if item.track not in (track_info.index, track_info.medium_index): dist += TRACK_INDEX_WEIGHT dist_max += TRACK_INDEX_WEIGHT @@ -280,7 +271,7 @@ def distance(items, album_info): # Track distances. for i, (item, track_info) in enumerate(zip(items, album_info.tracks)): if item: - dist += track_distance(item, track_info, i+1, album_info.va) * \ + dist += track_distance(item, track_info, album_info.va) * \ TRACK_WEIGHT dist_max += TRACK_WEIGHT else: @@ -348,7 +339,7 @@ def recommendation(results): rec = RECOMMEND_NONE return rec -def validate_candidate(items, tuple_dict, info): +def validate_candidate(items, results, info): """Given a candidate AlbumInfo object, attempt to add the candidate to the output dictionary of result tuples. This involves checking the track count, ordering the items, checking for duplicates, and @@ -357,7 +348,7 @@ def validate_candidate(items, tuple_dict, info): log.debug('Candidate: %s - %s' % (info.artist, info.album)) # Don't duplicate. - if info.album_id in tuple_dict: + if info.album_id in results: log.debug('Duplicate.') return @@ -368,16 +359,17 @@ def validate_candidate(items, tuple_dict, info): return # Put items in order. - ordered = order_items(items, info.tracks) - if not ordered: - log.debug('Not orderable.') - return + mapping, extra_items, extra_tracks = assign_items(items, info.tracks) + # TEMPORARY: make ordered item list with gaps. + ordered = [None] * len(info.tracks) + for item, track_info in mapping.iteritems(): + ordered[track_info.index - 1] = item # Get the change distance. dist = distance(ordered, info) log.debug('Success. Distance: %f' % dist) - tuple_dict[info.album_id] = dist, ordered, info + results[info.album_id] = dist, ordered, info def tag_album(items, timid=False, search_artist=None, search_album=None, search_id=None): diff --git a/test/test_autotag.py b/test/test_autotag.py index a976634de..8da79cf24 100644 --- a/test/test_autotag.py +++ b/test/test_autotag.py @@ -345,14 +345,14 @@ class MultiDiscAlbumsInDirTest(unittest.TestCase): albums = list(autotag.albums_in_dir(self.base)) self.assertEquals(len(albums), 0) -class OrderingTest(unittest.TestCase): +class AssignmentTest(unittest.TestCase): def item(self, title, track): return Item({ 'title': title, 'track': track, 'mb_trackid': '', 'mb_albumid': '', 'mb_artistid': '', }) - def test_order_corrects_metadata(self): + def test_reorder_when_track_numbers_incorrect(self): items = [] items.append(self.item('one', 1)) items.append(self.item('three', 2)) @@ -361,12 +361,17 @@ class OrderingTest(unittest.TestCase): trackinfo.append(TrackInfo('one', None)) trackinfo.append(TrackInfo('two', None)) trackinfo.append(TrackInfo('three', None)) - ordered = match.order_items(items, trackinfo) - self.assertEqual(ordered[0].title, 'one') - self.assertEqual(ordered[1].title, 'two') - self.assertEqual(ordered[2].title, 'three') + mapping, extra_items, extra_tracks = \ + match.assign_items(items, trackinfo) + self.assertEqual(extra_items, set()) + self.assertEqual(extra_tracks, set()) + self.assertEqual(mapping, { + items[0]: trackinfo[0], + items[1]: trackinfo[2], + items[2]: trackinfo[1], + }) - def test_order_works_with_incomplete_metadata(self): + def test_order_works_with_invalid_track_numbers(self): items = [] items.append(self.item('one', 1)) items.append(self.item('three', 1)) @@ -375,21 +380,15 @@ class OrderingTest(unittest.TestCase): trackinfo.append(TrackInfo('one', None)) trackinfo.append(TrackInfo('two', None)) trackinfo.append(TrackInfo('three', None)) - ordered = match.order_items(items, trackinfo) - self.assertEqual(ordered[0].title, 'one') - self.assertEqual(ordered[1].title, 'two') - self.assertEqual(ordered[2].title, 'three') - - def test_order_returns_none_for_length_mismatch(self): - items = [] - items.append(self.item('one', 1)) - items.append(self.item('two', 2)) - items.append(self.item('three', 3)) - items.append(self.item('four',4)) - trackinfo = [] - trackinfo.append(TrackInfo('one', None)) - ordered = match.order_items(items, trackinfo) - self.assertEqual(ordered, None) + mapping, extra_items, extra_tracks = \ + match.assign_items(items, trackinfo) + self.assertEqual(extra_items, set()) + self.assertEqual(extra_tracks, set()) + self.assertEqual(mapping, { + items[0]: trackinfo[0], + items[1]: trackinfo[2], + items[2]: trackinfo[1], + }) def test_order_works_with_missing_tracks(self): items = [] @@ -399,12 +398,16 @@ class OrderingTest(unittest.TestCase): trackinfo.append(TrackInfo('one', None)) trackinfo.append(TrackInfo('two', None)) trackinfo.append(TrackInfo('three', None)) - ordered = match.order_items(items, trackinfo) - self.assertEqual(ordered[0].title, 'one') - self.assertEqual(ordered[1], None) - self.assertEqual(ordered[2].title, 'three') + mapping, extra_items, extra_tracks = \ + match.assign_items(items, trackinfo) + self.assertEqual(extra_items, set()) + self.assertEqual(extra_tracks, set([trackinfo[1]])) + self.assertEqual(mapping, { + items[0]: trackinfo[0], + items[1]: trackinfo[2], + }) - def test_order_returns_none_for_extra_tracks(self): + def test_order_works_with_extra_tracks(self): items = [] items.append(self.item('one', 1)) items.append(self.item('two', 2)) @@ -412,10 +415,16 @@ class OrderingTest(unittest.TestCase): trackinfo = [] trackinfo.append(TrackInfo('one', None)) trackinfo.append(TrackInfo('three', None)) - ordered = match.order_items(items, trackinfo) - self.assertEqual(ordered, None) + mapping, extra_items, extra_tracks = \ + match.assign_items(items, trackinfo) + self.assertEqual(extra_items, set([items[1]])) + self.assertEqual(extra_tracks, set()) + self.assertEqual(mapping, { + items[0]: trackinfo[0], + items[2]: trackinfo[1], + }) - def test_order_corrects_when_track_names_are_entirely_wrong(self): + def test_order_works_when_track_names_are_entirely_wrong(self): # A real-world test case contributed by a user. def item(i, length): return Item({ @@ -440,25 +449,28 @@ class OrderingTest(unittest.TestCase): items.append(item(11, 243.57001238834192)) items.append(item(12, 186.45916150485752)) - def info(title, length): - return TrackInfo(title, None, length=length) + def info(index, title, length): + return TrackInfo(title, None, length=length, index=index) trackinfo = [] - trackinfo.append(info('Alone', 238.893)) - trackinfo.append(info('The Woman in You', 341.44)) - trackinfo.append(info('Less', 245.59999999999999)) - trackinfo.append(info('Two Hands of a Prayer', 470.49299999999999)) - trackinfo.append(info('Please Bleed', 277.86599999999999)) - trackinfo.append(info('Suzie Blue', 269.30599999999998)) - trackinfo.append(info('Steal My Kisses', 245.36000000000001)) - trackinfo.append(info('Burn to Shine', 214.90600000000001)) - trackinfo.append(info('Show Me a Little Shame', 224.09299999999999)) - trackinfo.append(info('Forgiven', 317.19999999999999)) - trackinfo.append(info('Beloved One', 243.733)) - trackinfo.append(info('In the Lord\'s Arms', 186.13300000000001)) + trackinfo.append(info(1, 'Alone', 238.893)) + trackinfo.append(info(2, 'The Woman in You', 341.44)) + trackinfo.append(info(3, 'Less', 245.59999999999999)) + trackinfo.append(info(4, 'Two Hands of a Prayer', 470.49299999999999)) + trackinfo.append(info(5, 'Please Bleed', 277.86599999999999)) + trackinfo.append(info(6, 'Suzie Blue', 269.30599999999998)) + trackinfo.append(info(7, 'Steal My Kisses', 245.36000000000001)) + trackinfo.append(info(8, 'Burn to Shine', 214.90600000000001)) + trackinfo.append(info(9, 'Show Me a Little Shame', 224.09299999999999)) + trackinfo.append(info(10, 'Forgiven', 317.19999999999999)) + trackinfo.append(info(11, 'Beloved One', 243.733)) + trackinfo.append(info(12, 'In the Lord\'s Arms', 186.13300000000001)) - ordered = match.order_items(items, trackinfo) - for i, item in enumerate(ordered): - self.assertEqual(i+1, item.track) + mapping, extra_items, extra_tracks = \ + match.assign_items(items, trackinfo) + self.assertEqual(extra_items, set()) + self.assertEqual(extra_tracks, set()) + for item, info in mapping.iteritems(): + self.assertEqual(items.index(item), trackinfo.index(info)) class ApplyTest(unittest.TestCase): def setUp(self):