From 509cbdcbe472aa3cf6559f6fba3eb2e6c9dcf47d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sat, 24 May 2025 15:16:02 +0100 Subject: [PATCH 1/8] Move sanitize_pairs/choices from plugins to util module --- beets/plugins.py | 60 ------------------------------------ beets/util/config.py | 66 ++++++++++++++++++++++++++++++++++++++++ beetsplug/fetchart.py | 3 +- beetsplug/lyrics.py | 3 +- test/test_plugins.py | 10 ------ test/util/test_config.py | 15 +++++++++ 6 files changed, 85 insertions(+), 72 deletions(-) create mode 100644 beets/util/config.py create mode 100644 test/util/test_config.py diff --git a/beets/plugins.py b/beets/plugins.py index d87dd5d1e..6d3a8447e 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -654,66 +654,6 @@ def feat_tokens(for_artist: bool = True) -> str: ) -def sanitize_choices( - choices: Sequence[str], choices_all: Sequence[str] -) -> list[str]: - """Clean up a stringlist configuration attribute: keep only choices - elements present in choices_all, remove duplicate elements, expand '*' - wildcard while keeping original stringlist order. - """ - seen: set[str] = set() - others = [x for x in choices_all if x not in choices] - res: list[str] = [] - for s in choices: - if s not in seen: - if s in list(choices_all): - res.append(s) - elif s == "*": - res.extend(others) - seen.add(s) - return res - - -def sanitize_pairs( - pairs: Sequence[tuple[str, str]], pairs_all: Sequence[tuple[str, str]] -) -> list[tuple[str, str]]: - """Clean up a single-element mapping configuration attribute as returned - by Confuse's `Pairs` template: keep only two-element tuples present in - pairs_all, remove duplicate elements, expand ('str', '*') and ('*', '*') - wildcards while keeping the original order. Note that ('*', '*') and - ('*', 'whatever') have the same effect. - - For example, - - >>> sanitize_pairs( - ... [('foo', 'baz bar'), ('key', '*'), ('*', '*')], - ... [('foo', 'bar'), ('foo', 'baz'), ('foo', 'foobar'), - ... ('key', 'value')] - ... ) - [('foo', 'baz'), ('foo', 'bar'), ('key', 'value'), ('foo', 'foobar')] - """ - pairs_all = list(pairs_all) - seen: set[tuple[str, str]] = set() - others = [x for x in pairs_all if x not in pairs] - res: list[tuple[str, str]] = [] - for k, values in pairs: - for v in values.split(): - x = (k, v) - if x in pairs_all: - if x not in seen: - seen.add(x) - res.append(x) - elif k == "*": - new = [o for o in others if o not in seen] - seen.update(new) - res.extend(new) - elif v == "*": - new = [o for o in others if o not in seen and o[0] == k] - seen.update(new) - res.extend(new) - return res - - def get_distance( config: ConfigView, data_source: str, info: AlbumInfo | TrackInfo ) -> Distance: diff --git a/beets/util/config.py b/beets/util/config.py new file mode 100644 index 000000000..218a9d133 --- /dev/null +++ b/beets/util/config.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Collection, Sequence + + +def sanitize_choices( + choices: Sequence[str], choices_all: Collection[str] +) -> list[str]: + """Clean up a stringlist configuration attribute: keep only choices + elements present in choices_all, remove duplicate elements, expand '*' + wildcard while keeping original stringlist order. + """ + seen: set[str] = set() + others = [x for x in choices_all if x not in choices] + res: list[str] = [] + for s in choices: + if s not in seen: + if s in list(choices_all): + res.append(s) + elif s == "*": + res.extend(others) + seen.add(s) + return res + + +def sanitize_pairs( + pairs: Sequence[tuple[str, str]], pairs_all: Sequence[tuple[str, str]] +) -> list[tuple[str, str]]: + """Clean up a single-element mapping configuration attribute as returned + by Confuse's `Pairs` template: keep only two-element tuples present in + pairs_all, remove duplicate elements, expand ('str', '*') and ('*', '*') + wildcards while keeping the original order. Note that ('*', '*') and + ('*', 'whatever') have the same effect. + + For example, + + >>> sanitize_pairs( + ... [('foo', 'baz bar'), ('key', '*'), ('*', '*')], + ... [('foo', 'bar'), ('foo', 'baz'), ('foo', 'foobar'), + ... ('key', 'value')] + ... ) + [('foo', 'baz'), ('foo', 'bar'), ('key', 'value'), ('foo', 'foobar')] + """ + pairs_all = list(pairs_all) + seen: set[tuple[str, str]] = set() + others = [x for x in pairs_all if x not in pairs] + res: list[tuple[str, str]] = [] + for k, values in pairs: + for v in values.split(): + x = (k, v) + if x in pairs_all: + if x not in seen: + seen.add(x) + res.append(x) + elif k == "*": + new = [o for o in others if o not in seen] + seen.update(new) + res.extend(new) + elif v == "*": + new = [o for o in others if o not in seen and o[0] == k] + seen.update(new) + res.extend(new) + return res diff --git a/beetsplug/fetchart.py b/beetsplug/fetchart.py index 3473fe08b..b442633da 100644 --- a/beetsplug/fetchart.py +++ b/beetsplug/fetchart.py @@ -32,6 +32,7 @@ from mediafile import image_mime_type from beets import config, importer, plugins, ui, util from beets.util import bytestring_path, get_temp_filename, sorted_walk, syspath from beets.util.artresizer import ArtResizer +from beets.util.config import sanitize_pairs if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Sequence @@ -1396,7 +1397,7 @@ class FetchArtPlugin(plugins.BeetsPlugin, RequestMixin): if s_cls.available(self._log, self.config) for c in s_cls.VALID_MATCHING_CRITERIA ] - sources = plugins.sanitize_pairs( + sources = sanitize_pairs( self.config["sources"].as_pairs(default_value="*"), available_sources, ) diff --git a/beetsplug/lyrics.py b/beetsplug/lyrics.py index 3e979221c..e2c0c7fd2 100644 --- a/beetsplug/lyrics.py +++ b/beetsplug/lyrics.py @@ -39,6 +39,7 @@ from unidecode import unidecode import beets from beets import plugins, ui from beets.autotag.hooks import string_dist +from beets.util.config import sanitize_choices if TYPE_CHECKING: from logging import Logger @@ -957,7 +958,7 @@ class LyricsPlugin(RequestHandler, plugins.BeetsPlugin): def backends(self) -> list[Backend]: user_sources = self.config["sources"].get() - chosen = plugins.sanitize_choices(user_sources, self.BACKEND_BY_NAME) + chosen = sanitize_choices(user_sources, self.BACKEND_BY_NAME) if "google" in chosen and not self.config["google_API_key"].get(): self.warn("Disabling Google source: no API key configured.") chosen.remove("google") diff --git a/test/test_plugins.py b/test/test_plugins.py index 3e809e492..207522430 100644 --- a/test/test_plugins.py +++ b/test/test_plugins.py @@ -15,7 +15,6 @@ import itertools import os -import unittest from unittest.mock import ANY, Mock, patch import pytest @@ -215,15 +214,6 @@ class EventsTest(PluginImportTestCase): ] -class HelpersTest(unittest.TestCase): - def test_sanitize_choices(self): - assert plugins.sanitize_choices(["A", "Z"], ("A", "B")) == ["A"] - assert plugins.sanitize_choices(["A", "A"], ("A")) == ["A"] - assert plugins.sanitize_choices( - ["D", "*", "A"], ("A", "B", "C", "D") - ) == ["D", "B", "C", "A"] - - class ListenersTest(PluginLoaderTestCase): def test_register(self): class DummyPlugin(plugins.BeetsPlugin): diff --git a/test/util/test_config.py b/test/util/test_config.py new file mode 100644 index 000000000..0c49f85b1 --- /dev/null +++ b/test/util/test_config.py @@ -0,0 +1,15 @@ +import unittest + +from beets.util.config import sanitize_choices + + +class HelpersTest(unittest.TestCase): + def test_sanitize_choices(self): + assert sanitize_choices(["A", "Z"], ("A", "B")) == ["A"] + assert sanitize_choices(["A", "A"], ("A")) == ["A"] + assert sanitize_choices(["D", "*", "A"], ("A", "B", "C", "D")) == [ + "D", + "B", + "C", + "A", + ] From 1c9aebd36c4ebd9a0a08ac6a131ee06927f2bad9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sun, 25 May 2025 10:52:39 +0100 Subject: [PATCH 2/8] match.current_metadata -> util.get_most_common_tags --- beets/autotag/__init__.py | 9 +---- beets/autotag/match.py | 44 ++------------------- beets/importer/tasks.py | 2 +- beets/util/__init__.py | 40 +++++++++++++++++++ test/test_autotag.py | 80 -------------------------------------- test/test_util.py | 82 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 127 insertions(+), 130 deletions(-) diff --git a/beets/autotag/__init__.py b/beets/autotag/__init__.py index 42f957b0d..5b6a11195 100644 --- a/beets/autotag/__init__.py +++ b/beets/autotag/__init__.py @@ -24,13 +24,7 @@ from beets.library import Album, Item, LibModel from beets.util import unique_list from .hooks import AlbumInfo, AlbumMatch, Distance, TrackInfo, TrackMatch -from .match import ( - Proposal, - Recommendation, - current_metadata, - tag_album, - tag_item, -) +from .match import Proposal, Recommendation, tag_album, tag_item __all__ = [ "AlbumInfo", @@ -43,7 +37,6 @@ __all__ = [ "apply_album_metadata", "apply_item_metadata", "apply_metadata", - "current_metadata", "tag_album", "tag_item", ] diff --git a/beets/autotag/match.py b/beets/autotag/match.py index 91a315de0..4dc4c1052 100644 --- a/beets/autotag/match.py +++ b/beets/autotag/match.py @@ -36,7 +36,7 @@ from beets.autotag import ( TrackMatch, hooks, ) -from beets.util import plurality +from beets.util import get_most_common_tags if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -80,44 +80,6 @@ class Proposal(NamedTuple): # Primary matching functionality. -def current_metadata( - items: Iterable[Item], -) -> tuple[dict[str, Any], dict[str, Any]]: - """Extract the likely current metadata for an album given a list of its - items. Return two dictionaries: - - The most common value for each field. - - Whether each field's value was unanimous (values are booleans). - """ - assert items # Must be nonempty. - - likelies = {} - consensus = {} - fields = [ - "artist", - "album", - "albumartist", - "year", - "disctotal", - "mb_albumid", - "label", - "barcode", - "catalognum", - "country", - "media", - "albumdisambig", - ] - for field in fields: - values = [item[field] for item in items if item] - likelies[field], freq = plurality(values) - consensus[field] = freq == len(values) - - # If there's an album artist consensus, use this for the artist. - if consensus["albumartist"] and likelies["albumartist"]: - likelies["artist"] = likelies["albumartist"] - - return likelies, consensus - - def assign_items( items: Sequence[Item], tracks: Sequence[TrackInfo], @@ -231,7 +193,7 @@ def distance( keys are a subset of `items` and the values are a subset of `album_info.tracks`. """ - likelies, _ = current_metadata(items) + likelies, _ = get_most_common_tags(items) dist = hooks.Distance() @@ -499,7 +461,7 @@ def tag_album( candidates. """ # Get current metadata. - likelies, consensus = current_metadata(items) + likelies, consensus = get_most_common_tags(items) cur_artist: str = likelies["artist"] cur_album: str = likelies["album"] log.debug("Tagging {0} - {1}", cur_artist, cur_album) diff --git a/beets/importer/tasks.py b/beets/importer/tasks.py index d2f638c55..75f04cf5a 100644 --- a/beets/importer/tasks.py +++ b/beets/importer/tasks.py @@ -228,7 +228,7 @@ class ImportTask(BaseImportTask): or APPLY (in which case the data comes from the choice). """ if self.choice_flag in (Action.ASIS, Action.RETAG): - likelies, consensus = autotag.current_metadata(self.items) + likelies, consensus = util.get_most_common_tags(self.items) return likelies elif self.choice_flag is Action.APPLY and self.match: return self.match.info.copy() diff --git a/beets/util/__init__.py b/beets/util/__init__.py index 6bc4d14ee..9bd7451f8 100644 --- a/beets/util/__init__.py +++ b/beets/util/__init__.py @@ -56,6 +56,8 @@ if TYPE_CHECKING: from collections.abc import Iterator, Sequence from logging import Logger + from beets.library import Item + if sys.version_info >= (3, 10): from typing import TypeAlias else: @@ -814,6 +816,44 @@ def plurality(objs: Iterable[T]) -> tuple[T, int]: return c.most_common(1)[0] +def get_most_common_tags( + items: Sequence[Item], +) -> tuple[dict[str, Any], dict[str, Any]]: + """Extract the likely current metadata for an album given a list of its + items. Return two dictionaries: + - The most common value for each field. + - Whether each field's value was unanimous (values are booleans). + """ + assert items # Must be nonempty. + + likelies = {} + consensus = {} + fields = [ + "artist", + "album", + "albumartist", + "year", + "disctotal", + "mb_albumid", + "label", + "barcode", + "catalognum", + "country", + "media", + "albumdisambig", + ] + for field in fields: + values = [item[field] for item in items if item] + likelies[field], freq = plurality(values) + consensus[field] = freq == len(values) + + # If there's an album artist consensus, use this for the artist. + if consensus["albumartist"] and likelies["albumartist"]: + likelies["artist"] = likelies["albumartist"] + + return likelies, consensus + + # stdout and stderr as bytes class CommandOutput(NamedTuple): stdout: bytes diff --git a/test/test_autotag.py b/test/test_autotag.py index 7f8ed3d2e..bd4205806 100644 --- a/test/test_autotag.py +++ b/test/test_autotag.py @@ -24,86 +24,6 @@ from beets.autotag import AlbumInfo, TrackInfo, correct_list_fields, match from beets.autotag.hooks import Distance, string_dist from beets.library import Item from beets.test.helper import BeetsTestCase, ConfigMixin -from beets.util import plurality - - -class PluralityTest(BeetsTestCase): - def test_plurality_consensus(self): - objs = [1, 1, 1, 1] - obj, freq = plurality(objs) - assert obj == 1 - assert freq == 4 - - def test_plurality_near_consensus(self): - objs = [1, 1, 2, 1] - obj, freq = plurality(objs) - assert obj == 1 - assert freq == 3 - - def test_plurality_conflict(self): - objs = [1, 1, 2, 2, 3] - obj, freq = plurality(objs) - assert obj in (1, 2) - assert freq == 2 - - def test_plurality_empty_sequence_raises_error(self): - with pytest.raises(ValueError, match="must be non-empty"): - plurality([]) - - def test_current_metadata_finds_pluralities(self): - items = [ - Item(artist="The Beetles", album="The White Album"), - Item(artist="The Beatles", album="The White Album"), - Item(artist="The Beatles", album="Teh White Album"), - ] - likelies, consensus = match.current_metadata(items) - assert likelies["artist"] == "The Beatles" - assert likelies["album"] == "The White Album" - assert not consensus["artist"] - - def test_current_metadata_artist_consensus(self): - items = [ - Item(artist="The Beatles", album="The White Album"), - Item(artist="The Beatles", album="The White Album"), - Item(artist="The Beatles", album="Teh White Album"), - ] - likelies, consensus = match.current_metadata(items) - assert likelies["artist"] == "The Beatles" - assert likelies["album"] == "The White Album" - assert consensus["artist"] - - def test_albumartist_consensus(self): - items = [ - Item(artist="tartist1", album="album", albumartist="aartist"), - Item(artist="tartist2", album="album", albumartist="aartist"), - Item(artist="tartist3", album="album", albumartist="aartist"), - ] - likelies, consensus = match.current_metadata(items) - assert likelies["artist"] == "aartist" - assert not consensus["artist"] - - def test_current_metadata_likelies(self): - fields = [ - "artist", - "album", - "albumartist", - "year", - "disctotal", - "mb_albumid", - "label", - "barcode", - "catalognum", - "country", - "media", - "albumdisambig", - ] - items = [Item(**{f: f"{f}_{i or 1}" for f in fields}) for i in range(5)] - likelies, _ = match.current_metadata(items) - for f in fields: - if isinstance(likelies[f], int): - assert likelies[f] == 0 - else: - assert likelies[f] == f"{f}_1" def _make_item(title, track, artist="some artist"): diff --git a/test/test_util.py b/test/test_util.py index d08868619..5aa6c5dc7 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -24,7 +24,10 @@ from unittest.mock import Mock, patch import pytest from beets import util +from beets.library import Item from beets.test import _common +from beets.test.helper import BeetsTestCase +from beets.util import plurality class UtilTest(unittest.TestCase): @@ -217,3 +220,82 @@ class TestPathLegalization: expected_path, expected_truncated, ) + + +class PluralityTest(BeetsTestCase): + def test_plurality_consensus(self): + objs = [1, 1, 1, 1] + obj, freq = plurality(objs) + assert obj == 1 + assert freq == 4 + + def test_plurality_near_consensus(self): + objs = [1, 1, 2, 1] + obj, freq = plurality(objs) + assert obj == 1 + assert freq == 3 + + def test_plurality_conflict(self): + objs = [1, 1, 2, 2, 3] + obj, freq = plurality(objs) + assert obj in (1, 2) + assert freq == 2 + + def test_plurality_empty_sequence_raises_error(self): + with pytest.raises(ValueError, match="must be non-empty"): + plurality([]) + + def test_current_metadata_finds_pluralities(self): + items = [ + Item(artist="The Beetles", album="The White Album"), + Item(artist="The Beatles", album="The White Album"), + Item(artist="The Beatles", album="Teh White Album"), + ] + likelies, consensus = util.get_most_common_tags(items) + assert likelies["artist"] == "The Beatles" + assert likelies["album"] == "The White Album" + assert not consensus["artist"] + + def test_current_metadata_artist_consensus(self): + items = [ + Item(artist="The Beatles", album="The White Album"), + Item(artist="The Beatles", album="The White Album"), + Item(artist="The Beatles", album="Teh White Album"), + ] + likelies, consensus = util.get_most_common_tags(items) + assert likelies["artist"] == "The Beatles" + assert likelies["album"] == "The White Album" + assert consensus["artist"] + + def test_albumartist_consensus(self): + items = [ + Item(artist="tartist1", album="album", albumartist="aartist"), + Item(artist="tartist2", album="album", albumartist="aartist"), + Item(artist="tartist3", album="album", albumartist="aartist"), + ] + likelies, consensus = util.get_most_common_tags(items) + assert likelies["artist"] == "aartist" + assert not consensus["artist"] + + def test_current_metadata_likelies(self): + fields = [ + "artist", + "album", + "albumartist", + "year", + "disctotal", + "mb_albumid", + "label", + "barcode", + "catalognum", + "country", + "media", + "albumdisambig", + ] + items = [Item(**{f: f"{f}_{i or 1}" for f in fields}) for i in range(5)] + likelies, _ = util.get_most_common_tags(items) + for f in fields: + if isinstance(likelies[f], int): + assert likelies[f] == 0 + else: + assert likelies[f] == f"{f}_1" From 01b6ea78987931804d08f7ae377ab54f46b95f0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sun, 25 May 2025 02:50:57 +0100 Subject: [PATCH 3/8] Simplify and speed up plurality/album tags retrieval tests --- test/test_util.py | 95 +++++++++++++---------------------------------- 1 file changed, 26 insertions(+), 69 deletions(-) diff --git a/test/test_util.py b/test/test_util.py index 5aa6c5dc7..d8a4ca0db 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -26,8 +26,6 @@ import pytest from beets import util from beets.library import Item from beets.test import _common -from beets.test.helper import BeetsTestCase -from beets.util import plurality class UtilTest(unittest.TestCase): @@ -222,80 +220,39 @@ class TestPathLegalization: ) -class PluralityTest(BeetsTestCase): - def test_plurality_consensus(self): - objs = [1, 1, 1, 1] - obj, freq = plurality(objs) - assert obj == 1 - assert freq == 4 +class TestPlurality: + @pytest.mark.parametrize( + "objs, expected_obj, expected_freq", + [ + pytest.param([1, 1, 1, 1], 1, 4, id="consensus"), + pytest.param([1, 1, 2, 1], 1, 3, id="near consensus"), + pytest.param([1, 1, 2, 2, 3], 1, 2, id="conflict-first-wins"), + ], + ) + def test_plurality(self, objs, expected_obj, expected_freq): + assert (expected_obj, expected_freq) == util.plurality(objs) - def test_plurality_near_consensus(self): - objs = [1, 1, 2, 1] - obj, freq = plurality(objs) - assert obj == 1 - assert freq == 3 - - def test_plurality_conflict(self): - objs = [1, 1, 2, 2, 3] - obj, freq = plurality(objs) - assert obj in (1, 2) - assert freq == 2 - - def test_plurality_empty_sequence_raises_error(self): + def test_empty_sequence_raises_error(self): with pytest.raises(ValueError, match="must be non-empty"): - plurality([]) + util.plurality([]) - def test_current_metadata_finds_pluralities(self): + def test_get_most_common_tags(self): items = [ - Item(artist="The Beetles", album="The White Album"), - Item(artist="The Beatles", album="The White Album"), - Item(artist="The Beatles", album="Teh White Album"), + Item(albumartist="aartist", label="label 1", album="album"), + Item(albumartist="aartist", label="label 2", album="album"), + Item(albumartist="aartist", label="label 3", album="another album"), ] - likelies, consensus = util.get_most_common_tags(items) - assert likelies["artist"] == "The Beatles" - assert likelies["album"] == "The White Album" - assert not consensus["artist"] - def test_current_metadata_artist_consensus(self): - items = [ - Item(artist="The Beatles", album="The White Album"), - Item(artist="The Beatles", album="The White Album"), - Item(artist="The Beatles", album="Teh White Album"), - ] likelies, consensus = util.get_most_common_tags(items) - assert likelies["artist"] == "The Beatles" - assert likelies["album"] == "The White Album" - assert consensus["artist"] - def test_albumartist_consensus(self): - items = [ - Item(artist="tartist1", album="album", albumartist="aartist"), - Item(artist="tartist2", album="album", albumartist="aartist"), - Item(artist="tartist3", album="album", albumartist="aartist"), - ] - likelies, consensus = util.get_most_common_tags(items) + assert likelies["albumartist"] == "aartist" + assert likelies["album"] == "album" + # albumartist consensus overrides artist assert likelies["artist"] == "aartist" - assert not consensus["artist"] + assert likelies["label"] == "label 1" + assert likelies["year"] == 0 - def test_current_metadata_likelies(self): - fields = [ - "artist", - "album", - "albumartist", - "year", - "disctotal", - "mb_albumid", - "label", - "barcode", - "catalognum", - "country", - "media", - "albumdisambig", - ] - items = [Item(**{f: f"{f}_{i or 1}" for f in fields}) for i in range(5)] - likelies, _ = util.get_most_common_tags(items) - for f in fields: - if isinstance(likelies[f], int): - assert likelies[f] == 0 - else: - assert likelies[f] == f"{f}_1" + assert consensus["year"] + assert consensus["albumartist"] + assert not consensus["album"] + assert not consensus["label"] From adbd50b2374edb4fe9f9455da9401dc1225a4818 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sun, 25 May 2025 02:59:27 +0100 Subject: [PATCH 4/8] Move distance to a separate module --- beets/autotag/__init__.py | 14 +- beets/autotag/distance.py | 531 ++++++++++++++++++++++++++++++++++ beets/autotag/hooks.py | 334 +-------------------- beets/autotag/match.py | 205 +------------ beets/plugins.py | 13 +- beetsplug/chroma.py | 2 +- beetsplug/discogs.py | 3 +- beetsplug/lyrics.py | 2 +- test/autotag/test_distance.py | 476 ++++++++++++++++++++++++++++++ test/test_autotag.py | 472 ------------------------------ 10 files changed, 1028 insertions(+), 1024 deletions(-) create mode 100644 beets/autotag/distance.py create mode 100644 test/autotag/test_distance.py diff --git a/beets/autotag/__init__.py b/beets/autotag/__init__.py index 5b6a11195..5b16b012e 100644 --- a/beets/autotag/__init__.py +++ b/beets/autotag/__init__.py @@ -14,22 +14,26 @@ """Facilities for automatically determining files' correct metadata.""" -from collections.abc import Mapping, Sequence -from typing import Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Union from beets import config, logging -from beets.library import Album, Item, LibModel # Parts of external interface. from beets.util import unique_list -from .hooks import AlbumInfo, AlbumMatch, Distance, TrackInfo, TrackMatch +from .hooks import AlbumInfo, AlbumMatch, TrackInfo, TrackMatch from .match import Proposal, Recommendation, tag_album, tag_item +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + from beets.library import Album, Item, LibModel + __all__ = [ "AlbumInfo", "AlbumMatch", - "Distance", "TrackInfo", "TrackMatch", "Proposal", diff --git a/beets/autotag/distance.py b/beets/autotag/distance.py new file mode 100644 index 000000000..d146c27f0 --- /dev/null +++ b/beets/autotag/distance.py @@ -0,0 +1,531 @@ +from __future__ import annotations + +import datetime +import re +from functools import cache, total_ordering +from typing import TYPE_CHECKING, Any + +from jellyfish import levenshtein_distance +from unidecode import unidecode + +from beets import config, plugins +from beets.util import as_string, cached_classproperty, get_most_common_tags + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + from beets.library import Item + + from .hooks import AlbumInfo, TrackInfo + +# Candidate distance scoring. + +# Artist signals that indicate "various artists". These are used at the +# album level to determine whether a given release is likely a VA +# release and also on the track level to to remove the penalty for +# differing artists. +VA_ARTISTS = ("", "various artists", "various", "va", "unknown") + +# Parameters for string distance function. +# Words that can be moved to the end of a string using a comma. +SD_END_WORDS = ["the", "a", "an"] +# Reduced weights for certain portions of the string. +SD_PATTERNS = [ + (r"^the ", 0.1), + (r"[\[\(]?(ep|single)[\]\)]?", 0.0), + (r"[\[\(]?(featuring|feat|ft)[\. :].+", 0.1), + (r"\(.*?\)", 0.3), + (r"\[.*?\]", 0.3), + (r"(, )?(pt\.|part) .+", 0.2), +] +# Replacements to use before testing distance. +SD_REPLACE = [ + (r"&", "and"), +] + + +def _string_dist_basic(str1: str, str2: str) -> float: + """Basic edit distance between two strings, ignoring + non-alphanumeric characters and case. Comparisons are based on a + transliteration/lowering to ASCII characters. Normalized by string + length. + """ + assert isinstance(str1, str) + assert isinstance(str2, str) + str1 = as_string(unidecode(str1)) + str2 = as_string(unidecode(str2)) + str1 = re.sub(r"[^a-z0-9]", "", str1.lower()) + str2 = re.sub(r"[^a-z0-9]", "", str2.lower()) + if not str1 and not str2: + return 0.0 + return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2))) + + +def string_dist(str1: str | None, str2: str | None) -> float: + """Gives an "intuitive" edit distance between two strings. This is + an edit distance, normalized by the string length, with a number of + tweaks that reflect intuition about text. + """ + if str1 is None and str2 is None: + return 0.0 + if str1 is None or str2 is None: + return 1.0 + + str1 = str1.lower() + str2 = str2.lower() + + # Don't penalize strings that move certain words to the end. For + # example, "the something" should be considered equal to + # "something, the". + for word in SD_END_WORDS: + if str1.endswith(", %s" % word): + str1 = "{} {}".format(word, str1[: -len(word) - 2]) + if str2.endswith(", %s" % word): + str2 = "{} {}".format(word, str2[: -len(word) - 2]) + + # Perform a couple of basic normalizing substitutions. + for pat, repl in SD_REPLACE: + str1 = re.sub(pat, repl, str1) + str2 = re.sub(pat, repl, str2) + + # Change the weight for certain string portions matched by a set + # of regular expressions. We gradually change the strings and build + # up penalties associated with parts of the string that were + # deleted. + base_dist = _string_dist_basic(str1, str2) + penalty = 0.0 + for pat, weight in SD_PATTERNS: + # Get strings that drop the pattern. + case_str1 = re.sub(pat, "", str1) + case_str2 = re.sub(pat, "", str2) + + if case_str1 != str1 or case_str2 != str2: + # If the pattern was present (i.e., it is deleted in the + # the current case), recalculate the distances for the + # modified strings. + case_dist = _string_dist_basic(case_str1, case_str2) + case_delta = max(0.0, base_dist - case_dist) + if case_delta == 0.0: + continue + + # Shift our baseline strings down (to avoid rematching the + # same part of the string) and add a scaled distance + # amount to the penalties. + str1 = case_str1 + str2 = case_str2 + base_dist = case_dist + penalty += weight * case_delta + + return base_dist + penalty + + +@total_ordering +class Distance: + """Keeps track of multiple distance penalties. Provides a single + weighted distance for all penalties as well as a weighted distance + for each individual penalty. + """ + + def __init__(self) -> None: + self._penalties: dict[str, list[float]] = {} + self.tracks: dict[TrackInfo, Distance] = {} + + @cached_classproperty + def _weights(cls) -> dict[str, float]: + """A dictionary from keys to floating-point weights.""" + weights_view = config["match"]["distance_weights"] + weights = {} + for key in weights_view.keys(): + weights[key] = weights_view[key].as_number() + return weights + + # Access the components and their aggregates. + + @property + def distance(self) -> float: + """Return a weighted and normalized distance across all + penalties. + """ + dist_max = self.max_distance + if dist_max: + return self.raw_distance / self.max_distance + return 0.0 + + @property + def max_distance(self) -> float: + """Return the maximum distance penalty (normalization factor).""" + dist_max = 0.0 + for key, penalty in self._penalties.items(): + dist_max += len(penalty) * self._weights[key] + return dist_max + + @property + def raw_distance(self) -> float: + """Return the raw (denormalized) distance.""" + dist_raw = 0.0 + for key, penalty in self._penalties.items(): + dist_raw += sum(penalty) * self._weights[key] + return dist_raw + + def items(self) -> list[tuple[str, float]]: + """Return a list of (key, dist) pairs, with `dist` being the + weighted distance, sorted from highest to lowest. Does not + include penalties with a zero value. + """ + list_ = [] + for key in self._penalties: + dist = self[key] + if dist: + list_.append((key, dist)) + # Convert distance into a negative float we can sort items in + # ascending order (for keys, when the penalty is equal) and + # still get the items with the biggest distance first. + return sorted( + list_, key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0]) + ) + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other) -> bool: + return self.distance == other + + # Behave like a float. + + def __lt__(self, other) -> bool: + return self.distance < other + + def __float__(self) -> float: + return self.distance + + def __sub__(self, other) -> float: + return self.distance - other + + def __rsub__(self, other) -> float: + return other - self.distance + + def __str__(self) -> str: + return f"{self.distance:.2f}" + + # Behave like a dict. + + def __getitem__(self, key) -> float: + """Returns the weighted distance for a named penalty.""" + dist = sum(self._penalties[key]) * self._weights[key] + dist_max = self.max_distance + if dist_max: + return dist / dist_max + return 0.0 + + def __iter__(self) -> Iterator[tuple[str, float]]: + return iter(self.items()) + + def __len__(self) -> int: + return len(self.items()) + + def keys(self) -> list[str]: + return [key for key, _ in self.items()] + + def update(self, dist: Distance): + """Adds all the distance penalties from `dist`.""" + if not isinstance(dist, Distance): + raise ValueError( + "`dist` must be a Distance object, not {}".format(type(dist)) + ) + for key, penalties in dist._penalties.items(): + self._penalties.setdefault(key, []).extend(penalties) + + # Adding components. + + def _eq(self, value1: re.Pattern[str] | Any, value2: Any) -> bool: + """Returns True if `value1` is equal to `value2`. `value1` may + be a compiled regular expression, in which case it will be + matched against `value2`. + """ + if isinstance(value1, re.Pattern): + return bool(value1.match(value2)) + return value1 == value2 + + def add(self, key: str, dist: float): + """Adds a distance penalty. `key` must correspond with a + configured weight setting. `dist` must be a float between 0.0 + and 1.0, and will be added to any existing distance penalties + for the same key. + """ + if not 0.0 <= dist <= 1.0: + raise ValueError(f"`dist` must be between 0.0 and 1.0, not {dist}") + self._penalties.setdefault(key, []).append(dist) + + def add_equality( + self, + key: str, + value: Any, + options: list[Any] | tuple[Any, ...] | Any, + ): + """Adds a distance penalty of 1.0 if `value` doesn't match any + of the values in `options`. If an option is a compiled regular + expression, it will be considered equal if it matches against + `value`. + """ + if not isinstance(options, (list, tuple)): + options = [options] + for opt in options: + if self._eq(opt, value): + dist = 0.0 + break + else: + dist = 1.0 + self.add(key, dist) + + def add_expr(self, key: str, expr: bool): + """Adds a distance penalty of 1.0 if `expr` evaluates to True, + or 0.0. + """ + if expr: + self.add(key, 1.0) + else: + self.add(key, 0.0) + + def add_number(self, key: str, number1: int, number2: int): + """Adds a distance penalty of 1.0 for each number of difference + between `number1` and `number2`, or 0.0 when there is no + difference. Use this when there is no upper limit on the + difference between the two numbers. + """ + diff = abs(number1 - number2) + if diff: + for i in range(diff): + self.add(key, 1.0) + else: + self.add(key, 0.0) + + def add_priority( + self, + key: str, + value: Any, + options: list[Any] | tuple[Any, ...] | Any, + ): + """Adds a distance penalty that corresponds to the position at + which `value` appears in `options`. A distance penalty of 0.0 + for the first option, or 1.0 if there is no matching option. If + an option is a compiled regular expression, it will be + considered equal if it matches against `value`. + """ + if not isinstance(options, (list, tuple)): + options = [options] + unit = 1.0 / (len(options) or 1) + for i, opt in enumerate(options): + if self._eq(opt, value): + dist = i * unit + break + else: + dist = 1.0 + self.add(key, dist) + + def add_ratio( + self, + key: str, + number1: int | float, + number2: int | float, + ): + """Adds a distance penalty for `number1` as a ratio of `number2`. + `number1` is bound at 0 and `number2`. + """ + number = float(max(min(number1, number2), 0)) + if number2: + dist = number / number2 + else: + dist = 0.0 + self.add(key, dist) + + def add_string(self, key: str, str1: str | None, str2: str | None): + """Adds a distance penalty based on the edit distance between + `str1` and `str2`. + """ + dist = string_dist(str1, str2) + self.add(key, dist) + + +@cache +def get_track_length_grace() -> float: + """Get cached grace period for track length matching.""" + return config["match"]["track_length_grace"].as_number() + + +@cache +def get_track_length_max() -> float: + """Get cached maximum track length for track length matching.""" + return config["match"]["track_length_max"].as_number() + + +def track_index_changed(item: Item, track_info: TrackInfo) -> bool: + """Returns True if the item and track info index is different. Tolerates + per disc and per release numbering. + """ + return item.track not in (track_info.medium_index, track_info.index) + + +def track_distance( + item: Item, + track_info: TrackInfo, + incl_artist: bool = False, +) -> Distance: + """Determines the significance of a track metadata change. Returns a + Distance object. `incl_artist` indicates that a distance component should + be included for the track artist (i.e., for various-artist releases). + + ``track_length_grace`` and ``track_length_max`` configuration options are + cached because this function is called many times during the matching + process and their access comes with a performance overhead. + """ + dist = Distance() + + # Length. + if info_length := track_info.length: + diff = abs(item.length - info_length) - get_track_length_grace() + dist.add_ratio("track_length", diff, get_track_length_max()) + + # Title. + dist.add_string("track_title", item.title, track_info.title) + + # Artist. Only check if there is actually an artist in the track data. + if ( + incl_artist + and track_info.artist + and item.artist.lower() not in VA_ARTISTS + ): + dist.add_string("track_artist", item.artist, track_info.artist) + + # Track index. + if track_info.index and item.track: + dist.add_expr("track_index", track_index_changed(item, track_info)) + + # Track ID. + if item.mb_trackid: + dist.add_expr("track_id", item.mb_trackid != track_info.track_id) + + # Penalize mismatching disc numbers. + if track_info.medium and item.disc: + dist.add_expr("medium", item.disc != track_info.medium) + + # Plugins. + dist.update(plugins.track_distance(item, track_info)) + + return dist + + +def distance( + items: Sequence[Item], + album_info: AlbumInfo, + mapping: dict[Item, TrackInfo], +) -> Distance: + """Determines how "significant" an album metadata change would be. + Returns a Distance object. `album_info` is an AlbumInfo object + reflecting the album to be compared. `items` is a sequence of all + Item objects that will be matched (order is not important). + `mapping` is a dictionary mapping Items to TrackInfo objects; the + keys are a subset of `items` and the values are a subset of + `album_info.tracks`. + """ + likelies, _ = get_most_common_tags(items) + + dist = Distance() + + # Artist, if not various. + if not album_info.va: + dist.add_string("artist", likelies["artist"], album_info.artist) + + # Album. + dist.add_string("album", likelies["album"], album_info.album) + + preferred_config = config["match"]["preferred"] + # Current or preferred media. + if album_info.media: + # Preferred media options. + media_patterns: Sequence[str] = preferred_config["media"].as_str_seq() + options = [ + re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in media_patterns + ] + if options: + dist.add_priority("media", album_info.media, options) + # Current media. + elif likelies["media"]: + dist.add_equality("media", album_info.media, likelies["media"]) + + # Mediums. + if likelies["disctotal"] and album_info.mediums: + dist.add_number("mediums", likelies["disctotal"], album_info.mediums) + + # Prefer earliest release. + if album_info.year and preferred_config["original_year"]: + # Assume 1889 (earliest first gramophone discs) if we don't know the + # original year. + original = album_info.original_year or 1889 + diff = abs(album_info.year - original) + diff_max = abs(datetime.date.today().year - original) + dist.add_ratio("year", diff, diff_max) + # Year. + elif likelies["year"] and album_info.year: + if likelies["year"] in (album_info.year, album_info.original_year): + # No penalty for matching release or original year. + dist.add("year", 0.0) + elif album_info.original_year: + # Prefer matchest closest to the release year. + diff = abs(likelies["year"] - album_info.year) + diff_max = abs( + datetime.date.today().year - album_info.original_year + ) + dist.add_ratio("year", diff, diff_max) + else: + # Full penalty when there is no original year. + dist.add("year", 1.0) + + # Preferred countries. + country_patterns: Sequence[str] = preferred_config["countries"].as_str_seq() + options = [re.compile(pat, re.I) for pat in country_patterns] + if album_info.country and options: + dist.add_priority("country", album_info.country, options) + # Country. + elif likelies["country"] and album_info.country: + dist.add_string("country", likelies["country"], album_info.country) + + # Label. + if likelies["label"] and album_info.label: + dist.add_string("label", likelies["label"], album_info.label) + + # Catalog number. + if likelies["catalognum"] and album_info.catalognum: + dist.add_string( + "catalognum", likelies["catalognum"], album_info.catalognum + ) + + # Disambiguation. + if likelies["albumdisambig"] and album_info.albumdisambig: + dist.add_string( + "albumdisambig", likelies["albumdisambig"], album_info.albumdisambig + ) + + # Album ID. + if likelies["mb_albumid"]: + dist.add_equality( + "album_id", likelies["mb_albumid"], album_info.album_id + ) + + # Tracks. + dist.tracks = {} + for item, track in mapping.items(): + dist.tracks[track] = track_distance(item, track, album_info.va) + dist.add("tracks", dist.tracks[track].distance) + + # Missing tracks. + for _ in range(len(album_info.tracks) - len(mapping)): + dist.add("missing_tracks", 1.0) + + # Unmatched tracks. + for _ in range(len(items) - len(mapping)): + dist.add("unmatched_tracks", 1.0) + + # Plugins. + dist.update(plugins.album_distance(items, album_info, mapping)) + + return dist diff --git a/beets/autotag/hooks.py b/beets/autotag/hooks.py index 641a6cb4f..7cd215fc4 100644 --- a/beets/autotag/hooks.py +++ b/beets/autotag/hooks.py @@ -16,21 +16,15 @@ from __future__ import annotations -import re -from functools import total_ordering from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar -from jellyfish import levenshtein_distance -from unidecode import unidecode - -from beets import config, logging -from beets.util import as_string, cached_classproperty +from beets import logging if TYPE_CHECKING: - from collections.abc import Iterator - from beets.library import Item + from .distance import Distance + log = logging.getLogger("beets") V = TypeVar("V") @@ -254,328 +248,6 @@ class TrackInfo(AttrDict[Any]): return dupe -# Candidate distance scoring. - -# Parameters for string distance function. -# Words that can be moved to the end of a string using a comma. -SD_END_WORDS = ["the", "a", "an"] -# Reduced weights for certain portions of the string. -SD_PATTERNS = [ - (r"^the ", 0.1), - (r"[\[\(]?(ep|single)[\]\)]?", 0.0), - (r"[\[\(]?(featuring|feat|ft)[\. :].+", 0.1), - (r"\(.*?\)", 0.3), - (r"\[.*?\]", 0.3), - (r"(, )?(pt\.|part) .+", 0.2), -] -# Replacements to use before testing distance. -SD_REPLACE = [ - (r"&", "and"), -] - - -def _string_dist_basic(str1: str, str2: str) -> float: - """Basic edit distance between two strings, ignoring - non-alphanumeric characters and case. Comparisons are based on a - transliteration/lowering to ASCII characters. Normalized by string - length. - """ - assert isinstance(str1, str) - assert isinstance(str2, str) - str1 = as_string(unidecode(str1)) - str2 = as_string(unidecode(str2)) - str1 = re.sub(r"[^a-z0-9]", "", str1.lower()) - str2 = re.sub(r"[^a-z0-9]", "", str2.lower()) - if not str1 and not str2: - return 0.0 - return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2))) - - -def string_dist(str1: str | None, str2: str | None) -> float: - """Gives an "intuitive" edit distance between two strings. This is - an edit distance, normalized by the string length, with a number of - tweaks that reflect intuition about text. - """ - if str1 is None and str2 is None: - return 0.0 - if str1 is None or str2 is None: - return 1.0 - - str1 = str1.lower() - str2 = str2.lower() - - # Don't penalize strings that move certain words to the end. For - # example, "the something" should be considered equal to - # "something, the". - for word in SD_END_WORDS: - if str1.endswith(", %s" % word): - str1 = "{} {}".format(word, str1[: -len(word) - 2]) - if str2.endswith(", %s" % word): - str2 = "{} {}".format(word, str2[: -len(word) - 2]) - - # Perform a couple of basic normalizing substitutions. - for pat, repl in SD_REPLACE: - str1 = re.sub(pat, repl, str1) - str2 = re.sub(pat, repl, str2) - - # Change the weight for certain string portions matched by a set - # of regular expressions. We gradually change the strings and build - # up penalties associated with parts of the string that were - # deleted. - base_dist = _string_dist_basic(str1, str2) - penalty = 0.0 - for pat, weight in SD_PATTERNS: - # Get strings that drop the pattern. - case_str1 = re.sub(pat, "", str1) - case_str2 = re.sub(pat, "", str2) - - if case_str1 != str1 or case_str2 != str2: - # If the pattern was present (i.e., it is deleted in the - # the current case), recalculate the distances for the - # modified strings. - case_dist = _string_dist_basic(case_str1, case_str2) - case_delta = max(0.0, base_dist - case_dist) - if case_delta == 0.0: - continue - - # Shift our baseline strings down (to avoid rematching the - # same part of the string) and add a scaled distance - # amount to the penalties. - str1 = case_str1 - str2 = case_str2 - base_dist = case_dist - penalty += weight * case_delta - - return base_dist + penalty - - -@total_ordering -class Distance: - """Keeps track of multiple distance penalties. Provides a single - weighted distance for all penalties as well as a weighted distance - for each individual penalty. - """ - - def __init__(self) -> None: - self._penalties: dict[str, list[float]] = {} - self.tracks: dict[TrackInfo, Distance] = {} - - @cached_classproperty - def _weights(cls) -> dict[str, float]: - """A dictionary from keys to floating-point weights.""" - weights_view = config["match"]["distance_weights"] - weights = {} - for key in weights_view.keys(): - weights[key] = weights_view[key].as_number() - return weights - - # Access the components and their aggregates. - - @property - def distance(self) -> float: - """Return a weighted and normalized distance across all - penalties. - """ - dist_max = self.max_distance - if dist_max: - return self.raw_distance / self.max_distance - return 0.0 - - @property - def max_distance(self) -> float: - """Return the maximum distance penalty (normalization factor).""" - dist_max = 0.0 - for key, penalty in self._penalties.items(): - dist_max += len(penalty) * self._weights[key] - return dist_max - - @property - def raw_distance(self) -> float: - """Return the raw (denormalized) distance.""" - dist_raw = 0.0 - for key, penalty in self._penalties.items(): - dist_raw += sum(penalty) * self._weights[key] - return dist_raw - - def items(self) -> list[tuple[str, float]]: - """Return a list of (key, dist) pairs, with `dist` being the - weighted distance, sorted from highest to lowest. Does not - include penalties with a zero value. - """ - list_ = [] - for key in self._penalties: - dist = self[key] - if dist: - list_.append((key, dist)) - # Convert distance into a negative float we can sort items in - # ascending order (for keys, when the penalty is equal) and - # still get the items with the biggest distance first. - return sorted( - list_, key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0]) - ) - - def __hash__(self) -> int: - return id(self) - - def __eq__(self, other) -> bool: - return self.distance == other - - # Behave like a float. - - def __lt__(self, other) -> bool: - return self.distance < other - - def __float__(self) -> float: - return self.distance - - def __sub__(self, other) -> float: - return self.distance - other - - def __rsub__(self, other) -> float: - return other - self.distance - - def __str__(self) -> str: - return f"{self.distance:.2f}" - - # Behave like a dict. - - def __getitem__(self, key) -> float: - """Returns the weighted distance for a named penalty.""" - dist = sum(self._penalties[key]) * self._weights[key] - dist_max = self.max_distance - if dist_max: - return dist / dist_max - return 0.0 - - def __iter__(self) -> Iterator[tuple[str, float]]: - return iter(self.items()) - - def __len__(self) -> int: - return len(self.items()) - - def keys(self) -> list[str]: - return [key for key, _ in self.items()] - - def update(self, dist: Distance): - """Adds all the distance penalties from `dist`.""" - if not isinstance(dist, Distance): - raise ValueError( - "`dist` must be a Distance object, not {}".format(type(dist)) - ) - for key, penalties in dist._penalties.items(): - self._penalties.setdefault(key, []).extend(penalties) - - # Adding components. - - def _eq(self, value1: re.Pattern[str] | Any, value2: Any) -> bool: - """Returns True if `value1` is equal to `value2`. `value1` may - be a compiled regular expression, in which case it will be - matched against `value2`. - """ - if isinstance(value1, re.Pattern): - return bool(value1.match(value2)) - return value1 == value2 - - def add(self, key: str, dist: float): - """Adds a distance penalty. `key` must correspond with a - configured weight setting. `dist` must be a float between 0.0 - and 1.0, and will be added to any existing distance penalties - for the same key. - """ - if not 0.0 <= dist <= 1.0: - raise ValueError(f"`dist` must be between 0.0 and 1.0, not {dist}") - self._penalties.setdefault(key, []).append(dist) - - def add_equality( - self, - key: str, - value: Any, - options: list[Any] | tuple[Any, ...] | Any, - ): - """Adds a distance penalty of 1.0 if `value` doesn't match any - of the values in `options`. If an option is a compiled regular - expression, it will be considered equal if it matches against - `value`. - """ - if not isinstance(options, (list, tuple)): - options = [options] - for opt in options: - if self._eq(opt, value): - dist = 0.0 - break - else: - dist = 1.0 - self.add(key, dist) - - def add_expr(self, key: str, expr: bool): - """Adds a distance penalty of 1.0 if `expr` evaluates to True, - or 0.0. - """ - if expr: - self.add(key, 1.0) - else: - self.add(key, 0.0) - - def add_number(self, key: str, number1: int, number2: int): - """Adds a distance penalty of 1.0 for each number of difference - between `number1` and `number2`, or 0.0 when there is no - difference. Use this when there is no upper limit on the - difference between the two numbers. - """ - diff = abs(number1 - number2) - if diff: - for i in range(diff): - self.add(key, 1.0) - else: - self.add(key, 0.0) - - def add_priority( - self, - key: str, - value: Any, - options: list[Any] | tuple[Any, ...] | Any, - ): - """Adds a distance penalty that corresponds to the position at - which `value` appears in `options`. A distance penalty of 0.0 - for the first option, or 1.0 if there is no matching option. If - an option is a compiled regular expression, it will be - considered equal if it matches against `value`. - """ - if not isinstance(options, (list, tuple)): - options = [options] - unit = 1.0 / (len(options) or 1) - for i, opt in enumerate(options): - if self._eq(opt, value): - dist = i * unit - break - else: - dist = 1.0 - self.add(key, dist) - - def add_ratio( - self, - key: str, - number1: int | float, - number2: int | float, - ): - """Adds a distance penalty for `number1` as a ratio of `number2`. - `number1` is bound at 0 and `number2`. - """ - number = float(max(min(number1, number2), 0)) - if number2: - dist = number / number2 - else: - dist = 0.0 - self.add(key, dist) - - def add_string(self, key: str, str1: str | None, str2: str | None): - """Adds a distance penalty based on the edit distance between - `str1` and `str2`. - """ - dist = string_dist(str1, str2) - self.add(key, dist) - - # Structures that compose all the information for a candidate match. diff --git a/beets/autotag/match.py b/beets/autotag/match.py index 4dc4c1052..64572cf3b 100644 --- a/beets/autotag/match.py +++ b/beets/autotag/match.py @@ -18,37 +18,23 @@ releases and tracks. from __future__ import annotations -import datetime -import re from enum import IntEnum -from functools import cache from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar import lap import numpy as np from beets import config, logging, plugins -from beets.autotag import ( - AlbumInfo, - AlbumMatch, - Distance, - TrackInfo, - TrackMatch, - hooks, -) +from beets.autotag import AlbumInfo, AlbumMatch, TrackInfo, TrackMatch, hooks from beets.util import get_most_common_tags +from .distance import VA_ARTISTS, distance, track_distance + if TYPE_CHECKING: from collections.abc import Iterable, Sequence from beets.library import Item -# Artist signals that indicate "various artists". These are used at the -# album level to determine whether a given release is likely a VA -# release and also on the track level to to remove the penalty for -# differing artists. -VA_ARTISTS = ("", "various artists", "various", "va", "unknown") - # Global logger. log = logging.getLogger("beets") @@ -112,191 +98,6 @@ def assign_items( return mapping, extra_items, extra_tracks -def track_index_changed(item: Item, track_info: TrackInfo) -> bool: - """Returns True if the item and track info index is different. Tolerates - per disc and per release numbering. - """ - return item.track not in (track_info.medium_index, track_info.index) - - -@cache -def get_track_length_grace() -> float: - """Get cached grace period for track length matching.""" - return config["match"]["track_length_grace"].as_number() - - -@cache -def get_track_length_max() -> float: - """Get cached maximum track length for track length matching.""" - return config["match"]["track_length_max"].as_number() - - -def track_distance( - item: Item, - track_info: TrackInfo, - incl_artist: bool = False, -) -> Distance: - """Determines the significance of a track metadata change. Returns a - Distance object. `incl_artist` indicates that a distance component should - be included for the track artist (i.e., for various-artist releases). - - ``track_length_grace`` and ``track_length_max`` configuration options are - cached because this function is called many times during the matching - process and their access comes with a performance overhead. - """ - dist = hooks.Distance() - - # Length. - if info_length := track_info.length: - diff = abs(item.length - info_length) - get_track_length_grace() - dist.add_ratio("track_length", diff, get_track_length_max()) - - # Title. - dist.add_string("track_title", item.title, track_info.title) - - # Artist. Only check if there is actually an artist in the track data. - if ( - incl_artist - and track_info.artist - and item.artist.lower() not in VA_ARTISTS - ): - dist.add_string("track_artist", item.artist, track_info.artist) - - # Track index. - if track_info.index and item.track: - dist.add_expr("track_index", track_index_changed(item, track_info)) - - # Track ID. - if item.mb_trackid: - dist.add_expr("track_id", item.mb_trackid != track_info.track_id) - - # Penalize mismatching disc numbers. - if track_info.medium and item.disc: - dist.add_expr("medium", item.disc != track_info.medium) - - # Plugins. - dist.update(plugins.track_distance(item, track_info)) - - return dist - - -def distance( - items: Sequence[Item], - album_info: AlbumInfo, - mapping: dict[Item, TrackInfo], -) -> Distance: - """Determines how "significant" an album metadata change would be. - Returns a Distance object. `album_info` is an AlbumInfo object - reflecting the album to be compared. `items` is a sequence of all - Item objects that will be matched (order is not important). - `mapping` is a dictionary mapping Items to TrackInfo objects; the - keys are a subset of `items` and the values are a subset of - `album_info.tracks`. - """ - likelies, _ = get_most_common_tags(items) - - dist = hooks.Distance() - - # Artist, if not various. - if not album_info.va: - dist.add_string("artist", likelies["artist"], album_info.artist) - - # Album. - dist.add_string("album", likelies["album"], album_info.album) - - preferred_config = config["match"]["preferred"] - # Current or preferred media. - if album_info.media: - # Preferred media options. - media_patterns: Sequence[str] = preferred_config["media"].as_str_seq() - options = [ - re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in media_patterns - ] - if options: - dist.add_priority("media", album_info.media, options) - # Current media. - elif likelies["media"]: - dist.add_equality("media", album_info.media, likelies["media"]) - - # Mediums. - if likelies["disctotal"] and album_info.mediums: - dist.add_number("mediums", likelies["disctotal"], album_info.mediums) - - # Prefer earliest release. - if album_info.year and preferred_config["original_year"]: - # Assume 1889 (earliest first gramophone discs) if we don't know the - # original year. - original = album_info.original_year or 1889 - diff = abs(album_info.year - original) - diff_max = abs(datetime.date.today().year - original) - dist.add_ratio("year", diff, diff_max) - # Year. - elif likelies["year"] and album_info.year: - if likelies["year"] in (album_info.year, album_info.original_year): - # No penalty for matching release or original year. - dist.add("year", 0.0) - elif album_info.original_year: - # Prefer matchest closest to the release year. - diff = abs(likelies["year"] - album_info.year) - diff_max = abs( - datetime.date.today().year - album_info.original_year - ) - dist.add_ratio("year", diff, diff_max) - else: - # Full penalty when there is no original year. - dist.add("year", 1.0) - - # Preferred countries. - country_patterns: Sequence[str] = preferred_config["countries"].as_str_seq() - options = [re.compile(pat, re.I) for pat in country_patterns] - if album_info.country and options: - dist.add_priority("country", album_info.country, options) - # Country. - elif likelies["country"] and album_info.country: - dist.add_string("country", likelies["country"], album_info.country) - - # Label. - if likelies["label"] and album_info.label: - dist.add_string("label", likelies["label"], album_info.label) - - # Catalog number. - if likelies["catalognum"] and album_info.catalognum: - dist.add_string( - "catalognum", likelies["catalognum"], album_info.catalognum - ) - - # Disambiguation. - if likelies["albumdisambig"] and album_info.albumdisambig: - dist.add_string( - "albumdisambig", likelies["albumdisambig"], album_info.albumdisambig - ) - - # Album ID. - if likelies["mb_albumid"]: - dist.add_equality( - "album_id", likelies["mb_albumid"], album_info.album_id - ) - - # Tracks. - dist.tracks = {} - for item, track in mapping.items(): - dist.tracks[track] = track_distance(item, track, album_info.va) - dist.add("tracks", dist.tracks[track].distance) - - # Missing tracks. - for _ in range(len(album_info.tracks) - len(mapping)): - dist.add("missing_tracks", 1.0) - - # Unmatched tracks. - for _ in range(len(items) - len(mapping)): - dist.add("unmatched_tracks", 1.0) - - # Plugins. - dist.update(plugins.album_distance(items, album_info, mapping)) - - return dist - - def match_by_id(items: Iterable[Item]) -> AlbumInfo | None: """If the items are tagged with an external source ID, return an AlbumInfo object for the corresponding album. Otherwise, returns diff --git a/beets/plugins.py b/beets/plugins.py index 6d3a8447e..cd66435b5 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -37,6 +37,7 @@ import mediafile import beets from beets import logging +from beets.autotag.distance import Distance from beets.util.id_extractors import extract_release_id if TYPE_CHECKING: @@ -53,7 +54,7 @@ if TYPE_CHECKING: from confuse import ConfigView - from beets.autotag import AlbumInfo, Distance, TrackInfo + from beets.autotag import AlbumInfo, TrackInfo from beets.dbcore import Query from beets.dbcore.db import FieldQueryType from beets.dbcore.types import Type @@ -224,8 +225,6 @@ class BeetsPlugin: """Should return a Distance object to be added to the distance for every track comparison. """ - from beets.autotag.hooks import Distance - return Distance() def album_distance( @@ -237,8 +236,6 @@ class BeetsPlugin: """Should return a Distance object to be added to the distance for every album-level comparison. """ - from beets.autotag.hooks import Distance - return Distance() def candidates( @@ -458,8 +455,6 @@ def track_distance(item: Item, info: TrackInfo) -> Distance: """Gets the track distance calculated by all loaded plugins. Returns a Distance object. """ - from beets.autotag.hooks import Distance - dist = Distance() for plugin in find_plugins(): dist.update(plugin.track_distance(item, info)) @@ -472,8 +467,6 @@ def album_distance( mapping: dict[Item, TrackInfo], ) -> Distance: """Returns the album distance calculated by plugins.""" - from beets.autotag.hooks import Distance - dist = Distance() for plugin in find_plugins(): dist.update(plugin.album_distance(items, album_info, mapping)) @@ -660,8 +653,6 @@ def get_distance( """Returns the ``data_source`` weight and the maximum source weight for albums or individual tracks. """ - from beets.autotag.hooks import Distance - dist = Distance() if info.data_source == data_source: dist.add("source", config["source_weight"].as_number()) diff --git a/beetsplug/chroma.py b/beetsplug/chroma.py index 518a41776..5c718154b 100644 --- a/beetsplug/chroma.py +++ b/beetsplug/chroma.py @@ -24,7 +24,7 @@ import acoustid import confuse from beets import config, plugins, ui, util -from beets.autotag.hooks import Distance +from beets.autotag.distance import Distance from beetsplug.musicbrainz import MusicBrainzPlugin API_KEY = "1vOwZtEn" diff --git a/beetsplug/discogs.py b/beetsplug/discogs.py index 696f1d1ac..2408f3498 100644 --- a/beetsplug/discogs.py +++ b/beetsplug/discogs.py @@ -38,7 +38,8 @@ from typing_extensions import TypedDict import beets import beets.ui from beets import config -from beets.autotag.hooks import AlbumInfo, TrackInfo, string_dist +from beets.autotag.distance import string_dist +from beets.autotag.hooks import AlbumInfo, TrackInfo from beets.plugins import BeetsPlugin, MetadataSourcePlugin, get_distance from beets.util.id_extractors import extract_release_id diff --git a/beetsplug/lyrics.py b/beetsplug/lyrics.py index e2c0c7fd2..f1c40ab24 100644 --- a/beetsplug/lyrics.py +++ b/beetsplug/lyrics.py @@ -38,7 +38,7 @@ from unidecode import unidecode import beets from beets import plugins, ui -from beets.autotag.hooks import string_dist +from beets.autotag.distance import string_dist from beets.util.config import sanitize_choices if TYPE_CHECKING: diff --git a/test/autotag/test_distance.py b/test/autotag/test_distance.py new file mode 100644 index 000000000..ec00ebcdf --- /dev/null +++ b/test/autotag/test_distance.py @@ -0,0 +1,476 @@ +import re +import unittest + +from beets import config +from beets.autotag import AlbumInfo, TrackInfo, match +from beets.autotag.distance import Distance, string_dist +from beets.library import Item +from beets.test.helper import BeetsTestCase + + +def _make_item(title, track, artist="some artist"): + return Item( + title=title, + track=track, + artist=artist, + album="some album", + length=1, + mb_trackid="", + mb_albumid="", + mb_artistid="", + ) + + +def _make_trackinfo(): + return [ + TrackInfo( + title="one", track_id=None, artist="some artist", length=1, index=1 + ), + TrackInfo( + title="two", track_id=None, artist="some artist", length=1, index=2 + ), + TrackInfo( + title="three", + track_id=None, + artist="some artist", + length=1, + index=3, + ), + ] + + +def _clear_weights(): + """Hack around the lazy descriptor used to cache weights for + Distance calculations. + """ + Distance.__dict__["_weights"].cache = {} + + +class DistanceTest(BeetsTestCase): + def tearDown(self): + super().tearDown() + _clear_weights() + + def test_add(self): + dist = Distance() + dist.add("add", 1.0) + assert dist._penalties == {"add": [1.0]} + + def test_add_equality(self): + dist = Distance() + dist.add_equality("equality", "ghi", ["abc", "def", "ghi"]) + assert dist._penalties["equality"] == [0.0] + + dist.add_equality("equality", "xyz", ["abc", "def", "ghi"]) + assert dist._penalties["equality"] == [0.0, 1.0] + + dist.add_equality("equality", "abc", re.compile(r"ABC", re.I)) + assert dist._penalties["equality"] == [0.0, 1.0, 0.0] + + def test_add_expr(self): + dist = Distance() + dist.add_expr("expr", True) + assert dist._penalties["expr"] == [1.0] + + dist.add_expr("expr", False) + assert dist._penalties["expr"] == [1.0, 0.0] + + def test_add_number(self): + dist = Distance() + # Add a full penalty for each number of difference between two numbers. + + dist.add_number("number", 1, 1) + assert dist._penalties["number"] == [0.0] + + dist.add_number("number", 1, 2) + assert dist._penalties["number"] == [0.0, 1.0] + + dist.add_number("number", 2, 1) + assert dist._penalties["number"] == [0.0, 1.0, 1.0] + + dist.add_number("number", -1, 2) + assert dist._penalties["number"] == [0.0, 1.0, 1.0, 1.0, 1.0, 1.0] + + def test_add_priority(self): + dist = Distance() + dist.add_priority("priority", "abc", "abc") + assert dist._penalties["priority"] == [0.0] + + dist.add_priority("priority", "def", ["abc", "def"]) + assert dist._penalties["priority"] == [0.0, 0.5] + + dist.add_priority( + "priority", "gh", ["ab", "cd", "ef", re.compile("GH", re.I)] + ) + assert dist._penalties["priority"] == [0.0, 0.5, 0.75] + + dist.add_priority("priority", "xyz", ["abc", "def"]) + assert dist._penalties["priority"] == [0.0, 0.5, 0.75, 1.0] + + def test_add_ratio(self): + dist = Distance() + dist.add_ratio("ratio", 25, 100) + assert dist._penalties["ratio"] == [0.25] + + dist.add_ratio("ratio", 10, 5) + assert dist._penalties["ratio"] == [0.25, 1.0] + + dist.add_ratio("ratio", -5, 5) + assert dist._penalties["ratio"] == [0.25, 1.0, 0.0] + + dist.add_ratio("ratio", 5, 0) + assert dist._penalties["ratio"] == [0.25, 1.0, 0.0, 0.0] + + def test_add_string(self): + dist = Distance() + sdist = string_dist("abc", "bcd") + dist.add_string("string", "abc", "bcd") + assert dist._penalties["string"] == [sdist] + assert dist._penalties["string"] != [0] + + def test_add_string_none(self): + dist = Distance() + dist.add_string("string", None, "string") + assert dist._penalties["string"] == [1] + + def test_add_string_both_none(self): + dist = Distance() + dist.add_string("string", None, None) + assert dist._penalties["string"] == [0] + + def test_distance(self): + config["match"]["distance_weights"]["album"] = 2.0 + config["match"]["distance_weights"]["medium"] = 1.0 + _clear_weights() + + dist = Distance() + dist.add("album", 0.5) + dist.add("media", 0.25) + dist.add("media", 0.75) + assert dist.distance == 0.5 + + # __getitem__() + assert dist["album"] == 0.25 + assert dist["media"] == 0.25 + + def test_max_distance(self): + config["match"]["distance_weights"]["album"] = 3.0 + config["match"]["distance_weights"]["medium"] = 1.0 + _clear_weights() + + dist = Distance() + dist.add("album", 0.5) + dist.add("medium", 0.0) + dist.add("medium", 0.0) + assert dist.max_distance == 5.0 + + def test_operators(self): + config["match"]["distance_weights"]["source"] = 1.0 + config["match"]["distance_weights"]["album"] = 2.0 + config["match"]["distance_weights"]["medium"] = 1.0 + _clear_weights() + + dist = Distance() + dist.add("source", 0.0) + dist.add("album", 0.5) + dist.add("medium", 0.25) + dist.add("medium", 0.75) + assert len(dist) == 2 + assert list(dist) == [("album", 0.2), ("medium", 0.2)] + assert dist == 0.4 + assert dist < 1.0 + assert dist > 0.0 + assert dist - 0.4 == 0.0 + assert 0.4 - dist == 0.0 + assert float(dist) == 0.4 + + def test_raw_distance(self): + config["match"]["distance_weights"]["album"] = 3.0 + config["match"]["distance_weights"]["medium"] = 1.0 + _clear_weights() + + dist = Distance() + dist.add("album", 0.5) + dist.add("medium", 0.25) + dist.add("medium", 0.5) + assert dist.raw_distance == 2.25 + + def test_items(self): + config["match"]["distance_weights"]["album"] = 4.0 + config["match"]["distance_weights"]["medium"] = 2.0 + _clear_weights() + + dist = Distance() + dist.add("album", 0.1875) + dist.add("medium", 0.75) + assert dist.items() == [("medium", 0.25), ("album", 0.125)] + + # Sort by key if distance is equal. + dist = Distance() + dist.add("album", 0.375) + dist.add("medium", 0.75) + assert dist.items() == [("album", 0.25), ("medium", 0.25)] + + def test_update(self): + dist1 = Distance() + dist1.add("album", 0.5) + dist1.add("media", 1.0) + + dist2 = Distance() + dist2.add("album", 0.75) + dist2.add("album", 0.25) + dist2.add("media", 0.05) + + dist1.update(dist2) + + assert dist1._penalties == { + "album": [0.5, 0.75, 0.25], + "media": [1.0, 0.05], + } + + +class TrackDistanceTest(BeetsTestCase): + def test_identical_tracks(self): + item = _make_item("one", 1) + info = _make_trackinfo()[0] + dist = match.track_distance(item, info, incl_artist=True) + assert dist == 0.0 + + def test_different_title(self): + item = _make_item("foo", 1) + info = _make_trackinfo()[0] + dist = match.track_distance(item, info, incl_artist=True) + assert dist != 0.0 + + def test_different_artist(self): + item = _make_item("one", 1) + item.artist = "foo" + info = _make_trackinfo()[0] + dist = match.track_distance(item, info, incl_artist=True) + assert dist != 0.0 + + def test_various_artists_tolerated(self): + item = _make_item("one", 1) + item.artist = "Various Artists" + info = _make_trackinfo()[0] + dist = match.track_distance(item, info, incl_artist=True) + assert dist == 0.0 + + +class AlbumDistanceTest(BeetsTestCase): + def _mapping(self, items, info): + out = {} + for i, t in zip(items, info.tracks): + out[i] = t + return out + + def _dist(self, items, info): + return match.distance(items, info, self._mapping(items, info)) + + def test_identical_albums(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("two", 2)) + items.append(_make_item("three", 3)) + info = AlbumInfo( + artist="some artist", + album="some album", + tracks=_make_trackinfo(), + va=False, + ) + assert self._dist(items, info) == 0 + + def test_incomplete_album(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("three", 3)) + info = AlbumInfo( + artist="some artist", + album="some album", + tracks=_make_trackinfo(), + va=False, + ) + dist = self._dist(items, info) + assert dist != 0 + # Make sure the distance is not too great + assert dist < 0.2 + + def test_global_artists_differ(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("two", 2)) + items.append(_make_item("three", 3)) + info = AlbumInfo( + artist="someone else", + album="some album", + tracks=_make_trackinfo(), + va=False, + ) + assert self._dist(items, info) != 0 + + def test_comp_track_artists_match(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("two", 2)) + items.append(_make_item("three", 3)) + info = AlbumInfo( + artist="should be ignored", + album="some album", + tracks=_make_trackinfo(), + va=True, + ) + assert self._dist(items, info) == 0 + + def test_comp_no_track_artists(self): + # Some VA releases don't have track artists (incomplete metadata). + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("two", 2)) + items.append(_make_item("three", 3)) + info = AlbumInfo( + artist="should be ignored", + album="some album", + tracks=_make_trackinfo(), + va=True, + ) + info.tracks[0].artist = None + info.tracks[1].artist = None + info.tracks[2].artist = None + assert self._dist(items, info) == 0 + + def test_comp_track_artists_do_not_match(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("two", 2, "someone else")) + items.append(_make_item("three", 3)) + info = AlbumInfo( + artist="some artist", + album="some album", + tracks=_make_trackinfo(), + va=True, + ) + assert self._dist(items, info) != 0 + + def test_tracks_out_of_order(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("three", 2)) + items.append(_make_item("two", 3)) + info = AlbumInfo( + artist="some artist", + album="some album", + tracks=_make_trackinfo(), + va=False, + ) + dist = self._dist(items, info) + assert 0 < dist < 0.2 + + def test_two_medium_release(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("two", 2)) + items.append(_make_item("three", 3)) + info = AlbumInfo( + artist="some artist", + album="some album", + tracks=_make_trackinfo(), + va=False, + ) + info.tracks[0].medium_index = 1 + info.tracks[1].medium_index = 2 + info.tracks[2].medium_index = 1 + dist = self._dist(items, info) + assert dist == 0 + + def test_per_medium_track_numbers(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("two", 2)) + items.append(_make_item("three", 1)) + info = AlbumInfo( + artist="some artist", + album="some album", + tracks=_make_trackinfo(), + va=False, + ) + info.tracks[0].medium_index = 1 + info.tracks[1].medium_index = 2 + info.tracks[2].medium_index = 1 + dist = self._dist(items, info) + assert dist == 0 + + +class StringDistanceTest(unittest.TestCase): + def test_equal_strings(self): + dist = string_dist("Some String", "Some String") + assert dist == 0.0 + + def test_different_strings(self): + dist = string_dist("Some String", "Totally Different") + assert dist != 0.0 + + def test_punctuation_ignored(self): + dist = string_dist("Some String", "Some.String!") + assert dist == 0.0 + + def test_case_ignored(self): + dist = string_dist("Some String", "sOME sTring") + assert dist == 0.0 + + def test_leading_the_has_lower_weight(self): + dist1 = string_dist("XXX Band Name", "Band Name") + dist2 = string_dist("The Band Name", "Band Name") + assert dist2 < dist1 + + def test_parens_have_lower_weight(self): + dist1 = string_dist("One .Two.", "One") + dist2 = string_dist("One (Two)", "One") + assert dist2 < dist1 + + def test_brackets_have_lower_weight(self): + dist1 = string_dist("One .Two.", "One") + dist2 = string_dist("One [Two]", "One") + assert dist2 < dist1 + + def test_ep_label_has_zero_weight(self): + dist = string_dist("My Song (EP)", "My Song") + assert dist == 0.0 + + def test_featured_has_lower_weight(self): + dist1 = string_dist("My Song blah Someone", "My Song") + dist2 = string_dist("My Song feat Someone", "My Song") + assert dist2 < dist1 + + def test_postfix_the(self): + dist = string_dist("The Song Title", "Song Title, The") + assert dist == 0.0 + + def test_postfix_a(self): + dist = string_dist("A Song Title", "Song Title, A") + assert dist == 0.0 + + def test_postfix_an(self): + dist = string_dist("An Album Title", "Album Title, An") + assert dist == 0.0 + + def test_empty_strings(self): + dist = string_dist("", "") + assert dist == 0.0 + + def test_solo_pattern(self): + # Just make sure these don't crash. + string_dist("The ", "") + string_dist("(EP)", "(EP)") + string_dist(", An", "") + + def test_heuristic_does_not_harm_distance(self): + dist = string_dist("Untitled", "[Untitled]") + assert dist == 0.0 + + def test_ampersand_expansion(self): + dist = string_dist("And", "&") + assert dist == 0.0 + + def test_accented_characters(self): + dist = string_dist("\xe9\xe1\xf1", "ean") + assert dist == 0.0 diff --git a/test/test_autotag.py b/test/test_autotag.py index bd4205806..8d467e5ed 100644 --- a/test/test_autotag.py +++ b/test/test_autotag.py @@ -14,410 +14,14 @@ """Tests for autotagging functionality.""" -import re -import unittest - import pytest from beets import autotag, config from beets.autotag import AlbumInfo, TrackInfo, correct_list_fields, match -from beets.autotag.hooks import Distance, string_dist from beets.library import Item from beets.test.helper import BeetsTestCase, ConfigMixin -def _make_item(title, track, artist="some artist"): - return Item( - title=title, - track=track, - artist=artist, - album="some album", - length=1, - mb_trackid="", - mb_albumid="", - mb_artistid="", - ) - - -def _make_trackinfo(): - return [ - TrackInfo( - title="one", track_id=None, artist="some artist", length=1, index=1 - ), - TrackInfo( - title="two", track_id=None, artist="some artist", length=1, index=2 - ), - TrackInfo( - title="three", - track_id=None, - artist="some artist", - length=1, - index=3, - ), - ] - - -def _clear_weights(): - """Hack around the lazy descriptor used to cache weights for - Distance calculations. - """ - Distance.__dict__["_weights"].cache = {} - - -class DistanceTest(BeetsTestCase): - def tearDown(self): - super().tearDown() - _clear_weights() - - def test_add(self): - dist = Distance() - dist.add("add", 1.0) - assert dist._penalties == {"add": [1.0]} - - def test_add_equality(self): - dist = Distance() - dist.add_equality("equality", "ghi", ["abc", "def", "ghi"]) - assert dist._penalties["equality"] == [0.0] - - dist.add_equality("equality", "xyz", ["abc", "def", "ghi"]) - assert dist._penalties["equality"] == [0.0, 1.0] - - dist.add_equality("equality", "abc", re.compile(r"ABC", re.I)) - assert dist._penalties["equality"] == [0.0, 1.0, 0.0] - - def test_add_expr(self): - dist = Distance() - dist.add_expr("expr", True) - assert dist._penalties["expr"] == [1.0] - - dist.add_expr("expr", False) - assert dist._penalties["expr"] == [1.0, 0.0] - - def test_add_number(self): - dist = Distance() - # Add a full penalty for each number of difference between two numbers. - - dist.add_number("number", 1, 1) - assert dist._penalties["number"] == [0.0] - - dist.add_number("number", 1, 2) - assert dist._penalties["number"] == [0.0, 1.0] - - dist.add_number("number", 2, 1) - assert dist._penalties["number"] == [0.0, 1.0, 1.0] - - dist.add_number("number", -1, 2) - assert dist._penalties["number"] == [0.0, 1.0, 1.0, 1.0, 1.0, 1.0] - - def test_add_priority(self): - dist = Distance() - dist.add_priority("priority", "abc", "abc") - assert dist._penalties["priority"] == [0.0] - - dist.add_priority("priority", "def", ["abc", "def"]) - assert dist._penalties["priority"] == [0.0, 0.5] - - dist.add_priority( - "priority", "gh", ["ab", "cd", "ef", re.compile("GH", re.I)] - ) - assert dist._penalties["priority"] == [0.0, 0.5, 0.75] - - dist.add_priority("priority", "xyz", ["abc", "def"]) - assert dist._penalties["priority"] == [0.0, 0.5, 0.75, 1.0] - - def test_add_ratio(self): - dist = Distance() - dist.add_ratio("ratio", 25, 100) - assert dist._penalties["ratio"] == [0.25] - - dist.add_ratio("ratio", 10, 5) - assert dist._penalties["ratio"] == [0.25, 1.0] - - dist.add_ratio("ratio", -5, 5) - assert dist._penalties["ratio"] == [0.25, 1.0, 0.0] - - dist.add_ratio("ratio", 5, 0) - assert dist._penalties["ratio"] == [0.25, 1.0, 0.0, 0.0] - - def test_add_string(self): - dist = Distance() - sdist = string_dist("abc", "bcd") - dist.add_string("string", "abc", "bcd") - assert dist._penalties["string"] == [sdist] - assert dist._penalties["string"] != [0] - - def test_add_string_none(self): - dist = Distance() - dist.add_string("string", None, "string") - assert dist._penalties["string"] == [1] - - def test_add_string_both_none(self): - dist = Distance() - dist.add_string("string", None, None) - assert dist._penalties["string"] == [0] - - def test_distance(self): - config["match"]["distance_weights"]["album"] = 2.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() - dist.add("album", 0.5) - dist.add("media", 0.25) - dist.add("media", 0.75) - assert dist.distance == 0.5 - - # __getitem__() - assert dist["album"] == 0.25 - assert dist["media"] == 0.25 - - def test_max_distance(self): - config["match"]["distance_weights"]["album"] = 3.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() - dist.add("album", 0.5) - dist.add("medium", 0.0) - dist.add("medium", 0.0) - assert dist.max_distance == 5.0 - - def test_operators(self): - config["match"]["distance_weights"]["source"] = 1.0 - config["match"]["distance_weights"]["album"] = 2.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() - dist.add("source", 0.0) - dist.add("album", 0.5) - dist.add("medium", 0.25) - dist.add("medium", 0.75) - assert len(dist) == 2 - assert list(dist) == [("album", 0.2), ("medium", 0.2)] - assert dist == 0.4 - assert dist < 1.0 - assert dist > 0.0 - assert dist - 0.4 == 0.0 - assert 0.4 - dist == 0.0 - assert float(dist) == 0.4 - - def test_raw_distance(self): - config["match"]["distance_weights"]["album"] = 3.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() - dist.add("album", 0.5) - dist.add("medium", 0.25) - dist.add("medium", 0.5) - assert dist.raw_distance == 2.25 - - def test_items(self): - config["match"]["distance_weights"]["album"] = 4.0 - config["match"]["distance_weights"]["medium"] = 2.0 - _clear_weights() - - dist = Distance() - dist.add("album", 0.1875) - dist.add("medium", 0.75) - assert dist.items() == [("medium", 0.25), ("album", 0.125)] - - # Sort by key if distance is equal. - dist = Distance() - dist.add("album", 0.375) - dist.add("medium", 0.75) - assert dist.items() == [("album", 0.25), ("medium", 0.25)] - - def test_update(self): - dist1 = Distance() - dist1.add("album", 0.5) - dist1.add("media", 1.0) - - dist2 = Distance() - dist2.add("album", 0.75) - dist2.add("album", 0.25) - dist2.add("media", 0.05) - - dist1.update(dist2) - - assert dist1._penalties == { - "album": [0.5, 0.75, 0.25], - "media": [1.0, 0.05], - } - - -class TrackDistanceTest(BeetsTestCase): - def test_identical_tracks(self): - item = _make_item("one", 1) - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist == 0.0 - - def test_different_title(self): - item = _make_item("foo", 1) - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist != 0.0 - - def test_different_artist(self): - item = _make_item("one", 1) - item.artist = "foo" - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist != 0.0 - - def test_various_artists_tolerated(self): - item = _make_item("one", 1) - item.artist = "Various Artists" - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist == 0.0 - - -class AlbumDistanceTest(BeetsTestCase): - def _mapping(self, items, info): - out = {} - for i, t in zip(items, info.tracks): - out[i] = t - return out - - def _dist(self, items, info): - return match.distance(items, info, self._mapping(items, info)) - - def test_identical_albums(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - assert self._dist(items, info) == 0 - - def test_incomplete_album(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - dist = self._dist(items, info) - assert dist != 0 - # Make sure the distance is not too great - assert dist < 0.2 - - def test_global_artists_differ(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="someone else", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - assert self._dist(items, info) != 0 - - def test_comp_track_artists_match(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="should be ignored", - album="some album", - tracks=_make_trackinfo(), - va=True, - ) - assert self._dist(items, info) == 0 - - def test_comp_no_track_artists(self): - # Some VA releases don't have track artists (incomplete metadata). - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="should be ignored", - album="some album", - tracks=_make_trackinfo(), - va=True, - ) - info.tracks[0].artist = None - info.tracks[1].artist = None - info.tracks[2].artist = None - assert self._dist(items, info) == 0 - - def test_comp_track_artists_do_not_match(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2, "someone else")) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=True, - ) - assert self._dist(items, info) != 0 - - def test_tracks_out_of_order(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("three", 2)) - items.append(_make_item("two", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - dist = self._dist(items, info) - assert 0 < dist < 0.2 - - def test_two_medium_release(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - info.tracks[0].medium_index = 1 - info.tracks[1].medium_index = 2 - info.tracks[2].medium_index = 1 - dist = self._dist(items, info) - assert dist == 0 - - def test_per_medium_track_numbers(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 1)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - info.tracks[0].medium_index = 1 - info.tracks[1].medium_index = 2 - info.tracks[2].medium_index = 1 - dist = self._dist(items, info) - assert dist == 0 - - class TestAssignment(ConfigMixin): A = "one" B = "two" @@ -840,82 +444,6 @@ class ApplyCompilationTest(BeetsTestCase, ApplyTestUtil): assert self.items[1].comp -class StringDistanceTest(unittest.TestCase): - def test_equal_strings(self): - dist = string_dist("Some String", "Some String") - assert dist == 0.0 - - def test_different_strings(self): - dist = string_dist("Some String", "Totally Different") - assert dist != 0.0 - - def test_punctuation_ignored(self): - dist = string_dist("Some String", "Some.String!") - assert dist == 0.0 - - def test_case_ignored(self): - dist = string_dist("Some String", "sOME sTring") - assert dist == 0.0 - - def test_leading_the_has_lower_weight(self): - dist1 = string_dist("XXX Band Name", "Band Name") - dist2 = string_dist("The Band Name", "Band Name") - assert dist2 < dist1 - - def test_parens_have_lower_weight(self): - dist1 = string_dist("One .Two.", "One") - dist2 = string_dist("One (Two)", "One") - assert dist2 < dist1 - - def test_brackets_have_lower_weight(self): - dist1 = string_dist("One .Two.", "One") - dist2 = string_dist("One [Two]", "One") - assert dist2 < dist1 - - def test_ep_label_has_zero_weight(self): - dist = string_dist("My Song (EP)", "My Song") - assert dist == 0.0 - - def test_featured_has_lower_weight(self): - dist1 = string_dist("My Song blah Someone", "My Song") - dist2 = string_dist("My Song feat Someone", "My Song") - assert dist2 < dist1 - - def test_postfix_the(self): - dist = string_dist("The Song Title", "Song Title, The") - assert dist == 0.0 - - def test_postfix_a(self): - dist = string_dist("A Song Title", "Song Title, A") - assert dist == 0.0 - - def test_postfix_an(self): - dist = string_dist("An Album Title", "Album Title, An") - assert dist == 0.0 - - def test_empty_strings(self): - dist = string_dist("", "") - assert dist == 0.0 - - def test_solo_pattern(self): - # Just make sure these don't crash. - string_dist("The ", "") - string_dist("(EP)", "(EP)") - string_dist(", An", "") - - def test_heuristic_does_not_harm_distance(self): - dist = string_dist("Untitled", "[Untitled]") - assert dist == 0.0 - - def test_ampersand_expansion(self): - dist = string_dist("And", "&") - assert dist == 0.0 - - def test_accented_characters(self): - dist = string_dist("\xe9\xe1\xf1", "ean") - assert dist == 0.0 - - @pytest.mark.parametrize( "single_field,list_field", [ From 318a840af2e4e80f5e8c6465a0d0d44d31dad4a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sun, 25 May 2025 13:43:20 +0100 Subject: [PATCH 5/8] Rewrite distance tests --- test/autotag/test_distance.py | 597 ++++++++++++---------------------- 1 file changed, 210 insertions(+), 387 deletions(-) diff --git a/test/autotag/test_distance.py b/test/autotag/test_distance.py index ec00ebcdf..e3ce9f891 100644 --- a/test/autotag/test_distance.py +++ b/test/autotag/test_distance.py @@ -1,176 +1,108 @@ import re -import unittest -from beets import config -from beets.autotag import AlbumInfo, TrackInfo, match -from beets.autotag.distance import Distance, string_dist +import pytest + +from beets.autotag import AlbumInfo, TrackInfo +from beets.autotag.distance import ( + Distance, + distance, + string_dist, + track_distance, +) from beets.library import Item -from beets.test.helper import BeetsTestCase +from beets.test.helper import ConfigMixin + +_p = pytest.param -def _make_item(title, track, artist="some artist"): - return Item( - title=title, - track=track, - artist=artist, - album="some album", - length=1, - mb_trackid="", - mb_albumid="", - mb_artistid="", - ) +class TestDistance: + @pytest.fixture(scope="class") + def config(self): + return ConfigMixin().config + @pytest.fixture + def dist(self, config): + config["match"]["distance_weights"]["source"] = 2.0 + config["match"]["distance_weights"]["album"] = 4.0 + config["match"]["distance_weights"]["medium"] = 2.0 -def _make_trackinfo(): - return [ - TrackInfo( - title="one", track_id=None, artist="some artist", length=1, index=1 - ), - TrackInfo( - title="two", track_id=None, artist="some artist", length=1, index=2 - ), - TrackInfo( - title="three", - track_id=None, - artist="some artist", - length=1, - index=3, - ), - ] + Distance.__dict__["_weights"].cache = {} + return Distance() -def _clear_weights(): - """Hack around the lazy descriptor used to cache weights for - Distance calculations. - """ - Distance.__dict__["_weights"].cache = {} - - -class DistanceTest(BeetsTestCase): - def tearDown(self): - super().tearDown() - _clear_weights() - - def test_add(self): - dist = Distance() + def test_add(self, dist): dist.add("add", 1.0) + assert dist._penalties == {"add": [1.0]} - def test_add_equality(self): - dist = Distance() - dist.add_equality("equality", "ghi", ["abc", "def", "ghi"]) - assert dist._penalties["equality"] == [0.0] + @pytest.mark.parametrize( + "key, args_with_expected", + [ + ( + "equality", + [ + (("ghi", ["abc", "def", "ghi"]), [0.0]), + (("xyz", ["abc", "def", "ghi"]), [0.0, 1.0]), + (("abc", re.compile(r"ABC", re.I)), [0.0, 1.0, 0.0]), + ], + ), + ("expr", [((True,), [1.0]), ((False,), [1.0, 0.0])]), + ( + "number", + [ + ((1, 1), [0.0]), + ((1, 2), [0.0, 1.0]), + ((2, 1), [0.0, 1.0, 1.0]), + ((-1, 2), [0.0, 1.0, 1.0, 1.0, 1.0, 1.0]), + ], + ), + ( + "priority", + [ + (("abc", "abc"), [0.0]), + (("def", ["abc", "def"]), [0.0, 0.5]), + (("gh", ["ab", "cd", "ef", re.compile("GH", re.I)]), [0.0, 0.5, 0.75]), # noqa: E501 + (("xyz", ["abc", "def"]), [0.0, 0.5, 0.75, 1.0]), + ], + ), + ( + "ratio", + [ + ((25, 100), [0.25]), + ((10, 5), [0.25, 1.0]), + ((-5, 5), [0.25, 1.0, 0.0]), + ((5, 0), [0.25, 1.0, 0.0, 0.0]), + ], + ), + ( + "string", + [ + (("abc", "bcd"), [2 / 3]), + (("abc", None), [2 / 3, 1]), + ((None, None), [2 / 3, 1, 0]), + ], + ), + ], + ) # fmt: skip + def test_add_methods(self, dist, key, args_with_expected): + method = getattr(dist, f"add_{key}") + for arg_set, expected in args_with_expected: + method(key, *arg_set) + assert dist._penalties[key] == expected - dist.add_equality("equality", "xyz", ["abc", "def", "ghi"]) - assert dist._penalties["equality"] == [0.0, 1.0] - - dist.add_equality("equality", "abc", re.compile(r"ABC", re.I)) - assert dist._penalties["equality"] == [0.0, 1.0, 0.0] - - def test_add_expr(self): - dist = Distance() - dist.add_expr("expr", True) - assert dist._penalties["expr"] == [1.0] - - dist.add_expr("expr", False) - assert dist._penalties["expr"] == [1.0, 0.0] - - def test_add_number(self): - dist = Distance() - # Add a full penalty for each number of difference between two numbers. - - dist.add_number("number", 1, 1) - assert dist._penalties["number"] == [0.0] - - dist.add_number("number", 1, 2) - assert dist._penalties["number"] == [0.0, 1.0] - - dist.add_number("number", 2, 1) - assert dist._penalties["number"] == [0.0, 1.0, 1.0] - - dist.add_number("number", -1, 2) - assert dist._penalties["number"] == [0.0, 1.0, 1.0, 1.0, 1.0, 1.0] - - def test_add_priority(self): - dist = Distance() - dist.add_priority("priority", "abc", "abc") - assert dist._penalties["priority"] == [0.0] - - dist.add_priority("priority", "def", ["abc", "def"]) - assert dist._penalties["priority"] == [0.0, 0.5] - - dist.add_priority( - "priority", "gh", ["ab", "cd", "ef", re.compile("GH", re.I)] - ) - assert dist._penalties["priority"] == [0.0, 0.5, 0.75] - - dist.add_priority("priority", "xyz", ["abc", "def"]) - assert dist._penalties["priority"] == [0.0, 0.5, 0.75, 1.0] - - def test_add_ratio(self): - dist = Distance() - dist.add_ratio("ratio", 25, 100) - assert dist._penalties["ratio"] == [0.25] - - dist.add_ratio("ratio", 10, 5) - assert dist._penalties["ratio"] == [0.25, 1.0] - - dist.add_ratio("ratio", -5, 5) - assert dist._penalties["ratio"] == [0.25, 1.0, 0.0] - - dist.add_ratio("ratio", 5, 0) - assert dist._penalties["ratio"] == [0.25, 1.0, 0.0, 0.0] - - def test_add_string(self): - dist = Distance() - sdist = string_dist("abc", "bcd") - dist.add_string("string", "abc", "bcd") - assert dist._penalties["string"] == [sdist] - assert dist._penalties["string"] != [0] - - def test_add_string_none(self): - dist = Distance() - dist.add_string("string", None, "string") - assert dist._penalties["string"] == [1] - - def test_add_string_both_none(self): - dist = Distance() - dist.add_string("string", None, None) - assert dist._penalties["string"] == [0] - - def test_distance(self): - config["match"]["distance_weights"]["album"] = 2.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() + def test_distance(self, dist): dist.add("album", 0.5) dist.add("media", 0.25) dist.add("media", 0.75) + assert dist.distance == 0.5 + assert dist.max_distance == 6.0 + assert dist.raw_distance == 3.0 - # __getitem__() - assert dist["album"] == 0.25 - assert dist["media"] == 0.25 + assert dist["album"] == 1 / 3 + assert dist["media"] == 1 / 6 - def test_max_distance(self): - config["match"]["distance_weights"]["album"] = 3.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() - dist.add("album", 0.5) - dist.add("medium", 0.0) - dist.add("medium", 0.0) - assert dist.max_distance == 5.0 - - def test_operators(self): - config["match"]["distance_weights"]["source"] = 1.0 - config["match"]["distance_weights"]["album"] = 2.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() + def test_operators(self, dist): dist.add("source", 0.0) dist.add("album", 0.5) dist.add("medium", 0.25) @@ -184,23 +116,7 @@ class DistanceTest(BeetsTestCase): assert 0.4 - dist == 0.0 assert float(dist) == 0.4 - def test_raw_distance(self): - config["match"]["distance_weights"]["album"] = 3.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() - dist.add("album", 0.5) - dist.add("medium", 0.25) - dist.add("medium", 0.5) - assert dist.raw_distance == 2.25 - - def test_items(self): - config["match"]["distance_weights"]["album"] = 4.0 - config["match"]["distance_weights"]["medium"] = 2.0 - _clear_weights() - - dist = Distance() + def test_penalties_sort(self, dist): dist.add("album", 0.1875) dist.add("medium", 0.75) assert dist.items() == [("medium", 0.25), ("album", 0.125)] @@ -211,8 +127,8 @@ class DistanceTest(BeetsTestCase): dist.add("medium", 0.75) assert dist.items() == [("album", 0.25), ("medium", 0.25)] - def test_update(self): - dist1 = Distance() + def test_update(self, dist): + dist1 = dist dist1.add("album", 0.5) dist1.add("media", 1.0) @@ -229,248 +145,155 @@ class DistanceTest(BeetsTestCase): } -class TrackDistanceTest(BeetsTestCase): - def test_identical_tracks(self): - item = _make_item("one", 1) - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist == 0.0 +class TestTrackDistance: + @pytest.fixture(scope="class") + def info(self): + return TrackInfo(title="title", artist="artist") - def test_different_title(self): - item = _make_item("foo", 1) - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist != 0.0 + @pytest.mark.parametrize( + "title, artist, expected_penalty", + [ + _p("title", "artist", False, id="identical"), + _p("title", "Various Artists", False, id="tolerate-va"), + _p("title", "different artist", True, id="different-artist"), + _p("different title", "artist", True, id="different-title"), + ], + ) + def test_track_distance(self, info, title, artist, expected_penalty): + item = Item(artist=artist, title=title) - def test_different_artist(self): - item = _make_item("one", 1) - item.artist = "foo" - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist != 0.0 - - def test_various_artists_tolerated(self): - item = _make_item("one", 1) - item.artist = "Various Artists" - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist == 0.0 + assert ( + bool(track_distance(item, info, incl_artist=True)) + == expected_penalty + ) -class AlbumDistanceTest(BeetsTestCase): - def _mapping(self, items, info): - out = {} - for i, t in zip(items, info.tracks): - out[i] = t - return out +class TestAlbumDistance: + @pytest.fixture(scope="class") + def items(self): + return [ + Item( + title=title, + track=track, + artist="artist", + album="album", + length=1, + ) + for title, track in [("one", 1), ("two", 2), ("three", 3)] + ] - def _dist(self, items, info): - return match.distance(items, info, self._mapping(items, info)) + @pytest.fixture + def get_dist(self, items): + def inner(info: AlbumInfo): + return distance(items, info, dict(zip(items, info.tracks))) - def test_identical_albums(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), + return inner + + @pytest.fixture + def info(self, items): + return AlbumInfo( + artist="artist", + album="album", + tracks=[ + TrackInfo( + title=i.title, + artist=i.artist, + index=i.track, + length=i.length, + ) + for i in items + ], va=False, ) - assert self._dist(items, info) == 0 - def test_incomplete_album(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, + def test_identical_albums(self, get_dist, info): + assert get_dist(info) == 0 + + def test_incomplete_album(self, get_dist, info): + info.tracks.pop(2) + + assert 0 < float(get_dist(info)) < 0.2 + + def test_overly_complete_album(self, get_dist, info): + info.tracks.append( + Item(index=4, title="four", artist="artist", length=1) ) - dist = self._dist(items, info) - assert dist != 0 - # Make sure the distance is not too great - assert dist < 0.2 - def test_global_artists_differ(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="someone else", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - assert self._dist(items, info) != 0 + assert 0 < float(get_dist(info)) < 0.2 - def test_comp_track_artists_match(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="should be ignored", - album="some album", - tracks=_make_trackinfo(), - va=True, - ) - assert self._dist(items, info) == 0 + @pytest.mark.parametrize("va", [True, False]) + def test_albumartist(self, get_dist, info, va): + info.artist = "another artist" + info.va = va - def test_comp_no_track_artists(self): + assert bool(get_dist(info)) is not va + + def test_comp_no_track_artists(self, get_dist, info): # Some VA releases don't have track artists (incomplete metadata). - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="should be ignored", - album="some album", - tracks=_make_trackinfo(), - va=True, - ) - info.tracks[0].artist = None - info.tracks[1].artist = None - info.tracks[2].artist = None - assert self._dist(items, info) == 0 + info.artist = "another artist" + info.va = True + for track in info.tracks: + track.artist = None - def test_comp_track_artists_do_not_match(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2, "someone else")) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=True, - ) - assert self._dist(items, info) != 0 + assert get_dist(info) == 0 - def test_tracks_out_of_order(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("three", 2)) - items.append(_make_item("two", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - dist = self._dist(items, info) - assert 0 < dist < 0.2 + def test_comp_track_artists_do_not_match(self, get_dist, info): + info.va = True + info.tracks[0].artist = "another artist" - def test_two_medium_release(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) + assert get_dist(info) != 0 + + def test_tracks_out_of_order(self, get_dist, info): + tracks = info.tracks + tracks[1].title, tracks[2].title = tracks[2].title, tracks[1].title + + assert 0 < float(get_dist(info)) < 0.2 + + def test_two_medium_release(self, get_dist, info): info.tracks[0].medium_index = 1 info.tracks[1].medium_index = 2 info.tracks[2].medium_index = 1 - dist = self._dist(items, info) - assert dist == 0 - def test_per_medium_track_numbers(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 1)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - info.tracks[0].medium_index = 1 - info.tracks[1].medium_index = 2 - info.tracks[2].medium_index = 1 - dist = self._dist(items, info) - assert dist == 0 + assert get_dist(info) == 0 -class StringDistanceTest(unittest.TestCase): - def test_equal_strings(self): - dist = string_dist("Some String", "Some String") - assert dist == 0.0 +class TestStringDistance: + @pytest.mark.parametrize( + "string1, string2", + [ + ("Some String", "Some String"), + ("Some String", "Some.String!"), + ("Some String", "sOME sTring"), + ("My Song (EP)", "My Song"), + ("The Song Title", "Song Title, The"), + ("A Song Title", "Song Title, A"), + ("An Album Title", "Album Title, An"), + ("", ""), + ("Untitled", "[Untitled]"), + ("And", "&"), + ("\xe9\xe1\xf1", "ean"), + ], + ) + def test_matching_distance(self, string1, string2): + assert string_dist(string1, string2) == 0.0 - def test_different_strings(self): - dist = string_dist("Some String", "Totally Different") - assert dist != 0.0 + def test_different_distance(self): + assert string_dist("Some String", "Totally Different") != 0.0 - def test_punctuation_ignored(self): - dist = string_dist("Some String", "Some.String!") - assert dist == 0.0 - - def test_case_ignored(self): - dist = string_dist("Some String", "sOME sTring") - assert dist == 0.0 - - def test_leading_the_has_lower_weight(self): - dist1 = string_dist("XXX Band Name", "Band Name") - dist2 = string_dist("The Band Name", "Band Name") - assert dist2 < dist1 - - def test_parens_have_lower_weight(self): - dist1 = string_dist("One .Two.", "One") - dist2 = string_dist("One (Two)", "One") - assert dist2 < dist1 - - def test_brackets_have_lower_weight(self): - dist1 = string_dist("One .Two.", "One") - dist2 = string_dist("One [Two]", "One") - assert dist2 < dist1 - - def test_ep_label_has_zero_weight(self): - dist = string_dist("My Song (EP)", "My Song") - assert dist == 0.0 - - def test_featured_has_lower_weight(self): - dist1 = string_dist("My Song blah Someone", "My Song") - dist2 = string_dist("My Song feat Someone", "My Song") - assert dist2 < dist1 - - def test_postfix_the(self): - dist = string_dist("The Song Title", "Song Title, The") - assert dist == 0.0 - - def test_postfix_a(self): - dist = string_dist("A Song Title", "Song Title, A") - assert dist == 0.0 - - def test_postfix_an(self): - dist = string_dist("An Album Title", "Album Title, An") - assert dist == 0.0 - - def test_empty_strings(self): - dist = string_dist("", "") - assert dist == 0.0 + @pytest.mark.parametrize( + "string1, string2, reference", + [ + ("XXX Band Name", "The Band Name", "Band Name"), + ("One .Two.", "One (Two)", "One"), + ("One .Two.", "One [Two]", "One"), + ("My Song blah Someone", "My Song feat Someone", "My Song"), + ], + ) + def test_relative_weights(self, string1, string2, reference): + assert string_dist(string2, reference) < string_dist(string1, reference) def test_solo_pattern(self): # Just make sure these don't crash. string_dist("The ", "") string_dist("(EP)", "(EP)") string_dist(", An", "") - - def test_heuristic_does_not_harm_distance(self): - dist = string_dist("Untitled", "[Untitled]") - assert dist == 0.0 - - def test_ampersand_expansion(self): - dist = string_dist("And", "&") - assert dist == 0.0 - - def test_accented_characters(self): - dist = string_dist("\xe9\xe1\xf1", "ean") - assert dist == 0.0 From 0da6192a4ab799196ff49e403e8cd91fbaaef4bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sun, 25 May 2025 14:01:29 +0100 Subject: [PATCH 6/8] Test sanitize_pairs --- test/util/test_config.py | 47 ++++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/test/util/test_config.py b/test/util/test_config.py index 0c49f85b1..7105844dd 100644 --- a/test/util/test_config.py +++ b/test/util/test_config.py @@ -1,15 +1,38 @@ -import unittest +import pytest -from beets.util.config import sanitize_choices +from beets.util.config import sanitize_choices, sanitize_pairs -class HelpersTest(unittest.TestCase): - def test_sanitize_choices(self): - assert sanitize_choices(["A", "Z"], ("A", "B")) == ["A"] - assert sanitize_choices(["A", "A"], ("A")) == ["A"] - assert sanitize_choices(["D", "*", "A"], ("A", "B", "C", "D")) == [ - "D", - "B", - "C", - "A", - ] +@pytest.mark.parametrize( + "input_choices, valid_choices, expected", + [ + (["A", "Z"], ("A", "B"), ["A"]), + (["A", "A"], ("A"), ["A"]), + (["D", "*", "A"], ("A", "B", "C", "D"), ["D", "B", "C", "A"]), + ], +) +def test_sanitize_choices(input_choices, valid_choices, expected): + assert sanitize_choices(input_choices, valid_choices) == expected + + +def test_sanitize_pairs(): + assert sanitize_pairs( + [ + ("foo", "baz bar"), + ("foo", "baz bar"), + ("key", "*"), + ("*", "*"), + ("discard", "bye"), + ], + [ + ("foo", "bar"), + ("foo", "baz"), + ("foo", "foobar"), + ("key", "value"), + ], + ) == [ + ("foo", "baz"), + ("foo", "bar"), + ("key", "value"), + ("foo", "foobar"), + ] From cb246c28bc69986a702fa683f1f7c3c1d94425c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Mon, 26 May 2025 11:49:20 +0100 Subject: [PATCH 7/8] Remove dead chartlyrics This integration test failed because `chartlyrics.com` website is no longer available, so I'm removing it. --- test/plugins/lyrics_pages.py | 39 ------------------------------------ 1 file changed, 39 deletions(-) diff --git a/test/plugins/lyrics_pages.py b/test/plugins/lyrics_pages.py index ef2eeb1a2..e1806b167 100644 --- a/test/plugins/lyrics_pages.py +++ b/test/plugins/lyrics_pages.py @@ -108,45 +108,6 @@ lyrics_pages = [ url_title="The Beatles - Lady Madonna Lyrics | AZLyrics.com", marks=[xfail_on_ci("AZLyrics is blocked by Cloudflare")], ), - LyricsPage.make( - "http://www.chartlyrics.com/_LsLsZ7P4EK-F-LD4dJgDQ/Lady+Madonna.aspx", - """ - Lady Madonna, - Children at your feet - Wonder how you manage to make ends meet. - - Who finds the money - When you pay the rent? - Did you think that money was heaven-sent? - - Friday night arrives without a suitcase. - Sunday morning creeping like a nun. - Monday's child has learned to tie his bootlace. - - See how they run. - - Lady Madonna, - Baby at your breast - Wonders how you manage to feed the rest. - - See how they run. - - Lady Madonna, - Lying on the bed. - Listen to the music playing in your head. - - Tuesday afternoon is never ending. - Wednesday morning papers didn't come. - Thursday night your stockings needed mending. - - See how they run. - - Lady Madonna, - Children at your feet - Wonder how you manage to make ends meet. - """, - url_title="The Beatles Lady Madonna lyrics", - ), LyricsPage.make( "https://www.dainuzodziai.lt/m/mergaites-nori-mylet-atlanta/", """ From 99f7e94b594314c39709125ba5e6f898c50592f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Mon, 26 May 2025 13:37:23 +0100 Subject: [PATCH 8/8] Add Distance and current_metadata to autotag.__init__ for backward compat --- beets/autotag/__init__.py | 8 ++++++-- beets/plugins.py | 12 +++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/beets/autotag/__init__.py b/beets/autotag/__init__.py index 5b16b012e..8cfe534ab 100644 --- a/beets/autotag/__init__.py +++ b/beets/autotag/__init__.py @@ -19,10 +19,12 @@ from __future__ import annotations from typing import TYPE_CHECKING, Union from beets import config, logging +from beets.util import get_most_common_tags as current_metadata # Parts of external interface. from beets.util import unique_list +from .distance import Distance from .hooks import AlbumInfo, AlbumMatch, TrackInfo, TrackMatch from .match import Proposal, Recommendation, tag_album, tag_item @@ -34,13 +36,15 @@ if TYPE_CHECKING: __all__ = [ "AlbumInfo", "AlbumMatch", - "TrackInfo", - "TrackMatch", + "Distance", # for backwards compatibility "Proposal", "Recommendation", + "TrackInfo", + "TrackMatch", "apply_album_metadata", "apply_item_metadata", "apply_metadata", + "current_metadata", # for backwards compatibility "tag_album", "tag_item", ] diff --git a/beets/plugins.py b/beets/plugins.py index cd66435b5..1ae672e20 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -37,7 +37,6 @@ import mediafile import beets from beets import logging -from beets.autotag.distance import Distance from beets.util.id_extractors import extract_release_id if TYPE_CHECKING: @@ -55,6 +54,7 @@ if TYPE_CHECKING: from confuse import ConfigView from beets.autotag import AlbumInfo, TrackInfo + from beets.autotag.distance import Distance from beets.dbcore import Query from beets.dbcore.db import FieldQueryType from beets.dbcore.types import Type @@ -225,6 +225,8 @@ class BeetsPlugin: """Should return a Distance object to be added to the distance for every track comparison. """ + from beets.autotag.distance import Distance + return Distance() def album_distance( @@ -236,6 +238,8 @@ class BeetsPlugin: """Should return a Distance object to be added to the distance for every album-level comparison. """ + from beets.autotag.distance import Distance + return Distance() def candidates( @@ -455,6 +459,8 @@ def track_distance(item: Item, info: TrackInfo) -> Distance: """Gets the track distance calculated by all loaded plugins. Returns a Distance object. """ + from beets.autotag.distance import Distance + dist = Distance() for plugin in find_plugins(): dist.update(plugin.track_distance(item, info)) @@ -467,6 +473,8 @@ def album_distance( mapping: dict[Item, TrackInfo], ) -> Distance: """Returns the album distance calculated by plugins.""" + from beets.autotag.distance import Distance + dist = Distance() for plugin in find_plugins(): dist.update(plugin.album_distance(items, album_info, mapping)) @@ -653,6 +661,8 @@ def get_distance( """Returns the ``data_source`` weight and the maximum source weight for albums or individual tracks. """ + from beets.autotag.distance import Distance + dist = Distance() if info.data_source == data_source: dist.add("source", config["source_weight"].as_number())