Move Distance to a dedicated module and refactor related tests (#5800)

This PR:

1. Reorganizes distance-related code by moving it from `hooks.py` and
`match.py` to a new dedicated `distance.py` module:
- The actual implementation logic and algorithms remain unchanged - code
is moved, not rewritten
- Distance class, string distance functions, and track/album distance
calculators are relocated intact
- Only imports and function references are updated to maintain
compatibility
- `current_metadata` function is replaced with equivalent
`get_most_common_tags` function for clarity

2. Refactors the distance testing code from unittest to pytest:
- Tests now use fixtures and parametrization while verifying the same
functionality
- The tested behaviors remain identical, just with improved test
structure
- Actually, `distance.py` coverage slightly increased since I included
an additional test

3. Adds a test for the `sanitize_pairs` function to complete config
utility test coverage

This is primarily a code organization improvement that follows better
separation of concerns, grouping related distance functionality in a
single module without changing how the distance calculations work. No
algorithm changes or behavior modifications were made to the core
distance calculation code - it was simply moved to a more appropriate
location.
This commit is contained in:
Šarūnas Nejus 2025-05-31 19:49:51 +01:00 committed by GitHub
commit 87701fd6f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 1053 additions and 1261 deletions

View file

@ -14,36 +14,37 @@
"""Facilities for automatically determining files' correct metadata."""
from collections.abc import Mapping, Sequence
from typing import Union
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from beets import config, logging
from beets.library import Album, Item, LibModel
from beets.util import get_most_common_tags as current_metadata
# Parts of external interface.
from beets.util import unique_list
from .hooks import AlbumInfo, AlbumMatch, Distance, TrackInfo, TrackMatch
from .match import (
Proposal,
Recommendation,
current_metadata,
tag_album,
tag_item,
)
from .distance import Distance
from .hooks import AlbumInfo, AlbumMatch, TrackInfo, TrackMatch
from .match import Proposal, Recommendation, tag_album, tag_item
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from beets.library import Album, Item, LibModel
__all__ = [
"AlbumInfo",
"AlbumMatch",
"Distance",
"TrackInfo",
"TrackMatch",
"Distance", # for backwards compatibility
"Proposal",
"Recommendation",
"TrackInfo",
"TrackMatch",
"apply_album_metadata",
"apply_item_metadata",
"apply_metadata",
"current_metadata",
"current_metadata", # for backwards compatibility
"tag_album",
"tag_item",
]

531
beets/autotag/distance.py Normal file
View file

@ -0,0 +1,531 @@
from __future__ import annotations
import datetime
import re
from functools import cache, total_ordering
from typing import TYPE_CHECKING, Any
from jellyfish import levenshtein_distance
from unidecode import unidecode
from beets import config, plugins
from beets.util import as_string, cached_classproperty, get_most_common_tags
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from beets.library import Item
from .hooks import AlbumInfo, TrackInfo
# Candidate distance scoring.
# Artist signals that indicate "various artists". These are used at the
# album level to determine whether a given release is likely a VA
# release and also on the track level to to remove the penalty for
# differing artists.
VA_ARTISTS = ("", "various artists", "various", "va", "unknown")
# Parameters for string distance function.
# Words that can be moved to the end of a string using a comma.
SD_END_WORDS = ["the", "a", "an"]
# Reduced weights for certain portions of the string.
SD_PATTERNS = [
(r"^the ", 0.1),
(r"[\[\(]?(ep|single)[\]\)]?", 0.0),
(r"[\[\(]?(featuring|feat|ft)[\. :].+", 0.1),
(r"\(.*?\)", 0.3),
(r"\[.*?\]", 0.3),
(r"(, )?(pt\.|part) .+", 0.2),
]
# Replacements to use before testing distance.
SD_REPLACE = [
(r"&", "and"),
]
def _string_dist_basic(str1: str, str2: str) -> float:
"""Basic edit distance between two strings, ignoring
non-alphanumeric characters and case. Comparisons are based on a
transliteration/lowering to ASCII characters. Normalized by string
length.
"""
assert isinstance(str1, str)
assert isinstance(str2, str)
str1 = as_string(unidecode(str1))
str2 = as_string(unidecode(str2))
str1 = re.sub(r"[^a-z0-9]", "", str1.lower())
str2 = re.sub(r"[^a-z0-9]", "", str2.lower())
if not str1 and not str2:
return 0.0
return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2)))
def string_dist(str1: str | None, str2: str | None) -> float:
"""Gives an "intuitive" edit distance between two strings. This is
an edit distance, normalized by the string length, with a number of
tweaks that reflect intuition about text.
"""
if str1 is None and str2 is None:
return 0.0
if str1 is None or str2 is None:
return 1.0
str1 = str1.lower()
str2 = str2.lower()
# Don't penalize strings that move certain words to the end. For
# example, "the something" should be considered equal to
# "something, the".
for word in SD_END_WORDS:
if str1.endswith(", %s" % word):
str1 = "{} {}".format(word, str1[: -len(word) - 2])
if str2.endswith(", %s" % word):
str2 = "{} {}".format(word, str2[: -len(word) - 2])
# Perform a couple of basic normalizing substitutions.
for pat, repl in SD_REPLACE:
str1 = re.sub(pat, repl, str1)
str2 = re.sub(pat, repl, str2)
# Change the weight for certain string portions matched by a set
# of regular expressions. We gradually change the strings and build
# up penalties associated with parts of the string that were
# deleted.
base_dist = _string_dist_basic(str1, str2)
penalty = 0.0
for pat, weight in SD_PATTERNS:
# Get strings that drop the pattern.
case_str1 = re.sub(pat, "", str1)
case_str2 = re.sub(pat, "", str2)
if case_str1 != str1 or case_str2 != str2:
# If the pattern was present (i.e., it is deleted in the
# the current case), recalculate the distances for the
# modified strings.
case_dist = _string_dist_basic(case_str1, case_str2)
case_delta = max(0.0, base_dist - case_dist)
if case_delta == 0.0:
continue
# Shift our baseline strings down (to avoid rematching the
# same part of the string) and add a scaled distance
# amount to the penalties.
str1 = case_str1
str2 = case_str2
base_dist = case_dist
penalty += weight * case_delta
return base_dist + penalty
@total_ordering
class Distance:
"""Keeps track of multiple distance penalties. Provides a single
weighted distance for all penalties as well as a weighted distance
for each individual penalty.
"""
def __init__(self) -> None:
self._penalties: dict[str, list[float]] = {}
self.tracks: dict[TrackInfo, Distance] = {}
@cached_classproperty
def _weights(cls) -> dict[str, float]:
"""A dictionary from keys to floating-point weights."""
weights_view = config["match"]["distance_weights"]
weights = {}
for key in weights_view.keys():
weights[key] = weights_view[key].as_number()
return weights
# Access the components and their aggregates.
@property
def distance(self) -> float:
"""Return a weighted and normalized distance across all
penalties.
"""
dist_max = self.max_distance
if dist_max:
return self.raw_distance / self.max_distance
return 0.0
@property
def max_distance(self) -> float:
"""Return the maximum distance penalty (normalization factor)."""
dist_max = 0.0
for key, penalty in self._penalties.items():
dist_max += len(penalty) * self._weights[key]
return dist_max
@property
def raw_distance(self) -> float:
"""Return the raw (denormalized) distance."""
dist_raw = 0.0
for key, penalty in self._penalties.items():
dist_raw += sum(penalty) * self._weights[key]
return dist_raw
def items(self) -> list[tuple[str, float]]:
"""Return a list of (key, dist) pairs, with `dist` being the
weighted distance, sorted from highest to lowest. Does not
include penalties with a zero value.
"""
list_ = []
for key in self._penalties:
dist = self[key]
if dist:
list_.append((key, dist))
# Convert distance into a negative float we can sort items in
# ascending order (for keys, when the penalty is equal) and
# still get the items with the biggest distance first.
return sorted(
list_, key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0])
)
def __hash__(self) -> int:
return id(self)
def __eq__(self, other) -> bool:
return self.distance == other
# Behave like a float.
def __lt__(self, other) -> bool:
return self.distance < other
def __float__(self) -> float:
return self.distance
def __sub__(self, other) -> float:
return self.distance - other
def __rsub__(self, other) -> float:
return other - self.distance
def __str__(self) -> str:
return f"{self.distance:.2f}"
# Behave like a dict.
def __getitem__(self, key) -> float:
"""Returns the weighted distance for a named penalty."""
dist = sum(self._penalties[key]) * self._weights[key]
dist_max = self.max_distance
if dist_max:
return dist / dist_max
return 0.0
def __iter__(self) -> Iterator[tuple[str, float]]:
return iter(self.items())
def __len__(self) -> int:
return len(self.items())
def keys(self) -> list[str]:
return [key for key, _ in self.items()]
def update(self, dist: Distance):
"""Adds all the distance penalties from `dist`."""
if not isinstance(dist, Distance):
raise ValueError(
"`dist` must be a Distance object, not {}".format(type(dist))
)
for key, penalties in dist._penalties.items():
self._penalties.setdefault(key, []).extend(penalties)
# Adding components.
def _eq(self, value1: re.Pattern[str] | Any, value2: Any) -> bool:
"""Returns True if `value1` is equal to `value2`. `value1` may
be a compiled regular expression, in which case it will be
matched against `value2`.
"""
if isinstance(value1, re.Pattern):
return bool(value1.match(value2))
return value1 == value2
def add(self, key: str, dist: float):
"""Adds a distance penalty. `key` must correspond with a
configured weight setting. `dist` must be a float between 0.0
and 1.0, and will be added to any existing distance penalties
for the same key.
"""
if not 0.0 <= dist <= 1.0:
raise ValueError(f"`dist` must be between 0.0 and 1.0, not {dist}")
self._penalties.setdefault(key, []).append(dist)
def add_equality(
self,
key: str,
value: Any,
options: list[Any] | tuple[Any, ...] | Any,
):
"""Adds a distance penalty of 1.0 if `value` doesn't match any
of the values in `options`. If an option is a compiled regular
expression, it will be considered equal if it matches against
`value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
for opt in options:
if self._eq(opt, value):
dist = 0.0
break
else:
dist = 1.0
self.add(key, dist)
def add_expr(self, key: str, expr: bool):
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
or 0.0.
"""
if expr:
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_number(self, key: str, number1: int, number2: int):
"""Adds a distance penalty of 1.0 for each number of difference
between `number1` and `number2`, or 0.0 when there is no
difference. Use this when there is no upper limit on the
difference between the two numbers.
"""
diff = abs(number1 - number2)
if diff:
for i in range(diff):
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_priority(
self,
key: str,
value: Any,
options: list[Any] | tuple[Any, ...] | Any,
):
"""Adds a distance penalty that corresponds to the position at
which `value` appears in `options`. A distance penalty of 0.0
for the first option, or 1.0 if there is no matching option. If
an option is a compiled regular expression, it will be
considered equal if it matches against `value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
unit = 1.0 / (len(options) or 1)
for i, opt in enumerate(options):
if self._eq(opt, value):
dist = i * unit
break
else:
dist = 1.0
self.add(key, dist)
def add_ratio(
self,
key: str,
number1: int | float,
number2: int | float,
):
"""Adds a distance penalty for `number1` as a ratio of `number2`.
`number1` is bound at 0 and `number2`.
"""
number = float(max(min(number1, number2), 0))
if number2:
dist = number / number2
else:
dist = 0.0
self.add(key, dist)
def add_string(self, key: str, str1: str | None, str2: str | None):
"""Adds a distance penalty based on the edit distance between
`str1` and `str2`.
"""
dist = string_dist(str1, str2)
self.add(key, dist)
@cache
def get_track_length_grace() -> float:
"""Get cached grace period for track length matching."""
return config["match"]["track_length_grace"].as_number()
@cache
def get_track_length_max() -> float:
"""Get cached maximum track length for track length matching."""
return config["match"]["track_length_max"].as_number()
def track_index_changed(item: Item, track_info: TrackInfo) -> bool:
"""Returns True if the item and track info index is different. Tolerates
per disc and per release numbering.
"""
return item.track not in (track_info.medium_index, track_info.index)
def track_distance(
item: Item,
track_info: TrackInfo,
incl_artist: bool = False,
) -> Distance:
"""Determines the significance of a track metadata change. Returns a
Distance object. `incl_artist` indicates that a distance component should
be included for the track artist (i.e., for various-artist releases).
``track_length_grace`` and ``track_length_max`` configuration options are
cached because this function is called many times during the matching
process and their access comes with a performance overhead.
"""
dist = Distance()
# Length.
if info_length := track_info.length:
diff = abs(item.length - info_length) - get_track_length_grace()
dist.add_ratio("track_length", diff, get_track_length_max())
# Title.
dist.add_string("track_title", item.title, track_info.title)
# Artist. Only check if there is actually an artist in the track data.
if (
incl_artist
and track_info.artist
and item.artist.lower() not in VA_ARTISTS
):
dist.add_string("track_artist", item.artist, track_info.artist)
# Track index.
if track_info.index and item.track:
dist.add_expr("track_index", track_index_changed(item, track_info))
# Track ID.
if item.mb_trackid:
dist.add_expr("track_id", item.mb_trackid != track_info.track_id)
# Penalize mismatching disc numbers.
if track_info.medium and item.disc:
dist.add_expr("medium", item.disc != track_info.medium)
# Plugins.
dist.update(plugins.track_distance(item, track_info))
return dist
def distance(
items: Sequence[Item],
album_info: AlbumInfo,
mapping: dict[Item, TrackInfo],
) -> Distance:
"""Determines how "significant" an album metadata change would be.
Returns a Distance object. `album_info` is an AlbumInfo object
reflecting the album to be compared. `items` is a sequence of all
Item objects that will be matched (order is not important).
`mapping` is a dictionary mapping Items to TrackInfo objects; the
keys are a subset of `items` and the values are a subset of
`album_info.tracks`.
"""
likelies, _ = get_most_common_tags(items)
dist = Distance()
# Artist, if not various.
if not album_info.va:
dist.add_string("artist", likelies["artist"], album_info.artist)
# Album.
dist.add_string("album", likelies["album"], album_info.album)
preferred_config = config["match"]["preferred"]
# Current or preferred media.
if album_info.media:
# Preferred media options.
media_patterns: Sequence[str] = preferred_config["media"].as_str_seq()
options = [
re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in media_patterns
]
if options:
dist.add_priority("media", album_info.media, options)
# Current media.
elif likelies["media"]:
dist.add_equality("media", album_info.media, likelies["media"])
# Mediums.
if likelies["disctotal"] and album_info.mediums:
dist.add_number("mediums", likelies["disctotal"], album_info.mediums)
# Prefer earliest release.
if album_info.year and preferred_config["original_year"]:
# Assume 1889 (earliest first gramophone discs) if we don't know the
# original year.
original = album_info.original_year or 1889
diff = abs(album_info.year - original)
diff_max = abs(datetime.date.today().year - original)
dist.add_ratio("year", diff, diff_max)
# Year.
elif likelies["year"] and album_info.year:
if likelies["year"] in (album_info.year, album_info.original_year):
# No penalty for matching release or original year.
dist.add("year", 0.0)
elif album_info.original_year:
# Prefer matchest closest to the release year.
diff = abs(likelies["year"] - album_info.year)
diff_max = abs(
datetime.date.today().year - album_info.original_year
)
dist.add_ratio("year", diff, diff_max)
else:
# Full penalty when there is no original year.
dist.add("year", 1.0)
# Preferred countries.
country_patterns: Sequence[str] = preferred_config["countries"].as_str_seq()
options = [re.compile(pat, re.I) for pat in country_patterns]
if album_info.country and options:
dist.add_priority("country", album_info.country, options)
# Country.
elif likelies["country"] and album_info.country:
dist.add_string("country", likelies["country"], album_info.country)
# Label.
if likelies["label"] and album_info.label:
dist.add_string("label", likelies["label"], album_info.label)
# Catalog number.
if likelies["catalognum"] and album_info.catalognum:
dist.add_string(
"catalognum", likelies["catalognum"], album_info.catalognum
)
# Disambiguation.
if likelies["albumdisambig"] and album_info.albumdisambig:
dist.add_string(
"albumdisambig", likelies["albumdisambig"], album_info.albumdisambig
)
# Album ID.
if likelies["mb_albumid"]:
dist.add_equality(
"album_id", likelies["mb_albumid"], album_info.album_id
)
# Tracks.
dist.tracks = {}
for item, track in mapping.items():
dist.tracks[track] = track_distance(item, track, album_info.va)
dist.add("tracks", dist.tracks[track].distance)
# Missing tracks.
for _ in range(len(album_info.tracks) - len(mapping)):
dist.add("missing_tracks", 1.0)
# Unmatched tracks.
for _ in range(len(items) - len(mapping)):
dist.add("unmatched_tracks", 1.0)
# Plugins.
dist.update(plugins.album_distance(items, album_info, mapping))
return dist

View file

@ -16,21 +16,15 @@
from __future__ import annotations
import re
from functools import total_ordering
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
from jellyfish import levenshtein_distance
from unidecode import unidecode
from beets import config, logging
from beets.util import as_string, cached_classproperty
from beets import logging
if TYPE_CHECKING:
from collections.abc import Iterator
from beets.library import Item
from .distance import Distance
log = logging.getLogger("beets")
V = TypeVar("V")
@ -254,328 +248,6 @@ class TrackInfo(AttrDict[Any]):
return dupe
# Candidate distance scoring.
# Parameters for string distance function.
# Words that can be moved to the end of a string using a comma.
SD_END_WORDS = ["the", "a", "an"]
# Reduced weights for certain portions of the string.
SD_PATTERNS = [
(r"^the ", 0.1),
(r"[\[\(]?(ep|single)[\]\)]?", 0.0),
(r"[\[\(]?(featuring|feat|ft)[\. :].+", 0.1),
(r"\(.*?\)", 0.3),
(r"\[.*?\]", 0.3),
(r"(, )?(pt\.|part) .+", 0.2),
]
# Replacements to use before testing distance.
SD_REPLACE = [
(r"&", "and"),
]
def _string_dist_basic(str1: str, str2: str) -> float:
"""Basic edit distance between two strings, ignoring
non-alphanumeric characters and case. Comparisons are based on a
transliteration/lowering to ASCII characters. Normalized by string
length.
"""
assert isinstance(str1, str)
assert isinstance(str2, str)
str1 = as_string(unidecode(str1))
str2 = as_string(unidecode(str2))
str1 = re.sub(r"[^a-z0-9]", "", str1.lower())
str2 = re.sub(r"[^a-z0-9]", "", str2.lower())
if not str1 and not str2:
return 0.0
return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2)))
def string_dist(str1: str | None, str2: str | None) -> float:
"""Gives an "intuitive" edit distance between two strings. This is
an edit distance, normalized by the string length, with a number of
tweaks that reflect intuition about text.
"""
if str1 is None and str2 is None:
return 0.0
if str1 is None or str2 is None:
return 1.0
str1 = str1.lower()
str2 = str2.lower()
# Don't penalize strings that move certain words to the end. For
# example, "the something" should be considered equal to
# "something, the".
for word in SD_END_WORDS:
if str1.endswith(", %s" % word):
str1 = "{} {}".format(word, str1[: -len(word) - 2])
if str2.endswith(", %s" % word):
str2 = "{} {}".format(word, str2[: -len(word) - 2])
# Perform a couple of basic normalizing substitutions.
for pat, repl in SD_REPLACE:
str1 = re.sub(pat, repl, str1)
str2 = re.sub(pat, repl, str2)
# Change the weight for certain string portions matched by a set
# of regular expressions. We gradually change the strings and build
# up penalties associated with parts of the string that were
# deleted.
base_dist = _string_dist_basic(str1, str2)
penalty = 0.0
for pat, weight in SD_PATTERNS:
# Get strings that drop the pattern.
case_str1 = re.sub(pat, "", str1)
case_str2 = re.sub(pat, "", str2)
if case_str1 != str1 or case_str2 != str2:
# If the pattern was present (i.e., it is deleted in the
# the current case), recalculate the distances for the
# modified strings.
case_dist = _string_dist_basic(case_str1, case_str2)
case_delta = max(0.0, base_dist - case_dist)
if case_delta == 0.0:
continue
# Shift our baseline strings down (to avoid rematching the
# same part of the string) and add a scaled distance
# amount to the penalties.
str1 = case_str1
str2 = case_str2
base_dist = case_dist
penalty += weight * case_delta
return base_dist + penalty
@total_ordering
class Distance:
"""Keeps track of multiple distance penalties. Provides a single
weighted distance for all penalties as well as a weighted distance
for each individual penalty.
"""
def __init__(self) -> None:
self._penalties: dict[str, list[float]] = {}
self.tracks: dict[TrackInfo, Distance] = {}
@cached_classproperty
def _weights(cls) -> dict[str, float]:
"""A dictionary from keys to floating-point weights."""
weights_view = config["match"]["distance_weights"]
weights = {}
for key in weights_view.keys():
weights[key] = weights_view[key].as_number()
return weights
# Access the components and their aggregates.
@property
def distance(self) -> float:
"""Return a weighted and normalized distance across all
penalties.
"""
dist_max = self.max_distance
if dist_max:
return self.raw_distance / self.max_distance
return 0.0
@property
def max_distance(self) -> float:
"""Return the maximum distance penalty (normalization factor)."""
dist_max = 0.0
for key, penalty in self._penalties.items():
dist_max += len(penalty) * self._weights[key]
return dist_max
@property
def raw_distance(self) -> float:
"""Return the raw (denormalized) distance."""
dist_raw = 0.0
for key, penalty in self._penalties.items():
dist_raw += sum(penalty) * self._weights[key]
return dist_raw
def items(self) -> list[tuple[str, float]]:
"""Return a list of (key, dist) pairs, with `dist` being the
weighted distance, sorted from highest to lowest. Does not
include penalties with a zero value.
"""
list_ = []
for key in self._penalties:
dist = self[key]
if dist:
list_.append((key, dist))
# Convert distance into a negative float we can sort items in
# ascending order (for keys, when the penalty is equal) and
# still get the items with the biggest distance first.
return sorted(
list_, key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0])
)
def __hash__(self) -> int:
return id(self)
def __eq__(self, other) -> bool:
return self.distance == other
# Behave like a float.
def __lt__(self, other) -> bool:
return self.distance < other
def __float__(self) -> float:
return self.distance
def __sub__(self, other) -> float:
return self.distance - other
def __rsub__(self, other) -> float:
return other - self.distance
def __str__(self) -> str:
return f"{self.distance:.2f}"
# Behave like a dict.
def __getitem__(self, key) -> float:
"""Returns the weighted distance for a named penalty."""
dist = sum(self._penalties[key]) * self._weights[key]
dist_max = self.max_distance
if dist_max:
return dist / dist_max
return 0.0
def __iter__(self) -> Iterator[tuple[str, float]]:
return iter(self.items())
def __len__(self) -> int:
return len(self.items())
def keys(self) -> list[str]:
return [key for key, _ in self.items()]
def update(self, dist: Distance):
"""Adds all the distance penalties from `dist`."""
if not isinstance(dist, Distance):
raise ValueError(
"`dist` must be a Distance object, not {}".format(type(dist))
)
for key, penalties in dist._penalties.items():
self._penalties.setdefault(key, []).extend(penalties)
# Adding components.
def _eq(self, value1: re.Pattern[str] | Any, value2: Any) -> bool:
"""Returns True if `value1` is equal to `value2`. `value1` may
be a compiled regular expression, in which case it will be
matched against `value2`.
"""
if isinstance(value1, re.Pattern):
return bool(value1.match(value2))
return value1 == value2
def add(self, key: str, dist: float):
"""Adds a distance penalty. `key` must correspond with a
configured weight setting. `dist` must be a float between 0.0
and 1.0, and will be added to any existing distance penalties
for the same key.
"""
if not 0.0 <= dist <= 1.0:
raise ValueError(f"`dist` must be between 0.0 and 1.0, not {dist}")
self._penalties.setdefault(key, []).append(dist)
def add_equality(
self,
key: str,
value: Any,
options: list[Any] | tuple[Any, ...] | Any,
):
"""Adds a distance penalty of 1.0 if `value` doesn't match any
of the values in `options`. If an option is a compiled regular
expression, it will be considered equal if it matches against
`value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
for opt in options:
if self._eq(opt, value):
dist = 0.0
break
else:
dist = 1.0
self.add(key, dist)
def add_expr(self, key: str, expr: bool):
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
or 0.0.
"""
if expr:
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_number(self, key: str, number1: int, number2: int):
"""Adds a distance penalty of 1.0 for each number of difference
between `number1` and `number2`, or 0.0 when there is no
difference. Use this when there is no upper limit on the
difference between the two numbers.
"""
diff = abs(number1 - number2)
if diff:
for i in range(diff):
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_priority(
self,
key: str,
value: Any,
options: list[Any] | tuple[Any, ...] | Any,
):
"""Adds a distance penalty that corresponds to the position at
which `value` appears in `options`. A distance penalty of 0.0
for the first option, or 1.0 if there is no matching option. If
an option is a compiled regular expression, it will be
considered equal if it matches against `value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
unit = 1.0 / (len(options) or 1)
for i, opt in enumerate(options):
if self._eq(opt, value):
dist = i * unit
break
else:
dist = 1.0
self.add(key, dist)
def add_ratio(
self,
key: str,
number1: int | float,
number2: int | float,
):
"""Adds a distance penalty for `number1` as a ratio of `number2`.
`number1` is bound at 0 and `number2`.
"""
number = float(max(min(number1, number2), 0))
if number2:
dist = number / number2
else:
dist = 0.0
self.add(key, dist)
def add_string(self, key: str, str1: str | None, str2: str | None):
"""Adds a distance penalty based on the edit distance between
`str1` and `str2`.
"""
dist = string_dist(str1, str2)
self.add(key, dist)
# Structures that compose all the information for a candidate match.

