Move distance to a separate module

This commit is contained in:
Šarūnas Nejus 2025-05-25 02:59:27 +01:00
parent 01b6ea7898
commit adbd50b237
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
10 changed files with 1028 additions and 1024 deletions

View file

@ -14,22 +14,26 @@
"""Facilities for automatically determining files' correct metadata."""
from collections.abc import Mapping, Sequence
from typing import Union
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from beets import config, logging
from beets.library import Album, Item, LibModel
# Parts of external interface.
from beets.util import unique_list
from .hooks import AlbumInfo, AlbumMatch, Distance, TrackInfo, TrackMatch
from .hooks import AlbumInfo, AlbumMatch, TrackInfo, TrackMatch
from .match import Proposal, Recommendation, tag_album, tag_item
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from beets.library import Album, Item, LibModel
__all__ = [
"AlbumInfo",
"AlbumMatch",
"Distance",
"TrackInfo",
"TrackMatch",
"Proposal",

531
beets/autotag/distance.py Normal file
View file

@ -0,0 +1,531 @@
from __future__ import annotations
import datetime
import re
from functools import cache, total_ordering
from typing import TYPE_CHECKING, Any
from jellyfish import levenshtein_distance
from unidecode import unidecode
from beets import config, plugins
from beets.util import as_string, cached_classproperty, get_most_common_tags
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from beets.library import Item
from .hooks import AlbumInfo, TrackInfo
# Candidate distance scoring.
# Artist signals that indicate "various artists". These are used at the
# album level to determine whether a given release is likely a VA
# release and also on the track level to to remove the penalty for
# differing artists.
VA_ARTISTS = ("", "various artists", "various", "va", "unknown")
# Parameters for string distance function.
# Words that can be moved to the end of a string using a comma.
SD_END_WORDS = ["the", "a", "an"]
# Reduced weights for certain portions of the string.
SD_PATTERNS = [
(r"^the ", 0.1),
(r"[\[\(]?(ep|single)[\]\)]?", 0.0),
(r"[\[\(]?(featuring|feat|ft)[\. :].+", 0.1),
(r"\(.*?\)", 0.3),
(r"\[.*?\]", 0.3),
(r"(, )?(pt\.|part) .+", 0.2),
]
# Replacements to use before testing distance.
SD_REPLACE = [
(r"&", "and"),
]
def _string_dist_basic(str1: str, str2: str) -> float:
"""Basic edit distance between two strings, ignoring
non-alphanumeric characters and case. Comparisons are based on a
transliteration/lowering to ASCII characters. Normalized by string
length.
"""
assert isinstance(str1, str)
assert isinstance(str2, str)
str1 = as_string(unidecode(str1))
str2 = as_string(unidecode(str2))
str1 = re.sub(r"[^a-z0-9]", "", str1.lower())
str2 = re.sub(r"[^a-z0-9]", "", str2.lower())
if not str1 and not str2:
return 0.0
return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2)))
def string_dist(str1: str | None, str2: str | None) -> float:
"""Gives an "intuitive" edit distance between two strings. This is
an edit distance, normalized by the string length, with a number of
tweaks that reflect intuition about text.
"""
if str1 is None and str2 is None:
return 0.0
if str1 is None or str2 is None:
return 1.0
str1 = str1.lower()
str2 = str2.lower()
# Don't penalize strings that move certain words to the end. For
# example, "the something" should be considered equal to
# "something, the".
for word in SD_END_WORDS:
if str1.endswith(", %s" % word):
str1 = "{} {}".format(word, str1[: -len(word) - 2])
if str2.endswith(", %s" % word):
str2 = "{} {}".format(word, str2[: -len(word) - 2])
# Perform a couple of basic normalizing substitutions.
for pat, repl in SD_REPLACE:
str1 = re.sub(pat, repl, str1)
str2 = re.sub(pat, repl, str2)
# Change the weight for certain string portions matched by a set
# of regular expressions. We gradually change the strings and build
# up penalties associated with parts of the string that were
# deleted.
base_dist = _string_dist_basic(str1, str2)
penalty = 0.0
for pat, weight in SD_PATTERNS:
# Get strings that drop the pattern.
case_str1 = re.sub(pat, "", str1)
case_str2 = re.sub(pat, "", str2)
if case_str1 != str1 or case_str2 != str2:
# If the pattern was present (i.e., it is deleted in the
# the current case), recalculate the distances for the
# modified strings.
case_dist = _string_dist_basic(case_str1, case_str2)
case_delta = max(0.0, base_dist - case_dist)
if case_delta == 0.0:
continue
# Shift our baseline strings down (to avoid rematching the
# same part of the string) and add a scaled distance
# amount to the penalties.
str1 = case_str1
str2 = case_str2
base_dist = case_dist
penalty += weight * case_delta
return base_dist + penalty
@total_ordering
class Distance:
"""Keeps track of multiple distance penalties. Provides a single
weighted distance for all penalties as well as a weighted distance
for each individual penalty.
"""
def __init__(self) -> None:
self._penalties: dict[str, list[float]] = {}
self.tracks: dict[TrackInfo, Distance] = {}
@cached_classproperty
def _weights(cls) -> dict[str, float]:
"""A dictionary from keys to floating-point weights."""
weights_view = config["match"]["distance_weights"]
weights = {}
for key in weights_view.keys():
weights[key] = weights_view[key].as_number()
return weights
# Access the components and their aggregates.
@property
def distance(self) -> float:
"""Return a weighted and normalized distance across all
penalties.
"""
dist_max = self.max_distance
if dist_max:
return self.raw_distance / self.max_distance
return 0.0
@property
def max_distance(self) -> float:
"""Return the maximum distance penalty (normalization factor)."""
dist_max = 0.0
for key, penalty in self._penalties.items():
dist_max += len(penalty) * self._weights[key]
return dist_max
@property
def raw_distance(self) -> float:
"""Return the raw (denormalized) distance."""
dist_raw = 0.0
for key, penalty in self._penalties.items():
dist_raw += sum(penalty) * self._weights[key]
return dist_raw
def items(self) -> list[tuple[str, float]]:
"""Return a list of (key, dist) pairs, with `dist` being the
weighted distance, sorted from highest to lowest. Does not
include penalties with a zero value.
"""
list_ = []
for key in self._penalties:
dist = self[key]
if dist:
list_.append((key, dist))
# Convert distance into a negative float we can sort items in
# ascending order (for keys, when the penalty is equal) and
# still get the items with the biggest distance first.
return sorted(
list_, key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0])
)
def __hash__(self) -> int:
return id(self)
def __eq__(self, other) -> bool:
return self.distance == other
# Behave like a float.
def __lt__(self, other) -> bool:
return self.distance < other
def __float__(self) -> float:
return self.distance
def __sub__(self, other) -> float:
return self.distance - other
def __rsub__(self, other) -> float:
return other - self.distance
def __str__(self) -> str:
return f"{self.distance:.2f}"
# Behave like a dict.
def __getitem__(self, key) -> float:
"""Returns the weighted distance for a named penalty."""
dist = sum(self._penalties[key]) * self._weights[key]
dist_max = self.max_distance
if dist_max:
return dist / dist_max
return 0.0
def __iter__(self) -> Iterator[tuple[str, float]]:
return iter(self.items())
def __len__(self) -> int:
return len(self.items())
def keys(self) -> list[str]:
return [key for key, _ in self.items()]
def update(self, dist: Distance):
"""Adds all the distance penalties from `dist`."""
if not isinstance(dist, Distance):
raise ValueError(
"`dist` must be a Distance object, not {}".format(type(dist))
)
for key, penalties in dist._penalties.items():
self._penalties.setdefault(key, []).extend(penalties)
# Adding components.
def _eq(self, value1: re.Pattern[str] | Any, value2: Any) -> bool:
"""Returns True if `value1` is equal to `value2`. `value1` may
be a compiled regular expression, in which case it will be
matched against `value2`.
"""
if isinstance(value1, re.Pattern):
return bool(value1.match(value2))
return value1 == value2
def add(self, key: str, dist: float):
"""Adds a distance penalty. `key` must correspond with a
configured weight setting. `dist` must be a float between 0.0
and 1.0, and will be added to any existing distance penalties
for the same key.
"""
if not 0.0 <= dist <= 1.0:
raise ValueError(f"`dist` must be between 0.0 and 1.0, not {dist}")
self._penalties.setdefault(key, []).append(dist)
def add_equality(
self,
key: str,
value: Any,
options: list[Any] | tuple[Any, ...] | Any,
):
"""Adds a distance penalty of 1.0 if `value` doesn't match any
of the values in `options`. If an option is a compiled regular
expression, it will be considered equal if it matches against
`value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
for opt in options:
if self._eq(opt, value):
dist = 0.0
break
else:
dist = 1.0
self.add(key, dist)
def add_expr(self, key: str, expr: bool):
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
or 0.0.
"""
if expr:
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_number(self, key: str, number1: int, number2: int):
"""Adds a distance penalty of 1.0 for each number of difference
between `number1` and `number2`, or 0.0 when there is no
difference. Use this when there is no upper limit on the
difference between the two numbers.
"""
diff = abs(number1 - number2)
if diff:
for i in range(diff):
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_priority(
self,
key: str,
value: Any,
options: list[Any] | tuple[Any, ...] | Any,
):
"""Adds a distance penalty that corresponds to the position at
which `value` appears in `options`. A distance penalty of 0.0
for the first option, or 1.0 if there is no matching option. If
an option is a compiled regular expression, it will be
considered equal if it matches against `value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
unit = 1.0 / (len(options) or 1)
for i, opt in enumerate(options):
if self._eq(opt, value):
dist = i * unit
break
else:
dist = 1.0
self.add(key, dist)
def add_ratio(
self,
key: str,
number1: int | float,
number2: int | float,
):
"""Adds a distance penalty for `number1` as a ratio of `number2`.
`number1` is bound at 0 and `number2`.
"""
number = float(max(min(number1, number2), 0))
if number2:
dist = number / number2
else:
dist = 0.0
self.add(key, dist)
def add_string(self, key: str, str1: str | None, str2: str | None):
"""Adds a distance penalty based on the edit distance between
`str1` and `str2`.
"""
dist = string_dist(str1, str2)
self.add(key, dist)
@cache
def get_track_length_grace() -> float:
"""Get cached grace period for track length matching."""
return config["match"]["track_length_grace"].as_number()
@cache
def get_track_length_max() -> float:
"""Get cached maximum track length for track length matching."""
return config["match"]["track_length_max"].as_number()
def track_index_changed(item: Item, track_info: TrackInfo) -> bool:
"""Returns True if the item and track info index is different. Tolerates
per disc and per release numbering.
"""
return item.track not in (track_info.medium_index, track_info.index)
def track_distance(
item: Item,
track_info: TrackInfo,
incl_artist: bool = False,
) -> Distance:
"""Determines the significance of a track metadata change. Returns a
Distance object. `incl_artist` indicates that a distance component should
be included for the track artist (i.e., for various-artist releases).
``track_length_grace`` and ``track_length_max`` configuration options are
cached because this function is called many times during the matching
process and their access comes with a performance overhead.
"""
dist = Distance()
# Length.
if info_length := track_info.length:
diff = abs(item.length - info_length) - get_track_length_grace()
dist.add_ratio("track_length", diff, get_track_length_max())
# Title.
dist.add_string("track_title", item.title, track_info.title)
# Artist. Only check if there is actually an artist in the track data.
if (
incl_artist
and track_info.artist
and item.artist.lower() not in VA_ARTISTS
):
dist.add_string("track_artist", item.artist, track_info.artist)
# Track index.
if track_info.index and item.track:
dist.add_expr("track_index", track_index_changed(item, track_info))
# Track ID.
if item.mb_trackid:
dist.add_expr("track_id", item.mb_trackid != track_info.track_id)
# Penalize mismatching disc numbers.
if track_info.medium and item.disc:
dist.add_expr("medium", item.disc != track_info.medium)
# Plugins.
dist.update(plugins.track_distance(item, track_info))
return dist
def distance(
items: Sequence[Item],
album_info: AlbumInfo,
mapping: dict[Item, TrackInfo],
) -> Distance:
"""Determines how "significant" an album metadata change would be.
Returns a Distance object. `album_info` is an AlbumInfo object
reflecting the album to be compared. `items` is a sequence of all
Item objects that will be matched (order is not important).
`mapping` is a dictionary mapping Items to TrackInfo objects; the
keys are a subset of `items` and the values are a subset of
`album_info.tracks`.
"""
likelies, _ = get_most_common_tags(items)
dist = Distance()
# Artist, if not various.
if not album_info.va:
dist.add_string("artist", likelies["artist"], album_info.artist)
# Album.
dist.add_string("album", likelies["album"], album_info.album)
preferred_config = config["match"]["preferred"]
# Current or preferred media.
if album_info.media:
# Preferred media options.
media_patterns: Sequence[str] = preferred_config["media"].as_str_seq()
options = [
re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in media_patterns
]
if options:
dist.add_priority("media", album_info.media, options)
# Current media.
elif likelies["media"]:
dist.add_equality("media", album_info.media, likelies["media"])
# Mediums.
if likelies["disctotal"] and album_info.mediums:
dist.add_number("mediums", likelies["disctotal"], album_info.mediums)
# Prefer earliest release.
if album_info.year and preferred_config["original_year"]:
# Assume 1889 (earliest first gramophone discs) if we don't know the
# original year.
original = album_info.original_year or 1889
diff = abs(album_info.year - original)
diff_max = abs(datetime.date.today().year - original)
dist.add_ratio("year", diff, diff_max)
# Year.
elif likelies["year"] and album_info.year:
if likelies["year"] in (album_info.year, album_info.original_year):
# No penalty for matching release or original year.
dist.add("year", 0.0)
elif album_info.original_year:
# Prefer matchest closest to the release year.
diff = abs(likelies["year"] - album_info.year)
diff_max = abs(
datetime.date.today().year - album_info.original_year
)
dist.add_ratio("year", diff, diff_max)
else:
# Full penalty when there is no original year.
dist.add("year", 1.0)
# Preferred countries.
country_patterns: Sequence[str] = preferred_config["countries"].as_str_seq()
options = [re.compile(pat, re.I) for pat in country_patterns]
if album_info.country and options:
dist.add_priority("country", album_info.country, options)
# Country.
elif likelies["country"] and album_info.country:
dist.add_string("country", likelies["country"], album_info.country)
# Label.
if likelies["label"] and album_info.label:
dist.add_string("label", likelies["label"], album_info.label)
# Catalog number.
if likelies["catalognum"] and album_info.catalognum:
dist.add_string(
"catalognum", likelies["catalognum"], album_info.catalognum
)
# Disambiguation.
if likelies["albumdisambig"] and album_info.albumdisambig:
dist.add_string(
"albumdisambig", likelies["albumdisambig"], album_info.albumdisambig
)
# Album ID.
if likelies["mb_albumid"]:
dist.add_equality(
"album_id", likelies["mb_albumid"], album_info.album_id
)
# Tracks.
dist.tracks = {}
for item, track in mapping.items():
dist.tracks[track] = track_distance(item, track, album_info.va)
dist.add("tracks", dist.tracks[track].distance)
# Missing tracks.
for _ in range(len(album_info.tracks) - len(mapping)):
dist.add("missing_tracks", 1.0)
# Unmatched tracks.
for _ in range(len(items) - len(mapping)):
dist.add("unmatched_tracks", 1.0)
# Plugins.
dist.update(plugins.album_distance(items, album_info, mapping))
return dist

View file

@ -16,21 +16,15 @@
from __future__ import annotations
import re
from functools import total_ordering
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
from jellyfish import levenshtein_distance
from unidecode import unidecode
from beets import config, logging
from beets.util import as_string, cached_classproperty
from beets import logging
if TYPE_CHECKING:
from collections.abc import Iterator
from beets.library import Item
from .distance import Distance
log = logging.getLogger("beets")
V = TypeVar("V")
@ -254,328 +248,6 @@ class TrackInfo(AttrDict[Any]):
return dupe
# Candidate distance scoring.
# Parameters for string distance function.
# Words that can be moved to the end of a string using a comma.
SD_END_WORDS = ["the", "a", "an"]
# Reduced weights for certain portions of the string.
SD_PATTERNS = [
(r"^the ", 0.1),
(r"[\[\(]?(ep|single)[\]\)]?", 0.0),
(r"[\[\(]?(featuring|feat|ft)[\. :].+", 0.1),
(r"\(.*?\)", 0.3),
(r"\[.*?\]", 0.3),
(r"(, )?(pt\.|part) .+", 0.2),
]
# Replacements to use before testing distance.
SD_REPLACE = [
(r"&", "and"),
]
def _string_dist_basic(str1: str, str2: str) -> float:
"""Basic edit distance between two strings, ignoring
non-alphanumeric characters and case. Comparisons are based on a
transliteration/lowering to ASCII characters. Normalized by string
length.
"""
assert isinstance(str1, str)
assert isinstance(str2, str)
str1 = as_string(unidecode(str1))
str2 = as_string(unidecode(str2))
str1 = re.sub(r"[^a-z0-9]", "", str1.lower())
str2 = re.sub(r"[^a-z0-9]", "", str2.lower())
if not str1 and not str2:
return 0.0
return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2)))
def string_dist(str1: str | None, str2: str | None) -> float:
"""Gives an "intuitive" edit distance between two strings. This is
an edit distance, normalized by the string length, with a number of
tweaks that reflect intuition about text.
"""
if str1 is None and str2 is None:
return 0.0
if str1 is None or str2 is None:
return 1.0
str1 = str1.lower()
str2 = str2.lower()
# Don't penalize strings that move certain words to the end. For
# example, "the something" should be considered equal to
# "something, the".
for word in SD_END_WORDS:
if str1.endswith(", %s" % word):
str1 = "{} {}".format(word, str1[: -len(word) - 2])
if str2.endswith(", %s" % word):
str2 = "{} {}".format(word, str2[: -len(word) - 2])
# Perform a couple of basic normalizing substitutions.
for pat, repl in SD_REPLACE:
str1 = re.sub(pat, repl, str1)
str2 = re.sub(pat, repl, str2)
# Change the weight for certain string portions matched by a set
# of regular expressions. We gradually change the strings and build
# up penalties associated with parts of the string that were
# deleted.
base_dist = _string_dist_basic(str1, str2)
penalty = 0.0
for pat, weight in SD_PATTERNS:
# Get strings that drop the pattern.
case_str1 = re.sub(pat, "", str1)
case_str2 = re.sub(pat, "", str2)
if case_str1 != str1 or case_str2 != str2:
# If the pattern was present (i.e., it is deleted in the
# the current case), recalculate the distances for the
# modified strings.
case_dist = _string_dist_basic(case_str1, case_str2)
case_delta = max(0.0, base_dist - case_dist)
if case_delta == 0.0:
continue
# Shift our baseline strings down (to avoid rematching the
# same part of the string) and add a scaled distance
# amount to the penalties.
str1 = case_str1
str2 = case_str2
base_dist = case_dist
penalty += weight * case_delta
return base_dist + penalty
@total_ordering
class Distance:
"""Keeps track of multiple distance penalties. Provides a single
weighted distance for all penalties as well as a weighted distance
for each individual penalty.
"""
def __init__(self) -> None:
self._penalties: dict[str, list[float]] = {}
self.tracks: dict[TrackInfo, Distance] = {}
@cached_classproperty
def _weights(cls) -> dict[str, float]:
"""A dictionary from keys to floating-point weights."""
weights_view = config["match"]["distance_weights"]
weights = {}
for key in weights_view.keys():
weights[key] = weights_view[key].as_number()
return weights
# Access the components and their aggregates.
@property
def distance(self) -> float:
"""Return a weighted and normalized distance across all
penalties.
"""
dist_max = self.max_distance
if dist_max:
return self.raw_distance / self.max_distance
return 0.0
@property
def max_distance(self) -> float:
"""Return the maximum distance penalty (normalization factor)."""
dist_max = 0.0
for key, penalty in self._penalties.items():
dist_max += len(penalty) * self._weights[key]
return dist_max
@property
def raw_distance(self) -> float:
"""Return the raw (denormalized) distance."""
dist_raw = 0.0
for key, penalty in self._penalties.items():
dist_raw += sum(penalty) * self._weights[key]
return dist_raw
def items(self) -> list[tuple[str, float]]:
"""Return a list of (key, dist) pairs, with `dist` being the
weighted distance, sorted from highest to lowest. Does not
include penalties with a zero value.
"""
list_ = []
for key in self._penalties:
dist = self[key]
if dist:
list_.append((key, dist))
# Convert distance into a negative float we can sort items in
# ascending order (for keys, when the penalty is equal) and
# still get the items with the biggest distance first.
return sorted(
list_, key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0])
)
def __hash__(self) -> int:
return id(self)
def __eq__(self, other) -> bool:
return self.distance == other
# Behave like a float.
def __lt__(self, other) -> bool:
return self.distance < other
def __float__(self) -> float:
return self.distance
def __sub__(self, other) -> float:
return self.distance - other
def __rsub__(self, other) -> float:
return other - self.distance
def __str__(self) -> str:
return f"{self.distance:.2f}"
# Behave like a dict.
def __getitem__(self, key) -> float:
"""Returns the weighted distance for a named penalty."""
dist = sum(self._penalties[key]) * self._weights[key]
dist_max = self.max_distance
if dist_max:
return dist / dist_max
return 0.0
def __iter__(self) -> Iterator[tuple[str, float]]:
return iter(self.items())
def __len__(self) -> int:
return len(self.items())
def keys(self) -> list[str]:
return [key for key, _ in self.items()]
def update(self, dist: Distance):
"""Adds all the distance penalties from `dist`."""
if not isinstance(dist, Distance):
raise ValueError(
"`dist` must be a Distance object, not {}".format(type(dist))
)
for key, penalties in dist._penalties.items():
self._penalties.setdefault(key, []).extend(penalties)
# Adding components.
def _eq(self, value1: re.Pattern[str] | Any, value2: Any) -> bool:
"""Returns True if `value1` is equal to `value2`. `value1` may
be a compiled regular expression, in which case it will be
matched against `value2`.
"""
if isinstance(value1, re.Pattern):
return bool(value1.match(value2))
return value1 == value2
def add(self, key: str, dist: float):
"""Adds a distance penalty. `key` must correspond with a
configured weight setting. `dist` must be a float between 0.0
and 1.0, and will be added to any existing distance penalties
for the same key.
"""
if not 0.0 <= dist <= 1.0:
raise ValueError(f"`dist` must be between 0.0 and 1.0, not {dist}")
self._penalties.setdefault(key, []).append(dist)
def add_equality(
self,
key: str,
value: Any,
options: list[Any] | tuple[Any, ...] | Any,
):
"""Adds a distance penalty of 1.0 if `value` doesn't match any
of the values in `options`. If an option is a compiled regular
expression, it will be considered equal if it matches against
`value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
for opt in options:
if self._eq(opt, value):
dist = 0.0
break
else:
dist = 1.0
self.add(key, dist)
def add_expr(self, key: str, expr: bool):
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
or 0.0.
"""
if expr:
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_number(self, key: str, number1: int, number2: int):
"""Adds a distance penalty of 1.0 for each number of difference
between `number1` and `number2`, or 0.0 when there is no
difference. Use this when there is no upper limit on the
difference between the two numbers.
"""
diff = abs(number1 - number2)
if diff:
for i in range(diff):
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_priority(
self,
key: str,
value: Any,
options: list[Any] | tuple[Any, ...] | Any,
):
"""Adds a distance penalty that corresponds to the position at
which `value` appears in `options`. A distance penalty of 0.0
for the first option, or 1.0 if there is no matching option. If
an option is a compiled regular expression, it will be
considered equal if it matches against `value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
unit = 1.0 / (len(options) or 1)
for i, opt in enumerate(options):
if self._eq(opt, value):
dist = i * unit
break
else:
dist = 1.0
self.add(key, dist)
def add_ratio(
self,
key: str,
number1: int | float,
number2: int | float,
):
"""Adds a distance penalty for `number1` as a ratio of `number2`.
`number1` is bound at 0 and `number2`.
"""
number = float(max(min(number1, number2), 0))
if number2:
dist = number / number2
else:
dist = 0.0
self.add(key, dist)
def add_string(self, key: str, str1: str | None, str2: str | None):
"""Adds a distance penalty based on the edit distance between
`str1` and `str2`.
"""
dist = string_dist(str1, str2)
self.add(key, dist)
# Structures that compose all the information for a candidate match.

