Move distance to a separate module

This commit is contained in:
Šarūnas Nejus 2025-05-25 02:59:27 +01:00
parent 01b6ea7898
commit adbd50b237
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
10 changed files with 1028 additions and 1024 deletions

View file

@ -14,22 +14,26 @@
"""Facilities for automatically determining files' correct metadata.""" """Facilities for automatically determining files' correct metadata."""
from collections.abc import Mapping, Sequence from __future__ import annotations
from typing import Union
from typing import TYPE_CHECKING, Union
from beets import config, logging from beets import config, logging
from beets.library import Album, Item, LibModel
# Parts of external interface. # Parts of external interface.
from beets.util import unique_list from beets.util import unique_list
from .hooks import AlbumInfo, AlbumMatch, Distance, TrackInfo, TrackMatch from .hooks import AlbumInfo, AlbumMatch, TrackInfo, TrackMatch
from .match import Proposal, Recommendation, tag_album, tag_item from .match import Proposal, Recommendation, tag_album, tag_item
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from beets.library import Album, Item, LibModel
__all__ = [ __all__ = [
"AlbumInfo", "AlbumInfo",
"AlbumMatch", "AlbumMatch",
"Distance",
"TrackInfo", "TrackInfo",
"TrackMatch", "TrackMatch",
"Proposal", "Proposal",

531
beets/autotag/distance.py Normal file
View file

@ -0,0 +1,531 @@
from __future__ import annotations
import datetime
import re
from functools import cache, total_ordering
from typing import TYPE_CHECKING, Any
from jellyfish import levenshtein_distance
from unidecode import unidecode
from beets import config, plugins
from beets.util import as_string, cached_classproperty, get_most_common_tags
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from beets.library import Item
from .hooks import AlbumInfo, TrackInfo
# Candidate distance scoring.
# Artist signals that indicate "various artists". These are used at the
# album level to determine whether a given release is likely a VA
# release and also on the track level to to remove the penalty for
# differing artists.
VA_ARTISTS = ("", "various artists", "various", "va", "unknown")
# Parameters for string distance function.
# Words that can be moved to the end of a string using a comma.
SD_END_WORDS = ["the", "a", "an"]
# Reduced weights for certain portions of the string.
SD_PATTERNS = [
(r"^the ", 0.1),
(r"[\[\(]?(ep|single)[\]\)]?", 0.0),
(r"[\[\(]?(featuring|feat|ft)[\. :].+", 0.1),
(r"\(.*?\)", 0.3),
(r"\[.*?\]", 0.3),
(r"(, )?(pt\.|part) .+", 0.2),
]
# Replacements to use before testing distance.
SD_REPLACE = [
(r"&", "and"),
]
def _string_dist_basic(str1: str, str2: str) -> float:
"""Basic edit distance between two strings, ignoring
non-alphanumeric characters and case. Comparisons are based on a
transliteration/lowering to ASCII characters. Normalized by string
length.
"""
assert isinstance(str1, str)
assert isinstance(str2, str)
str1 = as_string(unidecode(str1))
str2 = as_string(unidecode(str2))
str1 = re.sub(r"[^a-z0-9]", "", str1.lower())
str2 = re.sub(r"[^a-z0-9]", "", str2.lower())
if not str1 and not str2:
return 0.0
return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2)))
def string_dist(str1: str | None, str2: str | None) -> float:
"""Gives an "intuitive" edit distance between two strings. This is
an edit distance, normalized by the string length, with a number of
tweaks that reflect intuition about text.
"""
if str1 is None and str2 is None:
return 0.0
if str1 is None or str2 is None:
return 1.0
str1 = str1.lower()
str2 = str2.lower()
# Don't penalize strings that move certain words to the end. For
# example, "the something" should be considered equal to
# "something, the".
for word in SD_END_WORDS:
if str1.endswith(", %s" % word):
str1 = "{} {}".format(word, str1[: -len(word) - 2])
if str2.endswith(", %s" % word):
str2 = "{} {}".format(word, str2[: -len(word) - 2])
# Perform a couple of basic normalizing substitutions.
for pat, repl in SD_REPLACE:
str1 = re.sub(pat, repl, str1)
str2 = re.sub(pat, repl, str2)
# Change the weight for certain string portions matched by a set
# of regular expressions. We gradually change the strings and build
# up penalties associated with parts of the string that were
# deleted.
base_dist = _string_dist_basic(str1, str2)
penalty = 0.0
for pat, weight in SD_PATTERNS:
# Get strings that drop the pattern.
case_str1 = re.sub(pat, "", str1)
case_str2 = re.sub(pat, "", str2)
if case_str1 != str1 or case_str2 != str2:
# If the pattern was present (i.e., it is deleted in the
# the current case), recalculate the distances for the
# modified strings.
case_dist = _string_dist_basic(case_str1, case_str2)
case_delta = max(0.0, base_dist - case_dist)
if case_delta == 0.0:
continue
# Shift our baseline strings down (to avoid rematching the
# same part of the string) and add a scaled distance
# amount to the penalties.
str1 = case_str1
str2 = case_str2
base_dist = case_dist
penalty += weight * case_delta
return base_dist + penalty
@total_ordering
class Distance:
"""Keeps track of multiple distance penalties. Provides a single
weighted distance for all penalties as well as a weighted distance
for each individual penalty.
"""
def __init__(self) -> None:
self._penalties: dict[str, list[float]] = {}
self.tracks: dict[TrackInfo, Distance] = {}
@cached_classproperty
def _weights(cls) -> dict[str, float]:
"""A dictionary from keys to floating-point weights."""
weights_view = config["match"]["distance_weights"]
weights = {}
for key in weights_view.keys():
weights[key] = weights_view[key].as_number()
return weights
# Access the components and their aggregates.
@property
def distance(self) -> float:
"""Return a weighted and normalized distance across all
penalties.
"""
dist_max = self.max_distance
if dist_max:
return self.raw_distance / self.max_distance
return 0.0
@property
def max_distance(self) -> float:
"""Return the maximum distance penalty (normalization factor)."""
dist_max = 0.0
for key, penalty in self._penalties.items():
dist_max += len(penalty) * self._weights[key]
return dist_max
@property
def raw_distance(self) -> float:
"""Return the raw (denormalized) distance."""
dist_raw = 0.0
for key, penalty in self._penalties.items():
dist_raw += sum(penalty) * self._weights[key]
return dist_raw
def items(self) -> list[tuple[str, float]]:
"""Return a list of (key, dist) pairs, with `dist` being the
weighted distance, sorted from highest to lowest. Does not
include penalties with a zero value.
"""
list_ = []
for key in self._penalties:
dist = self[key]
if dist:
list_.append((key, dist))
# Convert distance into a negative float we can sort items in
# ascending order (for keys, when the penalty is equal) and
# still get the items with the biggest distance first.
return sorted(
list_, key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0])
)
def __hash__(self) -> int:
return id(self)
def __eq__(self, other) -> bool:
return self.distance == other
# Behave like a float.
def __lt__(self, other) -> bool:
return self.distance < other
def __float__(self) -> float:
return self.distance
def __sub__(self, other) -> float:
return self.distance - other
def __rsub__(self, other) -> float:
return other - self.distance
def __str__(self) -> str:
return f"{self.distance:.2f}"
# Behave like a dict.
def __getitem__(self, key) -> float:
"""Returns the weighted distance for a named penalty."""
dist = sum(self._penalties[key]) * self._weights[key]
dist_max = self.max_distance
if dist_max:
return dist / dist_max
return 0.0
def __iter__(self) -> Iterator[tuple[str, float]]:
return iter(self.items())
def __len__(self) -> int:
return len(self.items())
def keys(self) -> list[str]:
return [key for key, _ in self.items()]
def update(self, dist: Distance):
"""Adds all the distance penalties from `dist`."""
if not isinstance(dist, Distance):
raise ValueError(
"`dist` must be a Distance object, not {}".format(type(dist))
)
for key, penalties in dist._penalties.items():
self._penalties.setdefault(key, []).extend(penalties)
# Adding components.
def _eq(self, value1: re.Pattern[str] | Any, value2: Any) -> bool:
"""Returns True if `value1` is equal to `value2`. `value1` may
be a compiled regular expression, in which case it will be
matched against `value2`.
"""
if isinstance(value1, re.Pattern):
return bool(value1.match(value2))
return value1 == value2
def add(self, key: str, dist: float):
"""Adds a distance penalty. `key` must correspond with a
configured weight setting. `dist` must be a float between 0.0
and 1.0, and will be added to any existing distance penalties
for the same key.
"""
if not 0.0 <= dist <= 1.0:
raise ValueError(f"`dist` must be between 0.0 and 1.0, not {dist}")
self._penalties.setdefault(key, []).append(dist)
def add_equality(
self,
key: str,
value: Any,
options: list[Any] | tuple[Any, ...] | Any,
):
"""Adds a distance penalty of 1.0 if `value` doesn't match any
of the values in `options`. If an option is a compiled regular
expression, it will be considered equal if it matches against
`value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
for opt in options:
if self._eq(opt, value):
dist = 0.0
break
else:
dist = 1.0
self.add(key, dist)
def add_expr(self, key: str, expr: bool):
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
or 0.0.
"""
if expr:
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_number(self, key: str, number1: int, number2: int):
"""Adds a distance penalty of 1.0 for each number of difference
between `number1` and `number2`, or 0.0 when there is no
difference. Use this when there is no upper limit on the
difference between the two numbers.
"""
diff = abs(number1 - number2)
if diff:
for i in range(diff):
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_priority(
self,
key: str,
value: Any,
options: list[Any] | tuple[Any, ...] | Any,
):
"""Adds a distance penalty that corresponds to the position at
which `value` appears in `options`. A distance penalty of 0.0
for the first option, or 1.0 if there is no matching option. If
an option is a compiled regular expression, it will be
considered equal if it matches against `value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
unit = 1.0 / (len(options) or 1)
for i, opt in enumerate(options):
if self._eq(opt, value):
dist = i * unit
break
else:
dist = 1.0
self.add(key, dist)
def add_ratio(
self,
key: str,
number1: int | float,
number2: int | float,
):
"""Adds a distance penalty for `number1` as a ratio of `number2`.
`number1` is bound at 0 and `number2`.
"""
number = float(max(min(number1, number2), 0))
if number2:
dist = number / number2
else:
dist = 0.0
self.add(key, dist)
def add_string(self, key: str, str1: str | None, str2: str | None):
"""Adds a distance penalty based on the edit distance between
`str1` and `str2`.
"""
dist = string_dist(str1, str2)
self.add(key, dist)
@cache
def get_track_length_grace() -> float:
"""Get cached grace period for track length matching."""
return config["match"]["track_length_grace"].as_number()
@cache
def get_track_length_max() -> float:
"""Get cached maximum track length for track length matching."""
return config["match"]["track_length_max"].as_number()
def track_index_changed(item: Item, track_info: TrackInfo) -> bool:
"""Returns True if the item and track info index is different. Tolerates
per disc and per release numbering.
"""
return item.track not in (track_info.medium_index, track_info.index)
def track_distance(
item: Item,
track_info: TrackInfo,
incl_artist: bool = False,
) -> Distance:
"""Determines the significance of a track metadata change. Returns a
Distance object. `incl_artist` indicates that a distance component should
be included for the track artist (i.e., for various-artist releases).
``track_length_grace`` and ``track_length_max`` configuration options are
cached because this function is called many times during the matching
process and their access comes with a performance overhead.
"""
dist = Distance()
# Length.
if info_length := track_info.length:
diff = abs(item.length - info_length) - get_track_length_grace()
dist.add_ratio("track_length", diff, get_track_length_max())
# Title.
dist.add_string("track_title", item.title, track_info.title)
# Artist. Only check if there is actually an artist in the track data.
if (
incl_artist
and track_info.artist
and item.artist.lower() not in VA_ARTISTS
):
dist.add_string("track_artist", item.artist, track_info.artist)
# Track index.
if track_info.index and item.track:
dist.add_expr("track_index", track_index_changed(item, track_info))
# Track ID.
if item.mb_trackid:
dist.add_expr("track_id", item.mb_trackid != track_info.track_id)
# Penalize mismatching disc numbers.
if track_info.medium and item.disc:
dist.add_expr("medium", item.disc != track_info.medium)
# Plugins.
dist.update(plugins.track_distance(item, track_info))
return dist
def distance(
items: Sequence[Item],
album_info: AlbumInfo,
mapping: dict[Item, TrackInfo],
) -> Distance:
"""Determines how "significant" an album metadata change would be.
Returns a Distance object. `album_info` is an AlbumInfo object
reflecting the album to be compared. `items` is a sequence of all
Item objects that will be matched (order is not important).
`mapping` is a dictionary mapping Items to TrackInfo objects; the
keys are a subset of `items` and the values are a subset of
`album_info.tracks`.
"""
likelies, _ = get_most_common_tags(items)
dist = Distance()
# Artist, if not various.
if not album_info.va:
dist.add_string("artist", likelies["artist"], album_info.artist)
# Album.
dist.add_string("album", likelies["album"], album_info.album)
preferred_config = config["match"]["preferred"]
# Current or preferred media.
if album_info.media:
# Preferred media options.
media_patterns: Sequence[str] = preferred_config["media"].as_str_seq()
options = [
re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in media_patterns
]
if options:
dist.add_priority("media", album_info.media, options)
# Current media.
elif likelies["media"]:
dist.add_equality("media", album_info.media, likelies["media"])
# Mediums.
if likelies["disctotal"] and album_info.mediums:
dist.add_number("mediums", likelies["disctotal"], album_info.mediums)
# Prefer earliest release.
if album_info.year and preferred_config["original_year"]:
# Assume 1889 (earliest first gramophone discs) if we don't know the
# original year.
original = album_info.original_year or 1889
diff = abs(album_info.year - original)
diff_max = abs(datetime.date.today().year - original)
dist.add_ratio("year", diff, diff_max)
# Year.
elif likelies["year"] and album_info.year:
if likelies["year"] in (album_info.year, album_info.original_year):
# No penalty for matching release or original year.
dist.add("year", 0.0)
elif album_info.original_year:
# Prefer matchest closest to the release year.
diff = abs(likelies["year"] - album_info.year)
diff_max = abs(
datetime.date.today().year - album_info.original_year
)
dist.add_ratio("year", diff, diff_max)
else:
# Full penalty when there is no original year.
dist.add("year", 1.0)
# Preferred countries.
country_patterns: Sequence[str] = preferred_config["countries"].as_str_seq()
options = [re.compile(pat, re.I) for pat in country_patterns]
if album_info.country and options:
dist.add_priority("country", album_info.country, options)
# Country.
elif likelies["country"] and album_info.country:
dist.add_string("country", likelies["country"], album_info.country)
# Label.
if likelies["label"] and album_info.label:
dist.add_string("label", likelies["label"], album_info.label)
# Catalog number.
if likelies["catalognum"] and album_info.catalognum:
dist.add_string(
"catalognum", likelies["catalognum"], album_info.catalognum
)
# Disambiguation.
if likelies["albumdisambig"] and album_info.albumdisambig:
dist.add_string(
"albumdisambig", likelies["albumdisambig"], album_info.albumdisambig
)
# Album ID.
if likelies["mb_albumid"]:
dist.add_equality(
"album_id", likelies["mb_albumid"], album_info.album_id
)
# Tracks.
dist.tracks = {}
for item, track in mapping.items():
dist.tracks[track] = track_distance(item, track, album_info.va)
dist.add("tracks", dist.tracks[track].distance)
# Missing tracks.
for _ in range(len(album_info.tracks) - len(mapping)):
dist.add("missing_tracks", 1.0)
# Unmatched tracks.
for _ in range(len(items) - len(mapping)):
dist.add("unmatched_tracks", 1.0)
# Plugins.
dist.update(plugins.album_distance(items, album_info, mapping))
return dist

View file

@ -16,21 +16,15 @@
from __future__ import annotations from __future__ import annotations
import re
from functools import total_ordering
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
from jellyfish import levenshtein_distance from beets import logging
from unidecode import unidecode
from beets import config, logging
from beets.util import as_string, cached_classproperty
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Iterator
from beets.library import Item from beets.library import Item
from .distance import Distance
log = logging.getLogger("beets") log = logging.getLogger("beets")
V = TypeVar("V") V = TypeVar("V")
@ -254,328 +248,6 @@ class TrackInfo(AttrDict[Any]):
return dupe return dupe
# Candidate distance scoring.
# Parameters for string distance function.
# Words that can be moved to the end of a string using a comma.
SD_END_WORDS = ["the", "a", "an"]
# Reduced weights for certain portions of the string.
SD_PATTERNS = [
(r"^the ", 0.1),
(r"[\[\(]?(ep|single)[\]\)]?", 0.0),
(r"[\[\(]?(featuring|feat|ft)[\. :].+", 0.1),
(r"\(.*?\)", 0.3),
(r"\[.*?\]", 0.3),
(r"(, )?(pt\.|part) .+", 0.2),
]
# Replacements to use before testing distance.
SD_REPLACE = [
(r"&", "and"),
]
def _string_dist_basic(str1: str, str2: str) -> float:
"""Basic edit distance between two strings, ignoring
non-alphanumeric characters and case. Comparisons are based on a
transliteration/lowering to ASCII characters. Normalized by string
length.
"""
assert isinstance(str1, str)
assert isinstance(str2, str)
str1 = as_string(unidecode(str1))
str2 = as_string(unidecode(str2))
str1 = re.sub(r"[^a-z0-9]", "", str1.lower())
str2 = re.sub(r"[^a-z0-9]", "", str2.lower())
if not str1 and not str2:
return 0.0
return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2)))
def string_dist(str1: str | None, str2: str | None) -> float:
"""Gives an "intuitive" edit distance between two strings. This is
an edit distance, normalized by the string length, with a number of
tweaks that reflect intuition about text.
"""
if str1 is None and str2 is None:
return 0.0
if str1 is None or str2 is None:
return 1.0
str1 = str1.lower()
str2 = str2.lower()
# Don't penalize strings that move certain words to the end. For
# example, "the something" should be considered equal to
# "something, the".
for word in SD_END_WORDS:
if str1.endswith(", %s" % word):
str1 = "{} {}".format(word, str1[: -len(word) - 2])
if str2.endswith(", %s" % word):
str2 = "{} {}".format(word, str2[: -len(word) - 2])
# Perform a couple of basic normalizing substitutions.
for pat, repl in SD_REPLACE:
str1 = re.sub(pat, repl, str1)
str2 = re.sub(pat, repl, str2)
# Change the weight for certain string portions matched by a set
# of regular expressions. We gradually change the strings and build
# up penalties associated with parts of the string that were
# deleted.
base_dist = _string_dist_basic(str1, str2)
penalty = 0.0
for pat, weight in SD_PATTERNS:
# Get strings that drop the pattern.
case_str1 = re.sub(pat, "", str1)
case_str2 = re.sub(pat, "", str2)
if case_str1 != str1 or case_str2 != str2:
# If the pattern was present (i.e., it is deleted in the
# the current case), recalculate the distances for the
# modified strings.
case_dist = _string_dist_basic(case_str1, case_str2)
case_delta = max(0.0, base_dist - case_dist)
if case_delta == 0.0:
continue
# Shift our baseline strings down (to avoid rematching the
# same part of the string) and add a scaled distance
# amount to the penalties.
str1 = case_str1
str2 = case_str2
base_dist = case_dist
penalty += weight * case_delta
return base_dist + penalty
@total_ordering
class Distance:
"""Keeps track of multiple distance penalties. Provides a single
weighted distance for all penalties as well as a weighted distance
for each individual penalty.
"""
def __init__(self) -> None:
self._penalties: dict[str, list[float]] = {}
self.tracks: dict[TrackInfo, Distance] = {}
@cached_classproperty
def _weights(cls) -> dict[str, float]:
"""A dictionary from keys to floating-point weights."""
weights_view = config["match"]["distance_weights"]
weights = {}
for key in weights_view.keys():
weights[key] = weights_view[key].as_number()
return weights
# Access the components and their aggregates.
@property
def distance(self) -> float:
"""Return a weighted and normalized distance across all
penalties.
"""
dist_max = self.max_distance
if dist_max:
return self.raw_distance / self.max_distance
return 0.0
@property
def max_distance(self) -> float:
"""Return the maximum distance penalty (normalization factor)."""
dist_max = 0.0
for key, penalty in self._penalties.items():
dist_max += len(penalty) * self._weights[key]
return dist_max
@property
def raw_distance(self) -> float:
"""Return the raw (denormalized) distance."""
dist_raw = 0.0
for key, penalty in self._penalties.items():
dist_raw += sum(penalty) * self._weights[key]
return dist_raw
def items(self) -> list[tuple[str, float]]:
"""Return a list of (key, dist) pairs, with `dist` being the
weighted distance, sorted from highest to lowest. Does not
include penalties with a zero value.
"""
list_ = []
for key in self._penalties:
dist = self[key]
if dist:
list_.append((key, dist))
# Convert distance into a negative float we can sort items in
# ascending order (for keys, when the penalty is equal) and
# still get the items with the biggest distance first.
return sorted(
list_, key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0])
)
def __hash__(self) -> int:
return id(self)
def __eq__(self, other) -> bool:
return self.distance == other
# Behave like a float.
def __lt__(self, other) -> bool:
return self.distance < other
def __float__(self) -> float:
return self.distance
def __sub__(self, other) -> float:
return self.distance - other
def __rsub__(self, other) -> float:
return other - self.distance
def __str__(self) -> str:
return f"{self.distance:.2f}"
# Behave like a dict.
def __getitem__(self, key) -> float:
"""Returns the weighted distance for a named penalty."""
dist = sum(self._penalties[key]) * self._weights[key]
dist_max = self.max_distance
if dist_max:
return dist / dist_max
return 0.0
def __iter__(self) -> Iterator[tuple[str, float]]:
return iter(self.items())
def __len__(self) -> int:
return len(self.items())
def keys(self) -> list[str]:
return [key for key, _ in self.items()]
def update(self, dist: Distance):
"""Adds all the distance penalties from `dist`."""
if not isinstance(dist, Distance):
raise ValueError(
"`dist` must be a Distance object, not {}".format(type(dist))
)
for key, penalties in dist._penalties.items():
self._penalties.setdefault(key, []).extend(penalties)
# Adding components.
def _eq(self, value1: re.Pattern[str] | Any, value2: Any) -> bool:
"""Returns True if `value1` is equal to `value2`. `value1` may
be a compiled regular expression, in which case it will be
matched against `value2`.
"""
if isinstance(value1, re.Pattern):
return bool(value1.match(value2))
return value1 == value2
def add(self, key: str, dist: float):
"""Adds a distance penalty. `key` must correspond with a
configured weight setting. `dist` must be a float between 0.0
and 1.0, and will be added to any existing distance penalties
for the same key.
"""
if not 0.0 <= dist <= 1.0:
raise ValueError(f"`dist` must be between 0.0 and 1.0, not {dist}")
self._penalties.setdefault(key, []).append(dist)
def add_equality(
self,
key: str,
value: Any,
options: list[Any] | tuple[Any, ...] | Any,
):
"""Adds a distance penalty of 1.0 if `value` doesn't match any
of the values in `options`. If an option is a compiled regular
expression, it will be considered equal if it matches against
`value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
for opt in options:
if self._eq(opt, value):
dist = 0.0
break
else:
dist = 1.0
self.add(key, dist)
def add_expr(self, key: str, expr: bool):
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
or 0.0.
"""
if expr:
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_number(self, key: str, number1: int, number2: int):
"""Adds a distance penalty of 1.0 for each number of difference
between `number1` and `number2`, or 0.0 when there is no
difference. Use this when there is no upper limit on the
difference between the two numbers.
"""
diff = abs(number1 - number2)
if diff:
for i in range(diff):
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_priority(
self,
key: str,
value: Any,
options: list[Any] | tuple[Any, ...] | Any,
):
"""Adds a distance penalty that corresponds to the position at
which `value` appears in `options`. A distance penalty of 0.0
for the first option, or 1.0 if there is no matching option. If
an option is a compiled regular expression, it will be
considered equal if it matches against `value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
unit = 1.0 / (len(options) or 1)
for i, opt in enumerate(options):
if self._eq(opt, value):
dist = i * unit
break
else:
dist = 1.0
self.add(key, dist)
def add_ratio(
self,
key: str,
number1: int | float,
number2: int | float,
):
"""Adds a distance penalty for `number1` as a ratio of `number2`.
`number1` is bound at 0 and `number2`.
"""
number = float(max(min(number1, number2), 0))
if number2:
dist = number / number2
else:
dist = 0.0
self.add(key, dist)
def add_string(self, key: str, str1: str | None, str2: str | None):
"""Adds a distance penalty based on the edit distance between
`str1` and `str2`.
"""
dist = string_dist(str1, str2)
self.add(key, dist)
# Structures that compose all the information for a candidate match. # Structures that compose all the information for a candidate match.