View file

@ -18,37 +18,23 @@ releases and tracks.
from __future__ import annotations
import datetime
import re
from enum import IntEnum
from functools import cache
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
import lap
import numpy as np
from beets import config, logging, plugins
from beets.autotag import (
AlbumInfo,
AlbumMatch,
Distance,
TrackInfo,
TrackMatch,
hooks,
)
from beets.util import plurality
from beets.autotag import AlbumInfo, AlbumMatch, TrackInfo, TrackMatch, hooks
from beets.util import get_most_common_tags
from .distance import VA_ARTISTS, distance, track_distance
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from beets.library import Item
# Artist signals that indicate "various artists". These are used at the
# album level to determine whether a given release is likely a VA
# release and also on the track level to to remove the penalty for
# differing artists.
VA_ARTISTS = ("", "various artists", "various", "va", "unknown")
# Global logger.
log = logging.getLogger("beets")
@ -80,44 +66,6 @@ class Proposal(NamedTuple):
# Primary matching functionality.
def current_metadata(
items: Iterable[Item],
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Extract the likely current metadata for an album given a list of its
items. Return two dictionaries:
- The most common value for each field.
- Whether each field's value was unanimous (values are booleans).
"""
assert items # Must be nonempty.
likelies = {}
consensus = {}
fields = [
"artist",
"album",
"albumartist",
"year",
"disctotal",
"mb_albumid",
"label",
"barcode",
"catalognum",
"country",
"media",
"albumdisambig",
]
for field in fields:
values = [item[field] for item in items if item]
likelies[field], freq = plurality(values)
consensus[field] = freq == len(values)
# If there's an album artist consensus, use this for the artist.
if consensus["albumartist"] and likelies["albumartist"]:
likelies["artist"] = likelies["albumartist"]
return likelies, consensus
def assign_items(
items: Sequence[Item],
tracks: Sequence[TrackInfo],
@ -150,191 +98,6 @@ def assign_items(
return mapping, extra_items, extra_tracks
def track_index_changed(item: Item, track_info: TrackInfo) -> bool:
"""Returns True if the item and track info index is different. Tolerates
per disc and per release numbering.
"""
return item.track not in (track_info.medium_index, track_info.index)
@cache
def get_track_length_grace() -> float:
"""Get cached grace period for track length matching."""
return config["match"]["track_length_grace"].as_number()
@cache
def get_track_length_max() -> float:
"""Get cached maximum track length for track length matching."""
return config["match"]["track_length_max"].as_number()
def track_distance(
item: Item,
track_info: TrackInfo,
incl_artist: bool = False,
) -> Distance:
"""Determines the significance of a track metadata change. Returns a
Distance object. `incl_artist` indicates that a distance component should
be included for the track artist (i.e., for various-artist releases).
``track_length_grace`` and ``track_length_max`` configuration options are
cached because this function is called many times during the matching
process and their access comes with a performance overhead.
"""
dist = hooks.Distance()
# Length.
if info_length := track_info.length:
diff = abs(item.length - info_length) - get_track_length_grace()
dist.add_ratio("track_length", diff, get_track_length_max())
# Title.
dist.add_string("track_title", item.title, track_info.title)
# Artist. Only check if there is actually an artist in the track data.
if (
incl_artist
and track_info.artist
and item.artist.lower() not in VA_ARTISTS
):
dist.add_string("track_artist", item.artist, track_info.artist)
# Track index.
if track_info.index and item.track:
dist.add_expr("track_index", track_index_changed(item, track_info))
# Track ID.
if item.mb_trackid:
dist.add_expr("track_id", item.mb_trackid != track_info.track_id)
# Penalize mismatching disc numbers.
if track_info.medium and item.disc:
dist.add_expr("medium", item.disc != track_info.medium)
# Plugins.
dist.update(plugins.track_distance(item, track_info))
return dist
def distance(
items: Sequence[Item],
album_info: AlbumInfo,
mapping: dict[Item, TrackInfo],
) -> Distance:
"""Determines how "significant" an album metadata change would be.
Returns a Distance object. `album_info` is an AlbumInfo object
reflecting the album to be compared. `items` is a sequence of all
Item objects that will be matched (order is not important).
`mapping` is a dictionary mapping Items to TrackInfo objects; the
keys are a subset of `items` and the values are a subset of
`album_info.tracks`.
"""
likelies, _ = current_metadata(items)
dist = hooks.Distance()
# Artist, if not various.
if not album_info.va:
dist.add_string("artist", likelies["artist"], album_info.artist)
# Album.
dist.add_string("album", likelies["album"], album_info.album)
preferred_config = config["match"]["preferred"]
# Current or preferred media.
if album_info.media:
# Preferred media options.
media_patterns: Sequence[str] = preferred_config["media"].as_str_seq()
options = [
re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in media_patterns
]
if options:
dist.add_priority("media", album_info.media, options)
# Current media.
elif likelies["media"]:
dist.add_equality("media", album_info.media, likelies["media"])
# Mediums.
if likelies["disctotal"] and album_info.mediums:
dist.add_number("mediums", likelies["disctotal"], album_info.mediums)
# Prefer earliest release.
if album_info.year and preferred_config["original_year"]:
# Assume 1889 (earliest first gramophone discs) if we don't know the
# original year.
original = album_info.original_year or 1889
diff = abs(album_info.year - original)
diff_max = abs(datetime.date.today().year - original)
dist.add_ratio("year", diff, diff_max)
# Year.
elif likelies["year"] and album_info.year:
if likelies["year"] in (album_info.year, album_info.original_year):
# No penalty for matching release or original year.
dist.add("year", 0.0)
elif album_info.original_year:
# Prefer matchest closest to the release year.
diff = abs(likelies["year"] - album_info.year)
diff_max = abs(
datetime.date.today().year - album_info.original_year
)
dist.add_ratio("year", diff, diff_max)
else:
# Full penalty when there is no original year.
dist.add("year", 1.0)
# Preferred countries.
country_patterns: Sequence[str] = preferred_config["countries"].as_str_seq()
options = [re.compile(pat, re.I) for pat in country_patterns]
if album_info.country and options:
dist.add_priority("country", album_info.country, options)
# Country.
elif likelies["country"] and album_info.country:
dist.add_string("country", likelies["country"], album_info.country)
# Label.
if likelies["label"] and album_info.label:
dist.add_string("label", likelies["label"], album_info.label)
# Catalog number.
if likelies["catalognum"] and album_info.catalognum:
dist.add_string(
"catalognum", likelies["catalognum"], album_info.catalognum
)
# Disambiguation.
if likelies["albumdisambig"] and album_info.albumdisambig:
dist.add_string(
"albumdisambig", likelies["albumdisambig"], album_info.albumdisambig
)
# Album ID.
if likelies["mb_albumid"]:
dist.add_equality(
"album_id", likelies["mb_albumid"], album_info.album_id
)
# Tracks.
dist.tracks = {}
for item, track in mapping.items():
dist.tracks[track] = track_distance(item, track, album_info.va)
dist.add("tracks", dist.tracks[track].distance)
# Missing tracks.
for _ in range(len(album_info.tracks) - len(mapping)):
dist.add("missing_tracks", 1.0)
# Unmatched tracks.
for _ in range(len(items) - len(mapping)):
dist.add("unmatched_tracks", 1.0)
# Plugins.
dist.update(plugins.album_distance(items, album_info, mapping))
return dist
def match_by_id(items: Iterable[Item]) -> AlbumInfo | None:
"""If the items are tagged with an external source ID, return an
AlbumInfo object for the corresponding album. Otherwise, returns
@ -499,7 +262,7 @@ def tag_album(
candidates.
"""
# Get current metadata.
likelies, consensus = current_metadata(items)
likelies, consensus = get_most_common_tags(items)
cur_artist: str = likelies["artist"]
cur_album: str = likelies["album"]
log.debug("Tagging {0} - {1}", cur_artist, cur_album)

View file

@ -228,7 +228,7 @@ class ImportTask(BaseImportTask):
or APPLY (in which case the data comes from the choice).
"""
if self.choice_flag in (Action.ASIS, Action.RETAG):
likelies, consensus = autotag.current_metadata(self.items)
likelies, consensus = util.get_most_common_tags(self.items)
return likelies
elif self.choice_flag is Action.APPLY and self.match:
return self.match.info.copy()

View file

@ -53,7 +53,8 @@ if TYPE_CHECKING:
from confuse import ConfigView
from beets.autotag import AlbumInfo, Distance, TrackInfo
from beets.autotag import AlbumInfo, TrackInfo
from beets.autotag.distance import Distance
from beets.dbcore import Query
from beets.dbcore.db import FieldQueryType
from beets.dbcore.types import Type
@ -224,7 +225,7 @@ class BeetsPlugin:
"""Should return a Distance object to be added to the
distance for every track comparison.
"""
from beets.autotag.hooks import Distance
from beets.autotag.distance import Distance
return Distance()
@ -237,7 +238,7 @@ class BeetsPlugin:
"""Should return a Distance object to be added to the
distance for every album-level comparison.
"""
from beets.autotag.hooks import Distance
from beets.autotag.distance import Distance
return Distance()
@ -458,7 +459,7 @@ def track_distance(item: Item, info: TrackInfo) -> Distance:
"""Gets the track distance calculated by all loaded plugins.
Returns a Distance object.
"""
from beets.autotag.hooks import Distance
from beets.autotag.distance import Distance
dist = Distance()
for plugin in find_plugins():
@ -472,7 +473,7 @@ def album_distance(
mapping: dict[Item, TrackInfo],
) -> Distance:
"""Returns the album distance calculated by plugins."""
from beets.autotag.hooks import Distance
from beets.autotag.distance import Distance
dist = Distance()
for plugin in find_plugins():
@ -654,73 +655,13 @@ def feat_tokens(for_artist: bool = True) -> str:
)
def sanitize_choices(
choices: Sequence[str], choices_all: Sequence[str]
) -> list[str]:
"""Clean up a stringlist configuration attribute: keep only choices
elements present in choices_all, remove duplicate elements, expand '*'
wildcard while keeping original stringlist order.
"""
seen: set[str] = set()
others = [x for x in choices_all if x not in choices]
res: list[str] = []
for s in choices:
if s not in seen:
if s in list(choices_all):
res.append(s)
elif s == "*":
res.extend(others)
seen.add(s)
return res
def sanitize_pairs(
pairs: Sequence[tuple[str, str]], pairs_all: Sequence[tuple[str, str]]
) -> list[tuple[str, str]]:
"""Clean up a single-element mapping configuration attribute as returned
by Confuse's `Pairs` template: keep only two-element tuples present in
pairs_all, remove duplicate elements, expand ('str', '*') and ('*', '*')
wildcards while keeping the original order. Note that ('*', '*') and
('*', 'whatever') have the same effect.
For example,
>>> sanitize_pairs(
... [('foo', 'baz bar'), ('key', '*'), ('*', '*')],
... [('foo', 'bar'), ('foo', 'baz'), ('foo', 'foobar'),
... ('key', 'value')]
... )
[('foo', 'baz'), ('foo', 'bar'), ('key', 'value'), ('foo', 'foobar')]
"""
pairs_all = list(pairs_all)
seen: set[tuple[str, str]] = set()
others = [x for x in pairs_all if x not in pairs]
res: list[tuple[str, str]] = []
for k, values in pairs:
for v in values.split():
x = (k, v)
if x in pairs_all:
if x not in seen:
seen.add(x)
res.append(x)
elif k == "*":
new = [o for o in others if o not in seen]
seen.update(new)
res.extend(new)
elif v == "*":
new = [o for o in others if o not in seen and o[0] == k]
seen.update(new)
res.extend(new)
return res
def get_distance(
config: ConfigView, data_source: str, info: AlbumInfo | TrackInfo
) -> Distance:
"""Returns the ``data_source`` weight and the maximum source weight
for albums or individual tracks.
"""
from beets.autotag.hooks import Distance
from beets.autotag.distance import Distance
dist = Distance()
if info.data_source == data_source:

View file

@ -56,6 +56,8 @@ if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from logging import Logger
from beets.library import Item
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
@ -814,6 +816,44 @@ def plurality(objs: Iterable[T]) -> tuple[T, int]:
return c.most_common(1)[0]
def get_most_common_tags(
items: Sequence[Item],
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Extract the likely current metadata for an album given a list of its
items. Return two dictionaries:
- The most common value for each field.
- Whether each field's value was unanimous (values are booleans).
"""
assert items # Must be nonempty.
likelies = {}
consensus = {}
fields = [
"artist",
"album",
"albumartist",
"year",
"disctotal",
"mb_albumid",
"label",
"barcode",
"catalognum",
"country",
"media",
"albumdisambig",
]
for field in fields:
values = [item[field] for item in items if item]
likelies[field], freq = plurality(values)
consensus[field] = freq == len(values)
# If there's an album artist consensus, use this for the artist.
if consensus["albumartist"] and likelies["albumartist"]:
likelies["artist"] = likelies["albumartist"]
return likelies, consensus
# stdout and stderr as bytes
class CommandOutput(NamedTuple):
stdout: bytes

66
beets/util/config.py Normal file
View file

@ -0,0 +1,66 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Collection, Sequence
def sanitize_choices(
choices: Sequence[str], choices_all: Collection[str]
) -> list[str]:
"""Clean up a stringlist configuration attribute: keep only choices
elements present in choices_all, remove duplicate elements, expand '*'
wildcard while keeping original stringlist order.
"""
seen: set[str] = set()
others = [x for x in choices_all if x not in choices]
res: list[str] = []
for s in choices:
if s not in seen:
if s in list(choices_all):
res.append(s)
elif s == "*":
res.extend(others)
seen.add(s)
return res
def sanitize_pairs(
pairs: Sequence[tuple[str, str]], pairs_all: Sequence[tuple[str, str]]
) -> list[tuple[str, str]]:
"""Clean up a single-element mapping configuration attribute as returned
by Confuse's `Pairs` template: keep only two-element tuples present in
pairs_all, remove duplicate elements, expand ('str', '*') and ('*', '*')
wildcards while keeping the original order. Note that ('*', '*') and
('*', 'whatever') have the same effect.
For example,
>>> sanitize_pairs(
... [('foo', 'baz bar'), ('key', '*'), ('*', '*')],
... [('foo', 'bar'), ('foo', 'baz'), ('foo', 'foobar'),
... ('key', 'value')]
... )
[('foo', 'baz'), ('foo', 'bar'), ('key', 'value'), ('foo', 'foobar')]
"""
pairs_all = list(pairs_all)
seen: set[tuple[str, str]] = set()
others = [x for x in pairs_all if x not in pairs]
res: list[tuple[str, str]] = []
for k, values in pairs:
for v in values.split():
x = (k, v)
if x in pairs_all:
if x not in seen:
seen.add(x)
res.append(x)
elif k == "*":
new = [o for o in others if o not in seen]
seen.update(new)
res.extend(new)
elif v == "*":
new = [o for o in others if o not in seen and o[0] == k]
seen.update(new)
res.extend(new)
return res

View file

@ -24,7 +24,7 @@ import acoustid
import confuse
from beets import config, plugins, ui, util
from beets.autotag.hooks import Distance
from beets.autotag.distance import Distance
from beetsplug.musicbrainz import MusicBrainzPlugin
API_KEY = "1vOwZtEn"

View file

@ -38,7 +38,8 @@ from typing_extensions import TypedDict
import beets
import beets.ui
from beets import config
from beets.autotag.hooks import AlbumInfo, TrackInfo, string_dist
from beets.autotag.distance import string_dist
from beets.autotag.hooks import AlbumInfo, TrackInfo
from beets.plugins import BeetsPlugin, MetadataSourcePlugin, get_distance
from beets.util.id_extractors import extract_release_id

View file

@ -32,6 +32,7 @@ from mediafile import image_mime_type
from beets import config, importer, plugins, ui, util
from beets.util import bytestring_path, get_temp_filename, sorted_walk, syspath
from beets.util.artresizer import ArtResizer
from beets.util.config import sanitize_pairs
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Sequence
@ -1396,7 +1397,7 @@ class FetchArtPlugin(plugins.BeetsPlugin, RequestMixin):
if s_cls.available(self._log, self.config)
for c in s_cls.VALID_MATCHING_CRITERIA
]
sources = plugins.sanitize_pairs(
sources = sanitize_pairs(
self.config["sources"].as_pairs(default_value="*"),
available_sources,
)

View file

@ -38,7 +38,8 @@ from unidecode import unidecode
import beets
from beets import plugins, ui
from beets.autotag.hooks import string_dist
from beets.autotag.distance import string_dist
from beets.util.config import sanitize_choices
if TYPE_CHECKING:
from logging import Logger
@ -957,7 +958,7 @@ class LyricsPlugin(RequestHandler, plugins.BeetsPlugin):
def backends(self) -> list[Backend]:
user_sources = self.config["sources"].get()
chosen = plugins.sanitize_choices(user_sources, self.BACKEND_BY_NAME)
chosen = sanitize_choices(user_sources, self.BACKEND_BY_NAME)
if "google" in chosen and not self.config["google_API_key"].get():
self.warn("Disabling Google source: no API key configured.")
chosen.remove("google")

View file

@ -0,0 +1,299 @@
import re
import pytest
from beets.autotag import AlbumInfo, TrackInfo
from beets.autotag.distance import (
Distance,
distance,
string_dist,
track_distance,
)
from beets.library import Item
from beets.test.helper import ConfigMixin
_p = pytest.param
class TestDistance:
@pytest.fixture(scope="class")
def config(self):
return ConfigMixin().config
@pytest.fixture
def dist(self, config):
config["match"]["distance_weights"]["source"] = 2.0
config["match"]["distance_weights"]["album"] = 4.0
config["match"]["distance_weights"]["medium"] = 2.0
Distance.__dict__["_weights"].cache = {}
return Distance()
def test_add(self, dist):
dist.add("add", 1.0)
assert dist._penalties == {"add": [1.0]}
@pytest.mark.parametrize(
"key, args_with_expected",
[
(
"equality",
[
(("ghi", ["abc", "def", "ghi"]), [0.0]),
(("xyz", ["abc", "def", "ghi"]), [0.0, 1.0]),
(("abc", re.compile(r"ABC", re.I)), [0.0, 1.0, 0.0]),
],
),
("expr", [((True,), [1.0]), ((False,), [1.0, 0.0])]),
(
"number",
[
((1, 1), [0.0]),
((1, 2), [0.0, 1.0]),
((2, 1), [0.0, 1.0, 1.0]),
((-1, 2), [0.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
],
),
(
"priority",
[
(("abc", "abc"), [0.0]),
(("def", ["abc", "def"]), [0.0, 0.5]),
(("gh", ["ab", "cd", "ef", re.compile("GH", re.I)]), [0.0, 0.5, 0.75]), # noqa: E501
(("xyz", ["abc", "def"]), [0.0, 0.5, 0.75, 1.0]),
],
),
(
"ratio",
[
((25, 100), [0.25]),
((10, 5), [0.25, 1.0]),
((-5, 5), [0.25, 1.0, 0.0]),
((5, 0), [0.25, 1.0, 0.0, 0.0]),
],
),
(
"string",
[
(("abc", "bcd"), [2 / 3]),
(("abc", None), [2 / 3, 1]),
((None, None), [2 / 3, 1, 0]),
],
),
],
) # fmt: skip
def test_add_methods(self, dist, key, args_with_expected):
method = getattr(dist, f"add_{key}")
for arg_set, expected in args_with_expected:
method(key, *arg_set)
assert dist._penalties[key] == expected
def test_distance(self, dist):
dist.add("album", 0.5)
dist.add("media", 0.25)
dist.add("media", 0.75)
assert dist.distance == 0.5
assert dist.max_distance == 6.0
assert dist.raw_distance == 3.0
assert dist["album"] == 1 / 3
assert dist["media"] == 1 / 6
def test_operators(self, dist):
dist.add("source", 0.0)
dist.add("album", 0.5)
dist.add("medium", 0.25)
dist.add("medium", 0.75)
assert len(dist) == 2
assert list(dist) == [("album", 0.2), ("medium", 0.2)]
assert dist == 0.4
assert dist < 1.0
assert dist > 0.0
assert dist - 0.4 == 0.0
assert 0.4 - dist == 0.0
assert float(dist) == 0.4
def test_penalties_sort(self, dist):
dist.add("album", 0.1875)
dist.add("medium", 0.75)
assert dist.items() == [("medium", 0.25), ("album", 0.125)]
# Sort by key if distance is equal.
dist = Distance()
dist.add("album", 0.375)
dist.add("medium", 0.75)
assert dist.items() == [("album", 0.25), ("medium", 0.25)]
def test_update(self, dist):
dist1 = dist
dist1.add("album", 0.5)
dist1.add("media", 1.0)
dist2 = Distance()
dist2.add("album", 0.75)
dist2.add("album", 0.25)
dist2.add("media", 0.05)
dist1.update(dist2)
assert dist1._penalties == {
"album": [0.5, 0.75, 0.25],
"media": [1.0, 0.05],
}
class TestTrackDistance:
@pytest.fixture(scope="class")
def info(self):
return TrackInfo(title="title", artist="artist")
@pytest.mark.parametrize(
"title, artist, expected_penalty",
[
_p("title", "artist", False, id="identical"),
_p("title", "Various Artists", False, id="tolerate-va"),
_p("title", "different artist", True, id="different-artist"),
_p("different title", "artist", True, id="different-title"),
],
)
def test_track_distance(self, info, title, artist, expected_penalty):
item = Item(artist=artist, title=title)
assert (
bool(track_distance(item, info, incl_artist=True))
== expected_penalty
)
class TestAlbumDistance:
@pytest.fixture(scope="class")
def items(self):
return [
Item(
title=title,
track=track,
artist="artist",
album="album",
length=1,
)
for title, track in [("one", 1), ("two", 2), ("three", 3)]
]
@pytest.fixture
def get_dist(self, items):
def inner(info: AlbumInfo):
return distance(items, info, dict(zip(items, info.tracks)))
return inner
@pytest.fixture
def info(self, items):
return AlbumInfo(
artist="artist",
album="album",
tracks=[
TrackInfo(
title=i.title,
artist=i.artist,
index=i.track,
length=i.length,
)
for i in items
],
va=False,
)
def test_identical_albums(self, get_dist, info):
assert get_dist(info) == 0
def test_incomplete_album(self, get_dist, info):
info.tracks.pop(2)
assert 0 < float(get_dist(info)) < 0.2
def test_overly_complete_album(self, get_dist, info):
info.tracks.append(
Item(index=4, title="four", artist="artist", length=1)
)
assert 0 < float(get_dist(info)) < 0.2
@pytest.mark.parametrize("va", [True, False])
def test_albumartist(self, get_dist, info, va):
info.artist = "another artist"
info.va = va
assert bool(get_dist(info)) is not va
def test_comp_no_track_artists(self, get_dist, info):
# Some VA releases don't have track artists (incomplete metadata).
info.artist = "another artist"
info.va = True
for track in info.tracks:
track.artist = None
assert get_dist(info) == 0
def test_comp_track_artists_do_not_match(self, get_dist, info):
info.va = True
info.tracks[0].artist = "another artist"
assert get_dist(info) != 0
def test_tracks_out_of_order(self, get_dist, info):
tracks = info.tracks
tracks[1].title, tracks[2].title = tracks[2].title, tracks[1].title
assert 0 < float(get_dist(info)) < 0.2
def test_two_medium_release(self, get_dist, info):
info.tracks[0].medium_index = 1
info.tracks[1].medium_index = 2
info.tracks[2].medium_index = 1
assert get_dist(info) == 0
class TestStringDistance:
@pytest.mark.parametrize(
"string1, string2",
[
("Some String", "Some String"),
("Some String", "Some.String!"),
("Some String", "sOME sTring"),
("My Song (EP)", "My Song"),
("The Song Title", "Song Title, The"),
("A Song Title", "Song Title, A"),
("An Album Title", "Album Title, An"),
("", ""),
("Untitled", "[Untitled]"),
("And", "&"),
("\xe9\xe1\xf1", "ean"),
],
)
def test_matching_distance(self, string1, string2):
assert string_dist(string1, string2) == 0.0
def test_different_distance(self):
assert string_dist("Some String", "Totally Different") != 0.0
@pytest.mark.parametrize(
"string1, string2, reference",
[
("XXX Band Name", "The Band Name", "Band Name"),
("One .Two.", "One (Two)", "One"),
("One .Two.", "One [Two]", "One"),
("My Song blah Someone", "My Song feat Someone", "My Song"),
],
)
def test_relative_weights(self, string1, string2, reference):
assert string_dist(string2, reference) < string_dist(string1, reference)
def test_solo_pattern(self):
# Just make sure these don't crash.
string_dist("The ", "")
string_dist("(EP)", "(EP)")
string_dist(", An", "")

View file

@ -108,45 +108,6 @@ lyrics_pages = [
url_title="The Beatles - Lady Madonna Lyrics | AZLyrics.com",
marks=[xfail_on_ci("AZLyrics is blocked by Cloudflare")],
),
LyricsPage.make(
"http://www.chartlyrics.com/_LsLsZ7P4EK-F-LD4dJgDQ/Lady+Madonna.aspx",
"""
Lady Madonna,
Children at your feet
Wonder how you manage to make ends meet.
Who finds the money
When you pay the rent?
Did you think that money was heaven-sent?
Friday night arrives without a suitcase.
Sunday morning creeping like a nun.
Monday's child has learned to tie his bootlace.
See how they run.
Lady Madonna,
Baby at your breast
Wonders how you manage to feed the rest.
See how they run.
Lady Madonna,
Lying on the bed.
Listen to the music playing in your head.
Tuesday afternoon is never ending.
Wednesday morning papers didn't come.
Thursday night your stockings needed mending.
See how they run.
Lady Madonna,
Children at your feet
Wonder how you manage to make ends meet.
""",
url_title="The Beatles Lady Madonna lyrics",
),
LyricsPage.make(
"https://www.dainuzodziai.lt/m/mergaites-nori-mylet-atlanta/",
"""

View file

@ -14,488 +14,12 @@
"""Tests for autotagging functionality."""
import re
import unittest
import pytest
from beets import autotag, config
from beets.autotag import AlbumInfo, TrackInfo, correct_list_fields, match
from beets.autotag.hooks import Distance, string_dist
from beets.library import Item
from beets.test.helper import BeetsTestCase, ConfigMixin
from beets.util import plurality
class PluralityTest(BeetsTestCase):
def test_plurality_consensus(self):
objs = [1, 1, 1, 1]
obj, freq = plurality(objs)
assert obj == 1
assert freq == 4
def test_plurality_near_consensus(self):
objs = [1, 1, 2, 1]
obj, freq = plurality(objs)
assert obj == 1
assert freq == 3
def test_plurality_conflict(self):
objs = [1, 1, 2, 2, 3]
obj, freq = plurality(objs)
assert obj in (1, 2)
assert freq == 2
def test_plurality_empty_sequence_raises_error(self):
with pytest.raises(ValueError, match="must be non-empty"):
plurality([])
def test_current_metadata_finds_pluralities(self):
items = [
Item(artist="The Beetles", album="The White Album"),
Item(artist="The Beatles", album="The White Album"),
Item(artist="The Beatles", album="Teh White Album"),
]
likelies, consensus = match.current_metadata(items)
assert likelies["artist"] == "The Beatles"
assert likelies["album"] == "The White Album"
assert not consensus["artist"]
def test_current_metadata_artist_consensus(self):
items = [
Item(artist="The Beatles", album="The White Album"),
Item(artist="The Beatles", album="The White Album"),
Item(artist="The Beatles", album="Teh White Album"),
]
likelies, consensus = match.current_metadata(items)
assert likelies["artist"] == "The Beatles"
assert likelies["album"] == "The White Album"
assert consensus["artist"]
def test_albumartist_consensus(self):
items = [
Item(artist="tartist1", album="album", albumartist="aartist"),
Item(artist="tartist2", album="album", albumartist="aartist"),
Item(artist="tartist3", album="album", albumartist="aartist"),
]
likelies, consensus = match.current_metadata(items)
assert likelies["artist"] == "aartist"
assert not consensus["artist"]
def test_current_metadata_likelies(self):
fields = [
"artist",
"album",
"albumartist",
"year",
"disctotal",
"mb_albumid",
"label",
"barcode",
"catalognum",
"country",
"media",
"albumdisambig",
]
items = [Item(**{f: f"{f}_{i or 1}" for f in fields}) for i in range(5)]
likelies, _ = match.current_metadata(items)
for f in fields:
if isinstance(likelies[f], int):
assert likelies[f] == 0
else:
assert likelies[f] == f"{f}_1"
def _make_item(title, track, artist="some artist"):
return Item(
title=title,
track=track,
artist=artist,
album="some album",
length=1,
mb_trackid="",
mb_albumid="",
mb_artistid="",
)
def _make_trackinfo():
return [
TrackInfo(
title="one", track_id=None, artist="some artist", length=1, index=1
),
TrackInfo(
title="two", track_id=None, artist="some artist", length=1, index=2
),
TrackInfo(
title="three",
track_id=None,
artist="some artist",
length=1,
index=3,
),
]
def _clear_weights():
"""Hack around the lazy descriptor used to cache weights for
Distance calculations.
"""
Distance.__dict__["_weights"].cache = {}
class DistanceTest(BeetsTestCase):
def tearDown(self):
super().tearDown()
_clear_weights()
def test_add(self):
dist = Distance()
dist.add("add", 1.0)
assert dist._penalties == {"add": [1.0]}
def test_add_equality(self):
dist = Distance()
dist.add_equality("equality", "ghi", ["abc", "def", "ghi"])
assert dist._penalties["equality"] == [0.0]
dist.add_equality("equality", "xyz", ["abc", "def", "ghi"])
assert dist._penalties["equality"] == [0.0, 1.0]
dist.add_equality("equality", "abc", re.compile(r"ABC", re.I))
assert dist._penalties["equality"] == [0.0, 1.0, 0.0]
def test_add_expr(self):
dist = Distance()
dist.add_expr("expr", True)
assert dist._penalties["expr"] == [1.0]
dist.add_expr("expr", False)
assert dist._penalties["expr"] == [1.0, 0.0]
def test_add_number(self):
dist = Distance()
# Add a full penalty for each number of difference between two numbers.
dist.add_number("number", 1, 1)
assert dist._penalties["number"] == [0.0]
dist.add_number("number", 1, 2)
assert dist._penalties["number"] == [0.0, 1.0]
dist.add_number("number", 2, 1)
assert dist._penalties["number"] == [0.0, 1.0, 1.0]
dist.add_number("number", -1, 2)
assert dist._penalties["number"] == [0.0, 1.0, 1.0, 1.0, 1.0, 1.0]
def test_add_priority(self):
dist = Distance()
dist.add_priority("priority", "abc", "abc")
assert dist._penalties["priority"] == [0.0]
dist.add_priority("priority", "def", ["abc", "def"])
assert dist._penalties["priority"] == [0.0, 0.5]
dist.add_priority(
"priority", "gh", ["ab", "cd", "ef", re.compile("GH", re.I)]
)
assert dist._penalties["priority"] == [0.0, 0.5, 0.75]
dist.add_priority("priority", "xyz", ["abc", "def"])
assert dist._penalties["priority"] == [0.0, 0.5, 0.75, 1.0]
def test_add_ratio(self):
dist = Distance()
dist.add_ratio("ratio", 25, 100)
assert dist._penalties["ratio"] == [0.25]
dist.add_ratio("ratio", 10, 5)
assert dist._penalties["ratio"] == [0.25, 1.0]
dist.add_ratio("ratio", -5, 5)
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0]
dist.add_ratio("ratio", 5, 0)
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0, 0.0]
def test_add_string(self):
dist = Distance()
sdist = string_dist("abc", "bcd")
dist.add_string("string", "abc", "bcd")
assert dist._penalties["string"] == [sdist]
assert dist._penalties["string"] != [0]
def test_add_string_none(self):
dist = Distance()
dist.add_string("string", None, "string")
assert dist._penalties["string"] == [1]
def test_add_string_both_none(self):
dist = Distance()
dist.add_string("string", None, None)
assert dist._penalties["string"] == [0]
def test_distance(self):
config["match"]["distance_weights"]["album"] = 2.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("media", 0.25)
dist.add("media", 0.75)
assert dist.distance == 0.5
# __getitem__()
assert dist["album"] == 0.25
assert dist["media"] == 0.25
def test_max_distance(self):
config["match"]["distance_weights"]["album"] = 3.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("medium", 0.0)
dist.add("medium", 0.0)
assert dist.max_distance == 5.0
def test_operators(self):
config["match"]["distance_weights"]["source"] = 1.0
config["match"]["distance_weights"]["album"] = 2.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("source", 0.0)
dist.add("album", 0.5)
dist.add("medium", 0.25)
dist.add("medium", 0.75)
assert len(dist) == 2
assert list(dist) == [("album", 0.2), ("medium", 0.2)]
assert dist == 0.4
assert dist < 1.0
assert dist > 0.0
assert dist - 0.4 == 0.0
assert 0.4 - dist == 0.0
assert float(dist) == 0.4
def test_raw_distance(self):
config["match"]["distance_weights"]["album"] = 3.0
config["match"]["distance_weights"]["medium"] = 1.0
_clear_weights()
dist = Distance()
dist.add("album", 0.5)
dist.add("medium", 0.25)
dist.add("medium", 0.5)
assert dist.raw_distance == 2.25
def test_items(self):
config["match"]["distance_weights"]["album"] = 4.0
config["match"]["distance_weights"]["medium"] = 2.0
_clear_weights()
dist = Distance()
dist.add("album", 0.1875)
dist.add("medium", 0.75)
assert dist.items() == [("medium", 0.25), ("album", 0.125)]
# Sort by key if distance is equal.
dist = Distance()
dist.add("album", 0.375)
dist.add("medium", 0.75)
assert dist.items() == [("album", 0.25), ("medium", 0.25)]
def test_update(self):
dist1 = Distance()
dist1.add("album", 0.5)
dist1.add("media", 1.0)
dist2 = Distance()
dist2.add("album", 0.75)
dist2.add("album", 0.25)
dist2.add("media", 0.05)
dist1.update(dist2)
assert dist1._penalties == {
"album": [0.5, 0.75, 0.25],
"media": [1.0, 0.05],
}
class TrackDistanceTest(BeetsTestCase):
def test_identical_tracks(self):
item = _make_item("one", 1)
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist == 0.0
def test_different_title(self):
item = _make_item("foo", 1)
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist != 0.0
def test_different_artist(self):
item = _make_item("one", 1)
item.artist = "foo"
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist != 0.0
def test_various_artists_tolerated(self):
item = _make_item("one", 1)
item.artist = "Various Artists"
info = _make_trackinfo()[0]
dist = match.track_distance(item, info, incl_artist=True)
assert dist == 0.0
class AlbumDistanceTest(BeetsTestCase):
def _mapping(self, items, info):
out = {}
for i, t in zip(items, info.tracks):
out[i] = t
return out
def _dist(self, items, info):
return match.distance(items, info, self._mapping(items, info))
def test_identical_albums(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
assert self._dist(items, info) == 0
def test_incomplete_album(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
dist = self._dist(items, info)
assert dist != 0
# Make sure the distance is not too great
assert dist < 0.2
def test_global_artists_differ(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="someone else",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
assert self._dist(items, info) != 0
def test_comp_track_artists_match(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="should be ignored",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
assert self._dist(items, info) == 0
def test_comp_no_track_artists(self):
# Some VA releases don't have track artists (incomplete metadata).
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="should be ignored",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
info.tracks[0].artist = None
info.tracks[1].artist = None
info.tracks[2].artist = None
assert self._dist(items, info) == 0
def test_comp_track_artists_do_not_match(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2, "someone else"))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=True,
)
assert self._dist(items, info) != 0
def test_tracks_out_of_order(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("three", 2))
items.append(_make_item("two", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
dist = self._dist(items, info)
assert 0 < dist < 0.2
def test_two_medium_release(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 3))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
info.tracks[0].medium_index = 1
info.tracks[1].medium_index = 2
info.tracks[2].medium_index = 1
dist = self._dist(items, info)
assert dist == 0
def test_per_medium_track_numbers(self):
items = []
items.append(_make_item("one", 1))
items.append(_make_item("two", 2))
items.append(_make_item("three", 1))
info = AlbumInfo(
artist="some artist",
album="some album",
tracks=_make_trackinfo(),
va=False,
)
info.tracks[0].medium_index = 1
info.tracks[1].medium_index = 2
info.tracks[2].medium_index = 1
dist = self._dist(items, info)
assert dist == 0
class TestAssignment(ConfigMixin):
@ -920,82 +444,6 @@ class ApplyCompilationTest(BeetsTestCase, ApplyTestUtil):
assert self.items[1].comp
class StringDistanceTest(unittest.TestCase):
def test_equal_strings(self):
dist = string_dist("Some String", "Some String")
assert dist == 0.0
def test_different_strings(self):
dist = string_dist("Some String", "Totally Different")
assert dist != 0.0
def test_punctuation_ignored(self):
dist = string_dist("Some String", "Some.String!")
assert dist == 0.0
def test_case_ignored(self):
dist = string_dist("Some String", "sOME sTring")
assert dist == 0.0
def test_leading_the_has_lower_weight(self):
dist1 = string_dist("XXX Band Name", "Band Name")
dist2 = string_dist("The Band Name", "Band Name")
assert dist2 < dist1
def test_parens_have_lower_weight(self):
dist1 = string_dist("One .Two.", "One")
dist2 = string_dist("One (Two)", "One")
assert dist2 < dist1
def test_brackets_have_lower_weight(self):
dist1 = string_dist("One .Two.", "One")
dist2 = string_dist("One [Two]", "One")
assert dist2 < dist1
def test_ep_label_has_zero_weight(self):
dist = string_dist("My Song (EP)", "My Song")
assert dist == 0.0
def test_featured_has_lower_weight(self):
dist1 = string_dist("My Song blah Someone", "My Song")
dist2 = string_dist("My Song feat Someone", "My Song")
assert dist2 < dist1
def test_postfix_the(self):
dist = string_dist("The Song Title", "Song Title, The")
assert dist == 0.0
def test_postfix_a(self):
dist = string_dist("A Song Title", "Song Title, A")
assert dist == 0.0
def test_postfix_an(self):
dist = string_dist("An Album Title", "Album Title, An")
assert dist == 0.0
def test_empty_strings(self):
dist = string_dist("", "")
assert dist == 0.0
def test_solo_pattern(self):
# Just make sure these don't crash.
string_dist("The ", "")
string_dist("(EP)", "(EP)")
string_dist(", An", "")
def test_heuristic_does_not_harm_distance(self):
dist = string_dist("Untitled", "[Untitled]")
assert dist == 0.0
def test_ampersand_expansion(self):
dist = string_dist("And", "&")
assert dist == 0.0
def test_accented_characters(self):
dist = string_dist("\xe9\xe1\xf1", "ean")
assert dist == 0.0
@pytest.mark.parametrize(
"single_field,list_field",
[

View file

@ -15,7 +15,6 @@
import itertools
import os
import unittest
from unittest.mock import ANY, Mock, patch
import pytest
@ -215,15 +214,6 @@ class EventsTest(PluginImportTestCase):
]
class HelpersTest(unittest.TestCase):
def test_sanitize_choices(self):
assert plugins.sanitize_choices(["A", "Z"], ("A", "B")) == ["A"]
assert plugins.sanitize_choices(["A", "A"], ("A")) == ["A"]
assert plugins.sanitize_choices(
["D", "*", "A"], ("A", "B", "C", "D")
) == ["D", "B", "C", "A"]
class ListenersTest(PluginLoaderTestCase):
def test_register(self):
class DummyPlugin(plugins.BeetsPlugin):

View file

@ -24,6 +24,7 @@ from unittest.mock import Mock, patch
import pytest
from beets import util
from beets.library import Item
from beets.test import _common
@ -217,3 +218,41 @@ class TestPathLegalization:
expected_path,
expected_truncated,
)
class TestPlurality:
@pytest.mark.parametrize(
"objs, expected_obj, expected_freq",
[
pytest.param([1, 1, 1, 1], 1, 4, id="consensus"),
pytest.param([1, 1, 2, 1], 1, 3, id="near consensus"),
pytest.param([1, 1, 2, 2, 3], 1, 2, id="conflict-first-wins"),
],
)
def test_plurality(self, objs, expected_obj, expected_freq):
assert (expected_obj, expected_freq) == util.plurality(objs)
def test_empty_sequence_raises_error(self):
with pytest.raises(ValueError, match="must be non-empty"):
util.plurality([])
def test_get_most_common_tags(self):
items = [
Item(albumartist="aartist", label="label 1", album="album"),
Item(albumartist="aartist", label="label 2", album="album"),
Item(albumartist="aartist", label="label 3", album="another album"),
]
likelies, consensus = util.get_most_common_tags(items)
assert likelies["albumartist"] == "aartist"
assert likelies["album"] == "album"
# albumartist consensus overrides artist
assert likelies["artist"] == "aartist"
assert likelies["label"] == "label 1"
assert likelies["year"] == 0
assert consensus["year"]
assert consensus["albumartist"]
assert not consensus["album"]
assert not consensus["label"]

38
test/util/test_config.py Normal file
View file

@ -0,0 +1,38 @@
import pytest
from beets.util.config import sanitize_choices, sanitize_pairs
@pytest.mark.parametrize(
"input_choices, valid_choices, expected",
[
(["A", "Z"], ("A", "B"), ["A"]),
(["A", "A"], ("A"), ["A"]),
(["D", "*", "A"], ("A", "B", "C", "D"), ["D", "B", "C", "A"]),
],
)
def test_sanitize_choices(input_choices, valid_choices, expected):
assert sanitize_choices(input_choices, valid_choices) == expected
def test_sanitize_pairs():
assert sanitize_pairs(
[
("foo", "baz bar"),
("foo", "baz bar"),
("key", "*"),
("*", "*"),
("discard", "bye"),
],
[
("foo", "bar"),
("foo", "baz"),
("foo", "foobar"),
("key", "value"),
],
) == [
("foo", "baz"),
("foo", "bar"),
("key", "value"),
("foo", "foobar"),
]