From adbd50b2374edb4fe9f9455da9401dc1225a4818 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sun, 25 May 2025 02:59:27 +0100 Subject: [PATCH] Move distance to a separate module --- beets/autotag/__init__.py | 14 +- beets/autotag/distance.py | 531 ++++++++++++++++++++++++++++++++++ beets/autotag/hooks.py | 334 +-------------------- beets/autotag/match.py | 205 +------------ beets/plugins.py | 13 +- beetsplug/chroma.py | 2 +- beetsplug/discogs.py | 3 +- beetsplug/lyrics.py | 2 +- test/autotag/test_distance.py | 476 ++++++++++++++++++++++++++++++ test/test_autotag.py | 472 ------------------------------ 10 files changed, 1028 insertions(+), 1024 deletions(-) create mode 100644 beets/autotag/distance.py create mode 100644 test/autotag/test_distance.py diff --git a/beets/autotag/__init__.py b/beets/autotag/__init__.py index 5b6a11195..5b16b012e 100644 --- a/beets/autotag/__init__.py +++ b/beets/autotag/__init__.py @@ -14,22 +14,26 @@ """Facilities for automatically determining files' correct metadata.""" -from collections.abc import Mapping, Sequence -from typing import Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Union from beets import config, logging -from beets.library import Album, Item, LibModel # Parts of external interface. from beets.util import unique_list -from .hooks import AlbumInfo, AlbumMatch, Distance, TrackInfo, TrackMatch +from .hooks import AlbumInfo, AlbumMatch, TrackInfo, TrackMatch from .match import Proposal, Recommendation, tag_album, tag_item +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + from beets.library import Album, Item, LibModel + __all__ = [ "AlbumInfo", "AlbumMatch", - "Distance", "TrackInfo", "TrackMatch", "Proposal", diff --git a/beets/autotag/distance.py b/beets/autotag/distance.py new file mode 100644 index 000000000..d146c27f0 --- /dev/null +++ b/beets/autotag/distance.py @@ -0,0 +1,531 @@ +from __future__ import annotations + +import datetime +import re +from functools import cache, total_ordering +from typing import TYPE_CHECKING, Any + +from jellyfish import levenshtein_distance +from unidecode import unidecode + +from beets import config, plugins +from beets.util import as_string, cached_classproperty, get_most_common_tags + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + from beets.library import Item + + from .hooks import AlbumInfo, TrackInfo + +# Candidate distance scoring. + +# Artist signals that indicate "various artists". These are used at the +# album level to determine whether a given release is likely a VA +# release and also on the track level to to remove the penalty for +# differing artists. +VA_ARTISTS = ("", "various artists", "various", "va", "unknown") + +# Parameters for string distance function. +# Words that can be moved to the end of a string using a comma. +SD_END_WORDS = ["the", "a", "an"] +# Reduced weights for certain portions of the string. +SD_PATTERNS = [ + (r"^the ", 0.1), + (r"[\[\(]?(ep|single)[\]\)]?", 0.0), + (r"[\[\(]?(featuring|feat|ft)[\. :].+", 0.1), + (r"\(.*?\)", 0.3), + (r"\[.*?\]", 0.3), + (r"(, )?(pt\.|part) .+", 0.2), +] +# Replacements to use before testing distance. +SD_REPLACE = [ + (r"&", "and"), +] + + +def _string_dist_basic(str1: str, str2: str) -> float: + """Basic edit distance between two strings, ignoring + non-alphanumeric characters and case. Comparisons are based on a + transliteration/lowering to ASCII characters. Normalized by string + length. + """ + assert isinstance(str1, str) + assert isinstance(str2, str) + str1 = as_string(unidecode(str1)) + str2 = as_string(unidecode(str2)) + str1 = re.sub(r"[^a-z0-9]", "", str1.lower()) + str2 = re.sub(r"[^a-z0-9]", "", str2.lower()) + if not str1 and not str2: + return 0.0 + return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2))) + + +def string_dist(str1: str | None, str2: str | None) -> float: + """Gives an "intuitive" edit distance between two strings. This is + an edit distance, normalized by the string length, with a number of + tweaks that reflect intuition about text. + """ + if str1 is None and str2 is None: + return 0.0 + if str1 is None or str2 is None: + return 1.0 + + str1 = str1.lower() + str2 = str2.lower() + + # Don't penalize strings that move certain words to the end. For + # example, "the something" should be considered equal to + # "something, the". + for word in SD_END_WORDS: + if str1.endswith(", %s" % word): + str1 = "{} {}".format(word, str1[: -len(word) - 2]) + if str2.endswith(", %s" % word): + str2 = "{} {}".format(word, str2[: -len(word) - 2]) + + # Perform a couple of basic normalizing substitutions. + for pat, repl in SD_REPLACE: + str1 = re.sub(pat, repl, str1) + str2 = re.sub(pat, repl, str2) + + # Change the weight for certain string portions matched by a set + # of regular expressions. We gradually change the strings and build + # up penalties associated with parts of the string that were + # deleted. + base_dist = _string_dist_basic(str1, str2) + penalty = 0.0 + for pat, weight in SD_PATTERNS: + # Get strings that drop the pattern. + case_str1 = re.sub(pat, "", str1) + case_str2 = re.sub(pat, "", str2) + + if case_str1 != str1 or case_str2 != str2: + # If the pattern was present (i.e., it is deleted in the + # the current case), recalculate the distances for the + # modified strings. + case_dist = _string_dist_basic(case_str1, case_str2) + case_delta = max(0.0, base_dist - case_dist) + if case_delta == 0.0: + continue + + # Shift our baseline strings down (to avoid rematching the + # same part of the string) and add a scaled distance + # amount to the penalties. + str1 = case_str1 + str2 = case_str2 + base_dist = case_dist + penalty += weight * case_delta + + return base_dist + penalty + + +@total_ordering +class Distance: + """Keeps track of multiple distance penalties. Provides a single + weighted distance for all penalties as well as a weighted distance + for each individual penalty. + """ + + def __init__(self) -> None: + self._penalties: dict[str, list[float]] = {} + self.tracks: dict[TrackInfo, Distance] = {} + + @cached_classproperty + def _weights(cls) -> dict[str, float]: + """A dictionary from keys to floating-point weights.""" + weights_view = config["match"]["distance_weights"] + weights = {} + for key in weights_view.keys(): + weights[key] = weights_view[key].as_number() + return weights + + # Access the components and their aggregates. + + @property + def distance(self) -> float: + """Return a weighted and normalized distance across all + penalties. + """ + dist_max = self.max_distance + if dist_max: + return self.raw_distance / self.max_distance + return 0.0 + + @property + def max_distance(self) -> float: + """Return the maximum distance penalty (normalization factor).""" + dist_max = 0.0 + for key, penalty in self._penalties.items(): + dist_max += len(penalty) * self._weights[key] + return dist_max + + @property + def raw_distance(self) -> float: + """Return the raw (denormalized) distance.""" + dist_raw = 0.0 + for key, penalty in self._penalties.items(): + dist_raw += sum(penalty) * self._weights[key] + return dist_raw + + def items(self) -> list[tuple[str, float]]: + """Return a list of (key, dist) pairs, with `dist` being the + weighted distance, sorted from highest to lowest. Does not + include penalties with a zero value. + """ + list_ = [] + for key in self._penalties: + dist = self[key] + if dist: + list_.append((key, dist)) + # Convert distance into a negative float we can sort items in + # ascending order (for keys, when the penalty is equal) and + # still get the items with the biggest distance first. + return sorted( + list_, key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0]) + ) + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other) -> bool: + return self.distance == other + + # Behave like a float. + + def __lt__(self, other) -> bool: + return self.distance < other + + def __float__(self) -> float: + return self.distance + + def __sub__(self, other) -> float: + return self.distance - other + + def __rsub__(self, other) -> float: + return other - self.distance + + def __str__(self) -> str: + return f"{self.distance:.2f}" + + # Behave like a dict. + + def __getitem__(self, key) -> float: + """Returns the weighted distance for a named penalty.""" + dist = sum(self._penalties[key]) * self._weights[key] + dist_max = self.max_distance + if dist_max: + return dist / dist_max + return 0.0 + + def __iter__(self) -> Iterator[tuple[str, float]]: + return iter(self.items()) + + def __len__(self) -> int: + return len(self.items()) + + def keys(self) -> list[str]: + return [key for key, _ in self.items()] + + def update(self, dist: Distance): + """Adds all the distance penalties from `dist`.""" + if not isinstance(dist, Distance): + raise ValueError( + "`dist` must be a Distance object, not {}".format(type(dist)) + ) + for key, penalties in dist._penalties.items(): + self._penalties.setdefault(key, []).extend(penalties) + + # Adding components. + + def _eq(self, value1: re.Pattern[str] | Any, value2: Any) -> bool: + """Returns True if `value1` is equal to `value2`. `value1` may + be a compiled regular expression, in which case it will be + matched against `value2`. + """ + if isinstance(value1, re.Pattern): + return bool(value1.match(value2)) + return value1 == value2 + + def add(self, key: str, dist: float): + """Adds a distance penalty. `key` must correspond with a + configured weight setting. `dist` must be a float between 0.0 + and 1.0, and will be added to any existing distance penalties + for the same key. + """ + if not 0.0 <= dist <= 1.0: + raise ValueError(f"`dist` must be between 0.0 and 1.0, not {dist}") + self._penalties.setdefault(key, []).append(dist) + + def add_equality( + self, + key: str, + value: Any, + options: list[Any] | tuple[Any, ...] | Any, + ): + """Adds a distance penalty of 1.0 if `value` doesn't match any + of the values in `options`. If an option is a compiled regular + expression, it will be considered equal if it matches against + `value`. + """ + if not isinstance(options, (list, tuple)): + options = [options] + for opt in options: + if self._eq(opt, value): + dist = 0.0 + break + else: + dist = 1.0 + self.add(key, dist) + + def add_expr(self, key: str, expr: bool): + """Adds a distance penalty of 1.0 if `expr` evaluates to True, + or 0.0. + """ + if expr: + self.add(key, 1.0) + else: + self.add(key, 0.0) + + def add_number(self, key: str, number1: int, number2: int): + """Adds a distance penalty of 1.0 for each number of difference + between `number1` and `number2`, or 0.0 when there is no + difference. Use this when there is no upper limit on the + difference between the two numbers. + """ + diff = abs(number1 - number2) + if diff: + for i in range(diff): + self.add(key, 1.0) + else: + self.add(key, 0.0) + + def add_priority( + self, + key: str, + value: Any, + options: list[Any] | tuple[Any, ...] | Any, + ): + """Adds a distance penalty that corresponds to the position at + which `value` appears in `options`. A distance penalty of 0.0 + for the first option, or 1.0 if there is no matching option. If + an option is a compiled regular expression, it will be + considered equal if it matches against `value`. + """ + if not isinstance(options, (list, tuple)): + options = [options] + unit = 1.0 / (len(options) or 1) + for i, opt in enumerate(options): + if self._eq(opt, value): + dist = i * unit + break + else: + dist = 1.0 + self.add(key, dist) + + def add_ratio( + self, + key: str, + number1: int | float, + number2: int | float, + ): + """Adds a distance penalty for `number1` as a ratio of `number2`. + `number1` is bound at 0 and `number2`. + """ + number = float(max(min(number1, number2), 0)) + if number2: + dist = number / number2 + else: + dist = 0.0 + self.add(key, dist) + + def add_string(self, key: str, str1: str | None, str2: str | None): + """Adds a distance penalty based on the edit distance between + `str1` and `str2`. + """ + dist = string_dist(str1, str2) + self.add(key, dist) + + +@cache +def get_track_length_grace() -> float: + """Get cached grace period for track length matching.""" + return config["match"]["track_length_grace"].as_number() + + +@cache +def get_track_length_max() -> float: + """Get cached maximum track length for track length matching.""" + return config["match"]["track_length_max"].as_number() + + +def track_index_changed(item: Item, track_info: TrackInfo) -> bool: + """Returns True if the item and track info index is different. Tolerates + per disc and per release numbering. + """ + return item.track not in (track_info.medium_index, track_info.index) + + +def track_distance( + item: Item, + track_info: TrackInfo, + incl_artist: bool = False, +) -> Distance: + """Determines the significance of a track metadata change. Returns a + Distance object. `incl_artist` indicates that a distance component should + be included for the track artist (i.e., for various-artist releases). + + ``track_length_grace`` and ``track_length_max`` configuration options are + cached because this function is called many times during the matching + process and their access comes with a performance overhead. + """ + dist = Distance() + + # Length. + if info_length := track_info.length: + diff = abs(item.length - info_length) - get_track_length_grace() + dist.add_ratio("track_length", diff, get_track_length_max()) + + # Title. + dist.add_string("track_title", item.title, track_info.title) + + # Artist. Only check if there is actually an artist in the track data. + if ( + incl_artist + and track_info.artist + and item.artist.lower() not in VA_ARTISTS + ): + dist.add_string("track_artist", item.artist, track_info.artist) + + # Track index. + if track_info.index and item.track: + dist.add_expr("track_index", track_index_changed(item, track_info)) + + # Track ID. + if item.mb_trackid: + dist.add_expr("track_id", item.mb_trackid != track_info.track_id) + + # Penalize mismatching disc numbers. + if track_info.medium and item.disc: + dist.add_expr("medium", item.disc != track_info.medium) + + # Plugins. + dist.update(plugins.track_distance(item, track_info)) + + return dist + + +def distance( + items: Sequence[Item], + album_info: AlbumInfo, + mapping: dict[Item, TrackInfo], +) -> Distance: + """Determines how "significant" an album metadata change would be. + Returns a Distance object. `album_info` is an AlbumInfo object + reflecting the album to be compared. `items` is a sequence of all + Item objects that will be matched (order is not important). + `mapping` is a dictionary mapping Items to TrackInfo objects; the + keys are a subset of `items` and the values are a subset of + `album_info.tracks`. + """ + likelies, _ = get_most_common_tags(items) + + dist = Distance() + + # Artist, if not various. + if not album_info.va: + dist.add_string("artist", likelies["artist"], album_info.artist) + + # Album. + dist.add_string("album", likelies["album"], album_info.album) + + preferred_config = config["match"]["preferred"] + # Current or preferred media. + if album_info.media: + # Preferred media options. + media_patterns: Sequence[str] = preferred_config["media"].as_str_seq() + options = [ + re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in media_patterns + ] + if options: + dist.add_priority("media", album_info.media, options) + # Current media. + elif likelies["media"]: + dist.add_equality("media", album_info.media, likelies["media"]) + + # Mediums. + if likelies["disctotal"] and album_info.mediums: + dist.add_number("mediums", likelies["disctotal"], album_info.mediums) + + # Prefer earliest release. + if album_info.year and preferred_config["original_year"]: + # Assume 1889 (earliest first gramophone discs) if we don't know the + # original year. + original = album_info.original_year or 1889 + diff = abs(album_info.year - original) + diff_max = abs(datetime.date.today().year - original) + dist.add_ratio("year", diff, diff_max) + # Year. + elif likelies["year"] and album_info.year: + if likelies["year"] in (album_info.year, album_info.original_year): + # No penalty for matching release or original year. + dist.add("year", 0.0) + elif album_info.original_year: + # Prefer matchest closest to the release year. + diff = abs(likelies["year"] - album_info.year) + diff_max = abs( + datetime.date.today().year - album_info.original_year + ) + dist.add_ratio("year", diff, diff_max) + else: + # Full penalty when there is no original year. + dist.add("year", 1.0) + + # Preferred countries. + country_patterns: Sequence[str] = preferred_config["countries"].as_str_seq() + options = [re.compile(pat, re.I) for pat in country_patterns] + if album_info.country and options: + dist.add_priority("country", album_info.country, options) + # Country. + elif likelies["country"] and album_info.country: + dist.add_string("country", likelies["country"], album_info.country) + + # Label. + if likelies["label"] and album_info.label: + dist.add_string("label", likelies["label"], album_info.label) + + # Catalog number. + if likelies["catalognum"] and album_info.catalognum: + dist.add_string( + "catalognum", likelies["catalognum"], album_info.catalognum + ) + + # Disambiguation. + if likelies["albumdisambig"] and album_info.albumdisambig: + dist.add_string( + "albumdisambig", likelies["albumdisambig"], album_info.albumdisambig + ) + + # Album ID. + if likelies["mb_albumid"]: + dist.add_equality( + "album_id", likelies["mb_albumid"], album_info.album_id + ) + + # Tracks. + dist.tracks = {} + for item, track in mapping.items(): + dist.tracks[track] = track_distance(item, track, album_info.va) + dist.add("tracks", dist.tracks[track].distance) + + # Missing tracks. + for _ in range(len(album_info.tracks) - len(mapping)): + dist.add("missing_tracks", 1.0) + + # Unmatched tracks. + for _ in range(len(items) - len(mapping)): + dist.add("unmatched_tracks", 1.0) + + # Plugins. + dist.update(plugins.album_distance(items, album_info, mapping)) + + return dist diff --git a/beets/autotag/hooks.py b/beets/autotag/hooks.py index 641a6cb4f..7cd215fc4 100644 --- a/beets/autotag/hooks.py +++ b/beets/autotag/hooks.py @@ -16,21 +16,15 @@ from __future__ import annotations -import re -from functools import total_ordering from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar -from jellyfish import levenshtein_distance -from unidecode import unidecode - -from beets import config, logging -from beets.util import as_string, cached_classproperty +from beets import logging if TYPE_CHECKING: - from collections.abc import Iterator - from beets.library import Item + from .distance import Distance + log = logging.getLogger("beets") V = TypeVar("V") @@ -254,328 +248,6 @@ class TrackInfo(AttrDict[Any]): return dupe -# Candidate distance scoring. - -# Parameters for string distance function. -# Words that can be moved to the end of a string using a comma. -SD_END_WORDS = ["the", "a", "an"] -# Reduced weights for certain portions of the string. -SD_PATTERNS = [ - (r"^the ", 0.1), - (r"[\[\(]?(ep|single)[\]\)]?", 0.0), - (r"[\[\(]?(featuring|feat|ft)[\. :].+", 0.1), - (r"\(.*?\)", 0.3), - (r"\[.*?\]", 0.3), - (r"(, )?(pt\.|part) .+", 0.2), -] -# Replacements to use before testing distance. -SD_REPLACE = [ - (r"&", "and"), -] - - -def _string_dist_basic(str1: str, str2: str) -> float: - """Basic edit distance between two strings, ignoring - non-alphanumeric characters and case. Comparisons are based on a - transliteration/lowering to ASCII characters. Normalized by string - length. - """ - assert isinstance(str1, str) - assert isinstance(str2, str) - str1 = as_string(unidecode(str1)) - str2 = as_string(unidecode(str2)) - str1 = re.sub(r"[^a-z0-9]", "", str1.lower()) - str2 = re.sub(r"[^a-z0-9]", "", str2.lower()) - if not str1 and not str2: - return 0.0 - return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2))) - - -def string_dist(str1: str | None, str2: str | None) -> float: - """Gives an "intuitive" edit distance between two strings. This is - an edit distance, normalized by the string length, with a number of - tweaks that reflect intuition about text. - """ - if str1 is None and str2 is None: - return 0.0 - if str1 is None or str2 is None: - return 1.0 - - str1 = str1.lower() - str2 = str2.lower() - - # Don't penalize strings that move certain words to the end. For - # example, "the something" should be considered equal to - # "something, the". - for word in SD_END_WORDS: - if str1.endswith(", %s" % word): - str1 = "{} {}".format(word, str1[: -len(word) - 2]) - if str2.endswith(", %s" % word): - str2 = "{} {}".format(word, str2[: -len(word) - 2]) - - # Perform a couple of basic normalizing substitutions. - for pat, repl in SD_REPLACE: - str1 = re.sub(pat, repl, str1) - str2 = re.sub(pat, repl, str2) - - # Change the weight for certain string portions matched by a set - # of regular expressions. We gradually change the strings and build - # up penalties associated with parts of the string that were - # deleted. - base_dist = _string_dist_basic(str1, str2) - penalty = 0.0 - for pat, weight in SD_PATTERNS: - # Get strings that drop the pattern. - case_str1 = re.sub(pat, "", str1) - case_str2 = re.sub(pat, "", str2) - - if case_str1 != str1 or case_str2 != str2: - # If the pattern was present (i.e., it is deleted in the - # the current case), recalculate the distances for the - # modified strings. - case_dist = _string_dist_basic(case_str1, case_str2) - case_delta = max(0.0, base_dist - case_dist) - if case_delta == 0.0: - continue - - # Shift our baseline strings down (to avoid rematching the - # same part of the string) and add a scaled distance - # amount to the penalties. - str1 = case_str1 - str2 = case_str2 - base_dist = case_dist - penalty += weight * case_delta - - return base_dist + penalty - - -@total_ordering -class Distance: - """Keeps track of multiple distance penalties. Provides a single - weighted distance for all penalties as well as a weighted distance - for each individual penalty. - """ - - def __init__(self) -> None: - self._penalties: dict[str, list[float]] = {} - self.tracks: dict[TrackInfo, Distance] = {} - - @cached_classproperty - def _weights(cls) -> dict[str, float]: - """A dictionary from keys to floating-point weights.""" - weights_view = config["match"]["distance_weights"] - weights = {} - for key in weights_view.keys(): - weights[key] = weights_view[key].as_number() - return weights - - # Access the components and their aggregates. - - @property - def distance(self) -> float: - """Return a weighted and normalized distance across all - penalties. - """ - dist_max = self.max_distance - if dist_max: - return self.raw_distance / self.max_distance - return 0.0 - - @property - def max_distance(self) -> float: - """Return the maximum distance penalty (normalization factor).""" - dist_max = 0.0 - for key, penalty in self._penalties.items(): - dist_max += len(penalty) * self._weights[key] - return dist_max - - @property - def raw_distance(self) -> float: - """Return the raw (denormalized) distance.""" - dist_raw = 0.0 - for key, penalty in self._penalties.items(): - dist_raw += sum(penalty) * self._weights[key] - return dist_raw - - def items(self) -> list[tuple[str, float]]: - """Return a list of (key, dist) pairs, with `dist` being the - weighted distance, sorted from highest to lowest. Does not - include penalties with a zero value. - """ - list_ = [] - for key in self._penalties: - dist = self[key] - if dist: - list_.append((key, dist)) - # Convert distance into a negative float we can sort items in - # ascending order (for keys, when the penalty is equal) and - # still get the items with the biggest distance first. - return sorted( - list_, key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0]) - ) - - def __hash__(self) -> int: - return id(self) - - def __eq__(self, other) -> bool: - return self.distance == other - - # Behave like a float. - - def __lt__(self, other) -> bool: - return self.distance < other - - def __float__(self) -> float: - return self.distance - - def __sub__(self, other) -> float: - return self.distance - other - - def __rsub__(self, other) -> float: - return other - self.distance - - def __str__(self) -> str: - return f"{self.distance:.2f}" - - # Behave like a dict. - - def __getitem__(self, key) -> float: - """Returns the weighted distance for a named penalty.""" - dist = sum(self._penalties[key]) * self._weights[key] - dist_max = self.max_distance - if dist_max: - return dist / dist_max - return 0.0 - - def __iter__(self) -> Iterator[tuple[str, float]]: - return iter(self.items()) - - def __len__(self) -> int: - return len(self.items()) - - def keys(self) -> list[str]: - return [key for key, _ in self.items()] - - def update(self, dist: Distance): - """Adds all the distance penalties from `dist`.""" - if not isinstance(dist, Distance): - raise ValueError( - "`dist` must be a Distance object, not {}".format(type(dist)) - ) - for key, penalties in dist._penalties.items(): - self._penalties.setdefault(key, []).extend(penalties) - - # Adding components. - - def _eq(self, value1: re.Pattern[str] | Any, value2: Any) -> bool: - """Returns True if `value1` is equal to `value2`. `value1` may - be a compiled regular expression, in which case it will be - matched against `value2`. - """ - if isinstance(value1, re.Pattern): - return bool(value1.match(value2)) - return value1 == value2 - - def add(self, key: str, dist: float): - """Adds a distance penalty. `key` must correspond with a - configured weight setting. `dist` must be a float between 0.0 - and 1.0, and will be added to any existing distance penalties - for the same key. - """ - if not 0.0 <= dist <= 1.0: - raise ValueError(f"`dist` must be between 0.0 and 1.0, not {dist}") - self._penalties.setdefault(key, []).append(dist) - - def add_equality( - self, - key: str, - value: Any, - options: list[Any] | tuple[Any, ...] | Any, - ): - """Adds a distance penalty of 1.0 if `value` doesn't match any - of the values in `options`. If an option is a compiled regular - expression, it will be considered equal if it matches against - `value`. - """ - if not isinstance(options, (list, tuple)): - options = [options] - for opt in options: - if self._eq(opt, value): - dist = 0.0 - break - else: - dist = 1.0 - self.add(key, dist) - - def add_expr(self, key: str, expr: bool): - """Adds a distance penalty of 1.0 if `expr` evaluates to True, - or 0.0. - """ - if expr: - self.add(key, 1.0) - else: - self.add(key, 0.0) - - def add_number(self, key: str, number1: int, number2: int): - """Adds a distance penalty of 1.0 for each number of difference - between `number1` and `number2`, or 0.0 when there is no - difference. Use this when there is no upper limit on the - difference between the two numbers. - """ - diff = abs(number1 - number2) - if diff: - for i in range(diff): - self.add(key, 1.0) - else: - self.add(key, 0.0) - - def add_priority( - self, - key: str, - value: Any, - options: list[Any] | tuple[Any, ...] | Any, - ): - """Adds a distance penalty that corresponds to the position at - which `value` appears in `options`. A distance penalty of 0.0 - for the first option, or 1.0 if there is no matching option. If - an option is a compiled regular expression, it will be - considered equal if it matches against `value`. - """ - if not isinstance(options, (list, tuple)): - options = [options] - unit = 1.0 / (len(options) or 1) - for i, opt in enumerate(options): - if self._eq(opt, value): - dist = i * unit - break - else: - dist = 1.0 - self.add(key, dist) - - def add_ratio( - self, - key: str, - number1: int | float, - number2: int | float, - ): - """Adds a distance penalty for `number1` as a ratio of `number2`. - `number1` is bound at 0 and `number2`. - """ - number = float(max(min(number1, number2), 0)) - if number2: - dist = number / number2 - else: - dist = 0.0 - self.add(key, dist) - - def add_string(self, key: str, str1: str | None, str2: str | None): - """Adds a distance penalty based on the edit distance between - `str1` and `str2`. - """ - dist = string_dist(str1, str2) - self.add(key, dist) - - # Structures that compose all the information for a candidate match. diff --git a/beets/autotag/match.py b/beets/autotag/match.py index 4dc4c1052..64572cf3b 100644 --- a/beets/autotag/match.py +++ b/beets/autotag/match.py @@ -18,37 +18,23 @@ releases and tracks. from __future__ import annotations -import datetime -import re from enum import IntEnum -from functools import cache from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar import lap import numpy as np from beets import config, logging, plugins -from beets.autotag import ( - AlbumInfo, - AlbumMatch, - Distance, - TrackInfo, - TrackMatch, - hooks, -) +from beets.autotag import AlbumInfo, AlbumMatch, TrackInfo, TrackMatch, hooks from beets.util import get_most_common_tags +from .distance import VA_ARTISTS, distance, track_distance + if TYPE_CHECKING: from collections.abc import Iterable, Sequence from beets.library import Item -# Artist signals that indicate "various artists". These are used at the -# album level to determine whether a given release is likely a VA -# release and also on the track level to to remove the penalty for -# differing artists. -VA_ARTISTS = ("", "various artists", "various", "va", "unknown") - # Global logger. log = logging.getLogger("beets") @@ -112,191 +98,6 @@ def assign_items( return mapping, extra_items, extra_tracks -def track_index_changed(item: Item, track_info: TrackInfo) -> bool: - """Returns True if the item and track info index is different. Tolerates - per disc and per release numbering. - """ - return item.track not in (track_info.medium_index, track_info.index) - - -@cache -def get_track_length_grace() -> float: - """Get cached grace period for track length matching.""" - return config["match"]["track_length_grace"].as_number() - - -@cache -def get_track_length_max() -> float: - """Get cached maximum track length for track length matching.""" - return config["match"]["track_length_max"].as_number() - - -def track_distance( - item: Item, - track_info: TrackInfo, - incl_artist: bool = False, -) -> Distance: - """Determines the significance of a track metadata change. Returns a - Distance object. `incl_artist` indicates that a distance component should - be included for the track artist (i.e., for various-artist releases). - - ``track_length_grace`` and ``track_length_max`` configuration options are - cached because this function is called many times during the matching - process and their access comes with a performance overhead. - """ - dist = hooks.Distance() - - # Length. - if info_length := track_info.length: - diff = abs(item.length - info_length) - get_track_length_grace() - dist.add_ratio("track_length", diff, get_track_length_max()) - - # Title. - dist.add_string("track_title", item.title, track_info.title) - - # Artist. Only check if there is actually an artist in the track data. - if ( - incl_artist - and track_info.artist - and item.artist.lower() not in VA_ARTISTS - ): - dist.add_string("track_artist", item.artist, track_info.artist) - - # Track index. - if track_info.index and item.track: - dist.add_expr("track_index", track_index_changed(item, track_info)) - - # Track ID. - if item.mb_trackid: - dist.add_expr("track_id", item.mb_trackid != track_info.track_id) - - # Penalize mismatching disc numbers. - if track_info.medium and item.disc: - dist.add_expr("medium", item.disc != track_info.medium) - - # Plugins. - dist.update(plugins.track_distance(item, track_info)) - - return dist - - -def distance( - items: Sequence[Item], - album_info: AlbumInfo, - mapping: dict[Item, TrackInfo], -) -> Distance: - """Determines how "significant" an album metadata change would be. - Returns a Distance object. `album_info` is an AlbumInfo object - reflecting the album to be compared. `items` is a sequence of all - Item objects that will be matched (order is not important). - `mapping` is a dictionary mapping Items to TrackInfo objects; the - keys are a subset of `items` and the values are a subset of - `album_info.tracks`. - """ - likelies, _ = get_most_common_tags(items) - - dist = hooks.Distance() - - # Artist, if not various. - if not album_info.va: - dist.add_string("artist", likelies["artist"], album_info.artist) - - # Album. - dist.add_string("album", likelies["album"], album_info.album) - - preferred_config = config["match"]["preferred"] - # Current or preferred media. - if album_info.media: - # Preferred media options. - media_patterns: Sequence[str] = preferred_config["media"].as_str_seq() - options = [ - re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in media_patterns - ] - if options: - dist.add_priority("media", album_info.media, options) - # Current media. - elif likelies["media"]: - dist.add_equality("media", album_info.media, likelies["media"]) - - # Mediums. - if likelies["disctotal"] and album_info.mediums: - dist.add_number("mediums", likelies["disctotal"], album_info.mediums) - - # Prefer earliest release. - if album_info.year and preferred_config["original_year"]: - # Assume 1889 (earliest first gramophone discs) if we don't know the - # original year. - original = album_info.original_year or 1889 - diff = abs(album_info.year - original) - diff_max = abs(datetime.date.today().year - original) - dist.add_ratio("year", diff, diff_max) - # Year. - elif likelies["year"] and album_info.year: - if likelies["year"] in (album_info.year, album_info.original_year): - # No penalty for matching release or original year. - dist.add("year", 0.0) - elif album_info.original_year: - # Prefer matchest closest to the release year. - diff = abs(likelies["year"] - album_info.year) - diff_max = abs( - datetime.date.today().year - album_info.original_year - ) - dist.add_ratio("year", diff, diff_max) - else: - # Full penalty when there is no original year. - dist.add("year", 1.0) - - # Preferred countries. - country_patterns: Sequence[str] = preferred_config["countries"].as_str_seq() - options = [re.compile(pat, re.I) for pat in country_patterns] - if album_info.country and options: - dist.add_priority("country", album_info.country, options) - # Country. - elif likelies["country"] and album_info.country: - dist.add_string("country", likelies["country"], album_info.country) - - # Label. - if likelies["label"] and album_info.label: - dist.add_string("label", likelies["label"], album_info.label) - - # Catalog number. - if likelies["catalognum"] and album_info.catalognum: - dist.add_string( - "catalognum", likelies["catalognum"], album_info.catalognum - ) - - # Disambiguation. - if likelies["albumdisambig"] and album_info.albumdisambig: - dist.add_string( - "albumdisambig", likelies["albumdisambig"], album_info.albumdisambig - ) - - # Album ID. - if likelies["mb_albumid"]: - dist.add_equality( - "album_id", likelies["mb_albumid"], album_info.album_id - ) - - # Tracks. - dist.tracks = {} - for item, track in mapping.items(): - dist.tracks[track] = track_distance(item, track, album_info.va) - dist.add("tracks", dist.tracks[track].distance) - - # Missing tracks. - for _ in range(len(album_info.tracks) - len(mapping)): - dist.add("missing_tracks", 1.0) - - # Unmatched tracks. - for _ in range(len(items) - len(mapping)): - dist.add("unmatched_tracks", 1.0) - - # Plugins. - dist.update(plugins.album_distance(items, album_info, mapping)) - - return dist - - def match_by_id(items: Iterable[Item]) -> AlbumInfo | None: """If the items are tagged with an external source ID, return an AlbumInfo object for the corresponding album. Otherwise, returns diff --git a/beets/plugins.py b/beets/plugins.py index 6d3a8447e..cd66435b5 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -37,6 +37,7 @@ import mediafile import beets from beets import logging +from beets.autotag.distance import Distance from beets.util.id_extractors import extract_release_id if TYPE_CHECKING: @@ -53,7 +54,7 @@ if TYPE_CHECKING: from confuse import ConfigView - from beets.autotag import AlbumInfo, Distance, TrackInfo + from beets.autotag import AlbumInfo, TrackInfo from beets.dbcore import Query from beets.dbcore.db import FieldQueryType from beets.dbcore.types import Type @@ -224,8 +225,6 @@ class BeetsPlugin: """Should return a Distance object to be added to the distance for every track comparison. """ - from beets.autotag.hooks import Distance - return Distance() def album_distance( @@ -237,8 +236,6 @@ class BeetsPlugin: """Should return a Distance object to be added to the distance for every album-level comparison. """ - from beets.autotag.hooks import Distance - return Distance() def candidates( @@ -458,8 +455,6 @@ def track_distance(item: Item, info: TrackInfo) -> Distance: """Gets the track distance calculated by all loaded plugins. Returns a Distance object. """ - from beets.autotag.hooks import Distance - dist = Distance() for plugin in find_plugins(): dist.update(plugin.track_distance(item, info)) @@ -472,8 +467,6 @@ def album_distance( mapping: dict[Item, TrackInfo], ) -> Distance: """Returns the album distance calculated by plugins.""" - from beets.autotag.hooks import Distance - dist = Distance() for plugin in find_plugins(): dist.update(plugin.album_distance(items, album_info, mapping)) @@ -660,8 +653,6 @@ def get_distance( """Returns the ``data_source`` weight and the maximum source weight for albums or individual tracks. """ - from beets.autotag.hooks import Distance - dist = Distance() if info.data_source == data_source: dist.add("source", config["source_weight"].as_number()) diff --git a/beetsplug/chroma.py b/beetsplug/chroma.py index 518a41776..5c718154b 100644 --- a/beetsplug/chroma.py +++ b/beetsplug/chroma.py @@ -24,7 +24,7 @@ import acoustid import confuse from beets import config, plugins, ui, util -from beets.autotag.hooks import Distance +from beets.autotag.distance import Distance from beetsplug.musicbrainz import MusicBrainzPlugin API_KEY = "1vOwZtEn" diff --git a/beetsplug/discogs.py b/beetsplug/discogs.py index 696f1d1ac..2408f3498 100644 --- a/beetsplug/discogs.py +++ b/beetsplug/discogs.py @@ -38,7 +38,8 @@ from typing_extensions import TypedDict import beets import beets.ui from beets import config -from beets.autotag.hooks import AlbumInfo, TrackInfo, string_dist +from beets.autotag.distance import string_dist +from beets.autotag.hooks import AlbumInfo, TrackInfo from beets.plugins import BeetsPlugin, MetadataSourcePlugin, get_distance from beets.util.id_extractors import extract_release_id diff --git a/beetsplug/lyrics.py b/beetsplug/lyrics.py index e2c0c7fd2..f1c40ab24 100644 --- a/beetsplug/lyrics.py +++ b/beetsplug/lyrics.py @@ -38,7 +38,7 @@ from unidecode import unidecode import beets from beets import plugins, ui -from beets.autotag.hooks import string_dist +from beets.autotag.distance import string_dist from beets.util.config import sanitize_choices if TYPE_CHECKING: diff --git a/test/autotag/test_distance.py b/test/autotag/test_distance.py new file mode 100644 index 000000000..ec00ebcdf --- /dev/null +++ b/test/autotag/test_distance.py @@ -0,0 +1,476 @@ +import re +import unittest + +from beets import config +from beets.autotag import AlbumInfo, TrackInfo, match +from beets.autotag.distance import Distance, string_dist +from beets.library import Item +from beets.test.helper import BeetsTestCase + + +def _make_item(title, track, artist="some artist"): + return Item( + title=title, + track=track, + artist=artist, + album="some album", + length=1, + mb_trackid="", + mb_albumid="", + mb_artistid="", + ) + + +def _make_trackinfo(): + return [ + TrackInfo( + title="one", track_id=None, artist="some artist", length=1, index=1 + ), + TrackInfo( + title="two", track_id=None, artist="some artist", length=1, index=2 + ), + TrackInfo( + title="three", + track_id=None, + artist="some artist", + length=1, + index=3, + ), + ] + + +def _clear_weights(): + """Hack around the lazy descriptor used to cache weights for + Distance calculations. + """ + Distance.__dict__["_weights"].cache = {} + + +class DistanceTest(BeetsTestCase): + def tearDown(self): + super().tearDown() + _clear_weights() + + def test_add(self): + dist = Distance() + dist.add("add", 1.0) + assert dist._penalties == {"add": [1.0]} + + def test_add_equality(self): + dist = Distance() + dist.add_equality("equality", "ghi", ["abc", "def", "ghi"]) + assert dist._penalties["equality"] == [0.0] + + dist.add_equality("equality", "xyz", ["abc", "def", "ghi"]) + assert dist._penalties["equality"] == [0.0, 1.0] + + dist.add_equality("equality", "abc", re.compile(r"ABC", re.I)) + assert dist._penalties["equality"] == [0.0, 1.0, 0.0] + + def test_add_expr(self): + dist = Distance() + dist.add_expr("expr", True) + assert dist._penalties["expr"] == [1.0] + + dist.add_expr("expr", False) + assert dist._penalties["expr"] == [1.0, 0.0] + + def test_add_number(self): + dist = Distance() + # Add a full penalty for each number of difference between two numbers. + + dist.add_number("number", 1, 1) + assert dist._penalties["number"] == [0.0] + + dist.add_number("number", 1, 2) + assert dist._penalties["number"] == [0.0, 1.0] + + dist.add_number("number", 2, 1) + assert dist._penalties["number"] == [0.0, 1.0, 1.0] + + dist.add_number("number", -1, 2) + assert dist._penalties["number"] == [0.0, 1.0, 1.0, 1.0, 1.0, 1.0] + + def test_add_priority(self): + dist = Distance() + dist.add_priority("priority", "abc", "abc") + assert dist._penalties["priority"] == [0.0] + + dist.add_priority("priority", "def", ["abc", "def"]) + assert dist._penalties["priority"] == [0.0, 0.5] + + dist.add_priority( + "priority", "gh", ["ab", "cd", "ef", re.compile("GH", re.I)] + ) + assert dist._penalties["priority"] == [0.0, 0.5, 0.75] + + dist.add_priority("priority", "xyz", ["abc", "def"]) + assert dist._penalties["priority"] == [0.0, 0.5, 0.75, 1.0] + + def test_add_ratio(self): + dist = Distance() + dist.add_ratio("ratio", 25, 100) + assert dist._penalties["ratio"] == [0.25] + + dist.add_ratio("ratio", 10, 5) + assert dist._penalties["ratio"] == [0.25, 1.0] + + dist.add_ratio("ratio", -5, 5) + assert dist._penalties["ratio"] == [0.25, 1.0, 0.0] + + dist.add_ratio("ratio", 5, 0) + assert dist._penalties["ratio"] == [0.25, 1.0, 0.0, 0.0] + + def test_add_string(self): + dist = Distance() + sdist = string_dist("abc", "bcd") + dist.add_string("string", "abc", "bcd") + assert dist._penalties["string"] == [sdist] + assert dist._penalties["string"] != [0] + + def test_add_string_none(self): + dist = Distance() + dist.add_string("string", None, "string") + assert dist._penalties["string"] == [1] + + def test_add_string_both_none(self): + dist = Distance() + dist.add_string("string", None, None) + assert dist._penalties["string"] == [0] + + def test_distance(self): + config["match"]["distance_weights"]["album"] = 2.0 + config["match"]["distance_weights"]["medium"] = 1.0 + _clear_weights() + + dist = Distance() + dist.add("album", 0.5) + dist.add("media", 0.25) + dist.add("media", 0.75) + assert dist.distance == 0.5 + + # __getitem__() + assert dist["album"] == 0.25 + assert dist["media"] == 0.25 + + def test_max_distance(self): + config["match"]["distance_weights"]["album"] = 3.0 + config["match"]["distance_weights"]["medium"] = 1.0 + _clear_weights() + + dist = Distance() + dist.add("album", 0.5) + dist.add("medium", 0.0) + dist.add("medium", 0.0) + assert dist.max_distance == 5.0 + + def test_operators(self): + config["match"]["distance_weights"]["source"] = 1.0 + config["match"]["distance_weights"]["album"] = 2.0 + config["match"]["distance_weights"]["medium"] = 1.0 + _clear_weights() + + dist = Distance() + dist.add("source", 0.0) + dist.add("album", 0.5) + dist.add("medium", 0.25) + dist.add("medium", 0.75) + assert len(dist) == 2 + assert list(dist) == [("album", 0.2), ("medium", 0.2)] + assert dist == 0.4 + assert dist < 1.0 + assert dist > 0.0 + assert dist - 0.4 == 0.0 + assert 0.4 - dist == 0.0 + assert float(dist) == 0.4 + + def test_raw_distance(self): + config["match"]["distance_weights"]["album"] = 3.0 + config["match"]["distance_weights"]["medium"] = 1.0 + _clear_weights() + + dist = Distance() + dist.add("album", 0.5) + dist.add("medium", 0.25) + dist.add("medium", 0.5) + assert dist.raw_distance == 2.25 + + def test_items(self): + config["match"]["distance_weights"]["album"] = 4.0 + config["match"]["distance_weights"]["medium"] = 2.0 + _clear_weights() + + dist = Distance() + dist.add("album", 0.1875) + dist.add("medium", 0.75) + assert dist.items() == [("medium", 0.25), ("album", 0.125)] + + # Sort by key if distance is equal. + dist = Distance() + dist.add("album", 0.375) + dist.add("medium", 0.75) + assert dist.items() == [("album", 0.25), ("medium", 0.25)] + + def test_update(self): + dist1 = Distance() + dist1.add("album", 0.5) + dist1.add("media", 1.0) + + dist2 = Distance() + dist2.add("album", 0.75) + dist2.add("album", 0.25) + dist2.add("media", 0.05) + + dist1.update(dist2) + + assert dist1._penalties == { + "album": [0.5, 0.75, 0.25], + "media": [1.0, 0.05], + } + + +class TrackDistanceTest(BeetsTestCase): + def test_identical_tracks(self): + item = _make_item("one", 1) + info = _make_trackinfo()[0] + dist = match.track_distance(item, info, incl_artist=True) + assert dist == 0.0 + + def test_different_title(self): + item = _make_item("foo", 1) + info = _make_trackinfo()[0] + dist = match.track_distance(item, info, incl_artist=True) + assert dist != 0.0 + + def test_different_artist(self): + item = _make_item("one", 1) + item.artist = "foo" + info = _make_trackinfo()[0] + dist = match.track_distance(item, info, incl_artist=True) + assert dist != 0.0 + + def test_various_artists_tolerated(self): + item = _make_item("one", 1) + item.artist = "Various Artists" + info = _make_trackinfo()[0] + dist = match.track_distance(item, info, incl_artist=True) + assert dist == 0.0 + + +class AlbumDistanceTest(BeetsTestCase): + def _mapping(self, items, info): + out = {} + for i, t in zip(items, info.tracks): + out[i] = t + return out + + def _dist(self, items, info): + return match.distance(items, info, self._mapping(items, info)) + + def test_identical_albums(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("two", 2)) + items.append(_make_item("three", 3)) + info = AlbumInfo( + artist="some artist", + album="some album", + tracks=_make_trackinfo(), + va=False, + ) + assert self._dist(items, info) == 0 + + def test_incomplete_album(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("three", 3)) + info = AlbumInfo( + artist="some artist", + album="some album", + tracks=_make_trackinfo(), + va=False, + ) + dist = self._dist(items, info) + assert dist != 0 + # Make sure the distance is not too great + assert dist < 0.2 + + def test_global_artists_differ(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("two", 2)) + items.append(_make_item("three", 3)) + info = AlbumInfo( + artist="someone else", + album="some album", + tracks=_make_trackinfo(), + va=False, + ) + assert self._dist(items, info) != 0 + + def test_comp_track_artists_match(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("two", 2)) + items.append(_make_item("three", 3)) + info = AlbumInfo( + artist="should be ignored", + album="some album", + tracks=_make_trackinfo(), + va=True, + ) + assert self._dist(items, info) == 0 + + def test_comp_no_track_artists(self): + # Some VA releases don't have track artists (incomplete metadata). + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("two", 2)) + items.append(_make_item("three", 3)) + info = AlbumInfo( + artist="should be ignored", + album="some album", + tracks=_make_trackinfo(), + va=True, + ) + info.tracks[0].artist = None + info.tracks[1].artist = None + info.tracks[2].artist = None + assert self._dist(items, info) == 0 + + def test_comp_track_artists_do_not_match(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("two", 2, "someone else")) + items.append(_make_item("three", 3)) + info = AlbumInfo( + artist="some artist", + album="some album", + tracks=_make_trackinfo(), + va=True, + ) + assert self._dist(items, info) != 0 + + def test_tracks_out_of_order(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("three", 2)) + items.append(_make_item("two", 3)) + info = AlbumInfo( + artist="some artist", + album="some album", + tracks=_make_trackinfo(), + va=False, + ) + dist = self._dist(items, info) + assert 0 < dist < 0.2 + + def test_two_medium_release(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("two", 2)) + items.append(_make_item("three", 3)) + info = AlbumInfo( + artist="some artist", + album="some album", + tracks=_make_trackinfo(), + va=False, + ) + info.tracks[0].medium_index = 1 + info.tracks[1].medium_index = 2 + info.tracks[2].medium_index = 1 + dist = self._dist(items, info) + assert dist == 0 + + def test_per_medium_track_numbers(self): + items = [] + items.append(_make_item("one", 1)) + items.append(_make_item("two", 2)) + items.append(_make_item("three", 1)) + info = AlbumInfo( + artist="some artist", + album="some album", + tracks=_make_trackinfo(), + va=False, + ) + info.tracks[0].medium_index = 1 + info.tracks[1].medium_index = 2 + info.tracks[2].medium_index = 1 + dist = self._dist(items, info) + assert dist == 0 + + +class StringDistanceTest(unittest.TestCase): + def test_equal_strings(self): + dist = string_dist("Some String", "Some String") + assert dist == 0.0 + + def test_different_strings(self): + dist = string_dist("Some String", "Totally Different") + assert dist != 0.0 + + def test_punctuation_ignored(self): + dist = string_dist("Some String", "Some.String!") + assert dist == 0.0 + + def test_case_ignored(self): + dist = string_dist("Some String", "sOME sTring") + assert dist == 0.0 + + def test_leading_the_has_lower_weight(self): + dist1 = string_dist("XXX Band Name", "Band Name") + dist2 = string_dist("The Band Name", "Band Name") + assert dist2 < dist1 + + def test_parens_have_lower_weight(self): + dist1 = string_dist("One .Two.", "One") + dist2 = string_dist("One (Two)", "One") + assert dist2 < dist1 + + def test_brackets_have_lower_weight(self): + dist1 = string_dist("One .Two.", "One") + dist2 = string_dist("One [Two]", "One") + assert dist2 < dist1 + + def test_ep_label_has_zero_weight(self): + dist = string_dist("My Song (EP)", "My Song") + assert dist == 0.0 + + def test_featured_has_lower_weight(self): + dist1 = string_dist("My Song blah Someone", "My Song") + dist2 = string_dist("My Song feat Someone", "My Song") + assert dist2 < dist1 + + def test_postfix_the(self): + dist = string_dist("The Song Title", "Song Title, The") + assert dist == 0.0 + + def test_postfix_a(self): + dist = string_dist("A Song Title", "Song Title, A") + assert dist == 0.0 + + def test_postfix_an(self): + dist = string_dist("An Album Title", "Album Title, An") + assert dist == 0.0 + + def test_empty_strings(self): + dist = string_dist("", "") + assert dist == 0.0 + + def test_solo_pattern(self): + # Just make sure these don't crash. + string_dist("The ", "") + string_dist("(EP)", "(EP)") + string_dist(", An", "") + + def test_heuristic_does_not_harm_distance(self): + dist = string_dist("Untitled", "[Untitled]") + assert dist == 0.0 + + def test_ampersand_expansion(self): + dist = string_dist("And", "&") + assert dist == 0.0 + + def test_accented_characters(self): + dist = string_dist("\xe9\xe1\xf1", "ean") + assert dist == 0.0 diff --git a/test/test_autotag.py b/test/test_autotag.py index bd4205806..8d467e5ed 100644 --- a/test/test_autotag.py +++ b/test/test_autotag.py @@ -14,410 +14,14 @@ """Tests for autotagging functionality.""" -import re -import unittest - import pytest from beets import autotag, config from beets.autotag import AlbumInfo, TrackInfo, correct_list_fields, match -from beets.autotag.hooks import Distance, string_dist from beets.library import Item from beets.test.helper import BeetsTestCase, ConfigMixin -def _make_item(title, track, artist="some artist"): - return Item( - title=title, - track=track, - artist=artist, - album="some album", - length=1, - mb_trackid="", - mb_albumid="", - mb_artistid="", - ) - - -def _make_trackinfo(): - return [ - TrackInfo( - title="one", track_id=None, artist="some artist", length=1, index=1 - ), - TrackInfo( - title="two", track_id=None, artist="some artist", length=1, index=2 - ), - TrackInfo( - title="three", - track_id=None, - artist="some artist", - length=1, - index=3, - ), - ] - - -def _clear_weights(): - """Hack around the lazy descriptor used to cache weights for - Distance calculations. - """ - Distance.__dict__["_weights"].cache = {} - - -class DistanceTest(BeetsTestCase): - def tearDown(self): - super().tearDown() - _clear_weights() - - def test_add(self): - dist = Distance() - dist.add("add", 1.0) - assert dist._penalties == {"add": [1.0]} - - def test_add_equality(self): - dist = Distance() - dist.add_equality("equality", "ghi", ["abc", "def", "ghi"]) - assert dist._penalties["equality"] == [0.0] - - dist.add_equality("equality", "xyz", ["abc", "def", "ghi"]) - assert dist._penalties["equality"] == [0.0, 1.0] - - dist.add_equality("equality", "abc", re.compile(r"ABC", re.I)) - assert dist._penalties["equality"] == [0.0, 1.0, 0.0] - - def test_add_expr(self): - dist = Distance() - dist.add_expr("expr", True) - assert dist._penalties["expr"] == [1.0] - - dist.add_expr("expr", False) - assert dist._penalties["expr"] == [1.0, 0.0] - - def test_add_number(self): - dist = Distance() - # Add a full penalty for each number of difference between two numbers. - - dist.add_number("number", 1, 1) - assert dist._penalties["number"] == [0.0] - - dist.add_number("number", 1, 2) - assert dist._penalties["number"] == [0.0, 1.0] - - dist.add_number("number", 2, 1) - assert dist._penalties["number"] == [0.0, 1.0, 1.0] - - dist.add_number("number", -1, 2) - assert dist._penalties["number"] == [0.0, 1.0, 1.0, 1.0, 1.0, 1.0] - - def test_add_priority(self): - dist = Distance() - dist.add_priority("priority", "abc", "abc") - assert dist._penalties["priority"] == [0.0] - - dist.add_priority("priority", "def", ["abc", "def"]) - assert dist._penalties["priority"] == [0.0, 0.5] - - dist.add_priority( - "priority", "gh", ["ab", "cd", "ef", re.compile("GH", re.I)] - ) - assert dist._penalties["priority"] == [0.0, 0.5, 0.75] - - dist.add_priority("priority", "xyz", ["abc", "def"]) - assert dist._penalties["priority"] == [0.0, 0.5, 0.75, 1.0] - - def test_add_ratio(self): - dist = Distance() - dist.add_ratio("ratio", 25, 100) - assert dist._penalties["ratio"] == [0.25] - - dist.add_ratio("ratio", 10, 5) - assert dist._penalties["ratio"] == [0.25, 1.0] - - dist.add_ratio("ratio", -5, 5) - assert dist._penalties["ratio"] == [0.25, 1.0, 0.0] - - dist.add_ratio("ratio", 5, 0) - assert dist._penalties["ratio"] == [0.25, 1.0, 0.0, 0.0] - - def test_add_string(self): - dist = Distance() - sdist = string_dist("abc", "bcd") - dist.add_string("string", "abc", "bcd") - assert dist._penalties["string"] == [sdist] - assert dist._penalties["string"] != [0] - - def test_add_string_none(self): - dist = Distance() - dist.add_string("string", None, "string") - assert dist._penalties["string"] == [1] - - def test_add_string_both_none(self): - dist = Distance() - dist.add_string("string", None, None) - assert dist._penalties["string"] == [0] - - def test_distance(self): - config["match"]["distance_weights"]["album"] = 2.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() - dist.add("album", 0.5) - dist.add("media", 0.25) - dist.add("media", 0.75) - assert dist.distance == 0.5 - - # __getitem__() - assert dist["album"] == 0.25 - assert dist["media"] == 0.25 - - def test_max_distance(self): - config["match"]["distance_weights"]["album"] = 3.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() - dist.add("album", 0.5) - dist.add("medium", 0.0) - dist.add("medium", 0.0) - assert dist.max_distance == 5.0 - - def test_operators(self): - config["match"]["distance_weights"]["source"] = 1.0 - config["match"]["distance_weights"]["album"] = 2.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() - dist.add("source", 0.0) - dist.add("album", 0.5) - dist.add("medium", 0.25) - dist.add("medium", 0.75) - assert len(dist) == 2 - assert list(dist) == [("album", 0.2), ("medium", 0.2)] - assert dist == 0.4 - assert dist < 1.0 - assert dist > 0.0 - assert dist - 0.4 == 0.0 - assert 0.4 - dist == 0.0 - assert float(dist) == 0.4 - - def test_raw_distance(self): - config["match"]["distance_weights"]["album"] = 3.0 - config["match"]["distance_weights"]["medium"] = 1.0 - _clear_weights() - - dist = Distance() - dist.add("album", 0.5) - dist.add("medium", 0.25) - dist.add("medium", 0.5) - assert dist.raw_distance == 2.25 - - def test_items(self): - config["match"]["distance_weights"]["album"] = 4.0 - config["match"]["distance_weights"]["medium"] = 2.0 - _clear_weights() - - dist = Distance() - dist.add("album", 0.1875) - dist.add("medium", 0.75) - assert dist.items() == [("medium", 0.25), ("album", 0.125)] - - # Sort by key if distance is equal. - dist = Distance() - dist.add("album", 0.375) - dist.add("medium", 0.75) - assert dist.items() == [("album", 0.25), ("medium", 0.25)] - - def test_update(self): - dist1 = Distance() - dist1.add("album", 0.5) - dist1.add("media", 1.0) - - dist2 = Distance() - dist2.add("album", 0.75) - dist2.add("album", 0.25) - dist2.add("media", 0.05) - - dist1.update(dist2) - - assert dist1._penalties == { - "album": [0.5, 0.75, 0.25], - "media": [1.0, 0.05], - } - - -class TrackDistanceTest(BeetsTestCase): - def test_identical_tracks(self): - item = _make_item("one", 1) - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist == 0.0 - - def test_different_title(self): - item = _make_item("foo", 1) - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist != 0.0 - - def test_different_artist(self): - item = _make_item("one", 1) - item.artist = "foo" - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist != 0.0 - - def test_various_artists_tolerated(self): - item = _make_item("one", 1) - item.artist = "Various Artists" - info = _make_trackinfo()[0] - dist = match.track_distance(item, info, incl_artist=True) - assert dist == 0.0 - - -class AlbumDistanceTest(BeetsTestCase): - def _mapping(self, items, info): - out = {} - for i, t in zip(items, info.tracks): - out[i] = t - return out - - def _dist(self, items, info): - return match.distance(items, info, self._mapping(items, info)) - - def test_identical_albums(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - assert self._dist(items, info) == 0 - - def test_incomplete_album(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - dist = self._dist(items, info) - assert dist != 0 - # Make sure the distance is not too great - assert dist < 0.2 - - def test_global_artists_differ(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="someone else", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - assert self._dist(items, info) != 0 - - def test_comp_track_artists_match(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="should be ignored", - album="some album", - tracks=_make_trackinfo(), - va=True, - ) - assert self._dist(items, info) == 0 - - def test_comp_no_track_artists(self): - # Some VA releases don't have track artists (incomplete metadata). - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="should be ignored", - album="some album", - tracks=_make_trackinfo(), - va=True, - ) - info.tracks[0].artist = None - info.tracks[1].artist = None - info.tracks[2].artist = None - assert self._dist(items, info) == 0 - - def test_comp_track_artists_do_not_match(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2, "someone else")) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=True, - ) - assert self._dist(items, info) != 0 - - def test_tracks_out_of_order(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("three", 2)) - items.append(_make_item("two", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - dist = self._dist(items, info) - assert 0 < dist < 0.2 - - def test_two_medium_release(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 3)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - info.tracks[0].medium_index = 1 - info.tracks[1].medium_index = 2 - info.tracks[2].medium_index = 1 - dist = self._dist(items, info) - assert dist == 0 - - def test_per_medium_track_numbers(self): - items = [] - items.append(_make_item("one", 1)) - items.append(_make_item("two", 2)) - items.append(_make_item("three", 1)) - info = AlbumInfo( - artist="some artist", - album="some album", - tracks=_make_trackinfo(), - va=False, - ) - info.tracks[0].medium_index = 1 - info.tracks[1].medium_index = 2 - info.tracks[2].medium_index = 1 - dist = self._dist(items, info) - assert dist == 0 - - class TestAssignment(ConfigMixin): A = "one" B = "two" @@ -840,82 +444,6 @@ class ApplyCompilationTest(BeetsTestCase, ApplyTestUtil): assert self.items[1].comp -class StringDistanceTest(unittest.TestCase): - def test_equal_strings(self): - dist = string_dist("Some String", "Some String") - assert dist == 0.0 - - def test_different_strings(self): - dist = string_dist("Some String", "Totally Different") - assert dist != 0.0 - - def test_punctuation_ignored(self): - dist = string_dist("Some String", "Some.String!") - assert dist == 0.0 - - def test_case_ignored(self): - dist = string_dist("Some String", "sOME sTring") - assert dist == 0.0 - - def test_leading_the_has_lower_weight(self): - dist1 = string_dist("XXX Band Name", "Band Name") - dist2 = string_dist("The Band Name", "Band Name") - assert dist2 < dist1 - - def test_parens_have_lower_weight(self): - dist1 = string_dist("One .Two.", "One") - dist2 = string_dist("One (Two)", "One") - assert dist2 < dist1 - - def test_brackets_have_lower_weight(self): - dist1 = string_dist("One .Two.", "One") - dist2 = string_dist("One [Two]", "One") - assert dist2 < dist1 - - def test_ep_label_has_zero_weight(self): - dist = string_dist("My Song (EP)", "My Song") - assert dist == 0.0 - - def test_featured_has_lower_weight(self): - dist1 = string_dist("My Song blah Someone", "My Song") - dist2 = string_dist("My Song feat Someone", "My Song") - assert dist2 < dist1 - - def test_postfix_the(self): - dist = string_dist("The Song Title", "Song Title, The") - assert dist == 0.0 - - def test_postfix_a(self): - dist = string_dist("A Song Title", "Song Title, A") - assert dist == 0.0 - - def test_postfix_an(self): - dist = string_dist("An Album Title", "Album Title, An") - assert dist == 0.0 - - def test_empty_strings(self): - dist = string_dist("", "") - assert dist == 0.0 - - def test_solo_pattern(self): - # Just make sure these don't crash. - string_dist("The ", "") - string_dist("(EP)", "(EP)") - string_dist(", An", "") - - def test_heuristic_does_not_harm_distance(self): - dist = string_dist("Untitled", "[Untitled]") - assert dist == 0.0 - - def test_ampersand_expansion(self): - dist = string_dist("And", "&") - assert dist == 0.0 - - def test_accented_characters(self): - dist = string_dist("\xe9\xe1\xf1", "ean") - assert dist == 0.0 - - @pytest.mark.parametrize( "single_field,list_field", [