Rewrite distance tests

This commit is contained in:
Šarūnas Nejus 2025-05-25 13:43:20 +01:00
parent adbd50b237
commit 318a840af2
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435

View file

@ -1,176 +1,108 @@
import re
import unittest
from beets import config
from beets.autotag import AlbumInfo, TrackInfo, match
from beets.autotag.distance import Distance, string_dist
import pytest
from beets.autotag import AlbumInfo, TrackInfo
from beets.autotag.distance import (
Distance,
distance,
string_dist,
track_distance,
)
from beets.library import Item
from beets.test.helper import BeetsTestCase
from beets.test.helper import ConfigMixin
_p = pytest.param
def _make_item(title, track, artist="some artist"):
return Item(
title=title,
track=track,
artist=artist,
album="some album",
length=1,
mb_trackid="",
mb_albumid="",
mb_artistid="",
)
class TestDistance:
@pytest.fixture(scope="class")
def config(self):
return ConfigMixin().config
@pytest.fixture
def dist(self, config):
config["match"]["distance_weights"]["source"] = 2.0
config["match"]["distance_weights"]["album"] = 4.0
config["match"]["distance_weights"]["medium"] = 2.0
def _make_trackinfo():
return [
TrackInfo(
title="one", track_id=None, artist="some artist", length=1, index=1
),
TrackInfo(
title="two", track_id=None, artist="some artist", length=1, index=2
),
TrackInfo(
title="three",
track_id=None,
artist="some artist",
length=1,
index=3,
),
]
Distance.__dict__["_weights"].cache = {}
return Distance()
def _clear_weights():
"""Hack around the lazy descriptor used to cache weights for
Distance calculations.
"""
Distance.__dict__["_weights"].cache = {}
class DistanceTest(BeetsTestCase):
def tearDown(self):
super().tearDown()
_clear_weights()
def test_add(self):
dist = Distance()
def test_add(self, dist):
dist.add("add", 1.0)
assert dist._penalties == {"add": [1.0]}
def test_add_equality(self):
dist = Distance()
dist.add_equality("equality", "ghi", ["abc", "def", "ghi"])
assert dist._penalties["equality"] == [0.0]
@pytest.mark.parametrize(
"key, args_with_expected",
[
(
"equality",
[
(("ghi", ["abc", "def", "ghi"]), [0.0]),
(("xyz", ["abc", "def", "ghi"]), [0.0, 1.0]),
(("abc", re.compile(r"ABC", re.I)), [0.0, 1.0, 0.0]),
],
),
("expr", [((True,), [1.0]), ((False,), [1.0, 0.0])]),
(
"number",
[
((1, 1), [0.0]),
((1, 2), [0.0, 1.0]),
((2, 1), [0.0, 1.0, 1.0]),
((-1, 2), [0.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
],
),
(
"priority",
[
(("abc", "abc"), [0.0]),
(("def", ["abc", "def"]), [0.0, 0.5]),
(("gh", ["ab", "cd", "ef", re.compile("GH", re.I)]), [0.0, 0.5, 0.75]), # noqa: E501
(("xyz", ["abc", "def"]), [0.0, 0.5, 0.75, 1.0]),
],
),
(
"ratio",
[
((25, 100), [0.25]),
((10, 5), [0.25, 1.0]),
((-5, 5), [0.25, 1.0, 0.0]),
((5, 0), [0.25, 1.0, 0.0, 0.0]),
],
),
(
"string",
[
(("abc", "bcd"), [2 / 3]),
(("abc", None), [2 / 3, 1]),
((None, None), [2 / 3, 1, 0]),
],
),
],
) # fmt: skip
def test_add_methods(self, dist, key, args_with_expected):
method = getattr(dist, f"add_{key}")
for arg_set, expected in args_with_expected:
method(key, *arg_set)
assert dist._penalties[key] == expected
dist.add_equality("equality", "xyz", ["abc", "def", "ghi"])
assert dist._penalties["equality"] == [0.0, 1.0]
dist.add_equality("equality", "abc", re.compile(r"ABC", re.I))
assert dist._penalties["equality"] == [0.0, 1.0, 0.0]
def test_add_expr(self):
dist = Distance()
dist.add_expr("expr", True)
assert dist._penalties["expr"] == [1.0]
dist.add_expr("expr", False)
assert dist._penalties["expr"] == [1.0, 0.0]
def test_add_number(self):
dist = Distance()
# Add a full penalty for each number of difference between two numbers.
dist.add_number("number", 1, 1)
assert dist._penalties["number"] == [0.0]
dist.add_number("number", 1, 2)
assert dist._penalties["number"] == [0.0, 1.0]
dist.add_number("number", 2, 1)
assert dist._penalties["number"] == [0.0, 1.0, 1.0]
dist.add_number("number", -1, 2)
assert dist._penalties["number"] == [0.0, 1.0, 1.0, 1.0, 1.0, 1.0]
def test_add_priority(self):
dist = Distance()
dist.add_priority("priority", "abc", "abc")
assert dist._penalties["priority"] == [0.0]
dist.add_priority("priority", "def", ["abc", "def"])
assert dist._penalties["priority"] == [0.0, 0.5]
dist.add_priority(
"priority", "gh", ["ab", "cd", "ef", re.compile("GH", re.I)]
)
assert dist._penalties["priority"] == [0.0, 0.5, 0.75]
dist.add_priority("priority", "xyz", ["abc", "def"])
assert dist._penalties["priority"] == [0.0, 0.5, 0.75, 1.0]
def test_add_ratio(self):
dist = Distance()
dist.add_ratio("ratio", 25, 100)
assert dist._penalties["ratio"] == [0.25]
dist.add_ratio("ratio", 10, 5)
assert dist._penalties["ratio"] == [0.25, 1.0]
dist.add_ratio("ratio", -5, 5)
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0]
dist.add_ratio("ratio", 5, 0)
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0, 0.0]
def test_add_string(self):
dist = Distance()
sdist = string_dist("abc", "bcd")
dist.add_string("string", "abc", "bcd")
assert dist._penalties["string"] == [sdist]
assert dist._penalties["string"] != [0]
def test_add_string_none(self):
dist = Distance()
dist.add_string("string", None, "string")
assert dist._penalties["string"] == [1]
def test_add_string_both_none(self):
dist = Distance()
dist.add_string("string", None, None)
assert dist._penalties["string"] == [0]
def test_distance(self):
config["match"]["distance_weights"]["album"] = 2.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
def test_distance(self, dist):
dist.add("album", 0.5)
dist.add("media", 0.25)
dist.add("media", 0.75)
assert dist.distance == 0.5
assert dist.max_distance == 6.0
assert dist.raw_distance == 3.0
# __getitem__()
assert dist["album"] == 0.25
assert dist["media"] == 0.25
assert dist["album"] == 1 / 3
assert dist["media"] == 1 / 6
def test_max_distance(self):
config["match"]["distance_weights"]["album"] = 3.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("medium", 0.0)
dist.add("medium", 0.0)
assert dist.max_distance == 5.0
def test_operators(self):
config["match"]["distance_weights"]["source"] = 1.0
config["match"]["distance_weights"]["album"] = 2.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
def test_operators(self, dist):
dist.add("source", 0.0)
dist.add("album", 0.5)
dist.add("medium", 0.25)
@ -184,23 +116,7 @@ class DistanceTest(BeetsTestCase):
assert 0.4 - dist == 0.0
assert float(dist) == 0.4
def test_raw_distance(self):
config["match"]["distance_weights"]["album"] = 3.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("medium", 0.25)
dist.add("medium", 0.5)
assert dist.raw_distance == 2.25
def test_items(self):
config["match"]["distance_weights"]["album"] = 4.0
config["match"]["distance_weights"]["medium"] = 2.0
_clear_weights()
dist = Distance()
def test_penalties_sort(self, dist):
dist.add("album", 0.1875)
dist.add("medium", 0.75)
assert dist.items() == [("medium", 0.25), ("album", 0.125)]
@ -211,8 +127,8 @@ class DistanceTest(BeetsTestCase):
dist.add("medium", 0.75)
assert dist.items() == [("album", 0.25), ("medium", 0.25)]
def test_update(self):
dist1 = Distance()
def test_update(self, dist):
dist1 = dist
dist1.add("album", 0.5)
dist1.add("media", 1.0)
@ -229,248 +145,155 @@ class DistanceTest(BeetsTestCase):
}
class TrackDistanceTest(BeetsTestCase):
def test_identical_tracks(self):
item = _make_item("one", 1)
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist == 0.0
class TestTrackDistance:
@pytest.fixture(scope="class")
def info(self):
return TrackInfo(title="title", artist="artist")
def test_different_title(self):
item = _make_item("foo", 1)
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist != 0.0
@pytest.mark.parametrize(
"title, artist, expected_penalty",
[
_p("title", "artist", False, id="identical"),
_p("title", "Various Artists", False, id="tolerate-va"),
_p("title", "different artist", True, id="different-artist"),
_p("different title", "artist", True, id="different-title"),
],
)
def test_track_distance(self, info, title, artist, expected_penalty):
item = Item(artist=artist, title=title)
def test_different_artist(self):
item = _make_item("one", 1)
item.artist = "foo"
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist != 0.0
def test_various_artists_tolerated(self):
item = _make_item("one", 1)
item.artist = "Various Artists"
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist == 0.0
assert (
bool(track_distance(item, info, incl_artist=True))
== expected_penalty
)
class AlbumDistanceTest(BeetsTestCase):
def _mapping(self, items, info):
out = {}
for i, t in zip(items, info.tracks):
out[i] = t
return out
class TestAlbumDistance:
@pytest.fixture(scope="class")
def items(self):
return [
Item(
title=title,
track=track,
artist="artist",
album="album",
length=1,
)
for title, track in [("one", 1), ("two", 2), ("three", 3)]
]
def _dist(self, items, info):
return match.distance(items, info, self._mapping(items, info))
@pytest.fixture
def get_dist(self, items):
def inner(info: AlbumInfo):
return distance(items, info, dict(zip(items, info.tracks)))
def test_identical_albums(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
return inner
@pytest.fixture
def info(self, items):
return AlbumInfo(
artist="artist",
album="album",
tracks=[
TrackInfo(
title=i.title,
artist=i.artist,
index=i.track,
length=i.length,
)
for i in items
],
va=False,
)
assert self._dist(items, info) == 0
def test_incomplete_album(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
def test_identical_albums(self, get_dist, info):
assert get_dist(info) == 0
def test_incomplete_album(self, get_dist, info):
info.tracks.pop(2)
assert 0 < float(get_dist(info)) < 0.2
def test_overly_complete_album(self, get_dist, info):
info.tracks.append(
Item(index=4, title="four", artist="artist", length=1)
)
dist = self._dist(items, info)
assert dist != 0
# Make sure the distance is not too great
assert dist < 0.2
def test_global_artists_differ(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="someone else",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
assert self._dist(items, info) != 0
assert 0 < float(get_dist(info)) < 0.2
def test_comp_track_artists_match(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="should be ignored",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
assert self._dist(items, info) == 0
@pytest.mark.parametrize("va", [True, False])
def test_albumartist(self, get_dist, info, va):
info.artist = "another artist"
info.va = va
def test_comp_no_track_artists(self):
assert bool(get_dist(info)) is not va
def test_comp_no_track_artists(self, get_dist, info):
# Some VA releases don't have track artists (incomplete metadata).
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="should be ignored",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
info.tracks[0].artist = None
info.tracks[1].artist = None
info.tracks[2].artist = None
assert self._dist(items, info) == 0
info.artist = "another artist"
info.va = True
for track in info.tracks:
track.artist = None
def test_comp_track_artists_do_not_match(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2, "someone else"))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
assert self._dist(items, info) != 0
assert get_dist(info) == 0
def test_tracks_out_of_order(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("three", 2))
items.append(_make_item("two", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
dist = self._dist(items, info)
assert 0 < dist < 0.2
def test_comp_track_artists_do_not_match(self, get_dist, info):
info.va = True
info.tracks[0].artist = "another artist"
def test_two_medium_release(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
assert get_dist(info) != 0
def test_tracks_out_of_order(self, get_dist, info):
tracks = info.tracks
tracks[1].title, tracks[2].title = tracks[2].title, tracks[1].title
assert 0 < float(get_dist(info)) < 0.2
def test_two_medium_release(self, get_dist, info):
info.tracks[0].medium_index = 1
info.tracks[1].medium_index = 2
info.tracks[2].medium_index = 1
dist = self._dist(items, info)
assert dist == 0
def test_per_medium_track_numbers(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 1))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
info.tracks[0].medium_index = 1
info.tracks[1].medium_index = 2
info.tracks[2].medium_index = 1
dist = self._dist(items, info)
assert dist == 0
assert get_dist(info) == 0
class StringDistanceTest(unittest.TestCase):
def test_equal_strings(self):
dist = string_dist("Some String", "Some String")
assert dist == 0.0
class TestStringDistance:
@pytest.mark.parametrize(
"string1, string2",
[
("Some String", "Some String"),
("Some String", "Some.String!"),
("Some String", "sOME sTring"),
("My Song (EP)", "My Song"),
("The Song Title", "Song Title, The"),
("A Song Title", "Song Title, A"),
("An Album Title", "Album Title, An"),
("", ""),
("Untitled", "[Untitled]"),
("And", "&"),
("\xe9\xe1\xf1", "ean"),
],
)
def test_matching_distance(self, string1, string2):
assert string_dist(string1, string2) == 0.0
def test_different_strings(self):
dist = string_dist("Some String", "Totally Different")
assert dist != 0.0
def test_different_distance(self):
assert string_dist("Some String", "Totally Different") != 0.0
def test_punctuation_ignored(self):
dist = string_dist("Some String", "Some.String!")
assert dist == 0.0
def test_case_ignored(self):
dist = string_dist("Some String", "sOME sTring")
assert dist == 0.0
def test_leading_the_has_lower_weight(self):
dist1 = string_dist("XXX Band Name", "Band Name")
dist2 = string_dist("The Band Name", "Band Name")
assert dist2 < dist1
def test_parens_have_lower_weight(self):
dist1 = string_dist("One .Two.", "One")
dist2 = string_dist("One (Two)", "One")
assert dist2 < dist1
def test_brackets_have_lower_weight(self):
dist1 = string_dist("One .Two.", "One")
dist2 = string_dist("One [Two]", "One")
assert dist2 < dist1
def test_ep_label_has_zero_weight(self):
dist = string_dist("My Song (EP)", "My Song")
assert dist == 0.0
def test_featured_has_lower_weight(self):
dist1 = string_dist("My Song blah Someone", "My Song")
dist2 = string_dist("My Song feat Someone", "My Song")
assert dist2 < dist1
def test_postfix_the(self):
dist = string_dist("The Song Title", "Song Title, The")
assert dist == 0.0
def test_postfix_a(self):
dist = string_dist("A Song Title", "Song Title, A")
assert dist == 0.0
def test_postfix_an(self):
dist = string_dist("An Album Title", "Album Title, An")
assert dist == 0.0
def test_empty_strings(self):
dist = string_dist("", "")
assert dist == 0.0
@pytest.mark.parametrize(
"string1, string2, reference",
[
("XXX Band Name", "The Band Name", "Band Name"),
("One .Two.", "One (Two)", "One"),
("One .Two.", "One [Two]", "One"),
("My Song blah Someone", "My Song feat Someone", "My Song"),
],
)
def test_relative_weights(self, string1, string2, reference):
assert string_dist(string2, reference) < string_dist(string1, reference)
def test_solo_pattern(self):
# Just make sure these don't crash.
string_dist("The ", "")
string_dist("(EP)", "(EP)")
string_dist(", An", "")
def test_heuristic_does_not_harm_distance(self):
dist = string_dist("Untitled", "[Untitled]")
assert dist == 0.0
def test_ampersand_expansion(self):
dist = string_dist("And", "&")
assert dist == 0.0
def test_accented_characters(self):
dist = string_dist("\xe9\xe1\xf1", "ean")
assert dist == 0.0