View file

@ -18,37 +18,23 @@ releases and tracks.
from __future__ import annotations
import datetime
import re
from enum import IntEnum
from functools import cache
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
import lap
import numpy as np
from beets import config, logging, plugins
from beets.autotag import (
AlbumInfo,
AlbumMatch,
Distance,
TrackInfo,
TrackMatch,
hooks,
)
from beets.autotag import AlbumInfo, AlbumMatch, TrackInfo, TrackMatch, hooks
from beets.util import get_most_common_tags
from .distance import VA_ARTISTS, distance, track_distance
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from beets.library import Item
# Artist signals that indicate "various artists". These are used at the
# album level to determine whether a given release is likely a VA
# release and also on the track level to to remove the penalty for
# differing artists.
VA_ARTISTS = ("", "various artists", "various", "va", "unknown")
# Global logger.
log = logging.getLogger("beets")
@ -112,191 +98,6 @@ def assign_items(
return mapping, extra_items, extra_tracks
def track_index_changed(item: Item, track_info: TrackInfo) -> bool:
"""Returns True if the item and track info index is different. Tolerates
per disc and per release numbering.
"""
return item.track not in (track_info.medium_index, track_info.index)
@cache
def get_track_length_grace() -> float:
"""Get cached grace period for track length matching."""
return config["match"]["track_length_grace"].as_number()
@cache
def get_track_length_max() -> float:
"""Get cached maximum track length for track length matching."""
return config["match"]["track_length_max"].as_number()
def track_distance(
item: Item,
track_info: TrackInfo,
incl_artist: bool = False,
) -> Distance:
"""Determines the significance of a track metadata change. Returns a
Distance object. `incl_artist` indicates that a distance component should
be included for the track artist (i.e., for various-artist releases).
``track_length_grace`` and ``track_length_max`` configuration options are
cached because this function is called many times during the matching
process and their access comes with a performance overhead.
"""
dist = hooks.Distance()
# Length.
if info_length := track_info.length:
diff = abs(item.length - info_length) - get_track_length_grace()
dist.add_ratio("track_length", diff, get_track_length_max())
# Title.
dist.add_string("track_title", item.title, track_info.title)
# Artist. Only check if there is actually an artist in the track data.
if (
incl_artist
and track_info.artist
and item.artist.lower() not in VA_ARTISTS
):
dist.add_string("track_artist", item.artist, track_info.artist)
# Track index.
if track_info.index and item.track:
dist.add_expr("track_index", track_index_changed(item, track_info))
# Track ID.
if item.mb_trackid:
dist.add_expr("track_id", item.mb_trackid != track_info.track_id)
# Penalize mismatching disc numbers.
if track_info.medium and item.disc:
dist.add_expr("medium", item.disc != track_info.medium)
# Plugins.
dist.update(plugins.track_distance(item, track_info))
return dist
def distance(
items: Sequence[Item],
album_info: AlbumInfo,
mapping: dict[Item, TrackInfo],
) -> Distance:
"""Determines how "significant" an album metadata change would be.
Returns a Distance object. `album_info` is an AlbumInfo object
reflecting the album to be compared. `items` is a sequence of all
Item objects that will be matched (order is not important).
`mapping` is a dictionary mapping Items to TrackInfo objects; the
keys are a subset of `items` and the values are a subset of
`album_info.tracks`.
"""
likelies, _ = get_most_common_tags(items)
dist = hooks.Distance()
# Artist, if not various.
if not album_info.va:
dist.add_string("artist", likelies["artist"], album_info.artist)
# Album.
dist.add_string("album", likelies["album"], album_info.album)
preferred_config = config["match"]["preferred"]
# Current or preferred media.
if album_info.media:
# Preferred media options.
media_patterns: Sequence[str] = preferred_config["media"].as_str_seq()
options = [
re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in media_patterns
]
if options:
dist.add_priority("media", album_info.media, options)
# Current media.
elif likelies["media"]:
dist.add_equality("media", album_info.media, likelies["media"])
# Mediums.
if likelies["disctotal"] and album_info.mediums:
dist.add_number("mediums", likelies["disctotal"], album_info.mediums)
# Prefer earliest release.
if album_info.year and preferred_config["original_year"]:
# Assume 1889 (earliest first gramophone discs) if we don't know the
# original year.
original = album_info.original_year or 1889
diff = abs(album_info.year - original)
diff_max = abs(datetime.date.today().year - original)
dist.add_ratio("year", diff, diff_max)
# Year.
elif likelies["year"] and album_info.year:
if likelies["year"] in (album_info.year, album_info.original_year):
# No penalty for matching release or original year.
dist.add("year", 0.0)
elif album_info.original_year:
# Prefer matchest closest to the release year.
diff = abs(likelies["year"] - album_info.year)
diff_max = abs(
datetime.date.today().year - album_info.original_year
)
dist.add_ratio("year", diff, diff_max)
else:
# Full penalty when there is no original year.
dist.add("year", 1.0)
# Preferred countries.
country_patterns: Sequence[str] = preferred_config["countries"].as_str_seq()
options = [re.compile(pat, re.I) for pat in country_patterns]
if album_info.country and options:
dist.add_priority("country", album_info.country, options)
# Country.
elif likelies["country"] and album_info.country:
dist.add_string("country", likelies["country"], album_info.country)
# Label.
if likelies["label"] and album_info.label:
dist.add_string("label", likelies["label"], album_info.label)
# Catalog number.
if likelies["catalognum"] and album_info.catalognum:
dist.add_string(
"catalognum", likelies["catalognum"], album_info.catalognum
)
# Disambiguation.
if likelies["albumdisambig"] and album_info.albumdisambig:
dist.add_string(
"albumdisambig", likelies["albumdisambig"], album_info.albumdisambig
)
# Album ID.
if likelies["mb_albumid"]:
dist.add_equality(
"album_id", likelies["mb_albumid"], album_info.album_id
)
# Tracks.
dist.tracks = {}
for item, track in mapping.items():
dist.tracks[track] = track_distance(item, track, album_info.va)
dist.add("tracks", dist.tracks[track].distance)
# Missing tracks.
for _ in range(len(album_info.tracks) - len(mapping)):
dist.add("missing_tracks", 1.0)
# Unmatched tracks.
for _ in range(len(items) - len(mapping)):
dist.add("unmatched_tracks", 1.0)
# Plugins.
dist.update(plugins.album_distance(items, album_info, mapping))
return dist
def match_by_id(items: Iterable[Item]) -> AlbumInfo | None:
"""If the items are tagged with an external source ID, return an
AlbumInfo object for the corresponding album. Otherwise, returns

View file

@ -37,6 +37,7 @@ import mediafile
import beets
from beets import logging
from beets.autotag.distance import Distance
from beets.util.id_extractors import extract_release_id
if TYPE_CHECKING:
@ -53,7 +54,7 @@ if TYPE_CHECKING:
from confuse import ConfigView
from beets.autotag import AlbumInfo, Distance, TrackInfo
from beets.autotag import AlbumInfo, TrackInfo
from beets.dbcore import Query
from beets.dbcore.db import FieldQueryType
from beets.dbcore.types import Type
@ -224,8 +225,6 @@ class BeetsPlugin:
"""Should return a Distance object to be added to the
distance for every track comparison.
"""
from beets.autotag.hooks import Distance
return Distance()
def album_distance(
@ -237,8 +236,6 @@ class BeetsPlugin:
"""Should return a Distance object to be added to the
distance for every album-level comparison.
"""
from beets.autotag.hooks import Distance
return Distance()
def candidates(
@ -458,8 +455,6 @@ def track_distance(item: Item, info: TrackInfo) -> Distance:
"""Gets the track distance calculated by all loaded plugins.
Returns a Distance object.
"""
from beets.autotag.hooks import Distance
dist = Distance()
for plugin in find_plugins():
dist.update(plugin.track_distance(item, info))
@ -472,8 +467,6 @@ def album_distance(
mapping: dict[Item, TrackInfo],
) -> Distance:
"""Returns the album distance calculated by plugins."""
from beets.autotag.hooks import Distance
dist = Distance()
for plugin in find_plugins():
dist.update(plugin.album_distance(items, album_info, mapping))
@ -660,8 +653,6 @@ def get_distance(
"""Returns the ``data_source`` weight and the maximum source weight
for albums or individual tracks.
"""
from beets.autotag.hooks import Distance
dist = Distance()
if info.data_source == data_source:
dist.add("source", config["source_weight"].as_number())

View file

@ -24,7 +24,7 @@ import acoustid
import confuse
from beets import config, plugins, ui, util
from beets.autotag.hooks import Distance
from beets.autotag.distance import Distance
from beetsplug.musicbrainz import MusicBrainzPlugin
API_KEY = "1vOwZtEn"

View file

@ -38,7 +38,8 @@ from typing_extensions import TypedDict
import beets
import beets.ui
from beets import config
from beets.autotag.hooks import AlbumInfo, TrackInfo, string_dist
from beets.autotag.distance import string_dist
from beets.autotag.hooks import AlbumInfo, TrackInfo
from beets.plugins import BeetsPlugin, MetadataSourcePlugin, get_distance
from beets.util.id_extractors import extract_release_id

View file

@ -38,7 +38,7 @@ from unidecode import unidecode
import beets
from beets import plugins, ui
from beets.autotag.hooks import string_dist
from beets.autotag.distance import string_dist
from beets.util.config import sanitize_choices
if TYPE_CHECKING:

View file

@ -0,0 +1,476 @@
import re
import unittest
from beets import config
from beets.autotag import AlbumInfo, TrackInfo, match
from beets.autotag.distance import Distance, string_dist
from beets.library import Item
from beets.test.helper import BeetsTestCase
def _make_item(title, track, artist="some artist"):
return Item(
title=title,
track=track,
artist=artist,
album="some album",
length=1,
mb_trackid="",
mb_albumid="",
mb_artistid="",
)
def _make_trackinfo():
return [
TrackInfo(
title="one", track_id=None, artist="some artist", length=1, index=1
),
TrackInfo(
title="two", track_id=None, artist="some artist", length=1, index=2
),
TrackInfo(
title="three",
track_id=None,
artist="some artist",
length=1,
index=3,
),
]
def _clear_weights():
"""Hack around the lazy descriptor used to cache weights for
Distance calculations.
"""
Distance.__dict__["_weights"].cache = {}
class DistanceTest(BeetsTestCase):
def tearDown(self):
super().tearDown()
_clear_weights()
def test_add(self):
dist = Distance()
dist.add("add", 1.0)
assert dist._penalties == {"add": [1.0]}
def test_add_equality(self):
dist = Distance()
dist.add_equality("equality", "ghi", ["abc", "def", "ghi"])
assert dist._penalties["equality"] == [0.0]
dist.add_equality("equality", "xyz", ["abc", "def", "ghi"])
assert dist._penalties["equality"] == [0.0, 1.0]
dist.add_equality("equality", "abc", re.compile(r"ABC", re.I))
assert dist._penalties["equality"] == [0.0, 1.0, 0.0]
def test_add_expr(self):
dist = Distance()
dist.add_expr("expr", True)
assert dist._penalties["expr"] == [1.0]
dist.add_expr("expr", False)
assert dist._penalties["expr"] == [1.0, 0.0]
def test_add_number(self):
dist = Distance()
# Add a full penalty for each number of difference between two numbers.
dist.add_number("number", 1, 1)
assert dist._penalties["number"] == [0.0]
dist.add_number("number", 1, 2)
assert dist._penalties["number"] == [0.0, 1.0]
dist.add_number("number", 2, 1)
assert dist._penalties["number"] == [0.0, 1.0, 1.0]
dist.add_number("number", -1, 2)
assert dist._penalties["number"] == [0.0, 1.0, 1.0, 1.0, 1.0, 1.0]
def test_add_priority(self):
dist = Distance()
dist.add_priority("priority", "abc", "abc")
assert dist._penalties["priority"] == [0.0]
dist.add_priority("priority", "def", ["abc", "def"])
assert dist._penalties["priority"] == [0.0, 0.5]
dist.add_priority(
"priority", "gh", ["ab", "cd", "ef", re.compile("GH", re.I)]
)
assert dist._penalties["priority"] == [0.0, 0.5, 0.75]
dist.add_priority("priority", "xyz", ["abc", "def"])
assert dist._penalties["priority"] == [0.0, 0.5, 0.75, 1.0]
def test_add_ratio(self):
dist = Distance()
dist.add_ratio("ratio", 25, 100)
assert dist._penalties["ratio"] == [0.25]
dist.add_ratio("ratio", 10, 5)
assert dist._penalties["ratio"] == [0.25, 1.0]
dist.add_ratio("ratio", -5, 5)
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0]
dist.add_ratio("ratio", 5, 0)
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0, 0.0]
def test_add_string(self):
dist = Distance()
sdist = string_dist("abc", "bcd")
dist.add_string("string", "abc", "bcd")
assert dist._penalties["string"] == [sdist]
assert dist._penalties["string"] != [0]
def test_add_string_none(self):
dist = Distance()
dist.add_string("string", None, "string")
assert dist._penalties["string"] == [1]
def test_add_string_both_none(self):
dist = Distance()
dist.add_string("string", None, None)
assert dist._penalties["string"] == [0]
def test_distance(self):
config["match"]["distance_weights"]["album"] = 2.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("media", 0.25)
dist.add("media", 0.75)
assert dist.distance == 0.5
# __getitem__()
assert dist["album"] == 0.25
assert dist["media"] == 0.25
def test_max_distance(self):
config["match"]["distance_weights"]["album"] = 3.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("medium", 0.0)
dist.add("medium", 0.0)
assert dist.max_distance == 5.0
def test_operators(self):
config["match"]["distance_weights"]["source"] = 1.0
config["match"]["distance_weights"]["album"] = 2.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("source", 0.0)
dist.add("album", 0.5)
dist.add("medium", 0.25)
dist.add("medium", 0.75)
assert len(dist) == 2
assert list(dist) == [("album", 0.2), ("medium", 0.2)]
assert dist == 0.4
assert dist < 1.0
assert dist > 0.0
assert dist - 0.4 == 0.0
assert 0.4 - dist == 0.0
assert float(dist) == 0.4
def test_raw_distance(self):
config["match"]["distance_weights"]["album"] = 3.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("medium", 0.25)
dist.add("medium", 0.5)
assert dist.raw_distance == 2.25
def test_items(self):
config["match"]["distance_weights"]["album"] = 4.0
config["match"]["distance_weights"]["medium"] = 2.0
_clear_weights()
dist = Distance()
dist.add("album", 0.1875)
dist.add("medium", 0.75)
assert dist.items() == [("medium", 0.25), ("album", 0.125)]
# Sort by key if distance is equal.
dist = Distance()
dist.add("album", 0.375)
dist.add("medium", 0.75)
assert dist.items() == [("album", 0.25), ("medium", 0.25)]
def test_update(self):
dist1 = Distance()
dist1.add("album", 0.5)
dist1.add("media", 1.0)
dist2 = Distance()
dist2.add("album", 0.75)
dist2.add("album", 0.25)
dist2.add("media", 0.05)
dist1.update(dist2)
assert dist1._penalties == {
"album": [0.5, 0.75, 0.25],
"media": [1.0, 0.05],
}
class TrackDistanceTest(BeetsTestCase):
def test_identical_tracks(self):
item = _make_item("one", 1)
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist == 0.0
def test_different_title(self):
item = _make_item("foo", 1)
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist != 0.0
def test_different_artist(self):
item = _make_item("one", 1)
item.artist = "foo"
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist != 0.0
def test_various_artists_tolerated(self):
item = _make_item("one", 1)
item.artist = "Various Artists"
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist == 0.0
class AlbumDistanceTest(BeetsTestCase):
def _mapping(self, items, info):
out = {}
for i, t in zip(items, info.tracks):
out[i] = t
return out
def _dist(self, items, info):
return match.distance(items, info, self._mapping(items, info))
def test_identical_albums(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
assert self._dist(items, info) == 0
def test_incomplete_album(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
dist = self._dist(items, info)
assert dist != 0
# Make sure the distance is not too great
assert dist < 0.2
def test_global_artists_differ(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="someone else",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
assert self._dist(items, info) != 0
def test_comp_track_artists_match(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="should be ignored",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
assert self._dist(items, info) == 0
def test_comp_no_track_artists(self):
# Some VA releases don't have track artists (incomplete metadata).
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="should be ignored",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
info.tracks[0].artist = None
info.tracks[1].artist = None
info.tracks[2].artist = None
assert self._dist(items, info) == 0
def test_comp_track_artists_do_not_match(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2, "someone else"))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
assert self._dist(items, info) != 0
def test_tracks_out_of_order(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("three", 2))
items.append(_make_item("two", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
dist = self._dist(items, info)
assert 0 < dist < 0.2
def test_two_medium_release(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
info.tracks[0].medium_index = 1
info.tracks[1].medium_index = 2
info.tracks[2].medium_index = 1
dist = self._dist(items, info)
assert dist == 0
def test_per_medium_track_numbers(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 1))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
info.tracks[0].medium_index = 1
info.tracks[1].medium_index = 2
info.tracks[2].medium_index = 1
dist = self._dist(items, info)
assert dist == 0
class StringDistanceTest(unittest.TestCase):
def test_equal_strings(self):
dist = string_dist("Some String", "Some String")
assert dist == 0.0
def test_different_strings(self):
dist = string_dist("Some String", "Totally Different")
assert dist != 0.0
def test_punctuation_ignored(self):
dist = string_dist("Some String", "Some.String!")
assert dist == 0.0
def test_case_ignored(self):
dist = string_dist("Some String", "sOME sTring")
assert dist == 0.0
def test_leading_the_has_lower_weight(self):
dist1 = string_dist("XXX Band Name", "Band Name")
dist2 = string_dist("The Band Name", "Band Name")
assert dist2 < dist1
def test_parens_have_lower_weight(self):
dist1 = string_dist("One .Two.", "One")
dist2 = string_dist("One (Two)", "One")
assert dist2 < dist1
def test_brackets_have_lower_weight(self):
dist1 = string_dist("One .Two.", "One")
dist2 = string_dist("One [Two]", "One")
assert dist2 < dist1
def test_ep_label_has_zero_weight(self):
dist = string_dist("My Song (EP)", "My Song")
assert dist == 0.0
def test_featured_has_lower_weight(self):
dist1 = string_dist("My Song blah Someone", "My Song")
dist2 = string_dist("My Song feat Someone", "My Song")
assert dist2 < dist1
def test_postfix_the(self):
dist = string_dist("The Song Title", "Song Title, The")
assert dist == 0.0
def test_postfix_a(self):
dist = string_dist("A Song Title", "Song Title, A")
assert dist == 0.0
def test_postfix_an(self):
dist = string_dist("An Album Title", "Album Title, An")
assert dist == 0.0
def test_empty_strings(self):
dist = string_dist("", "")
assert dist == 0.0
def test_solo_pattern(self):
# Just make sure these don't crash.
string_dist("The ", "")
string_dist("(EP)", "(EP)")
string_dist(", An", "")
def test_heuristic_does_not_harm_distance(self):
dist = string_dist("Untitled", "[Untitled]")
assert dist == 0.0
def test_ampersand_expansion(self):
dist = string_dist("And", "&")
assert dist == 0.0
def test_accented_characters(self):
dist = string_dist("\xe9\xe1\xf1", "ean")
assert dist == 0.0

View file

@ -14,410 +14,14 @@
"""Tests for autotagging functionality."""
import re
import unittest
import pytest
from beets import autotag, config
from beets.autotag import AlbumInfo, TrackInfo, correct_list_fields, match
from beets.autotag.hooks import Distance, string_dist
from beets.library import Item
from beets.test.helper import BeetsTestCase, ConfigMixin
def _make_item(title, track, artist="some artist"):
return Item(
title=title,
track=track,
artist=artist,
album="some album",
length=1,
mb_trackid="",
mb_albumid="",
mb_artistid="",
)
def _make_trackinfo():
return [
TrackInfo(
title="one", track_id=None, artist="some artist", length=1, index=1
),
TrackInfo(
title="two", track_id=None, artist="some artist", length=1, index=2
),
TrackInfo(
title="three",
track_id=None,
artist="some artist",
length=1,
index=3,
),
]
def _clear_weights():
"""Hack around the lazy descriptor used to cache weights for
Distance calculations.
"""
Distance.__dict__["_weights"].cache = {}
class DistanceTest(BeetsTestCase):
def tearDown(self):
super().tearDown()
_clear_weights()
def test_add(self):
dist = Distance()
dist.add("add", 1.0)
assert dist._penalties == {"add": [1.0]}
def test_add_equality(self):
dist = Distance()
dist.add_equality("equality", "ghi", ["abc", "def", "ghi"])
assert dist._penalties["equality"] == [0.0]
dist.add_equality("equality", "xyz", ["abc", "def", "ghi"])
assert dist._penalties["equality"] == [0.0, 1.0]
dist.add_equality("equality", "abc", re.compile(r"ABC", re.I))
assert dist._penalties["equality"] == [0.0, 1.0, 0.0]
def test_add_expr(self):
dist = Distance()
dist.add_expr("expr", True)
assert dist._penalties["expr"] == [1.0]
dist.add_expr("expr", False)
assert dist._penalties["expr"] == [1.0, 0.0]
def test_add_number(self):
dist = Distance()
# Add a full penalty for each number of difference between two numbers.
dist.add_number("number", 1, 1)
assert dist._penalties["number"] == [0.0]
dist.add_number("number", 1, 2)
assert dist._penalties["number"] == [0.0, 1.0]
dist.add_number("number", 2, 1)
assert dist._penalties["number"] == [0.0, 1.0, 1.0]
dist.add_number("number", -1, 2)
assert dist._penalties["number"] == [0.0, 1.0, 1.0, 1.0, 1.0, 1.0]
def test_add_priority(self):
dist = Distance()
dist.add_priority("priority", "abc", "abc")
assert dist._penalties["priority"] == [0.0]
dist.add_priority("priority", "def", ["abc", "def"])
assert dist._penalties["priority"] == [0.0, 0.5]
dist.add_priority(
"priority", "gh", ["ab", "cd", "ef", re.compile("GH", re.I)]
)
assert dist._penalties["priority"] == [0.0, 0.5, 0.75]
dist.add_priority("priority", "xyz", ["abc", "def"])
assert dist._penalties["priority"] == [0.0, 0.5, 0.75, 1.0]
def test_add_ratio(self):
dist = Distance()
dist.add_ratio("ratio", 25, 100)
assert dist._penalties["ratio"] == [0.25]
dist.add_ratio("ratio", 10, 5)
assert dist._penalties["ratio"] == [0.25, 1.0]
dist.add_ratio("ratio", -5, 5)
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0]
dist.add_ratio("ratio", 5, 0)
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0, 0.0]
def test_add_string(self):
dist = Distance()
sdist = string_dist("abc", "bcd")
dist.add_string("string", "abc", "bcd")
assert dist._penalties["string"] == [sdist]
assert dist._penalties["string"] != [0]
def test_add_string_none(self):
dist = Distance()
dist.add_string("string", None, "string")
assert dist._penalties["string"] == [1]
def test_add_string_both_none(self):
dist = Distance()
dist.add_string("string", None, None)
assert dist._penalties["string"] == [0]
def test_distance(self):
config["match"]["distance_weights"]["album"] = 2.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("media", 0.25)
dist.add("media", 0.75)
assert dist.distance == 0.5
# __getitem__()
assert dist["album"] == 0.25
assert dist["media"] == 0.25
def test_max_distance(self):
config["match"]["distance_weights"]["album"] = 3.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("medium", 0.0)
dist.add("medium", 0.0)
assert dist.max_distance == 5.0
def test_operators(self):
config["match"]["distance_weights"]["source"] = 1.0
config["match"]["distance_weights"]["album"] = 2.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("source", 0.0)
dist.add("album", 0.5)
dist.add("medium", 0.25)
dist.add("medium", 0.75)
assert len(dist) == 2
assert list(dist) == [("album", 0.2), ("medium", 0.2)]
assert dist == 0.4
assert dist < 1.0
assert dist > 0.0
assert dist - 0.4 == 0.0
assert 0.4 - dist == 0.0
assert float(dist) == 0.4
def test_raw_distance(self):
config["match"]["distance_weights"]["album"] = 3.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("medium", 0.25)
dist.add("medium", 0.5)
assert dist.raw_distance == 2.25
def test_items(self):
config["match"]["distance_weights"]["album"] = 4.0
config["match"]["distance_weights"]["medium"] = 2.0
_clear_weights()
dist = Distance()
dist.add("album", 0.1875)
dist.add("medium", 0.75)
assert dist.items() == [("medium", 0.25), ("album", 0.125)]
# Sort by key if distance is equal.
dist = Distance()
dist.add("album", 0.375)
dist.add("medium", 0.75)
assert dist.items() == [("album", 0.25), ("medium", 0.25)]
def test_update(self):
dist1 = Distance()
dist1.add("album", 0.5)
dist1.add("media", 1.0)
dist2 = Distance()
dist2.add("album", 0.75)
dist2.add("album", 0.25)
dist2.add("media", 0.05)
dist1.update(dist2)
assert dist1._penalties == {
"album": [0.5, 0.75, 0.25],
"media": [1.0, 0.05],
}
class TrackDistanceTest(BeetsTestCase):
def test_identical_tracks(self):
item = _make_item("one", 1)
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist == 0.0
def test_different_title(self):
item = _make_item("foo", 1)
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist != 0.0
def test_different_artist(self):
item = _make_item("one", 1)
item.artist = "foo"
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist != 0.0
def test_various_artists_tolerated(self):
item = _make_item("one", 1)
item.artist = "Various Artists"
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist == 0.0
class AlbumDistanceTest(BeetsTestCase):
def _mapping(self, items, info):
out = {}
for i, t in zip(items, info.tracks):
out[i] = t
return out
def _dist(self, items, info):
return match.distance(items, info, self._mapping(items, info))
def test_identical_albums(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
assert self._dist(items, info) == 0
def test_incomplete_album(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
dist = self._dist(items, info)
assert dist != 0
# Make sure the distance is not too great
assert dist < 0.2
def test_global_artists_differ(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="someone else",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
assert self._dist(items, info) != 0
def test_comp_track_artists_match(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="should be ignored",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
assert self._dist(items, info) == 0
def test_comp_no_track_artists(self):
# Some VA releases don't have track artists (incomplete metadata).
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="should be ignored",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
info.tracks[0].artist = None
info.tracks[1].artist = None
info.tracks[2].artist = None
assert self._dist(items, info) == 0
def test_comp_track_artists_do_not_match(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2, "someone else"))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
assert self._dist(items, info) != 0
def test_tracks_out_of_order(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("three", 2))
items.append(_make_item("two", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
dist = self._dist(items, info)
assert 0 < dist < 0.2
def test_two_medium_release(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
info.tracks[0].medium_index = 1
info.tracks[1].medium_index = 2
info.tracks[2].medium_index = 1
dist = self._dist(items, info)
assert dist == 0
def test_per_medium_track_numbers(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 1))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
info.tracks[0].medium_index = 1
info.tracks[1].medium_index = 2
info.tracks[2].medium_index = 1
dist = self._dist(items, info)
assert dist == 0
class TestAssignment(ConfigMixin):
A = "one"
B = "two"
@ -840,82 +444,6 @@ class ApplyCompilationTest(BeetsTestCase, ApplyTestUtil):
assert self.items[1].comp
class StringDistanceTest(unittest.TestCase):
def test_equal_strings(self):
dist = string_dist("Some String", "Some String")
assert dist == 0.0
def test_different_strings(self):
dist = string_dist("Some String", "Totally Different")
assert dist != 0.0
def test_punctuation_ignored(self):
dist = string_dist("Some String", "Some.String!")
assert dist == 0.0
def test_case_ignored(self):
dist = string_dist("Some String", "sOME sTring")
assert dist == 0.0
def test_leading_the_has_lower_weight(self):
dist1 = string_dist("XXX Band Name", "Band Name")
dist2 = string_dist("The Band Name", "Band Name")
assert dist2 < dist1
def test_parens_have_lower_weight(self):
dist1 = string_dist("One .Two.", "One")
dist2 = string_dist("One (Two)", "One")
assert dist2 < dist1
def test_brackets_have_lower_weight(self):
dist1 = string_dist("One .Two.", "One")
dist2 = string_dist("One [Two]", "One")
assert dist2 < dist1
def test_ep_label_has_zero_weight(self):
dist = string_dist("My Song (EP)", "My Song")
assert dist == 0.0
def test_featured_has_lower_weight(self):
dist1 = string_dist("My Song blah Someone", "My Song")
dist2 = string_dist("My Song feat Someone", "My Song")
assert dist2 < dist1
def test_postfix_the(self):
dist = string_dist("The Song Title", "Song Title, The")
assert dist == 0.0
def test_postfix_a(self):
dist = string_dist("A Song Title", "Song Title, A")
assert dist == 0.0
def test_postfix_an(self):
dist = string_dist("An Album Title", "Album Title, An")
assert dist == 0.0
def test_empty_strings(self):
dist = string_dist("", "")
assert dist == 0.0
def test_solo_pattern(self):
# Just make sure these don't crash.
string_dist("The ", "")
string_dist("(EP)", "(EP)")
string_dist(", An", "")
def test_heuristic_does_not_harm_distance(self):
dist = string_dist("Untitled", "[Untitled]")
assert dist == 0.0
def test_ampersand_expansion(self):
dist = string_dist("And", "&")
assert dist == 0.0
def test_accented_characters(self):
dist = string_dist("\xe9\xe1\xf1", "ean")
assert dist == 0.0
@pytest.mark.parametrize(
"single_field,list_field",
[