match ordering without length assumptions

This replaces order_items with assign_items, the first step to allowing unequal
numbers of items on either side of the equation (user files and canonical
tracks). Rather than returning a "holey" list and assuming that the TrackInfo
objects stay static, the function returns a dictionary mapping Item objects to
TrackInfo objects. To indicate unmatched objects, two sets are also returned.

For the moment, some temporary code is included to turn the result from this
new function into the old format (a holey Item list). This allowed me to test
this change in isolation before plunging ahead with the necessary refactoring to
expose all of this to the importer workflow, etc.
This commit is contained in:
Adrian Sampson 2012-06-29 15:11:25 -07:00
parent 8d7397135f
commit d4c3ea74c6
2 changed files with 87 additions and 83 deletions

View file

@ -173,42 +173,33 @@ def current_metadata(items):
consensus[key] = (freq == len(values))
return likelies['artist'], likelies['album'], consensus['artist']
def order_items(items, trackinfo):
"""Orders the items based on how they match some canonical track
information. Returns a list of Items whose length is equal to the
length of ``trackinfo``. This always produces a result if the
numbers of items is at most the number of TrackInfo objects
(otherwise, returns None). In the case of a partial match, the
returned list may contain None in some positions.
def assign_items(items, tracks):
"""Given a list of Items and a list of TrackInfo objects, find the
best mapping between them. Returns a mapping from Items to TrackInfo
objects, a set of extra Items, and a set of extra TrackInfo
objects. These "extra" objects occur when there is an unequal number
of objects of the two types.
"""
# Make sure lengths match: If there is less items, it might just be that
# there is some tracks missing.
if len(items) > len(trackinfo):
return None
# Construct the cost matrix.
costs = []
for cur_item in items:
for item in items:
row = []
for i, canon_item in enumerate(trackinfo):
row.append(track_distance(cur_item, canon_item, i+1))
for i, track in enumerate(tracks):
row.append(track_distance(item, track))
costs.append(row)
# Find a minimum-cost bipartite matching.
matching = Munkres().compute(costs)
# Order items based on the matching.
ordered_items = [None]*len(trackinfo)
for cur_idx, canon_idx in matching:
ordered_items[canon_idx] = items[cur_idx]
return ordered_items
# Produce the output matching.
mapping = dict((items[i], tracks[j]) for (i, j) in matching)
extra_items = set(items) - set(mapping.keys())
extra_tracks = set(tracks) - set(mapping.values())
return mapping, extra_items, extra_tracks
def track_distance(item, track_info, track_index=None, incl_artist=False):
"""Determines the significance of a track metadata change. Returns
a float in [0.0,1.0]. `track_index` is the track number of the
`track_info` metadata set. If `track_index` is provided and
item.track is set, then these indices are used as a component of
the distance calculation. `incl_artist` indicates that a distance
def track_distance(item, track_info, incl_artist=False):
"""Determines the significance of a track metadata change. Returns a
float in [0.0,1.0]. `incl_artist` indicates that a distance
component should be included for the track artist (i.e., for
various-artist releases).
"""
@ -239,8 +230,8 @@ def track_distance(item, track_info, track_index=None, incl_artist=False):
dist_max += TRACK_ARTIST_WEIGHT
# Track index.
if track_index and item.track:
if item.track not in (track_index, track_info.medium_index):
if track_info.index and item.track:
if item.track not in (track_info.index, track_info.medium_index):
dist += TRACK_INDEX_WEIGHT
dist_max += TRACK_INDEX_WEIGHT
@ -280,7 +271,7 @@ def distance(items, album_info):
# Track distances.
for i, (item, track_info) in enumerate(zip(items, album_info.tracks)):
if item:
dist += track_distance(item, track_info, i+1, album_info.va) * \
dist += track_distance(item, track_info, album_info.va) * \
TRACK_WEIGHT
dist_max += TRACK_WEIGHT
else:
@ -348,7 +339,7 @@ def recommendation(results):
rec = RECOMMEND_NONE
return rec
def validate_candidate(items, tuple_dict, info):
def validate_candidate(items, results, info):
"""Given a candidate AlbumInfo object, attempt to add the candidate
to the output dictionary of result tuples. This involves checking
the track count, ordering the items, checking for duplicates, and
@ -357,7 +348,7 @@ def validate_candidate(items, tuple_dict, info):
log.debug('Candidate: %s - %s' % (info.artist, info.album))
# Don't duplicate.
if info.album_id in tuple_dict:
if info.album_id in results:
log.debug('Duplicate.')
return
@ -368,16 +359,17 @@ def validate_candidate(items, tuple_dict, info):
return
# Put items in order.
ordered = order_items(items, info.tracks)
if not ordered:
log.debug('Not orderable.')
return
mapping, extra_items, extra_tracks = assign_items(items, info.tracks)
# TEMPORARY: make ordered item list with gaps.
ordered = [None] * len(info.tracks)
for item, track_info in mapping.iteritems():
ordered[track_info.index - 1] = item
# Get the change distance.
dist = distance(ordered, info)
log.debug('Success. Distance: %f' % dist)
tuple_dict[info.album_id] = dist, ordered, info
results[info.album_id] = dist, ordered, info
def tag_album(items, timid=False, search_artist=None, search_album=None,
search_id=None):

View file

@ -345,14 +345,14 @@ class MultiDiscAlbumsInDirTest(unittest.TestCase):
albums = list(autotag.albums_in_dir(self.base))
self.assertEquals(len(albums), 0)
class OrderingTest(unittest.TestCase):
class AssignmentTest(unittest.TestCase):
def item(self, title, track):
return Item({
'title': title, 'track': track,
'mb_trackid': '', 'mb_albumid': '', 'mb_artistid': '',
})
def test_order_corrects_metadata(self):
def test_reorder_when_track_numbers_incorrect(self):
items = []
items.append(self.item('one', 1))
items.append(self.item('three', 2))
@ -361,12 +361,17 @@ class OrderingTest(unittest.TestCase):
trackinfo.append(TrackInfo('one', None))
trackinfo.append(TrackInfo('two', None))
trackinfo.append(TrackInfo('three', None))
ordered = match.order_items(items, trackinfo)
self.assertEqual(ordered[0].title, 'one')
self.assertEqual(ordered[1].title, 'two')
self.assertEqual(ordered[2].title, 'three')
mapping, extra_items, extra_tracks = \
match.assign_items(items, trackinfo)
self.assertEqual(extra_items, set())
self.assertEqual(extra_tracks, set())
self.assertEqual(mapping, {
items[0]: trackinfo[0],
items[1]: trackinfo[2],
items[2]: trackinfo[1],
})
def test_order_works_with_incomplete_metadata(self):
def test_order_works_with_invalid_track_numbers(self):
items = []
items.append(self.item('one', 1))
items.append(self.item('three', 1))
@ -375,21 +380,15 @@ class OrderingTest(unittest.TestCase):
trackinfo.append(TrackInfo('one', None))
trackinfo.append(TrackInfo('two', None))
trackinfo.append(TrackInfo('three', None))
ordered = match.order_items(items, trackinfo)
self.assertEqual(ordered[0].title, 'one')
self.assertEqual(ordered[1].title, 'two')
self.assertEqual(ordered[2].title, 'three')
def test_order_returns_none_for_length_mismatch(self):
items = []
items.append(self.item('one', 1))
items.append(self.item('two', 2))
items.append(self.item('three', 3))
items.append(self.item('four',4))
trackinfo = []
trackinfo.append(TrackInfo('one', None))
ordered = match.order_items(items, trackinfo)
self.assertEqual(ordered, None)
mapping, extra_items, extra_tracks = \
match.assign_items(items, trackinfo)
self.assertEqual(extra_items, set())
self.assertEqual(extra_tracks, set())
self.assertEqual(mapping, {
items[0]: trackinfo[0],
items[1]: trackinfo[2],
items[2]: trackinfo[1],
})
def test_order_works_with_missing_tracks(self):
items = []
@ -399,12 +398,16 @@ class OrderingTest(unittest.TestCase):
trackinfo.append(TrackInfo('one', None))
trackinfo.append(TrackInfo('two', None))
trackinfo.append(TrackInfo('three', None))
ordered = match.order_items(items, trackinfo)
self.assertEqual(ordered[0].title, 'one')
self.assertEqual(ordered[1], None)
self.assertEqual(ordered[2].title, 'three')
mapping, extra_items, extra_tracks = \
match.assign_items(items, trackinfo)
self.assertEqual(extra_items, set())
self.assertEqual(extra_tracks, set([trackinfo[1]]))
self.assertEqual(mapping, {
items[0]: trackinfo[0],
items[1]: trackinfo[2],
})
def test_order_returns_none_for_extra_tracks(self):
def test_order_works_with_extra_tracks(self):
items = []
items.append(self.item('one', 1))
items.append(self.item('two', 2))
@ -412,10 +415,16 @@ class OrderingTest(unittest.TestCase):
trackinfo = []
trackinfo.append(TrackInfo('one', None))
trackinfo.append(TrackInfo('three', None))
ordered = match.order_items(items, trackinfo)
self.assertEqual(ordered, None)
mapping, extra_items, extra_tracks = \
match.assign_items(items, trackinfo)
self.assertEqual(extra_items, set([items[1]]))
self.assertEqual(extra_tracks, set())
self.assertEqual(mapping, {
items[0]: trackinfo[0],
items[2]: trackinfo[1],
})
def test_order_corrects_when_track_names_are_entirely_wrong(self):
def test_order_works_when_track_names_are_entirely_wrong(self):
# A real-world test case contributed by a user.
def item(i, length):
return Item({
@ -440,25 +449,28 @@ class OrderingTest(unittest.TestCase):
items.append(item(11, 243.57001238834192))
items.append(item(12, 186.45916150485752))
def info(title, length):
return TrackInfo(title, None, length=length)
def info(index, title, length):
return TrackInfo(title, None, length=length, index=index)
trackinfo = []
trackinfo.append(info('Alone', 238.893))
trackinfo.append(info('The Woman in You', 341.44))
trackinfo.append(info('Less', 245.59999999999999))
trackinfo.append(info('Two Hands of a Prayer', 470.49299999999999))
trackinfo.append(info('Please Bleed', 277.86599999999999))
trackinfo.append(info('Suzie Blue', 269.30599999999998))
trackinfo.append(info('Steal My Kisses', 245.36000000000001))
trackinfo.append(info('Burn to Shine', 214.90600000000001))
trackinfo.append(info('Show Me a Little Shame', 224.09299999999999))
trackinfo.append(info('Forgiven', 317.19999999999999))
trackinfo.append(info('Beloved One', 243.733))
trackinfo.append(info('In the Lord\'s Arms', 186.13300000000001))
trackinfo.append(info(1, 'Alone', 238.893))
trackinfo.append(info(2, 'The Woman in You', 341.44))
trackinfo.append(info(3, 'Less', 245.59999999999999))
trackinfo.append(info(4, 'Two Hands of a Prayer', 470.49299999999999))
trackinfo.append(info(5, 'Please Bleed', 277.86599999999999))
trackinfo.append(info(6, 'Suzie Blue', 269.30599999999998))
trackinfo.append(info(7, 'Steal My Kisses', 245.36000000000001))
trackinfo.append(info(8, 'Burn to Shine', 214.90600000000001))
trackinfo.append(info(9, 'Show Me a Little Shame', 224.09299999999999))
trackinfo.append(info(10, 'Forgiven', 317.19999999999999))
trackinfo.append(info(11, 'Beloved One', 243.733))
trackinfo.append(info(12, 'In the Lord\'s Arms', 186.13300000000001))
ordered = match.order_items(items, trackinfo)
for i, item in enumerate(ordered):
self.assertEqual(i+1, item.track)
mapping, extra_items, extra_tracks = \
match.assign_items(items, trackinfo)
self.assertEqual(extra_items, set())
self.assertEqual(extra_tracks, set())
for item, info in mapping.iteritems():
self.assertEqual(items.index(item), trackinfo.index(info))
class ApplyTest(unittest.TestCase):
def setUp(self):