diff --git a/test/autotag/test_distance.py b/test/autotag/test_distance.py index ec00ebcdf..e3ce9f891 100644 --- a/test/autotag/test_distance.py +++ b/test/autotag/test_distance.py @@ -1,176 +1,108 @@ import re -import unittest -from beets import config -from beets.autotag import AlbumInfo, TrackInfo, match -from beets.autotag.distance import Distance, string_dist +import pytest + +from beets.autotag import AlbumInfo, TrackInfo +from beets.autotag.distance import ( + Distance, + distance, + string_dist, + track_distance, +) from beets.library import Item -from beets.test.helper import BeetsTestCase +from beets.test.helper import ConfigMixin + +_p = pytest.param -def _make_item(title, track, artist="some artist"): - return Item( - title=title, - track=track, - artist=artist, - album="some album", - length=1, - mb_trackid="", - mb_albumid="", - mb_artistid="", - ) +class TestDistance: + @pytest.fixture(scope="class") + def config(self): + return ConfigMixin().config + @pytest.fixture + def dist(self, config): + config["match"]["distance_weights"]["source"] = 2.0 + config["match"]["distance_weights"]["album"] = 4.0 + config["match"]["distance_weights"]["medium"] = 2.0 -def _make_trackinfo(): - return [ - TrackInfo( - title="one", track_id=None, artist="some artist", length=1, index=1 - ), - TrackInfo( - title="two", track_id=None, artist="some artist", length=1, index=2 - ), - TrackInfo( - title="three", - track_id=None, - artist="some artist", - length=1, - index=3, - ), - ] + Distance.__dict__["_weights"].cache = {} + return Distance() -def _clear_weights(): - """Hack around the lazy descriptor used to cache weights for - Distance calculations. - """ - Distance.__dict__["_weights"].cache = {} - - -class DistanceTest(BeetsTestCase): - def tearDown(self): - super().tearDown() - _clear_weights() - - def test_add(self): - dist = Distance() + def test_add(self, dist): dist.add("add", 1.0) + assert dist._penalties == {"add": [1.0]} - def test_add_equality(self): - dist = Distance() - dist.add_equality("equality", "ghi", ["abc", "def", "ghi"]) - assert dist._penalties["equality"] == [0.0] + @pytest.mark.parametrize( + "key, args_with_expected", + [ + ( + "equality", + [ + (("ghi", ["abc", "def", "ghi"]), [0.0]), + (("xyz", ["abc", "def", "ghi"]), [0.0, 1.0]), + (("abc", re.compile(r"ABC", re.I)), [0.0, 1.0, 0.0]), + ], + ), + ("expr", [((True,), [1.0]), ((False,), [1.0, 0.0])]), + ( + "number", + [ + ((1, 1), [0.0]), + ((1, 2), [0.0, 1.0]), + ((2, 1), [0.0, 1.0, 1.0]), + ((-1, 2), [0.0, 1.0, 1.0, 1.0, 1.0, 1.0]), + ], + ), + ( + "priority", + [ + (("abc", "abc"), [0.0]), + (("def", ["abc", "def"]), [0.0, 0.5]), + (("gh", ["ab", "cd", "ef", re.compile("GH", re.I)]), [0.0, 0.5, 0.75]), # noqa: E501 + (("xyz", ["abc", "def"]), [0.0, 0.5, 0.75, 1.0]), + ], + ), + ( + "ratio", + [ + ((25, 100), [0.25]), + ((10, 5), [0.25, 1.0]), + ((-5, 5), [0.25, 1.0, 0.0]), + ((5, 0), [0.25, 1.0, 0.0, 0.0]), + ], + ), + ( + "string", + [ + (("abc", "bcd"), [2 / 3]), + (("abc", None), [2 / 3, 1]), + ((None, None), [2 / 3, 1, 0]), + ], + ), + ], + ) # fmt: skip + def test_add_methods(self, dist, key, args_with_expected): + method = getattr(dist, f"add_{key}") + for arg_set, expected in args_with_expected: + method(key, *arg_set) + assert dist._penalties[key] == expected - dist.add_equality("equality", "xyz", ["abc", "def", "ghi"]) - assert dist._penalties["equality"] == [0.0, 1.0] - - dist.add_equality("equality", "abc", re.compile(r"ABC", re.I)) - assert dist._penalties["equality"] == [0.0, 1.0, 0.0] - - def test_add_expr(self): - dist = Distance() - dist.add_expr("expr", True) - assert dist._penalties["expr"] == [1.0] - - dist.add_expr("expr", False) - assert dist._penalties["expr"] == [1.0, 0.0] - - def test_add_number(self): - dist = Distance() - # Add a full penalty for each number of difference between two numbers. - - dist.add_number("number", 1, 1) - assert dist._penalties["number"] == [0.0] - - dist.add_number("number", 1, 2) - assert dist._penalties["number"] == [0.0, 1.0] - - dist.add_number("number", 2, 1) - assert dist._penalties["number"] == [0.0, 1.0, 1.0] - - dist.add_number("number", -1, 2) - assert dist._penalties["number"] == [0.0, 1.0, 1.0, 1.0, 1.0, 1.0] - - def test_add_priority(self): - dist = Distance() - dist.add_priority("priority", "abc", "abc") - assert dist._penalties["priority"] == [0.0] - - dist.add_priority("priority", "def", ["abc", "def"]) - assert dist._penalties["priority"] == [0.0, 0.5] - - dist.add_priority( - "priority", "gh", ["ab", "cd", "ef", re.compile("GH", re.I)] - ) - assert dist._penalties["priority"] == [0.0, 0.5, 0.75] - - dist.add_priority("priority", "xyz", ["abc", "def"]) - assert dist._penalties["priority"] == [0.0, 0.5, 0.75, 1.0] - - def test_add_ratio(self): - dist = Distance() - dist.add_ratio("ratio", 25, 100) - assert dist._penalties["ratio"] == [0.25] - - dist.add_ratio("ratio", 10, 5) - assert dist._penalties["ratio"] == [0.25, 1.0] - - dist.add_ratio("ratio", -5, 5) - assert dist._penalties["ratio"] == [0.25, 1.0, 0.0] - - dist.add_ratio("ratio", 5, 0) - assert dist._penalties["ratio"] == [0.25, 1.0, 0.0, 0.0] - - def test_add_string(self): - dist = Distance() - sdist = string_dist("abc", "bcd") - dist.add_string("string", "abc", "bcd") - assert dist._penalties["string"] == [sdist] - assert dist._penalties["string"] != [0] - - def test_add_string_none(self): - dist = Distance() - dist.add_string("string", None, "string") - assert dist._penalties["string"] == [1] - - def test_add_string_both_none(self): - dist = Distance() - dist.add_string("string", None, None) - assert dist._penalties["string"] == [0] - - def test_distance(self): - config["match"]["distance_weights"]["album"] = 2.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() + def test_distance(self, dist): dist.add("album", 0.5) dist.add("media", 0.25) dist.add("media", 0.75) + assert dist.distance == 0.5 + assert dist.max_distance == 6.0 + assert dist.raw_distance == 3.0 - # __getitem__() - assert dist["album"] == 0.25 - assert dist["media"] == 0.25 + assert dist["album"] == 1 / 3 + assert dist["media"] == 1 / 6 - def test_max_distance(self): - config["match"]["distance_weights"]["album"] = 3.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() - dist.add("album", 0.5) - dist.add("medium", 0.0) - dist.add("medium", 0.0) - assert dist.max_distance == 5.0 - - def test_operators(self): - config["match"]["distance_weights"]["source"] = 1.0 - config["match"]["distance_weights"]["album"] = 2.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() + def test_operators(self, dist): dist.add("source", 0.0) dist.add("album", 0.5) dist.add("medium", 0.25) @@ -184,23 +116,7 @@ class DistanceTest(BeetsTestCase): assert 0.4 - dist == 0.0 assert float(dist) == 0.4 - def test_raw_distance(self): - config["match"]["distance_weights"]["album"] = 3.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() - dist.add("album", 0.5) - dist.add("medium", 0.25) - dist.add("medium", 0.5) - assert dist.raw_distance == 2.25 - - def test_items(self): - config["match"]["distance_weights"]["album"] = 4.0 - config["match"]["distance_weights"]["medium"] = 2.0 - _clear_weights() - - dist = Distance() + def test_penalties_sort(self, dist): dist.add("album", 0.1875) dist.add("medium", 0.75) assert dist.items() == [("medium", 0.25), ("album", 0.125)] @@ -211,8 +127,8 @@ class DistanceTest(BeetsTestCase): dist.add("medium", 0.75) assert dist.items() == [("album", 0.25), ("medium", 0.25)] - def test_update(self): - dist1 = Distance() + def test_update(self, dist): + dist1 = dist dist1.add("album", 0.5) dist1.add("media", 1.0) @@ -229,248 +145,155 @@ class DistanceTest(BeetsTestCase): } -class TrackDistanceTest(BeetsTestCase): - def test_identical_tracks(self): - item = _make_item("one", 1) - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist == 0.0 +class TestTrackDistance: + @pytest.fixture(scope="class") + def info(self): + return TrackInfo(title="title", artist="artist") - def test_different_title(self): - item = _make_item("foo", 1) - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist != 0.0 + @pytest.mark.parametrize( + "title, artist, expected_penalty", + [ + _p("title", "artist", False, id="identical"), + _p("title", "Various Artists", False, id="tolerate-va"), + _p("title", "different artist", True, id="different-artist"), + _p("different title", "artist", True, id="different-title"), + ], + ) + def test_track_distance(self, info, title, artist, expected_penalty): + item = Item(artist=artist, title=title) - def test_different_artist(self): - item = _make_item("one", 1) - item.artist = "foo" - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist != 0.0 - - def test_various_artists_tolerated(self): - item = _make_item("one", 1) - item.artist = "Various Artists" - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist == 0.0 + assert ( + bool(track_distance(item, info, incl_artist=True)) + == expected_penalty + ) -class AlbumDistanceTest(BeetsTestCase): - def _mapping(self, items, info): - out = {} - for i, t in zip(items, info.tracks): - out[i] = t - return out +class TestAlbumDistance: + @pytest.fixture(scope="class") + def items(self): + return [ + Item( + title=title, + track=track, + artist="artist", + album="album", + length=1, + ) + for title, track in [("one", 1), ("two", 2), ("three", 3)] + ] - def _dist(self, items, info): - return match.distance(items, info, self._mapping(items, info)) + @pytest.fixture + def get_dist(self, items): + def inner(info: AlbumInfo): + return distance(items, info, dict(zip(items, info.tracks))) - def test_identical_albums(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), + return inner + + @pytest.fixture + def info(self, items): + return AlbumInfo( + artist="artist", + album="album", + tracks=[ + TrackInfo( + title=i.title, + artist=i.artist, + index=i.track, + length=i.length, + ) + for i in items + ], va=False, ) - assert self._dist(items, info) == 0 - def test_incomplete_album(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, + def test_identical_albums(self, get_dist, info): + assert get_dist(info) == 0 + + def test_incomplete_album(self, get_dist, info): + info.tracks.pop(2) + + assert 0 < float(get_dist(info)) < 0.2 + + def test_overly_complete_album(self, get_dist, info): + info.tracks.append( + Item(index=4, title="four", artist="artist", length=1) ) - dist = self._dist(items, info) - assert dist != 0 - # Make sure the distance is not too great - assert dist < 0.2 - def test_global_artists_differ(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="someone else", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - assert self._dist(items, info) != 0 + assert 0 < float(get_dist(info)) < 0.2 - def test_comp_track_artists_match(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="should be ignored", - album="some album", - tracks=_make_trackinfo(), - va=True, - ) - assert self._dist(items, info) == 0 + @pytest.mark.parametrize("va", [True, False]) + def test_albumartist(self, get_dist, info, va): + info.artist = "another artist" + info.va = va - def test_comp_no_track_artists(self): + assert bool(get_dist(info)) is not va + + def test_comp_no_track_artists(self, get_dist, info): # Some VA releases don't have track artists (incomplete metadata). - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="should be ignored", - album="some album", - tracks=_make_trackinfo(), - va=True, - ) - info.tracks[0].artist = None - info.tracks[1].artist = None - info.tracks[2].artist = None - assert self._dist(items, info) == 0 + info.artist = "another artist" + info.va = True + for track in info.tracks: + track.artist = None - def test_comp_track_artists_do_not_match(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2, "someone else")) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=True, - ) - assert self._dist(items, info) != 0 + assert get_dist(info) == 0 - def test_tracks_out_of_order(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("three", 2)) - items.append(_make_item("two", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - dist = self._dist(items, info) - assert 0 < dist < 0.2 + def test_comp_track_artists_do_not_match(self, get_dist, info): + info.va = True + info.tracks[0].artist = "another artist" - def test_two_medium_release(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) + assert get_dist(info) != 0 + + def test_tracks_out_of_order(self, get_dist, info): + tracks = info.tracks + tracks[1].title, tracks[2].title = tracks[2].title, tracks[1].title + + assert 0 < float(get_dist(info)) < 0.2 + + def test_two_medium_release(self, get_dist, info): info.tracks[0].medium_index = 1 info.tracks[1].medium_index = 2 info.tracks[2].medium_index = 1 - dist = self._dist(items, info) - assert dist == 0 - def test_per_medium_track_numbers(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 1)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - info.tracks[0].medium_index = 1 - info.tracks[1].medium_index = 2 - info.tracks[2].medium_index = 1 - dist = self._dist(items, info) - assert dist == 0 + assert get_dist(info) == 0 -class StringDistanceTest(unittest.TestCase): - def test_equal_strings(self): - dist = string_dist("Some String", "Some String") - assert dist == 0.0 +class TestStringDistance: + @pytest.mark.parametrize( + "string1, string2", + [ + ("Some String", "Some String"), + ("Some String", "Some.String!"), + ("Some String", "sOME sTring"), + ("My Song (EP)", "My Song"), + ("The Song Title", "Song Title, The"), + ("A Song Title", "Song Title, A"), + ("An Album Title", "Album Title, An"), + ("", ""), + ("Untitled", "[Untitled]"), + ("And", "&"), + ("\xe9\xe1\xf1", "ean"), + ], + ) + def test_matching_distance(self, string1, string2): + assert string_dist(string1, string2) == 0.0 - def test_different_strings(self): - dist = string_dist("Some String", "Totally Different") - assert dist != 0.0 + def test_different_distance(self): + assert string_dist("Some String", "Totally Different") != 0.0 - def test_punctuation_ignored(self): - dist = string_dist("Some String", "Some.String!") - assert dist == 0.0 - - def test_case_ignored(self): - dist = string_dist("Some String", "sOME sTring") - assert dist == 0.0 - - def test_leading_the_has_lower_weight(self): - dist1 = string_dist("XXX Band Name", "Band Name") - dist2 = string_dist("The Band Name", "Band Name") - assert dist2 < dist1 - - def test_parens_have_lower_weight(self): - dist1 = string_dist("One .Two.", "One") - dist2 = string_dist("One (Two)", "One") - assert dist2 < dist1 - - def test_brackets_have_lower_weight(self): - dist1 = string_dist("One .Two.", "One") - dist2 = string_dist("One [Two]", "One") - assert dist2 < dist1 - - def test_ep_label_has_zero_weight(self): - dist = string_dist("My Song (EP)", "My Song") - assert dist == 0.0 - - def test_featured_has_lower_weight(self): - dist1 = string_dist("My Song blah Someone", "My Song") - dist2 = string_dist("My Song feat Someone", "My Song") - assert dist2 < dist1 - - def test_postfix_the(self): - dist = string_dist("The Song Title", "Song Title, The") - assert dist == 0.0 - - def test_postfix_a(self): - dist = string_dist("A Song Title", "Song Title, A") - assert dist == 0.0 - - def test_postfix_an(self): - dist = string_dist("An Album Title", "Album Title, An") - assert dist == 0.0 - - def test_empty_strings(self): - dist = string_dist("", "") - assert dist == 0.0 + @pytest.mark.parametrize( + "string1, string2, reference", + [ + ("XXX Band Name", "The Band Name", "Band Name"), + ("One .Two.", "One (Two)", "One"), + ("One .Two.", "One [Two]", "One"), + ("My Song blah Someone", "My Song feat Someone", "My Song"), + ], + ) + def test_relative_weights(self, string1, string2, reference): + assert string_dist(string2, reference) < string_dist(string1, reference) def test_solo_pattern(self): # Just make sure these don't crash. string_dist("The ", "") string_dist("(EP)", "(EP)") string_dist(", An", "") - - def test_heuristic_does_not_harm_distance(self): - dist = string_dist("Untitled", "[Untitled]") - assert dist == 0.0 - - def test_ampersand_expansion(self): - dist = string_dist("And", "&") - assert dist == 0.0 - - def test_accented_characters(self): - dist = string_dist("\xe9\xe1\xf1", "ean") - assert dist == 0.0