Fix track matching

I had previously tested the `munkres` -> `lapjv` replacement
extensively, so I was today surprised to find that nothing gets matched
correctly when I tried importing some new tracks.

On the other hand I now remember making a small adjustment in the logic
to make autotagging tests pass which is when I introduced a bug: I did
not realize that `lapjv` returns index '-1' for each unmatched item.

This issue did not get caught by tests because this 'unmatched' item
index '-1' anecdotally ended up pointing to the last (expected) item in
the test making it pass.

This commit adjusts the aforementioned test to catch this issue and
fixes the logic to correctly identify unmatched tracks.
This commit is contained in:
Šarūnas Nejus 2024-12-31 00:08:10 +00:00
parent faf7529aa9
commit 0d6393e712
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
2 changed files with 14 additions and 8 deletions

View file

@ -127,15 +127,21 @@ def assign_items(
objects. These "extra" objects occur when there is an unequal number objects. These "extra" objects occur when there is an unequal number
of objects of the two types. of objects of the two types.
""" """
log.debug("Computing track assignment...")
# Construct the cost matrix. # Construct the cost matrix.
costs = [[float(track_distance(i, t)) for t in tracks] for i in items] costs = [[float(track_distance(i, t)) for t in tracks] for i in items]
# Find a minimum-cost bipartite matching. # Assign items to tracks
log.debug("Computing track assignment...") _, _, assigned_item_idxs = lap.lapjv(np.array(costs), extend_cost=True)
cost, _, assigned_idxs = lap.lapjv(np.array(costs), extend_cost=True)
log.debug("...done.") log.debug("...done.")
# Produce the output matching. # Each item in `assigned_item_idxs` list corresponds to a track in the
mapping = {items[i]: tracks[t] for (t, i) in enumerate(assigned_idxs)} # `tracks` list. Each value is either an index into the assigned item in
# `items` list, or -1 if that track has no match.
mapping = {
items[iidx]: t
for iidx, t in zip(assigned_item_idxs, tracks)
if iidx != -1
}
extra_items = list(set(items) - mapping.keys()) extra_items = list(set(items) - mapping.keys())
extra_items.sort(key=lambda i: (i.disc, i.track, i.title)) extra_items.sort(key=lambda i: (i.disc, i.track, i.title))
extra_tracks = list(set(tracks) - set(mapping.values())) extra_tracks = list(set(tracks) - set(mapping.values()))

View file

@ -551,7 +551,7 @@ class AssignmentTest(unittest.TestCase):
def test_order_works_with_missing_tracks(self): def test_order_works_with_missing_tracks(self):
items = [] items = []
items.append(self.item("one", 1)) items.append(self.item("one", 1))
items.append(self.item("three", 3)) items.append(self.item("two", 2))
trackinfo = [] trackinfo = []
trackinfo.append(TrackInfo(title="one")) trackinfo.append(TrackInfo(title="one"))
trackinfo.append(TrackInfo(title="two")) trackinfo.append(TrackInfo(title="two"))
@ -560,8 +560,8 @@ class AssignmentTest(unittest.TestCase):
items, trackinfo items, trackinfo
) )
assert extra_items == [] assert extra_items == []
assert extra_tracks == [trackinfo[1]] assert extra_tracks == [trackinfo[2]]
assert mapping == {items[0]: trackinfo[0], items[1]: trackinfo[2]} assert mapping == {items[0]: trackinfo[0], items[1]: trackinfo[1]}
def test_order_works_with_extra_tracks(self): def test_order_works_with_extra_tracks(self):
items = [] items = []