View file

@ -18,37 +18,23 @@ releases and tracks.
from __future__ import annotations from __future__ import annotations
import datetime
import re
from enum import IntEnum from enum import IntEnum
from functools import cache
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
import lap import lap
import numpy as np import numpy as np
from beets import config, logging, plugins from beets import config, logging, plugins
from beets.autotag import ( from beets.autotag import AlbumInfo, AlbumMatch, TrackInfo, TrackMatch, hooks
AlbumInfo,
AlbumMatch,
Distance,
TrackInfo,
TrackMatch,
hooks,
)
from beets.util import get_most_common_tags from beets.util import get_most_common_tags
from .distance import VA_ARTISTS, distance, track_distance
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from beets.library import Item from beets.library import Item
# Artist signals that indicate "various artists". These are used at the
# album level to determine whether a given release is likely a VA
# release and also on the track level to to remove the penalty for
# differing artists.
VA_ARTISTS = ("", "various artists", "various", "va", "unknown")
# Global logger. # Global logger.
log = logging.getLogger("beets") log = logging.getLogger("beets")
@ -112,191 +98,6 @@ def assign_items(
return mapping, extra_items, extra_tracks return mapping, extra_items, extra_tracks
def track_index_changed(item: Item, track_info: TrackInfo) -> bool:
"""Returns True if the item and track info index is different. Tolerates
per disc and per release numbering.
"""
return item.track not in (track_info.medium_index, track_info.index)
@cache
def get_track_length_grace() -> float:
"""Get cached grace period for track length matching."""
return config["match"]["track_length_grace"].as_number()
@cache
def get_track_length_max() -> float:
"""Get cached maximum track length for track length matching."""
return config["match"]["track_length_max"].as_number()
def track_distance(
item: Item,
track_info: TrackInfo,
incl_artist: bool = False,
) -> Distance:
"""Determines the significance of a track metadata change. Returns a
Distance object. `incl_artist` indicates that a distance component should
be included for the track artist (i.e., for various-artist releases).
``track_length_grace`` and ``track_length_max`` configuration options are
cached because this function is called many times during the matching
process and their access comes with a performance overhead.
"""
dist = hooks.Distance()
# Length.
if info_length := track_info.length:
diff = abs(item.length - info_length) - get_track_length_grace()
dist.add_ratio("track_length", diff, get_track_length_max())
# Title.
dist.add_string("track_title", item.title, track_info.title)
# Artist. Only check if there is actually an artist in the track data.
if (
incl_artist
and track_info.artist
and item.artist.lower() not in VA_ARTISTS
):
dist.add_string("track_artist", item.artist, track_info.artist)
# Track index.
if track_info.index and item.track:
dist.add_expr("track_index", track_index_changed(item, track_info))
# Track ID.
if item.mb_trackid:
dist.add_expr("track_id", item.mb_trackid != track_info.track_id)
# Penalize mismatching disc numbers.
if track_info.medium and item.disc:
dist.add_expr("medium", item.disc != track_info.medium)
# Plugins.
dist.update(plugins.track_distance(item, track_info))
return dist
def distance(
items: Sequence[Item],
album_info: AlbumInfo,
mapping: dict[Item, TrackInfo],
) -> Distance:
"""Determines how "significant" an album metadata change would be.
Returns a Distance object. `album_info` is an AlbumInfo object
reflecting the album to be compared. `items` is a sequence of all
Item objects that will be matched (order is not important).
`mapping` is a dictionary mapping Items to TrackInfo objects; the
keys are a subset of `items` and the values are a subset of
`album_info.tracks`.
"""
likelies, _ = get_most_common_tags(items)
dist = hooks.Distance()
# Artist, if not various.
if not album_info.va:
dist.add_string("artist", likelies["artist"], album_info.artist)
# Album.
dist.add_string("album", likelies["album"], album_info.album)
preferred_config = config["match"]["preferred"]
# Current or preferred media.
if album_info.media:
# Preferred media options.
media_patterns: Sequence[str] = preferred_config["media"].as_str_seq()
options = [
re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in media_patterns
]
if options:
dist.add_priority("media", album_info.media, options)
# Current media.
elif likelies["media"]:
dist.add_equality("media", album_info.media, likelies["media"])
# Mediums.
if likelies["disctotal"] and album_info.mediums:
dist.add_number("mediums", likelies["disctotal"], album_info.mediums)
# Prefer earliest release.
if album_info.year and preferred_config["original_year"]:
# Assume 1889 (earliest first gramophone discs) if we don't know the
# original year.
original = album_info.original_year or 1889
diff = abs(album_info.year - original)
diff_max = abs(datetime.date.today().year - original)
dist.add_ratio("year", diff, diff_max)
# Year.
elif likelies["year"] and album_info.year:
if likelies["year"] in (album_info.year, album_info.original_year):
# No penalty for matching release or original year.
dist.add("year", 0.0)
elif album_info.original_year:
# Prefer matchest closest to the release year.
diff = abs(likelies["year"] - album_info.year)
diff_max = abs(
datetime.date.today().year - album_info.original_year
)
dist.add_ratio("year", diff, diff_max)
else:
# Full penalty when there is no original year.
dist.add("year", 1.0)
# Preferred countries.
country_patterns: Sequence[str] = preferred_config["countries"].as_str_seq()
options = [re.compile(pat, re.I) for pat in country_patterns]
if album_info.country and options:
dist.add_priority("country", album_info.country, options)
# Country.
elif likelies["country"] and album_info.country:
dist.add_string("country", likelies["country"], album_info.country)
# Label.
if likelies["label"] and album_info.label:
dist.add_string("label", likelies["label"], album_info.label)
# Catalog number.
if likelies["catalognum"] and album_info.catalognum:
dist.add_string(
"catalognum", likelies["catalognum"], album_info.catalognum
)
# Disambiguation.
if likelies["albumdisambig"] and album_info.albumdisambig:
dist.add_string(
"albumdisambig", likelies["albumdisambig"], album_info.albumdisambig
)
# Album ID.
if likelies["mb_albumid"]:
dist.add_equality(
"album_id", likelies["mb_albumid"], album_info.album_id
)
# Tracks.
dist.tracks = {}
for item, track in mapping.items():
dist.tracks[track] = track_distance(item, track, album_info.va)
dist.add("tracks", dist.tracks[track].distance)
# Missing tracks.
for _ in range(len(album_info.tracks) - len(mapping)):
dist.add("missing_tracks", 1.0)
# Unmatched tracks.
for _ in range(len(items) - len(mapping)):
dist.add("unmatched_tracks", 1.0)
# Plugins.
dist.update(plugins.album_distance(items, album_info, mapping))
return dist
def match_by_id(items: Iterable[Item]) -> AlbumInfo | None: def match_by_id(items: Iterable[Item]) -> AlbumInfo | None:
"""If the items are tagged with an external source ID, return an """If the items are tagged with an external source ID, return an
AlbumInfo object for the corresponding album. Otherwise, returns AlbumInfo object for the corresponding album. Otherwise, returns

View file

@ -37,6 +37,7 @@ import mediafile
import beets import beets
from beets import logging from beets import logging
from beets.autotag.distance import Distance
from beets.util.id_extractors import extract_release_id from beets.util.id_extractors import extract_release_id
if TYPE_CHECKING: if TYPE_CHECKING:
@ -53,7 +54,7 @@ if TYPE_CHECKING:
from confuse import ConfigView from confuse import ConfigView
from beets.autotag import AlbumInfo, Distance, TrackInfo from beets.autotag import AlbumInfo, TrackInfo
from beets.dbcore import Query from beets.dbcore import Query
from beets.dbcore.db import FieldQueryType from beets.dbcore.db import FieldQueryType
from beets.dbcore.types import Type from beets.dbcore.types import Type
@ -224,8 +225,6 @@ class BeetsPlugin:
"""Should return a Distance object to be added to the """Should return a Distance object to be added to the
distance for every track comparison. distance for every track comparison.
""" """
from beets.autotag.hooks import Distance
return Distance() return Distance()
def album_distance( def album_distance(
@ -237,8 +236,6 @@ class BeetsPlugin:
"""Should return a Distance object to be added to the """Should return a Distance object to be added to the
distance for every album-level comparison. distance for every album-level comparison.
""" """
from beets.autotag.hooks import Distance
return Distance() return Distance()
def candidates( def candidates(
@ -458,8 +455,6 @@ def track_distance(item: Item, info: TrackInfo) -> Distance:
"""Gets the track distance calculated by all loaded plugins. """Gets the track distance calculated by all loaded plugins.
Returns a Distance object. Returns a Distance object.
""" """
from beets.autotag.hooks import Distance
dist = Distance() dist = Distance()
for plugin in find_plugins(): for plugin in find_plugins():
dist.update(plugin.track_distance(item, info)) dist.update(plugin.track_distance(item, info))
@ -472,8 +467,6 @@ def album_distance(
mapping: dict[Item, TrackInfo], mapping: dict[Item, TrackInfo],
) -> Distance: ) -> Distance:
"""Returns the album distance calculated by plugins.""" """Returns the album distance calculated by plugins."""
from beets.autotag.hooks import Distance
dist = Distance() dist = Distance()
for plugin in find_plugins(): for plugin in find_plugins():
dist.update(plugin.album_distance(items, album_info, mapping)) dist.update(plugin.album_distance(items, album_info, mapping))
@ -660,8 +653,6 @@ def get_distance(
"""Returns the ``data_source`` weight and the maximum source weight """Returns the ``data_source`` weight and the maximum source weight
for albums or individual tracks. for albums or individual tracks.
""" """
from beets.autotag.hooks import Distance
dist = Distance() dist = Distance()
if info.data_source == data_source: if info.data_source == data_source:
dist.add("source", config["source_weight"].as_number()) dist.add("source", config["source_weight"].as_number())

View file

@ -24,7 +24,7 @@ import acoustid
import confuse import confuse
from beets import config, plugins, ui, util from beets import config, plugins, ui, util
from beets.autotag.hooks import Distance from beets.autotag.distance import Distance
from beetsplug.musicbrainz import MusicBrainzPlugin from beetsplug.musicbrainz import MusicBrainzPlugin
API_KEY = "1vOwZtEn" API_KEY = "1vOwZtEn"

View file

@ -38,7 +38,8 @@ from typing_extensions import TypedDict
import beets import beets
import beets.ui import beets.ui
from beets import config from beets import config
from beets.autotag.hooks import AlbumInfo, TrackInfo, string_dist from beets.autotag.distance import string_dist
from beets.autotag.hooks import AlbumInfo, TrackInfo
from beets.plugins import BeetsPlugin, MetadataSourcePlugin, get_distance from beets.plugins import BeetsPlugin, MetadataSourcePlugin, get_distance
from beets.util.id_extractors import extract_release_id from beets.util.id_extractors import extract_release_id

View file

@ -38,7 +38,7 @@ from unidecode import unidecode
import beets import beets
from beets import plugins, ui from beets import plugins, ui
from beets.autotag.hooks import string_dist from beets.autotag.distance import string_dist
from beets.util.config import sanitize_choices from beets.util.config import sanitize_choices
if TYPE_CHECKING: if TYPE_CHECKING:

View file

@ -0,0 +1,476 @@
import re
import unittest
from beets import config
from beets.autotag import AlbumInfo, TrackInfo, match
from beets.autotag.distance import Distance, string_dist
from beets.library import Item
from beets.test.helper import BeetsTestCase
def _make_item(title, track, artist="some artist"):
return Item(
title=title,
track=track,
artist=artist,
album="some album",
length=1,
mb_trackid="",
mb_albumid="",
mb_artistid="",
)
def _make_trackinfo():
return [
TrackInfo(
title="one", track_id=None, artist="some artist", length=1, index=1
),
TrackInfo(
title="two", track_id=None, artist="some artist", length=1, index=2
),
TrackInfo(
title="three",
track_id=None,
artist="some artist",
length=1,
index=3,
),
]
def _clear_weights():
"""Hack around the lazy descriptor used to cache weights for
Distance calculations.
"""
Distance.__dict__["_weights"].cache = {}
class DistanceTest(BeetsTestCase):
def tearDown(self):
super().tearDown()
_clear_weights()
def test_add(self):
dist = Distance()
dist.add("add", 1.0)
assert dist._penalties == {"add": [1.0]}
def test_add_equality(self):
dist = Distance()
dist.add_equality("equality", "ghi", ["abc", "def", "ghi"])
assert dist._penalties["equality"] == [0.0]
dist.add_equality("equality", "xyz", ["abc", "def", "ghi"])
assert dist._penalties["equality"] == [0.0, 1.0]
dist.add_equality("equality", "abc", re.compile(r"ABC", re.I))
assert dist._penalties["equality"] == [0.0, 1.0, 0.0]
def test_add_expr(self):
dist = Distance()
dist.add_expr("expr", True)
assert dist._penalties["expr"] == [1.0]
dist.add_expr("expr", False)
assert dist._penalties["expr"] == [1.0, 0.0]
def test_add_number(self):
dist = Distance()
# Add a full penalty for each number of difference between two numbers.
dist.add_number("number", 1, 1)
assert dist._penalties["number"] == [0.0]
dist.add_number("number", 1, 2)
assert dist._penalties["number"] == [0.0, 1.0]
dist.add_number("number", 2, 1)
assert dist._penalties["number"] == [0.0, 1.0, 1.0]
dist.add_number("number", -1, 2)
assert dist._penalties["number"] == [0.0, 1.0, 1.0, 1.0, 1.0, 1.0]
def test_add_priority(self):
dist = Distance()
dist.add_priority("priority", "abc", "abc")
assert dist._penalties["priority"] == [0.0]
dist.add_priority("priority", "def", ["abc", "def"])
assert dist._penalties["priority"] == [0.0, 0.5]
dist.add_priority(
"priority", "gh", ["ab", "cd", "ef", re.compile("GH", re.I)]
)
assert dist._penalties["priority"] == [0.0, 0.5, 0.75]
dist.add_priority("priority", "xyz", ["abc", "def"])
assert dist._penalties["priority"] == [0.0, 0.5, 0.75, 1.0]
def test_add_ratio(self):
dist = Distance()
dist.add_ratio("ratio", 25, 100)
assert dist._penalties["ratio"] == [0.25]
dist.add_ratio("ratio", 10, 5)
assert dist._penalties["ratio"] == [0.25, 1.0]
dist.add_ratio("ratio", -5, 5)
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0]
dist.add_ratio("ratio", 5, 0)
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0, 0.0]
def test_add_string(self):
dist = Distance()
sdist = string_dist("abc", "bcd")
dist.add_string("string", "abc", "bcd")
assert dist._penalties["string"] == [sdist]
assert dist._penalties["string"] != [0]
def test_add_string_none(self):
dist = Distance()
dist.add_string("string", None, "string")
assert dist._penalties["string"] == [1]
def test_add_string_both_none(self):
dist = Distance()
dist.add_string("string", None, None)
assert dist._penalties["string"] == [0]
def test_distance(self):
config["match"]["distance_weights"]["album"] = 2.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("media", 0.25)
dist.add("media", 0.75)
assert dist.distance == 0.5
# __getitem__()
assert dist["album"] == 0.25
assert dist["media"] == 0.25
def test_max_distance(self):
config["match"]["distance_weights"]["album"] = 3.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("medium", 0.0)
dist.add("medium", 0.0)
assert dist.max_distance == 5.0
def test_operators(self):
config["match"]["distance_weights"]["source"] = 1.0
config["match"]["distance_weights"]["album"] = 2.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("source", 0.0)
dist.add("album", 0.5)
dist.add("medium", 0.25)
dist.add("medium", 0.75)
assert len(dist) == 2
assert list(dist) == [("album", 0.2), ("medium", 0.2)]
assert dist == 0.4
assert dist < 1.0
assert dist > 0.0
assert dist - 0.4 == 0.0
assert 0.4 - dist == 0.0
assert float(dist) == 0.4
def test_raw_distance(self):
config["match"]["distance_weights"]["album"] = 3.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("medium", 0.25)
dist.add("medium", 0.5)
assert dist.raw_distance == 2.25
def test_items(self):
config["match"]["distance_weights"]["album"] = 4.0
config["match"]["distance_weights"]["medium"] = 2.0
_clear_weights()
dist = Distance()
dist.add("album", 0.1875)
dist.add("medium", 0.75)
assert dist.items() == [("medium", 0.25), ("album", 0.125)]
# Sort by key if distance is equal.
dist = Distance()
dist.add("album", 0.375)
dist.add("medium", 0.75)
assert dist.items() == [("album", 0.25), ("medium", 0.25)]
def test_update(self):
dist1 = Distance()
dist1.add("album", 0.5)
dist1.add("media", 1.0)
dist2 = Distance()
dist2.add("album", 0.75)
dist2.add("album", 0.25)
dist2.add("media", 0.05)
dist1.update(dist2)
assert dist1._penalties == {
"album": [0.5, 0.75, 0.25],
"media": [1.0, 0.05],
}
class TrackDistanceTest(BeetsTestCase):
def test_identical_tracks(self):
item = _make_item("one", 1)
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist == 0.0
def test_different_title(self):
item = _make_item("foo", 1)
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist != 0.0
def test_different_artist(self):
item = _make_item("one", 1)
item.artist = "foo"
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist != 0.0
def test_various_artists_tolerated(self):
item = _make_item("one", 1)
item.artist = "Various Artists"
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist == 0.0
class AlbumDistanceTest(BeetsTestCase):
def _mapping(self, items, info):
out = {}
for i, t in zip(items, info.tracks):
out[i] = t
return out
def _dist(self, items, info):
return match.distance(items, info, self._mapping(items, info))
def test_identical_albums(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
assert self._dist(items, info) == 0
def test_incomplete_album(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
dist = self._dist(items, info)
assert dist != 0
# Make sure the distance is not too great
assert dist < 0.2
def test_global_artists_differ(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="someone else",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
assert self._dist(items, info) != 0
def test_comp_track_artists_match(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="should be ignored",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
assert self._dist(items, info) == 0
def test_comp_no_track_artists(self):
# Some VA releases don't have track artists (incomplete metadata).
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="should be ignored",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
info.tracks[0].artist = None
info.tracks[1].artist = None
info.tracks[2].artist = None
assert self._dist(items, info) == 0
def test_comp_track_artists_do_not_match(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2, "someone else"))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
assert self._dist(items, info) != 0
def test_tracks_out_of_order(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("three", 2))
items.append(_make_item("two", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
dist = self._dist(items, info)
assert 0 < dist < 0.2
def test_two_medium_release(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
info.tracks[0].medium_index = 1
info.tracks[1].medium_index = 2
info.tracks[2].medium_index = 1
dist = self._dist(items, info)
assert dist == 0
def test_per_medium_track_numbers(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 1))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
info.tracks[0].medium_index = 1
info.tracks[1].medium_index = 2
info.tracks[2].medium_index = 1
dist = self._dist(items, info)
assert dist == 0
class StringDistanceTest(unittest.TestCase):
def test_equal_strings(self):
dist = string_dist("Some String", "Some String")
assert dist == 0.0
def test_different_strings(self):
dist = string_dist("Some String", "Totally Different")
assert dist != 0.0
def test_punctuation_ignored(self):
dist = string_dist("Some String", "Some.String!")
assert dist == 0.0
def test_case_ignored(self):
dist = string_dist("Some String", "sOME sTring")
assert dist == 0.0
def test_leading_the_has_lower_weight(self):
dist1 = string_dist("XXX Band Name", "Band Name")
dist2 = string_dist("The Band Name", "Band Name")
assert dist2 < dist1
def test_parens_have_lower_weight(self):
dist1 = string_dist("One .Two.", "One")
dist2 = string_dist("One (Two)", "One")
assert dist2 < dist1
def test_brackets_have_lower_weight(self):
dist1 = string_dist("One .Two.", "One")
dist2 = string_dist("One [Two]", "One")
assert dist2 < dist1
def test_ep_label_has_zero_weight(self):
dist = string_dist("My Song (EP)", "My Song")
assert dist == 0.0
def test_featured_has_lower_weight(self):
dist1 = string_dist("My Song blah Someone", "My Song")
dist2 = string_dist("My Song feat Someone", "My Song")
assert dist2 < dist1
def test_postfix_the(self):
dist = string_dist("The Song Title", "Song Title, The")
assert dist == 0.0
def test_postfix_a(self):
dist = string_dist("A Song Title", "Song Title, A")
assert dist == 0.0
def test_postfix_an(self):
dist = string_dist("An Album Title", "Album Title, An")
assert dist == 0.0
def test_empty_strings(self):
dist = string_dist("", "")
assert dist == 0.0
def test_solo_pattern(self):
# Just make sure these don't crash.
string_dist("The ", "")
string_dist("(EP)", "(EP)")
string_dist(", An", "")
def test_heuristic_does_not_harm_distance(self):
dist = string_dist("Untitled", "[Untitled]")
assert dist == 0.0
def test_ampersand_expansion(self):
dist = string_dist("And", "&")
assert dist == 0.0
def test_accented_characters(self):
dist = string_dist("\xe9\xe1\xf1", "ean")
assert dist == 0.0

View file

@ -14,410 +14,14 @@
"""Tests for autotagging functionality.""" """Tests for autotagging functionality."""
import re
import unittest
import pytest import pytest
from beets import autotag, config from beets import autotag, config
from beets.autotag import AlbumInfo, TrackInfo, correct_list_fields, match from beets.autotag import AlbumInfo, TrackInfo, correct_list_fields, match
from beets.autotag.hooks import Distance, string_dist
from beets.library import Item from beets.library import Item
from beets.test.helper import BeetsTestCase, ConfigMixin from beets.test.helper import BeetsTestCase, ConfigMixin
def _make_item(title, track, artist="some artist"):
return Item(
title=title,
track=track,
artist=artist,
album="some album",
length=1,
mb_trackid="",
mb_albumid="",
mb_artistid="",
)
def _make_trackinfo():
return [
TrackInfo(
title="one", track_id=None, artist="some artist", length=1, index=1
),
TrackInfo(
title="two", track_id=None, artist="some artist", length=1, index=2
),
TrackInfo(
title="three",
track_id=None,
artist="some artist",
length=1,
index=3,
),
]
def _clear_weights():
"""Hack around the lazy descriptor used to cache weights for
Distance calculations.
"""
Distance.__dict__["_weights"].cache = {}
class DistanceTest(BeetsTestCase):
def tearDown(self):
super().tearDown()
_clear_weights()
def test_add(self):
dist = Distance()
dist.add("add", 1.0)
assert dist._penalties == {"add": [1.0]}
def test_add_equality(self):
dist = Distance()
dist.add_equality("equality", "ghi", ["abc", "def", "ghi"])
assert dist._penalties["equality"] == [0.0]
dist.add_equality("equality", "xyz", ["abc", "def", "ghi"])
assert dist._penalties["equality"] == [0.0, 1.0]
dist.add_equality("equality", "abc", re.compile(r"ABC", re.I))
assert dist._penalties["equality"] == [0.0, 1.0, 0.0]
def test_add_expr(self):
dist = Distance()
dist.add_expr("expr", True)
assert dist._penalties["expr"] == [1.0]
dist.add_expr("expr", False)
assert dist._penalties["expr"] == [1.0, 0.0]
def test_add_number(self):
dist = Distance()
# Add a full penalty for each number of difference between two numbers.
dist.add_number("number", 1, 1)
assert dist._penalties["number"] == [0.0]
dist.add_number("number", 1, 2)
assert dist._penalties["number"] == [0.0, 1.0]
dist.add_number("number", 2, 1)
assert dist._penalties["number"] == [0.0, 1.0, 1.0]
dist.add_number("number", -1, 2)
assert dist._penalties["number"] == [0.0, 1.0, 1.0, 1.0, 1.0, 1.0]
def test_add_priority(self):
dist = Distance()
dist.add_priority("priority", "abc", "abc")
assert dist._penalties["priority"] == [0.0]
dist.add_priority("priority", "def", ["abc", "def"])
assert dist._penalties["priority"] == [0.0, 0.5]
dist.add_priority(
"priority", "gh", ["ab", "cd", "ef", re.compile("GH", re.I)]
)
assert dist._penalties["priority"] == [0.0, 0.5, 0.75]
dist.add_priority("priority", "xyz", ["abc", "def"])
assert dist._penalties["priority"] == [0.0, 0.5, 0.75, 1.0]
def test_add_ratio(self):
dist = Distance()
dist.add_ratio("ratio", 25, 100)
assert dist._penalties["ratio"] == [0.25]
dist.add_ratio("ratio", 10, 5)
assert dist._penalties["ratio"] == [0.25, 1.0]
dist.add_ratio("ratio", -5, 5)
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0]
dist.add_ratio("ratio", 5, 0)
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0, 0.0]
def test_add_string(self):
dist = Distance()
sdist = string_dist("abc", "bcd")
dist.add_string("string", "abc", "bcd")
assert dist._penalties["string"] == [sdist]
assert dist._penalties["string"] != [0]
def test_add_string_none(self):
dist = Distance()
dist.add_string("string", None, "string")
assert dist._penalties["string"] == [1]
def test_add_string_both_none(self):
dist = Distance()
dist.add_string("string", None, None)
assert dist._penalties["string"] == [0]
def test_distance(self):
config["match"]["distance_weights"]["album"] = 2.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("media", 0.25)
dist.add("media", 0.75)
assert dist.distance == 0.5
# __getitem__()
assert dist["album"] == 0.25
assert dist["media"] == 0.25
def test_max_distance(self):
config["match"]["distance_weights"]["album"] = 3.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("medium", 0.0)
dist.add("medium", 0.0)
assert dist.max_distance == 5.0
def test_operators(self):
config["match"]["distance_weights"]["source"] = 1.0
config["match"]["distance_weights"]["album"] = 2.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("source", 0.0)
dist.add("album", 0.5)
dist.add("medium", 0.25)
dist.add("medium", 0.75)
assert len(dist) == 2
assert list(dist) == [("album", 0.2), ("medium", 0.2)]
assert dist == 0.4
assert dist < 1.0
assert dist > 0.0
assert dist - 0.4 == 0.0
assert 0.4 - dist == 0.0
assert float(dist) == 0.4
def test_raw_distance(self):
config["match"]["distance_weights"]["album"] = 3.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("medium", 0.25)
dist.add("medium", 0.5)
assert dist.raw_distance == 2.25
def test_items(self):
config["match"]["distance_weights"]["album"] = 4.0
config["match"]["distance_weights"]["medium"] = 2.0
_clear_weights()
dist = Distance()
dist.add("album", 0.1875)
dist.add("medium", 0.75)
assert dist.items() == [("medium", 0.25), ("album", 0.125)]
# Sort by key if distance is equal.
dist = Distance()
dist.add("album", 0.375)
dist.add("medium", 0.75)
assert dist.items() == [("album", 0.25), ("medium", 0.25)]
def test_update(self):
dist1 = Distance()
dist1.add("album", 0.5)
dist1.add("media", 1.0)
dist2 = Distance()
dist2.add("album", 0.75)
dist2.add("album", 0.25)
dist2.add("media", 0.05)
dist1.update(dist2)
assert dist1._penalties == {
"album": [0.5, 0.75, 0.25],
"media": [1.0, 0.05],
}
class TrackDistanceTest(BeetsTestCase):
def test_identical_tracks(self):
item = _make_item("one", 1)
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist == 0.0
def test_different_title(self):
item = _make_item("foo", 1)
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist != 0.0
def test_different_artist(self):
item = _make_item("one", 1)
item.artist = "foo"
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist != 0.0
def test_various_artists_tolerated(self):
item = _make_item("one", 1)
item.artist = "Various Artists"
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist == 0.0
class AlbumDistanceTest(BeetsTestCase):
def _mapping(self, items, info):
out = {}
for i, t in zip(items, info.tracks):
out[i] = t
return out
def _dist(self, items, info):
return match.distance(items, info, self._mapping(items, info))
def test_identical_albums(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
assert self._dist(items, info) == 0
def test_incomplete_album(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
dist = self._dist(items, info)
assert dist != 0
# Make sure the distance is not too great
assert dist < 0.2
def test_global_artists_differ(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="someone else",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
assert self._dist(items, info) != 0
def test_comp_track_artists_match(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="should be ignored",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
assert self._dist(items, info) == 0
def test_comp_no_track_artists(self):
# Some VA releases don't have track artists (incomplete metadata).
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="should be ignored",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
info.tracks[0].artist = None
info.tracks[1].artist = None
info.tracks[2].artist = None
assert self._dist(items, info) == 0
def test_comp_track_artists_do_not_match(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2, "someone else"))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
assert self._dist(items, info) != 0
def test_tracks_out_of_order(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("three", 2))
items.append(_make_item("two", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
dist = self._dist(items, info)
assert 0 < dist < 0.2
def test_two_medium_release(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
info.tracks[0].medium_index = 1
info.tracks[1].medium_index = 2
info.tracks[2].medium_index = 1
dist = self._dist(items, info)
assert dist == 0
def test_per_medium_track_numbers(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 1))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
info.tracks[0].medium_index = 1
info.tracks[1].medium_index = 2
info.tracks[2].medium_index = 1
dist = self._dist(items, info)
assert dist == 0
class TestAssignment(ConfigMixin): class TestAssignment(ConfigMixin):
A = "one" A = "one"
B = "two" B = "two"
@ -840,82 +444,6 @@ class ApplyCompilationTest(BeetsTestCase, ApplyTestUtil):
assert self.items[1].comp assert self.items[1].comp
class StringDistanceTest(unittest.TestCase):
def test_equal_strings(self):
dist = string_dist("Some String", "Some String")
assert dist == 0.0
def test_different_strings(self):
dist = string_dist("Some String", "Totally Different")
assert dist != 0.0
def test_punctuation_ignored(self):
dist = string_dist("Some String", "Some.String!")
assert dist == 0.0
def test_case_ignored(self):
dist = string_dist("Some String", "sOME sTring")
assert dist == 0.0
def test_leading_the_has_lower_weight(self):
dist1 = string_dist("XXX Band Name", "Band Name")
dist2 = string_dist("The Band Name", "Band Name")
assert dist2 < dist1
def test_parens_have_lower_weight(self):
dist1 = string_dist("One .Two.", "One")
dist2 = string_dist("One (Two)", "One")
assert dist2 < dist1
def test_brackets_have_lower_weight(self):
dist1 = string_dist("One .Two.", "One")
dist2 = string_dist("One [Two]", "One")
assert dist2 < dist1
def test_ep_label_has_zero_weight(self):
dist = string_dist("My Song (EP)", "My Song")
assert dist == 0.0
def test_featured_has_lower_weight(self):
dist1 = string_dist("My Song blah Someone", "My Song")
dist2 = string_dist("My Song feat Someone", "My Song")
assert dist2 < dist1
def test_postfix_the(self):
dist = string_dist("The Song Title", "Song Title, The")
assert dist == 0.0
def test_postfix_a(self):
dist = string_dist("A Song Title", "Song Title, A")
assert dist == 0.0
def test_postfix_an(self):
dist = string_dist("An Album Title", "Album Title, An")
assert dist == 0.0
def test_empty_strings(self):
dist = string_dist("", "")
assert dist == 0.0
def test_solo_pattern(self):
# Just make sure these don't crash.
string_dist("The ", "")
string_dist("(EP)", "(EP)")
string_dist(", An", "")
def test_heuristic_does_not_harm_distance(self):
dist = string_dist("Untitled", "[Untitled]")
assert dist == 0.0
def test_ampersand_expansion(self):
dist = string_dist("And", "&")
assert dist == 0.0
def test_accented_characters(self):
dist = string_dist("\xe9\xe1\xf1", "ean")
assert dist == 0.0
@pytest.mark.parametrize( @pytest.mark.parametrize(
"single_field,list_field", "single_field,list_field",
[ [