mirror of
https://github.com/beetbox/beets.git
synced 2025-12-06 08:39:17 +01:00
Move distance to a separate module
This commit is contained in:
parent
01b6ea7898
commit
adbd50b237
10 changed files with 1028 additions and 1024 deletions
|
|
@ -14,22 +14,26 @@
|
|||
|
||||
"""Facilities for automatically determining files' correct metadata."""
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from beets import config, logging
|
||||
from beets.library import Album, Item, LibModel
|
||||
|
||||
# Parts of external interface.
|
||||
from beets.util import unique_list
|
||||
|
||||
from .hooks import AlbumInfo, AlbumMatch, Distance, TrackInfo, TrackMatch
|
||||
from .hooks import AlbumInfo, AlbumMatch, TrackInfo, TrackMatch
|
||||
from .match import Proposal, Recommendation, tag_album, tag_item
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from beets.library import Album, Item, LibModel
|
||||
|
||||
__all__ = [
|
||||
"AlbumInfo",
|
||||
"AlbumMatch",
|
||||
"Distance",
|
||||
"TrackInfo",
|
||||
"TrackMatch",
|
||||
"Proposal",
|
||||
|
|
|
|||
531
beets/autotag/distance.py
Normal file
531
beets/autotag/distance.py
Normal file
|
|
@ -0,0 +1,531 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import re
|
||||
from functools import cache, total_ordering
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from jellyfish import levenshtein_distance
|
||||
from unidecode import unidecode
|
||||
|
||||
from beets import config, plugins
|
||||
from beets.util import as_string, cached_classproperty, get_most_common_tags
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
|
||||
from beets.library import Item
|
||||
|
||||
from .hooks import AlbumInfo, TrackInfo
|
||||
|
||||
# Candidate distance scoring.
|
||||
|
||||
# Artist signals that indicate "various artists". These are used at the
|
||||
# album level to determine whether a given release is likely a VA
|
||||
# release and also on the track level to to remove the penalty for
|
||||
# differing artists.
|
||||
VA_ARTISTS = ("", "various artists", "various", "va", "unknown")
|
||||
|
||||
# Parameters for string distance function.
|
||||
# Words that can be moved to the end of a string using a comma.
|
||||
SD_END_WORDS = ["the", "a", "an"]
|
||||
# Reduced weights for certain portions of the string.
|
||||
SD_PATTERNS = [
|
||||
(r"^the ", 0.1),
|
||||
(r"[\[\(]?(ep|single)[\]\)]?", 0.0),
|
||||
(r"[\[\(]?(featuring|feat|ft)[\. :].+", 0.1),
|
||||
(r"\(.*?\)", 0.3),
|
||||
(r"\[.*?\]", 0.3),
|
||||
(r"(, )?(pt\.|part) .+", 0.2),
|
||||
]
|
||||
# Replacements to use before testing distance.
|
||||
SD_REPLACE = [
|
||||
(r"&", "and"),
|
||||
]
|
||||
|
||||
|
||||
def _string_dist_basic(str1: str, str2: str) -> float:
|
||||
"""Basic edit distance between two strings, ignoring
|
||||
non-alphanumeric characters and case. Comparisons are based on a
|
||||
transliteration/lowering to ASCII characters. Normalized by string
|
||||
length.
|
||||
"""
|
||||
assert isinstance(str1, str)
|
||||
assert isinstance(str2, str)
|
||||
str1 = as_string(unidecode(str1))
|
||||
str2 = as_string(unidecode(str2))
|
||||
str1 = re.sub(r"[^a-z0-9]", "", str1.lower())
|
||||
str2 = re.sub(r"[^a-z0-9]", "", str2.lower())
|
||||
if not str1 and not str2:
|
||||
return 0.0
|
||||
return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2)))
|
||||
|
||||
|
||||
def string_dist(str1: str | None, str2: str | None) -> float:
|
||||
"""Gives an "intuitive" edit distance between two strings. This is
|
||||
an edit distance, normalized by the string length, with a number of
|
||||
tweaks that reflect intuition about text.
|
||||
"""
|
||||
if str1 is None and str2 is None:
|
||||
return 0.0
|
||||
if str1 is None or str2 is None:
|
||||
return 1.0
|
||||
|
||||
str1 = str1.lower()
|
||||
str2 = str2.lower()
|
||||
|
||||
# Don't penalize strings that move certain words to the end. For
|
||||
# example, "the something" should be considered equal to
|
||||
# "something, the".
|
||||
for word in SD_END_WORDS:
|
||||
if str1.endswith(", %s" % word):
|
||||
str1 = "{} {}".format(word, str1[: -len(word) - 2])
|
||||
if str2.endswith(", %s" % word):
|
||||
str2 = "{} {}".format(word, str2[: -len(word) - 2])
|
||||
|
||||
# Perform a couple of basic normalizing substitutions.
|
||||
for pat, repl in SD_REPLACE:
|
||||
str1 = re.sub(pat, repl, str1)
|
||||
str2 = re.sub(pat, repl, str2)
|
||||
|
||||
# Change the weight for certain string portions matched by a set
|
||||
# of regular expressions. We gradually change the strings and build
|
||||
# up penalties associated with parts of the string that were
|
||||
# deleted.
|
||||
base_dist = _string_dist_basic(str1, str2)
|
||||
penalty = 0.0
|
||||
for pat, weight in SD_PATTERNS:
|
||||
# Get strings that drop the pattern.
|
||||
case_str1 = re.sub(pat, "", str1)
|
||||
case_str2 = re.sub(pat, "", str2)
|
||||
|
||||
if case_str1 != str1 or case_str2 != str2:
|
||||
# If the pattern was present (i.e., it is deleted in the
|
||||
# the current case), recalculate the distances for the
|
||||
# modified strings.
|
||||
case_dist = _string_dist_basic(case_str1, case_str2)
|
||||
case_delta = max(0.0, base_dist - case_dist)
|
||||
if case_delta == 0.0:
|
||||
continue
|
||||
|
||||
# Shift our baseline strings down (to avoid rematching the
|
||||
# same part of the string) and add a scaled distance
|
||||
# amount to the penalties.
|
||||
str1 = case_str1
|
||||
str2 = case_str2
|
||||
base_dist = case_dist
|
||||
penalty += weight * case_delta
|
||||
|
||||
return base_dist + penalty
|
||||
|
||||
|
||||
@total_ordering
|
||||
class Distance:
|
||||
"""Keeps track of multiple distance penalties. Provides a single
|
||||
weighted distance for all penalties as well as a weighted distance
|
||||
for each individual penalty.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._penalties: dict[str, list[float]] = {}
|
||||
self.tracks: dict[TrackInfo, Distance] = {}
|
||||
|
||||
@cached_classproperty
|
||||
def _weights(cls) -> dict[str, float]:
|
||||
"""A dictionary from keys to floating-point weights."""
|
||||
weights_view = config["match"]["distance_weights"]
|
||||
weights = {}
|
||||
for key in weights_view.keys():
|
||||
weights[key] = weights_view[key].as_number()
|
||||
return weights
|
||||
|
||||
# Access the components and their aggregates.
|
||||
|
||||
@property
|
||||
def distance(self) -> float:
|
||||
"""Return a weighted and normalized distance across all
|
||||
penalties.
|
||||
"""
|
||||
dist_max = self.max_distance
|
||||
if dist_max:
|
||||
return self.raw_distance / self.max_distance
|
||||
return 0.0
|
||||
|
||||
@property
|
||||
def max_distance(self) -> float:
|
||||
"""Return the maximum distance penalty (normalization factor)."""
|
||||
dist_max = 0.0
|
||||
for key, penalty in self._penalties.items():
|
||||
dist_max += len(penalty) * self._weights[key]
|
||||
return dist_max
|
||||
|
||||
@property
|
||||
def raw_distance(self) -> float:
|
||||
"""Return the raw (denormalized) distance."""
|
||||
dist_raw = 0.0
|
||||
for key, penalty in self._penalties.items():
|
||||
dist_raw += sum(penalty) * self._weights[key]
|
||||
return dist_raw
|
||||
|
||||
def items(self) -> list[tuple[str, float]]:
|
||||
"""Return a list of (key, dist) pairs, with `dist` being the
|
||||
weighted distance, sorted from highest to lowest. Does not
|
||||
include penalties with a zero value.
|
||||
"""
|
||||
list_ = []
|
||||
for key in self._penalties:
|
||||
dist = self[key]
|
||||
if dist:
|
||||
list_.append((key, dist))
|
||||
# Convert distance into a negative float we can sort items in
|
||||
# ascending order (for keys, when the penalty is equal) and
|
||||
# still get the items with the biggest distance first.
|
||||
return sorted(
|
||||
list_, key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0])
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return self.distance == other
|
||||
|
||||
# Behave like a float.
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
return self.distance < other
|
||||
|
||||
def __float__(self) -> float:
|
||||
return self.distance
|
||||
|
||||
def __sub__(self, other) -> float:
|
||||
return self.distance - other
|
||||
|
||||
def __rsub__(self, other) -> float:
|
||||
return other - self.distance
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.distance:.2f}"
|
||||
|
||||
# Behave like a dict.
|
||||
|
||||
def __getitem__(self, key) -> float:
|
||||
"""Returns the weighted distance for a named penalty."""
|
||||
dist = sum(self._penalties[key]) * self._weights[key]
|
||||
dist_max = self.max_distance
|
||||
if dist_max:
|
||||
return dist / dist_max
|
||||
return 0.0
|
||||
|
||||
def __iter__(self) -> Iterator[tuple[str, float]]:
|
||||
return iter(self.items())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.items())
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
return [key for key, _ in self.items()]
|
||||
|
||||
def update(self, dist: Distance):
|
||||
"""Adds all the distance penalties from `dist`."""
|
||||
if not isinstance(dist, Distance):
|
||||
raise ValueError(
|
||||
"`dist` must be a Distance object, not {}".format(type(dist))
|
||||
)
|
||||
for key, penalties in dist._penalties.items():
|
||||
self._penalties.setdefault(key, []).extend(penalties)
|
||||
|
||||
# Adding components.
|
||||
|
||||
def _eq(self, value1: re.Pattern[str] | Any, value2: Any) -> bool:
|
||||
"""Returns True if `value1` is equal to `value2`. `value1` may
|
||||
be a compiled regular expression, in which case it will be
|
||||
matched against `value2`.
|
||||
"""
|
||||
if isinstance(value1, re.Pattern):
|
||||
return bool(value1.match(value2))
|
||||
return value1 == value2
|
||||
|
||||
def add(self, key: str, dist: float):
|
||||
"""Adds a distance penalty. `key` must correspond with a
|
||||
configured weight setting. `dist` must be a float between 0.0
|
||||
and 1.0, and will be added to any existing distance penalties
|
||||
for the same key.
|
||||
"""
|
||||
if not 0.0 <= dist <= 1.0:
|
||||
raise ValueError(f"`dist` must be between 0.0 and 1.0, not {dist}")
|
||||
self._penalties.setdefault(key, []).append(dist)
|
||||
|
||||
def add_equality(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
options: list[Any] | tuple[Any, ...] | Any,
|
||||
):
|
||||
"""Adds a distance penalty of 1.0 if `value` doesn't match any
|
||||
of the values in `options`. If an option is a compiled regular
|
||||
expression, it will be considered equal if it matches against
|
||||
`value`.
|
||||
"""
|
||||
if not isinstance(options, (list, tuple)):
|
||||
options = [options]
|
||||
for opt in options:
|
||||
if self._eq(opt, value):
|
||||
dist = 0.0
|
||||
break
|
||||
else:
|
||||
dist = 1.0
|
||||
self.add(key, dist)
|
||||
|
||||
def add_expr(self, key: str, expr: bool):
|
||||
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
|
||||
or 0.0.
|
||||
"""
|
||||
if expr:
|
||||
self.add(key, 1.0)
|
||||
else:
|
||||
self.add(key, 0.0)
|
||||
|
||||
def add_number(self, key: str, number1: int, number2: int):
|
||||
"""Adds a distance penalty of 1.0 for each number of difference
|
||||
between `number1` and `number2`, or 0.0 when there is no
|
||||
difference. Use this when there is no upper limit on the
|
||||
difference between the two numbers.
|
||||
"""
|
||||
diff = abs(number1 - number2)
|
||||
if diff:
|
||||
for i in range(diff):
|
||||
self.add(key, 1.0)
|
||||
else:
|
||||
self.add(key, 0.0)
|
||||
|
||||
def add_priority(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
options: list[Any] | tuple[Any, ...] | Any,
|
||||
):
|
||||
"""Adds a distance penalty that corresponds to the position at
|
||||
which `value` appears in `options`. A distance penalty of 0.0
|
||||
for the first option, or 1.0 if there is no matching option. If
|
||||
an option is a compiled regular expression, it will be
|
||||
considered equal if it matches against `value`.
|
||||
"""
|
||||
if not isinstance(options, (list, tuple)):
|
||||
options = [options]
|
||||
unit = 1.0 / (len(options) or 1)
|
||||
for i, opt in enumerate(options):
|
||||
if self._eq(opt, value):
|
||||
dist = i * unit
|
||||
break
|
||||
else:
|
||||
dist = 1.0
|
||||
self.add(key, dist)
|
||||
|
||||
def add_ratio(
|
||||
self,
|
||||
key: str,
|
||||
number1: int | float,
|
||||
number2: int | float,
|
||||
):
|
||||
"""Adds a distance penalty for `number1` as a ratio of `number2`.
|
||||
`number1` is bound at 0 and `number2`.
|
||||
"""
|
||||
number = float(max(min(number1, number2), 0))
|
||||
if number2:
|
||||
dist = number / number2
|
||||
else:
|
||||
dist = 0.0
|
||||
self.add(key, dist)
|
||||
|
||||
def add_string(self, key: str, str1: str | None, str2: str | None):
|
||||
"""Adds a distance penalty based on the edit distance between
|
||||
`str1` and `str2`.
|
||||
"""
|
||||
dist = string_dist(str1, str2)
|
||||
self.add(key, dist)
|
||||
|
||||
|
||||
@cache
|
||||
def get_track_length_grace() -> float:
|
||||
"""Get cached grace period for track length matching."""
|
||||
return config["match"]["track_length_grace"].as_number()
|
||||
|
||||
|
||||
@cache
|
||||
def get_track_length_max() -> float:
|
||||
"""Get cached maximum track length for track length matching."""
|
||||
return config["match"]["track_length_max"].as_number()
|
||||
|
||||
|
||||
def track_index_changed(item: Item, track_info: TrackInfo) -> bool:
|
||||
"""Returns True if the item and track info index is different. Tolerates
|
||||
per disc and per release numbering.
|
||||
"""
|
||||
return item.track not in (track_info.medium_index, track_info.index)
|
||||
|
||||
|
||||
def track_distance(
|
||||
item: Item,
|
||||
track_info: TrackInfo,
|
||||
incl_artist: bool = False,
|
||||
) -> Distance:
|
||||
"""Determines the significance of a track metadata change. Returns a
|
||||
Distance object. `incl_artist` indicates that a distance component should
|
||||
be included for the track artist (i.e., for various-artist releases).
|
||||
|
||||
``track_length_grace`` and ``track_length_max`` configuration options are
|
||||
cached because this function is called many times during the matching
|
||||
process and their access comes with a performance overhead.
|
||||
"""
|
||||
dist = Distance()
|
||||
|
||||
# Length.
|
||||
if info_length := track_info.length:
|
||||
diff = abs(item.length - info_length) - get_track_length_grace()
|
||||
dist.add_ratio("track_length", diff, get_track_length_max())
|
||||
|
||||
# Title.
|
||||
dist.add_string("track_title", item.title, track_info.title)
|
||||
|
||||
# Artist. Only check if there is actually an artist in the track data.
|
||||
if (
|
||||
incl_artist
|
||||
and track_info.artist
|
||||
and item.artist.lower() not in VA_ARTISTS
|
||||
):
|
||||
dist.add_string("track_artist", item.artist, track_info.artist)
|
||||
|
||||
# Track index.
|
||||
if track_info.index and item.track:
|
||||
dist.add_expr("track_index", track_index_changed(item, track_info))
|
||||
|
||||
# Track ID.
|
||||
if item.mb_trackid:
|
||||
dist.add_expr("track_id", item.mb_trackid != track_info.track_id)
|
||||
|
||||
# Penalize mismatching disc numbers.
|
||||
if track_info.medium and item.disc:
|
||||
dist.add_expr("medium", item.disc != track_info.medium)
|
||||
|
||||
# Plugins.
|
||||
dist.update(plugins.track_distance(item, track_info))
|
||||
|
||||
return dist
|
||||
|
||||
|
||||
def distance(
|
||||
items: Sequence[Item],
|
||||
album_info: AlbumInfo,
|
||||
mapping: dict[Item, TrackInfo],
|
||||
) -> Distance:
|
||||
"""Determines how "significant" an album metadata change would be.
|
||||
Returns a Distance object. `album_info` is an AlbumInfo object
|
||||
reflecting the album to be compared. `items` is a sequence of all
|
||||
Item objects that will be matched (order is not important).
|
||||
`mapping` is a dictionary mapping Items to TrackInfo objects; the
|
||||
keys are a subset of `items` and the values are a subset of
|
||||
`album_info.tracks`.
|
||||
"""
|
||||
likelies, _ = get_most_common_tags(items)
|
||||
|
||||
dist = Distance()
|
||||
|
||||
# Artist, if not various.
|
||||
if not album_info.va:
|
||||
dist.add_string("artist", likelies["artist"], album_info.artist)
|
||||
|
||||
# Album.
|
||||
dist.add_string("album", likelies["album"], album_info.album)
|
||||
|
||||
preferred_config = config["match"]["preferred"]
|
||||
# Current or preferred media.
|
||||
if album_info.media:
|
||||
# Preferred media options.
|
||||
media_patterns: Sequence[str] = preferred_config["media"].as_str_seq()
|
||||
options = [
|
||||
re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in media_patterns
|
||||
]
|
||||
if options:
|
||||
dist.add_priority("media", album_info.media, options)
|
||||
# Current media.
|
||||
elif likelies["media"]:
|
||||
dist.add_equality("media", album_info.media, likelies["media"])
|
||||
|
||||
# Mediums.
|
||||
if likelies["disctotal"] and album_info.mediums:
|
||||
dist.add_number("mediums", likelies["disctotal"], album_info.mediums)
|
||||
|
||||
# Prefer earliest release.
|
||||
if album_info.year and preferred_config["original_year"]:
|
||||
# Assume 1889 (earliest first gramophone discs) if we don't know the
|
||||
# original year.
|
||||
original = album_info.original_year or 1889
|
||||
diff = abs(album_info.year - original)
|
||||
diff_max = abs(datetime.date.today().year - original)
|
||||
dist.add_ratio("year", diff, diff_max)
|
||||
# Year.
|
||||
elif likelies["year"] and album_info.year:
|
||||
if likelies["year"] in (album_info.year, album_info.original_year):
|
||||
# No penalty for matching release or original year.
|
||||
dist.add("year", 0.0)
|
||||
elif album_info.original_year:
|
||||
# Prefer matchest closest to the release year.
|
||||
diff = abs(likelies["year"] - album_info.year)
|
||||
diff_max = abs(
|
||||
datetime.date.today().year - album_info.original_year
|
||||
)
|
||||
dist.add_ratio("year", diff, diff_max)
|
||||
else:
|
||||
# Full penalty when there is no original year.
|
||||
dist.add("year", 1.0)
|
||||
|
||||
# Preferred countries.
|
||||
country_patterns: Sequence[str] = preferred_config["countries"].as_str_seq()
|
||||
options = [re.compile(pat, re.I) for pat in country_patterns]
|
||||
if album_info.country and options:
|
||||
dist.add_priority("country", album_info.country, options)
|
||||
# Country.
|
||||
elif likelies["country"] and album_info.country:
|
||||
dist.add_string("country", likelies["country"], album_info.country)
|
||||
|
||||
# Label.
|
||||
if likelies["label"] and album_info.label:
|
||||
dist.add_string("label", likelies["label"], album_info.label)
|
||||
|
||||
# Catalog number.
|
||||
if likelies["catalognum"] and album_info.catalognum:
|
||||
dist.add_string(
|
||||
"catalognum", likelies["catalognum"], album_info.catalognum
|
||||
)
|
||||
|
||||
# Disambiguation.
|
||||
if likelies["albumdisambig"] and album_info.albumdisambig:
|
||||
dist.add_string(
|
||||
"albumdisambig", likelies["albumdisambig"], album_info.albumdisambig
|
||||
)
|
||||
|
||||
# Album ID.
|
||||
if likelies["mb_albumid"]:
|
||||
dist.add_equality(
|
||||
"album_id", likelies["mb_albumid"], album_info.album_id
|
||||
)
|
||||
|
||||
# Tracks.
|
||||
dist.tracks = {}
|
||||
for item, track in mapping.items():
|
||||
dist.tracks[track] = track_distance(item, track, album_info.va)
|
||||
dist.add("tracks", dist.tracks[track].distance)
|
||||
|
||||
# Missing tracks.
|
||||
for _ in range(len(album_info.tracks) - len(mapping)):
|
||||
dist.add("missing_tracks", 1.0)
|
||||
|
||||
# Unmatched tracks.
|
||||
for _ in range(len(items) - len(mapping)):
|
||||
dist.add("unmatched_tracks", 1.0)
|
||||
|
||||
# Plugins.
|
||||
dist.update(plugins.album_distance(items, album_info, mapping))
|
||||
|
||||
return dist
|
||||
|
|
@ -16,21 +16,15 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from functools import total_ordering
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
|
||||
|
||||
from jellyfish import levenshtein_distance
|
||||
from unidecode import unidecode
|
||||
|
||||
from beets import config, logging
|
||||
from beets.util import as_string, cached_classproperty
|
||||
from beets import logging
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
from beets.library import Item
|
||||
|
||||
from .distance import Distance
|
||||
|
||||
log = logging.getLogger("beets")
|
||||
|
||||
V = TypeVar("V")
|
||||
|
|
@ -254,328 +248,6 @@ class TrackInfo(AttrDict[Any]):
|
|||
return dupe
|
||||
|
||||
|
||||
# Candidate distance scoring.
|
||||
|
||||
# Parameters for string distance function.
|
||||
# Words that can be moved to the end of a string using a comma.
|
||||
SD_END_WORDS = ["the", "a", "an"]
|
||||
# Reduced weights for certain portions of the string.
|
||||
SD_PATTERNS = [
|
||||
(r"^the ", 0.1),
|
||||
(r"[\[\(]?(ep|single)[\]\)]?", 0.0),
|
||||
(r"[\[\(]?(featuring|feat|ft)[\. :].+", 0.1),
|
||||
(r"\(.*?\)", 0.3),
|
||||
(r"\[.*?\]", 0.3),
|
||||
(r"(, )?(pt\.|part) .+", 0.2),
|
||||
]
|
||||
# Replacements to use before testing distance.
|
||||
SD_REPLACE = [
|
||||
(r"&", "and"),
|
||||
]
|
||||
|
||||
|
||||
def _string_dist_basic(str1: str, str2: str) -> float:
|
||||
"""Basic edit distance between two strings, ignoring
|
||||
non-alphanumeric characters and case. Comparisons are based on a
|
||||
transliteration/lowering to ASCII characters. Normalized by string
|
||||
length.
|
||||
"""
|
||||
assert isinstance(str1, str)
|
||||
assert isinstance(str2, str)
|
||||
str1 = as_string(unidecode(str1))
|
||||
str2 = as_string(unidecode(str2))
|
||||
str1 = re.sub(r"[^a-z0-9]", "", str1.lower())
|
||||
str2 = re.sub(r"[^a-z0-9]", "", str2.lower())
|
||||
if not str1 and not str2:
|
||||
return 0.0
|
||||
return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2)))
|
||||
|
||||
|
||||
def string_dist(str1: str | None, str2: str | None) -> float:
|
||||
"""Gives an "intuitive" edit distance between two strings. This is
|
||||
an edit distance, normalized by the string length, with a number of
|
||||
tweaks that reflect intuition about text.
|
||||
"""
|
||||
if str1 is None and str2 is None:
|
||||
return 0.0
|
||||
if str1 is None or str2 is None:
|
||||
return 1.0
|
||||
|
||||
str1 = str1.lower()
|
||||
str2 = str2.lower()
|
||||
|
||||
# Don't penalize strings that move certain words to the end. For
|
||||
# example, "the something" should be considered equal to
|
||||
# "something, the".
|
||||
for word in SD_END_WORDS:
|
||||
if str1.endswith(", %s" % word):
|
||||
str1 = "{} {}".format(word, str1[: -len(word) - 2])
|
||||
if str2.endswith(", %s" % word):
|
||||
str2 = "{} {}".format(word, str2[: -len(word) - 2])
|
||||
|
||||
# Perform a couple of basic normalizing substitutions.
|
||||
for pat, repl in SD_REPLACE:
|
||||
str1 = re.sub(pat, repl, str1)
|
||||
str2 = re.sub(pat, repl, str2)
|
||||
|
||||
# Change the weight for certain string portions matched by a set
|
||||
# of regular expressions. We gradually change the strings and build
|
||||
# up penalties associated with parts of the string that were
|
||||
# deleted.
|
||||
base_dist = _string_dist_basic(str1, str2)
|
||||
penalty = 0.0
|
||||
for pat, weight in SD_PATTERNS:
|
||||
# Get strings that drop the pattern.
|
||||
case_str1 = re.sub(pat, "", str1)
|
||||
case_str2 = re.sub(pat, "", str2)
|
||||
|
||||
if case_str1 != str1 or case_str2 != str2:
|
||||
# If the pattern was present (i.e., it is deleted in the
|
||||
# the current case), recalculate the distances for the
|
||||
# modified strings.
|
||||
case_dist = _string_dist_basic(case_str1, case_str2)
|
||||
case_delta = max(0.0, base_dist - case_dist)
|
||||
if case_delta == 0.0:
|
||||
continue
|
||||
|
||||
# Shift our baseline strings down (to avoid rematching the
|
||||
# same part of the string) and add a scaled distance
|
||||
# amount to the penalties.
|
||||
str1 = case_str1
|
||||
str2 = case_str2
|
||||
base_dist = case_dist
|
||||
penalty += weight * case_delta
|
||||
|
||||
return base_dist + penalty
|
||||
|
||||
|
||||
@total_ordering
|
||||
class Distance:
|
||||
"""Keeps track of multiple distance penalties. Provides a single
|
||||
weighted distance for all penalties as well as a weighted distance
|
||||
for each individual penalty.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._penalties: dict[str, list[float]] = {}
|
||||
self.tracks: dict[TrackInfo, Distance] = {}
|
||||
|
||||
@cached_classproperty
|
||||
def _weights(cls) -> dict[str, float]:
|
||||
"""A dictionary from keys to floating-point weights."""
|
||||
weights_view = config["match"]["distance_weights"]
|
||||
weights = {}
|
||||
for key in weights_view.keys():
|
||||
weights[key] = weights_view[key].as_number()
|
||||
return weights
|
||||
|
||||
# Access the components and their aggregates.
|
||||
|
||||
@property
|
||||
def distance(self) -> float:
|
||||
"""Return a weighted and normalized distance across all
|
||||
penalties.
|
||||
"""
|
||||
dist_max = self.max_distance
|
||||
if dist_max:
|
||||
return self.raw_distance / self.max_distance
|
||||
return 0.0
|
||||
|
||||
@property
|
||||
def max_distance(self) -> float:
|
||||
"""Return the maximum distance penalty (normalization factor)."""
|
||||
dist_max = 0.0
|
||||
for key, penalty in self._penalties.items():
|
||||
dist_max += len(penalty) * self._weights[key]
|
||||
return dist_max
|
||||
|
||||
@property
|
||||
def raw_distance(self) -> float:
|
||||
"""Return the raw (denormalized) distance."""
|
||||
dist_raw = 0.0
|
||||
for key, penalty in self._penalties.items():
|
||||
dist_raw += sum(penalty) * self._weights[key]
|
||||
return dist_raw
|
||||
|
||||
def items(self) -> list[tuple[str, float]]:
|
||||
"""Return a list of (key, dist) pairs, with `dist` being the
|
||||
weighted distance, sorted from highest to lowest. Does not
|
||||
include penalties with a zero value.
|
||||
"""
|
||||
list_ = []
|
||||
for key in self._penalties:
|
||||
dist = self[key]
|
||||
if dist:
|
||||
list_.append((key, dist))
|
||||
# Convert distance into a negative float we can sort items in
|
||||
# ascending order (for keys, when the penalty is equal) and
|
||||
# still get the items with the biggest distance first.
|
||||
return sorted(
|
||||
list_, key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0])
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return self.distance == other
|
||||
|
||||
# Behave like a float.
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
return self.distance < other
|
||||
|
||||
def __float__(self) -> float:
|
||||
return self.distance
|
||||
|
||||
def __sub__(self, other) -> float:
|
||||
return self.distance - other
|
||||
|
||||
def __rsub__(self, other) -> float:
|
||||
return other - self.distance
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.distance:.2f}"
|
||||
|
||||
# Behave like a dict.
|
||||
|
||||
def __getitem__(self, key) -> float:
|
||||
"""Returns the weighted distance for a named penalty."""
|
||||
dist = sum(self._penalties[key]) * self._weights[key]
|
||||
dist_max = self.max_distance
|
||||
if dist_max:
|
||||
return dist / dist_max
|
||||
return 0.0
|
||||
|
||||
def __iter__(self) -> Iterator[tuple[str, float]]:
|
||||
return iter(self.items())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.items())
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
return [key for key, _ in self.items()]
|
||||
|
||||
def update(self, dist: Distance):
|
||||
"""Adds all the distance penalties from `dist`."""
|
||||
if not isinstance(dist, Distance):
|
||||
raise ValueError(
|
||||
"`dist` must be a Distance object, not {}".format(type(dist))
|
||||
)
|
||||
for key, penalties in dist._penalties.items():
|
||||
self._penalties.setdefault(key, []).extend(penalties)
|
||||
|
||||
# Adding components.
|
||||
|
||||
def _eq(self, value1: re.Pattern[str] | Any, value2: Any) -> bool:
|
||||
"""Returns True if `value1` is equal to `value2`. `value1` may
|
||||
be a compiled regular expression, in which case it will be
|
||||
matched against `value2`.
|
||||
"""
|
||||
if isinstance(value1, re.Pattern):
|
||||
return bool(value1.match(value2))
|
||||
return value1 == value2
|
||||
|
||||
def add(self, key: str, dist: float):
|
||||
"""Adds a distance penalty. `key` must correspond with a
|
||||
configured weight setting. `dist` must be a float between 0.0
|
||||
and 1.0, and will be added to any existing distance penalties
|
||||
for the same key.
|
||||
"""
|
||||
if not 0.0 <= dist <= 1.0:
|
||||
raise ValueError(f"`dist` must be between 0.0 and 1.0, not {dist}")
|
||||
self._penalties.setdefault(key, []).append(dist)
|
||||
|
||||
def add_equality(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
options: list[Any] | tuple[Any, ...] | Any,
|
||||
):
|
||||
"""Adds a distance penalty of 1.0 if `value` doesn't match any
|
||||
of the values in `options`. If an option is a compiled regular
|
||||
expression, it will be considered equal if it matches against
|
||||
`value`.
|
||||
"""
|
||||
if not isinstance(options, (list, tuple)):
|
||||
options = [options]
|
||||
for opt in options:
|
||||
if self._eq(opt, value):
|
||||
dist = 0.0
|
||||
break
|
||||
else:
|
||||
dist = 1.0
|
||||
self.add(key, dist)
|
||||
|
||||
def add_expr(self, key: str, expr: bool):
|
||||
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
|
||||
or 0.0.
|
||||
"""
|
||||
if expr:
|
||||
self.add(key, 1.0)
|
||||
else:
|
||||
self.add(key, 0.0)
|
||||
|
||||
def add_number(self, key: str, number1: int, number2: int):
|
||||
"""Adds a distance penalty of 1.0 for each number of difference
|
||||
between `number1` and `number2`, or 0.0 when there is no
|
||||
difference. Use this when there is no upper limit on the
|
||||
difference between the two numbers.
|
||||
"""
|
||||
diff = abs(number1 - number2)
|
||||
if diff:
|
||||
for i in range(diff):
|
||||
self.add(key, 1.0)
|
||||
else:
|
||||
self.add(key, 0.0)
|
||||
|
||||
def add_priority(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
options: list[Any] | tuple[Any, ...] | Any,
|
||||
):
|
||||
"""Adds a distance penalty that corresponds to the position at
|
||||
which `value` appears in `options`. A distance penalty of 0.0
|
||||
for the first option, or 1.0 if there is no matching option. If
|
||||
an option is a compiled regular expression, it will be
|
||||
considered equal if it matches against `value`.
|
||||
"""
|
||||
if not isinstance(options, (list, tuple)):
|
||||
options = [options]
|
||||
unit = 1.0 / (len(options) or 1)
|
||||
for i, opt in enumerate(options):
|
||||
if self._eq(opt, value):
|
||||
dist = i * unit
|
||||
break
|
||||
else:
|
||||
dist = 1.0
|
||||
self.add(key, dist)
|
||||
|
||||
def add_ratio(
|
||||
self,
|
||||
key: str,
|
||||
number1: int | float,
|
||||
number2: int | float,
|
||||
):
|
||||
"""Adds a distance penalty for `number1` as a ratio of `number2`.
|
||||
`number1` is bound at 0 and `number2`.
|
||||
"""
|
||||
number = float(max(min(number1, number2), 0))
|
||||
if number2:
|
||||
dist = number / number2
|
||||
else:
|
||||
dist = 0.0
|
||||
self.add(key, dist)
|
||||
|
||||
def add_string(self, key: str, str1: str | None, str2: str | None):
|
||||
"""Adds a distance penalty based on the edit distance between
|
||||
`str1` and `str2`.
|
||||
"""
|
||||
dist = string_dist(str1, str2)
|
||||
self.add(key, dist)
|
||||
|
||||
|
||||
# Structures that compose all the information for a candidate match.
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,37 +18,23 @@ releases and tracks.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import re
|
||||
from enum import IntEnum
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
|
||||
|
||||
import lap
|
||||
import numpy as np
|
||||
|
||||
from beets import config, logging, plugins
|
||||
from beets.autotag import (
|
||||
AlbumInfo,
|
||||
AlbumMatch,
|
||||
Distance,
|
||||
TrackInfo,
|
||||
TrackMatch,
|
||||
hooks,
|
||||
)
|
||||
from beets.autotag import AlbumInfo, AlbumMatch, TrackInfo, TrackMatch, hooks
|
||||
from beets.util import get_most_common_tags
|
||||
|
||||
from .distance import VA_ARTISTS, distance, track_distance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
from beets.library import Item
|
||||
|
||||
# Artist signals that indicate "various artists". These are used at the
|
||||
# album level to determine whether a given release is likely a VA
|
||||
# release and also on the track level to to remove the penalty for
|
||||
# differing artists.
|
||||
VA_ARTISTS = ("", "various artists", "various", "va", "unknown")
|
||||
|
||||
# Global logger.
|
||||
log = logging.getLogger("beets")
|
||||
|
||||
|
|
@ -112,191 +98,6 @@ def assign_items(
|
|||
return mapping, extra_items, extra_tracks
|
||||
|
||||
|
||||
def track_index_changed(item: Item, track_info: TrackInfo) -> bool:
|
||||
"""Returns True if the item and track info index is different. Tolerates
|
||||
per disc and per release numbering.
|
||||
"""
|
||||
return item.track not in (track_info.medium_index, track_info.index)
|
||||
|
||||
|
||||
@cache
|
||||
def get_track_length_grace() -> float:
|
||||
"""Get cached grace period for track length matching."""
|
||||
return config["match"]["track_length_grace"].as_number()
|
||||
|
||||
|
||||
@cache
|
||||
def get_track_length_max() -> float:
|
||||
"""Get cached maximum track length for track length matching."""
|
||||
return config["match"]["track_length_max"].as_number()
|
||||
|
||||
|
||||
def track_distance(
|
||||
item: Item,
|
||||
track_info: TrackInfo,
|
||||
incl_artist: bool = False,
|
||||
) -> Distance:
|
||||
"""Determines the significance of a track metadata change. Returns a
|
||||
Distance object. `incl_artist` indicates that a distance component should
|
||||
be included for the track artist (i.e., for various-artist releases).
|
||||
|
||||
``track_length_grace`` and ``track_length_max`` configuration options are
|
||||
cached because this function is called many times during the matching
|
||||
process and their access comes with a performance overhead.
|
||||
"""
|
||||
dist = hooks.Distance()
|
||||
|
||||
# Length.
|
||||
if info_length := track_info.length:
|
||||
diff = abs(item.length - info_length) - get_track_length_grace()
|
||||
dist.add_ratio("track_length", diff, get_track_length_max())
|
||||
|
||||
# Title.
|
||||
dist.add_string("track_title", item.title, track_info.title)
|
||||
|
||||
# Artist. Only check if there is actually an artist in the track data.
|
||||
if (
|
||||
incl_artist
|
||||
and track_info.artist
|
||||
and item.artist.lower() not in VA_ARTISTS
|
||||
):
|
||||
dist.add_string("track_artist", item.artist, track_info.artist)
|
||||
|
||||
# Track index.
|
||||
if track_info.index and item.track:
|
||||
dist.add_expr("track_index", track_index_changed(item, track_info))
|
||||
|
||||
# Track ID.
|
||||
if item.mb_trackid:
|
||||
dist.add_expr("track_id", item.mb_trackid != track_info.track_id)
|
||||
|
||||
# Penalize mismatching disc numbers.
|
||||
if track_info.medium and item.disc:
|
||||
dist.add_expr("medium", item.disc != track_info.medium)
|
||||
|
||||
# Plugins.
|
||||
dist.update(plugins.track_distance(item, track_info))
|
||||
|
||||
return dist
|
||||
|
||||
|
||||
def distance(
|
||||
items: Sequence[Item],
|
||||
album_info: AlbumInfo,
|
||||
mapping: dict[Item, TrackInfo],
|
||||
) -> Distance:
|
||||
"""Determines how "significant" an album metadata change would be.
|
||||
Returns a Distance object. `album_info` is an AlbumInfo object
|
||||
reflecting the album to be compared. `items` is a sequence of all
|
||||
Item objects that will be matched (order is not important).
|
||||
`mapping` is a dictionary mapping Items to TrackInfo objects; the
|
||||
keys are a subset of `items` and the values are a subset of
|
||||
`album_info.tracks`.
|
||||
"""
|
||||
likelies, _ = get_most_common_tags(items)
|
||||
|
||||
dist = hooks.Distance()
|
||||
|
||||
# Artist, if not various.
|
||||
if not album_info.va:
|
||||
dist.add_string("artist", likelies["artist"], album_info.artist)
|
||||
|
||||
# Album.
|
||||
dist.add_string("album", likelies["album"], album_info.album)
|
||||
|
||||
preferred_config = config["match"]["preferred"]
|
||||
# Current or preferred media.
|
||||
if album_info.media:
|
||||
# Preferred media options.
|
||||
media_patterns: Sequence[str] = preferred_config["media"].as_str_seq()
|
||||
options = [
|
||||
re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in media_patterns
|
||||
]
|
||||
if options:
|
||||
dist.add_priority("media", album_info.media, options)
|
||||
# Current media.
|
||||
elif likelies["media"]:
|
||||
dist.add_equality("media", album_info.media, likelies["media"])
|
||||
|
||||
# Mediums.
|
||||
if likelies["disctotal"] and album_info.mediums:
|
||||
dist.add_number("mediums", likelies["disctotal"], album_info.mediums)
|
||||
|
||||
# Prefer earliest release.
|
||||
if album_info.year and preferred_config["original_year"]:
|
||||
# Assume 1889 (earliest first gramophone discs) if we don't know the
|
||||
# original year.
|
||||
original = album_info.original_year or 1889
|
||||
diff = abs(album_info.year - original)
|
||||
diff_max = abs(datetime.date.today().year - original)
|
||||
dist.add_ratio("year", diff, diff_max)
|
||||
# Year.
|
||||
elif likelies["year"] and album_info.year:
|
||||
if likelies["year"] in (album_info.year, album_info.original_year):
|
||||
# No penalty for matching release or original year.
|
||||
dist.add("year", 0.0)
|
||||
elif album_info.original_year:
|
||||
# Prefer matchest closest to the release year.
|
||||
diff = abs(likelies["year"] - album_info.year)
|
||||
diff_max = abs(
|
||||
datetime.date.today().year - album_info.original_year
|
||||
)
|
||||
dist.add_ratio("year", diff, diff_max)
|
||||
else:
|
||||
# Full penalty when there is no original year.
|
||||
dist.add("year", 1.0)
|
||||
|
||||
# Preferred countries.
|
||||
country_patterns: Sequence[str] = preferred_config["countries"].as_str_seq()
|
||||
options = [re.compile(pat, re.I) for pat in country_patterns]
|
||||
if album_info.country and options:
|
||||
dist.add_priority("country", album_info.country, options)
|
||||
# Country.
|
||||
elif likelies["country"] and album_info.country:
|
||||
dist.add_string("country", likelies["country"], album_info.country)
|
||||
|
||||
# Label.
|
||||
if likelies["label"] and album_info.label:
|
||||
dist.add_string("label", likelies["label"], album_info.label)
|
||||
|
||||
# Catalog number.
|
||||
if likelies["catalognum"] and album_info.catalognum:
|
||||
dist.add_string(
|
||||
"catalognum", likelies["catalognum"], album_info.catalognum
|
||||
)
|
||||
|
||||
# Disambiguation.
|
||||
if likelies["albumdisambig"] and album_info.albumdisambig:
|
||||
dist.add_string(
|
||||
"albumdisambig", likelies["albumdisambig"], album_info.albumdisambig
|
||||
)
|
||||
|
||||
# Album ID.
|
||||
if likelies["mb_albumid"]:
|
||||
dist.add_equality(
|
||||
"album_id", likelies["mb_albumid"], album_info.album_id
|
||||
)
|
||||
|
||||
# Tracks.
|
||||
dist.tracks = {}
|
||||
for item, track in mapping.items():
|
||||
dist.tracks[track] = track_distance(item, track, album_info.va)
|
||||
dist.add("tracks", dist.tracks[track].distance)
|
||||
|
||||
# Missing tracks.
|
||||
for _ in range(len(album_info.tracks) - len(mapping)):
|
||||
dist.add("missing_tracks", 1.0)
|
||||
|
||||
# Unmatched tracks.
|
||||
for _ in range(len(items) - len(mapping)):
|
||||
dist.add("unmatched_tracks", 1.0)
|
||||
|
||||
# Plugins.
|
||||
dist.update(plugins.album_distance(items, album_info, mapping))
|
||||
|
||||
return dist
|
||||
|
||||
|
||||
def match_by_id(items: Iterable[Item]) -> AlbumInfo | None:
|
||||
"""If the items are tagged with an external source ID, return an
|
||||
AlbumInfo object for the corresponding album. Otherwise, returns
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ import mediafile
|
|||
|
||||
import beets
|
||||
from beets import logging
|
||||
from beets.autotag.distance import Distance
|
||||
from beets.util.id_extractors import extract_release_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -53,7 +54,7 @@ if TYPE_CHECKING:
|
|||
|
||||
from confuse import ConfigView
|
||||
|
||||
from beets.autotag import AlbumInfo, Distance, TrackInfo
|
||||
from beets.autotag import AlbumInfo, TrackInfo
|
||||
from beets.dbcore import Query
|
||||
from beets.dbcore.db import FieldQueryType
|
||||
from beets.dbcore.types import Type
|
||||
|
|
@ -224,8 +225,6 @@ class BeetsPlugin:
|
|||
"""Should return a Distance object to be added to the
|
||||
distance for every track comparison.
|
||||
"""
|
||||
from beets.autotag.hooks import Distance
|
||||
|
||||
return Distance()
|
||||
|
||||
def album_distance(
|
||||
|
|
@ -237,8 +236,6 @@ class BeetsPlugin:
|
|||
"""Should return a Distance object to be added to the
|
||||
distance for every album-level comparison.
|
||||
"""
|
||||
from beets.autotag.hooks import Distance
|
||||
|
||||
return Distance()
|
||||
|
||||
def candidates(
|
||||
|
|
@ -458,8 +455,6 @@ def track_distance(item: Item, info: TrackInfo) -> Distance:
|
|||
"""Gets the track distance calculated by all loaded plugins.
|
||||
Returns a Distance object.
|
||||
"""
|
||||
from beets.autotag.hooks import Distance
|
||||
|
||||
dist = Distance()
|
||||
for plugin in find_plugins():
|
||||
dist.update(plugin.track_distance(item, info))
|
||||
|
|
@ -472,8 +467,6 @@ def album_distance(
|
|||
mapping: dict[Item, TrackInfo],
|
||||
) -> Distance:
|
||||
"""Returns the album distance calculated by plugins."""
|
||||
from beets.autotag.hooks import Distance
|
||||
|
||||
dist = Distance()
|
||||
for plugin in find_plugins():
|
||||
dist.update(plugin.album_distance(items, album_info, mapping))
|
||||
|
|
@ -660,8 +653,6 @@ def get_distance(
|
|||
"""Returns the ``data_source`` weight and the maximum source weight
|
||||
for albums or individual tracks.
|
||||
"""
|
||||
from beets.autotag.hooks import Distance
|
||||
|
||||
dist = Distance()
|
||||
if info.data_source == data_source:
|
||||
dist.add("source", config["source_weight"].as_number())
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ import acoustid
|
|||
import confuse
|
||||
|
||||
from beets import config, plugins, ui, util
|
||||
from beets.autotag.hooks import Distance
|
||||
from beets.autotag.distance import Distance
|
||||
from beetsplug.musicbrainz import MusicBrainzPlugin
|
||||
|
||||
API_KEY = "1vOwZtEn"
|
||||
|
|
|
|||
|
|
@ -38,7 +38,8 @@ from typing_extensions import TypedDict
|
|||
import beets
|
||||
import beets.ui
|
||||
from beets import config
|
||||
from beets.autotag.hooks import AlbumInfo, TrackInfo, string_dist
|
||||
from beets.autotag.distance import string_dist
|
||||
from beets.autotag.hooks import AlbumInfo, TrackInfo
|
||||
from beets.plugins import BeetsPlugin, MetadataSourcePlugin, get_distance
|
||||
from beets.util.id_extractors import extract_release_id
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ from unidecode import unidecode
|
|||
|
||||
import beets
|
||||
from beets import plugins, ui
|
||||
from beets.autotag.hooks import string_dist
|
||||
from beets.autotag.distance import string_dist
|
||||
from beets.util.config import sanitize_choices
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
|
|||
476
test/autotag/test_distance.py
Normal file
476
test/autotag/test_distance.py
Normal file
|
|
@ -0,0 +1,476 @@
|
|||
import re
|
||||
import unittest
|
||||
|
||||
from beets import config
|
||||
from beets.autotag import AlbumInfo, TrackInfo, match
|
||||
from beets.autotag.distance import Distance, string_dist
|
||||
from beets.library import Item
|
||||
from beets.test.helper import BeetsTestCase
|
||||
|
||||
|
||||
def _make_item(title, track, artist="some artist"):
|
||||
return Item(
|
||||
title=title,
|
||||
track=track,
|
||||
artist=artist,
|
||||
album="some album",
|
||||
length=1,
|
||||
mb_trackid="",
|
||||
mb_albumid="",
|
||||
mb_artistid="",
|
||||
)
|
||||
|
||||
|
||||
def _make_trackinfo():
|
||||
return [
|
||||
TrackInfo(
|
||||
title="one", track_id=None, artist="some artist", length=1, index=1
|
||||
),
|
||||
TrackInfo(
|
||||
title="two", track_id=None, artist="some artist", length=1, index=2
|
||||
),
|
||||
TrackInfo(
|
||||
title="three",
|
||||
track_id=None,
|
||||
artist="some artist",
|
||||
length=1,
|
||||
index=3,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _clear_weights():
|
||||
"""Hack around the lazy descriptor used to cache weights for
|
||||
Distance calculations.
|
||||
"""
|
||||
Distance.__dict__["_weights"].cache = {}
|
||||
|
||||
|
||||
class DistanceTest(BeetsTestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
_clear_weights()
|
||||
|
||||
def test_add(self):
|
||||
dist = Distance()
|
||||
dist.add("add", 1.0)
|
||||
assert dist._penalties == {"add": [1.0]}
|
||||
|
||||
def test_add_equality(self):
|
||||
dist = Distance()
|
||||
dist.add_equality("equality", "ghi", ["abc", "def", "ghi"])
|
||||
assert dist._penalties["equality"] == [0.0]
|
||||
|
||||
dist.add_equality("equality", "xyz", ["abc", "def", "ghi"])
|
||||
assert dist._penalties["equality"] == [0.0, 1.0]
|
||||
|
||||
dist.add_equality("equality", "abc", re.compile(r"ABC", re.I))
|
||||
assert dist._penalties["equality"] == [0.0, 1.0, 0.0]
|
||||
|
||||
def test_add_expr(self):
|
||||
dist = Distance()
|
||||
dist.add_expr("expr", True)
|
||||
assert dist._penalties["expr"] == [1.0]
|
||||
|
||||
dist.add_expr("expr", False)
|
||||
assert dist._penalties["expr"] == [1.0, 0.0]
|
||||
|
||||
def test_add_number(self):
|
||||
dist = Distance()
|
||||
# Add a full penalty for each number of difference between two numbers.
|
||||
|
||||
dist.add_number("number", 1, 1)
|
||||
assert dist._penalties["number"] == [0.0]
|
||||
|
||||
dist.add_number("number", 1, 2)
|
||||
assert dist._penalties["number"] == [0.0, 1.0]
|
||||
|
||||
dist.add_number("number", 2, 1)
|
||||
assert dist._penalties["number"] == [0.0, 1.0, 1.0]
|
||||
|
||||
dist.add_number("number", -1, 2)
|
||||
assert dist._penalties["number"] == [0.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
def test_add_priority(self):
|
||||
dist = Distance()
|
||||
dist.add_priority("priority", "abc", "abc")
|
||||
assert dist._penalties["priority"] == [0.0]
|
||||
|
||||
dist.add_priority("priority", "def", ["abc", "def"])
|
||||
assert dist._penalties["priority"] == [0.0, 0.5]
|
||||
|
||||
dist.add_priority(
|
||||
"priority", "gh", ["ab", "cd", "ef", re.compile("GH", re.I)]
|
||||
)
|
||||
assert dist._penalties["priority"] == [0.0, 0.5, 0.75]
|
||||
|
||||
dist.add_priority("priority", "xyz", ["abc", "def"])
|
||||
assert dist._penalties["priority"] == [0.0, 0.5, 0.75, 1.0]
|
||||
|
||||
def test_add_ratio(self):
|
||||
dist = Distance()
|
||||
dist.add_ratio("ratio", 25, 100)
|
||||
assert dist._penalties["ratio"] == [0.25]
|
||||
|
||||
dist.add_ratio("ratio", 10, 5)
|
||||
assert dist._penalties["ratio"] == [0.25, 1.0]
|
||||
|
||||
dist.add_ratio("ratio", -5, 5)
|
||||
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0]
|
||||
|
||||
dist.add_ratio("ratio", 5, 0)
|
||||
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0, 0.0]
|
||||
|
||||
def test_add_string(self):
|
||||
dist = Distance()
|
||||
sdist = string_dist("abc", "bcd")
|
||||
dist.add_string("string", "abc", "bcd")
|
||||
assert dist._penalties["string"] == [sdist]
|
||||
assert dist._penalties["string"] != [0]
|
||||
|
||||
def test_add_string_none(self):
|
||||
dist = Distance()
|
||||
dist.add_string("string", None, "string")
|
||||
assert dist._penalties["string"] == [1]
|
||||
|
||||
def test_add_string_both_none(self):
|
||||
dist = Distance()
|
||||
dist.add_string("string", None, None)
|
||||
assert dist._penalties["string"] == [0]
|
||||
|
||||
def test_distance(self):
|
||||
config["match"]["distance_weights"]["album"] = 2.0
|
||||
config["match"]["distance_weights"]["medium"] = 1.0
|
||||
_clear_weights()
|
||||
|
||||
dist = Distance()
|
||||
dist.add("album", 0.5)
|
||||
dist.add("media", 0.25)
|
||||
dist.add("media", 0.75)
|
||||
assert dist.distance == 0.5
|
||||
|
||||
# __getitem__()
|
||||
assert dist["album"] == 0.25
|
||||
assert dist["media"] == 0.25
|
||||
|
||||
def test_max_distance(self):
|
||||
config["match"]["distance_weights"]["album"] = 3.0
|
||||
config["match"]["distance_weights"]["medium"] = 1.0
|
||||
_clear_weights()
|
||||
|
||||
dist = Distance()
|
||||
dist.add("album", 0.5)
|
||||
dist.add("medium", 0.0)
|
||||
dist.add("medium", 0.0)
|
||||
assert dist.max_distance == 5.0
|
||||
|
||||
def test_operators(self):
|
||||
config["match"]["distance_weights"]["source"] = 1.0
|
||||
config["match"]["distance_weights"]["album"] = 2.0
|
||||
config["match"]["distance_weights"]["medium"] = 1.0
|
||||
_clear_weights()
|
||||
|
||||
dist = Distance()
|
||||
dist.add("source", 0.0)
|
||||
dist.add("album", 0.5)
|
||||
dist.add("medium", 0.25)
|
||||
dist.add("medium", 0.75)
|
||||
assert len(dist) == 2
|
||||
assert list(dist) == [("album", 0.2), ("medium", 0.2)]
|
||||
assert dist == 0.4
|
||||
assert dist < 1.0
|
||||
assert dist > 0.0
|
||||
assert dist - 0.4 == 0.0
|
||||
assert 0.4 - dist == 0.0
|
||||
assert float(dist) == 0.4
|
||||
|
||||
def test_raw_distance(self):
|
||||
config["match"]["distance_weights"]["album"] = 3.0
|
||||
config["match"]["distance_weights"]["medium"] = 1.0
|
||||
_clear_weights()
|
||||
|
||||
dist = Distance()
|
||||
dist.add("album", 0.5)
|
||||
dist.add("medium", 0.25)
|
||||
dist.add("medium", 0.5)
|
||||
assert dist.raw_distance == 2.25
|
||||
|
||||
def test_items(self):
|
||||
config["match"]["distance_weights"]["album"] = 4.0
|
||||
config["match"]["distance_weights"]["medium"] = 2.0
|
||||
_clear_weights()
|
||||
|
||||
dist = Distance()
|
||||
dist.add("album", 0.1875)
|
||||
dist.add("medium", 0.75)
|
||||
assert dist.items() == [("medium", 0.25), ("album", 0.125)]
|
||||
|
||||
# Sort by key if distance is equal.
|
||||
dist = Distance()
|
||||
dist.add("album", 0.375)
|
||||
dist.add("medium", 0.75)
|
||||
assert dist.items() == [("album", 0.25), ("medium", 0.25)]
|
||||
|
||||
def test_update(self):
|
||||
dist1 = Distance()
|
||||
dist1.add("album", 0.5)
|
||||
dist1.add("media", 1.0)
|
||||
|
||||
dist2 = Distance()
|
||||
dist2.add("album", 0.75)
|
||||
dist2.add("album", 0.25)
|
||||
dist2.add("media", 0.05)
|
||||
|
||||
dist1.update(dist2)
|
||||
|
||||
assert dist1._penalties == {
|
||||
"album": [0.5, 0.75, 0.25],
|
||||
"media": [1.0, 0.05],
|
||||
}
|
||||
|
||||
|
||||
class TrackDistanceTest(BeetsTestCase):
|
||||
def test_identical_tracks(self):
|
||||
item = _make_item("one", 1)
|
||||
info = _make_trackinfo()[0]
|
||||
dist = match.track_distance(item, info, incl_artist=True)
|
||||
assert dist == 0.0
|
||||
|
||||
def test_different_title(self):
|
||||
item = _make_item("foo", 1)
|
||||
info = _make_trackinfo()[0]
|
||||
dist = match.track_distance(item, info, incl_artist=True)
|
||||
assert dist != 0.0
|
||||
|
||||
def test_different_artist(self):
|
||||
item = _make_item("one", 1)
|
||||
item.artist = "foo"
|
||||
info = _make_trackinfo()[0]
|
||||
dist = match.track_distance(item, info, incl_artist=True)
|
||||
assert dist != 0.0
|
||||
|
||||
def test_various_artists_tolerated(self):
|
||||
item = _make_item("one", 1)
|
||||
item.artist = "Various Artists"
|
||||
info = _make_trackinfo()[0]
|
||||
dist = match.track_distance(item, info, incl_artist=True)
|
||||
assert dist == 0.0
|
||||
|
||||
|
||||
class AlbumDistanceTest(BeetsTestCase):
|
||||
def _mapping(self, items, info):
|
||||
out = {}
|
||||
for i, t in zip(items, info.tracks):
|
||||
out[i] = t
|
||||
return out
|
||||
|
||||
def _dist(self, items, info):
|
||||
return match.distance(items, info, self._mapping(items, info))
|
||||
|
||||
def test_identical_albums(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("two", 2))
|
||||
items.append(_make_item("three", 3))
|
||||
info = AlbumInfo(
|
||||
artist="some artist",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=False,
|
||||
)
|
||||
assert self._dist(items, info) == 0
|
||||
|
||||
def test_incomplete_album(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("three", 3))
|
||||
info = AlbumInfo(
|
||||
artist="some artist",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=False,
|
||||
)
|
||||
dist = self._dist(items, info)
|
||||
assert dist != 0
|
||||
# Make sure the distance is not too great
|
||||
assert dist < 0.2
|
||||
|
||||
def test_global_artists_differ(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("two", 2))
|
||||
items.append(_make_item("three", 3))
|
||||
info = AlbumInfo(
|
||||
artist="someone else",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=False,
|
||||
)
|
||||
assert self._dist(items, info) != 0
|
||||
|
||||
def test_comp_track_artists_match(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("two", 2))
|
||||
items.append(_make_item("three", 3))
|
||||
info = AlbumInfo(
|
||||
artist="should be ignored",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=True,
|
||||
)
|
||||
assert self._dist(items, info) == 0
|
||||
|
||||
def test_comp_no_track_artists(self):
|
||||
# Some VA releases don't have track artists (incomplete metadata).
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("two", 2))
|
||||
items.append(_make_item("three", 3))
|
||||
info = AlbumInfo(
|
||||
artist="should be ignored",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=True,
|
||||
)
|
||||
info.tracks[0].artist = None
|
||||
info.tracks[1].artist = None
|
||||
info.tracks[2].artist = None
|
||||
assert self._dist(items, info) == 0
|
||||
|
||||
def test_comp_track_artists_do_not_match(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("two", 2, "someone else"))
|
||||
items.append(_make_item("three", 3))
|
||||
info = AlbumInfo(
|
||||
artist="some artist",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=True,
|
||||
)
|
||||
assert self._dist(items, info) != 0
|
||||
|
||||
def test_tracks_out_of_order(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("three", 2))
|
||||
items.append(_make_item("two", 3))
|
||||
info = AlbumInfo(
|
||||
artist="some artist",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=False,
|
||||
)
|
||||
dist = self._dist(items, info)
|
||||
assert 0 < dist < 0.2
|
||||
|
||||
def test_two_medium_release(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("two", 2))
|
||||
items.append(_make_item("three", 3))
|
||||
info = AlbumInfo(
|
||||
artist="some artist",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=False,
|
||||
)
|
||||
info.tracks[0].medium_index = 1
|
||||
info.tracks[1].medium_index = 2
|
||||
info.tracks[2].medium_index = 1
|
||||
dist = self._dist(items, info)
|
||||
assert dist == 0
|
||||
|
||||
def test_per_medium_track_numbers(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("two", 2))
|
||||
items.append(_make_item("three", 1))
|
||||
info = AlbumInfo(
|
||||
artist="some artist",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=False,
|
||||
)
|
||||
info.tracks[0].medium_index = 1
|
||||
info.tracks[1].medium_index = 2
|
||||
info.tracks[2].medium_index = 1
|
||||
dist = self._dist(items, info)
|
||||
assert dist == 0
|
||||
|
||||
|
||||
class StringDistanceTest(unittest.TestCase):
|
||||
def test_equal_strings(self):
|
||||
dist = string_dist("Some String", "Some String")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_different_strings(self):
|
||||
dist = string_dist("Some String", "Totally Different")
|
||||
assert dist != 0.0
|
||||
|
||||
def test_punctuation_ignored(self):
|
||||
dist = string_dist("Some String", "Some.String!")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_case_ignored(self):
|
||||
dist = string_dist("Some String", "sOME sTring")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_leading_the_has_lower_weight(self):
|
||||
dist1 = string_dist("XXX Band Name", "Band Name")
|
||||
dist2 = string_dist("The Band Name", "Band Name")
|
||||
assert dist2 < dist1
|
||||
|
||||
def test_parens_have_lower_weight(self):
|
||||
dist1 = string_dist("One .Two.", "One")
|
||||
dist2 = string_dist("One (Two)", "One")
|
||||
assert dist2 < dist1
|
||||
|
||||
def test_brackets_have_lower_weight(self):
|
||||
dist1 = string_dist("One .Two.", "One")
|
||||
dist2 = string_dist("One [Two]", "One")
|
||||
assert dist2 < dist1
|
||||
|
||||
def test_ep_label_has_zero_weight(self):
|
||||
dist = string_dist("My Song (EP)", "My Song")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_featured_has_lower_weight(self):
|
||||
dist1 = string_dist("My Song blah Someone", "My Song")
|
||||
dist2 = string_dist("My Song feat Someone", "My Song")
|
||||
assert dist2 < dist1
|
||||
|
||||
def test_postfix_the(self):
|
||||
dist = string_dist("The Song Title", "Song Title, The")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_postfix_a(self):
|
||||
dist = string_dist("A Song Title", "Song Title, A")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_postfix_an(self):
|
||||
dist = string_dist("An Album Title", "Album Title, An")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_empty_strings(self):
|
||||
dist = string_dist("", "")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_solo_pattern(self):
|
||||
# Just make sure these don't crash.
|
||||
string_dist("The ", "")
|
||||
string_dist("(EP)", "(EP)")
|
||||
string_dist(", An", "")
|
||||
|
||||
def test_heuristic_does_not_harm_distance(self):
|
||||
dist = string_dist("Untitled", "[Untitled]")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_ampersand_expansion(self):
|
||||
dist = string_dist("And", "&")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_accented_characters(self):
|
||||
dist = string_dist("\xe9\xe1\xf1", "ean")
|
||||
assert dist == 0.0
|
||||
|
|
@ -14,410 +14,14 @@
|
|||
|
||||
"""Tests for autotagging functionality."""
|
||||
|
||||
import re
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from beets import autotag, config
|
||||
from beets.autotag import AlbumInfo, TrackInfo, correct_list_fields, match
|
||||
from beets.autotag.hooks import Distance, string_dist
|
||||
from beets.library import Item
|
||||
from beets.test.helper import BeetsTestCase, ConfigMixin
|
||||
|
||||
|
||||
def _make_item(title, track, artist="some artist"):
|
||||
return Item(
|
||||
title=title,
|
||||
track=track,
|
||||
artist=artist,
|
||||
album="some album",
|
||||
length=1,
|
||||
mb_trackid="",
|
||||
mb_albumid="",
|
||||
mb_artistid="",
|
||||
)
|
||||
|
||||
|
||||
def _make_trackinfo():
|
||||
return [
|
||||
TrackInfo(
|
||||
title="one", track_id=None, artist="some artist", length=1, index=1
|
||||
),
|
||||
TrackInfo(
|
||||
title="two", track_id=None, artist="some artist", length=1, index=2
|
||||
),
|
||||
TrackInfo(
|
||||
title="three",
|
||||
track_id=None,
|
||||
artist="some artist",
|
||||
length=1,
|
||||
index=3,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _clear_weights():
|
||||
"""Hack around the lazy descriptor used to cache weights for
|
||||
Distance calculations.
|
||||
"""
|
||||
Distance.__dict__["_weights"].cache = {}
|
||||
|
||||
|
||||
class DistanceTest(BeetsTestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
_clear_weights()
|
||||
|
||||
def test_add(self):
|
||||
dist = Distance()
|
||||
dist.add("add", 1.0)
|
||||
assert dist._penalties == {"add": [1.0]}
|
||||
|
||||
def test_add_equality(self):
|
||||
dist = Distance()
|
||||
dist.add_equality("equality", "ghi", ["abc", "def", "ghi"])
|
||||
assert dist._penalties["equality"] == [0.0]
|
||||
|
||||
dist.add_equality("equality", "xyz", ["abc", "def", "ghi"])
|
||||
assert dist._penalties["equality"] == [0.0, 1.0]
|
||||
|
||||
dist.add_equality("equality", "abc", re.compile(r"ABC", re.I))
|
||||
assert dist._penalties["equality"] == [0.0, 1.0, 0.0]
|
||||
|
||||
def test_add_expr(self):
|
||||
dist = Distance()
|
||||
dist.add_expr("expr", True)
|
||||
assert dist._penalties["expr"] == [1.0]
|
||||
|
||||
dist.add_expr("expr", False)
|
||||
assert dist._penalties["expr"] == [1.0, 0.0]
|
||||
|
||||
def test_add_number(self):
|
||||
dist = Distance()
|
||||
# Add a full penalty for each number of difference between two numbers.
|
||||
|
||||
dist.add_number("number", 1, 1)
|
||||
assert dist._penalties["number"] == [0.0]
|
||||
|
||||
dist.add_number("number", 1, 2)
|
||||
assert dist._penalties["number"] == [0.0, 1.0]
|
||||
|
||||
dist.add_number("number", 2, 1)
|
||||
assert dist._penalties["number"] == [0.0, 1.0, 1.0]
|
||||
|
||||
dist.add_number("number", -1, 2)
|
||||
assert dist._penalties["number"] == [0.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
def test_add_priority(self):
|
||||
dist = Distance()
|
||||
dist.add_priority("priority", "abc", "abc")
|
||||
assert dist._penalties["priority"] == [0.0]
|
||||
|
||||
dist.add_priority("priority", "def", ["abc", "def"])
|
||||
assert dist._penalties["priority"] == [0.0, 0.5]
|
||||
|
||||
dist.add_priority(
|
||||
"priority", "gh", ["ab", "cd", "ef", re.compile("GH", re.I)]
|
||||
)
|
||||
assert dist._penalties["priority"] == [0.0, 0.5, 0.75]
|
||||
|
||||
dist.add_priority("priority", "xyz", ["abc", "def"])
|
||||
assert dist._penalties["priority"] == [0.0, 0.5, 0.75, 1.0]
|
||||
|
||||
def test_add_ratio(self):
|
||||
dist = Distance()
|
||||
dist.add_ratio("ratio", 25, 100)
|
||||
assert dist._penalties["ratio"] == [0.25]
|
||||
|
||||
dist.add_ratio("ratio", 10, 5)
|
||||
assert dist._penalties["ratio"] == [0.25, 1.0]
|
||||
|
||||
dist.add_ratio("ratio", -5, 5)
|
||||
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0]
|
||||
|
||||
dist.add_ratio("ratio", 5, 0)
|
||||
assert dist._penalties["ratio"] == [0.25, 1.0, 0.0, 0.0]
|
||||
|
||||
def test_add_string(self):
|
||||
dist = Distance()
|
||||
sdist = string_dist("abc", "bcd")
|
||||
dist.add_string("string", "abc", "bcd")
|
||||
assert dist._penalties["string"] == [sdist]
|
||||
assert dist._penalties["string"] != [0]
|
||||
|
||||
def test_add_string_none(self):
|
||||
dist = Distance()
|
||||
dist.add_string("string", None, "string")
|
||||
assert dist._penalties["string"] == [1]
|
||||
|
||||
def test_add_string_both_none(self):
|
||||
dist = Distance()
|
||||
dist.add_string("string", None, None)
|
||||
assert dist._penalties["string"] == [0]
|
||||
|
||||
def test_distance(self):
|
||||
config["match"]["distance_weights"]["album"] = 2.0
|
||||
config["match"]["distance_weights"]["medium"] = 1.0
|
||||
_clear_weights()
|
||||
|
||||
dist = Distance()
|
||||
dist.add("album", 0.5)
|
||||
dist.add("media", 0.25)
|
||||
dist.add("media", 0.75)
|
||||
assert dist.distance == 0.5
|
||||
|
||||
# __getitem__()
|
||||
assert dist["album"] == 0.25
|
||||
assert dist["media"] == 0.25
|
||||
|
||||
def test_max_distance(self):
|
||||
config["match"]["distance_weights"]["album"] = 3.0
|
||||
config["match"]["distance_weights"]["medium"] = 1.0
|
||||
_clear_weights()
|
||||
|
||||
dist = Distance()
|
||||
dist.add("album", 0.5)
|
||||
dist.add("medium", 0.0)
|
||||
dist.add("medium", 0.0)
|
||||
assert dist.max_distance == 5.0
|
||||
|
||||
def test_operators(self):
|
||||
config["match"]["distance_weights"]["source"] = 1.0
|
||||
config["match"]["distance_weights"]["album"] = 2.0
|
||||
config["match"]["distance_weights"]["medium"] = 1.0
|
||||
_clear_weights()
|
||||
|
||||
dist = Distance()
|
||||
dist.add("source", 0.0)
|
||||
dist.add("album", 0.5)
|
||||
dist.add("medium", 0.25)
|
||||
dist.add("medium", 0.75)
|
||||
assert len(dist) == 2
|
||||
assert list(dist) == [("album", 0.2), ("medium", 0.2)]
|
||||
assert dist == 0.4
|
||||
assert dist < 1.0
|
||||
assert dist > 0.0
|
||||
assert dist - 0.4 == 0.0
|
||||
assert 0.4 - dist == 0.0
|
||||
assert float(dist) == 0.4
|
||||
|
||||
def test_raw_distance(self):
|
||||
config["match"]["distance_weights"]["album"] = 3.0
|
||||
config["match"]["distance_weights"]["medium"] = 1.0
|
||||
_clear_weights()
|
||||
|
||||
dist = Distance()
|
||||
dist.add("album", 0.5)
|
||||
dist.add("medium", 0.25)
|
||||
dist.add("medium", 0.5)
|
||||
assert dist.raw_distance == 2.25
|
||||
|
||||
def test_items(self):
|
||||
config["match"]["distance_weights"]["album"] = 4.0
|
||||
config["match"]["distance_weights"]["medium"] = 2.0
|
||||
_clear_weights()
|
||||
|
||||
dist = Distance()
|
||||
dist.add("album", 0.1875)
|
||||
dist.add("medium", 0.75)
|
||||
assert dist.items() == [("medium", 0.25), ("album", 0.125)]
|
||||
|
||||
# Sort by key if distance is equal.
|
||||
dist = Distance()
|
||||
dist.add("album", 0.375)
|
||||
dist.add("medium", 0.75)
|
||||
assert dist.items() == [("album", 0.25), ("medium", 0.25)]
|
||||
|
||||
def test_update(self):
|
||||
dist1 = Distance()
|
||||
dist1.add("album", 0.5)
|
||||
dist1.add("media", 1.0)
|
||||
|
||||
dist2 = Distance()
|
||||
dist2.add("album", 0.75)
|
||||
dist2.add("album", 0.25)
|
||||
dist2.add("media", 0.05)
|
||||
|
||||
dist1.update(dist2)
|
||||
|
||||
assert dist1._penalties == {
|
||||
"album": [0.5, 0.75, 0.25],
|
||||
"media": [1.0, 0.05],
|
||||
}
|
||||
|
||||
|
||||
class TrackDistanceTest(BeetsTestCase):
|
||||
def test_identical_tracks(self):
|
||||
item = _make_item("one", 1)
|
||||
info = _make_trackinfo()[0]
|
||||
dist = match.track_distance(item, info, incl_artist=True)
|
||||
assert dist == 0.0
|
||||
|
||||
def test_different_title(self):
|
||||
item = _make_item("foo", 1)
|
||||
info = _make_trackinfo()[0]
|
||||
dist = match.track_distance(item, info, incl_artist=True)
|
||||
assert dist != 0.0
|
||||
|
||||
def test_different_artist(self):
|
||||
item = _make_item("one", 1)
|
||||
item.artist = "foo"
|
||||
info = _make_trackinfo()[0]
|
||||
dist = match.track_distance(item, info, incl_artist=True)
|
||||
assert dist != 0.0
|
||||
|
||||
def test_various_artists_tolerated(self):
|
||||
item = _make_item("one", 1)
|
||||
item.artist = "Various Artists"
|
||||
info = _make_trackinfo()[0]
|
||||
dist = match.track_distance(item, info, incl_artist=True)
|
||||
assert dist == 0.0
|
||||
|
||||
|
||||
class AlbumDistanceTest(BeetsTestCase):
|
||||
def _mapping(self, items, info):
|
||||
out = {}
|
||||
for i, t in zip(items, info.tracks):
|
||||
out[i] = t
|
||||
return out
|
||||
|
||||
def _dist(self, items, info):
|
||||
return match.distance(items, info, self._mapping(items, info))
|
||||
|
||||
def test_identical_albums(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("two", 2))
|
||||
items.append(_make_item("three", 3))
|
||||
info = AlbumInfo(
|
||||
artist="some artist",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=False,
|
||||
)
|
||||
assert self._dist(items, info) == 0
|
||||
|
||||
def test_incomplete_album(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("three", 3))
|
||||
info = AlbumInfo(
|
||||
artist="some artist",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=False,
|
||||
)
|
||||
dist = self._dist(items, info)
|
||||
assert dist != 0
|
||||
# Make sure the distance is not too great
|
||||
assert dist < 0.2
|
||||
|
||||
def test_global_artists_differ(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("two", 2))
|
||||
items.append(_make_item("three", 3))
|
||||
info = AlbumInfo(
|
||||
artist="someone else",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=False,
|
||||
)
|
||||
assert self._dist(items, info) != 0
|
||||
|
||||
def test_comp_track_artists_match(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("two", 2))
|
||||
items.append(_make_item("three", 3))
|
||||
info = AlbumInfo(
|
||||
artist="should be ignored",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=True,
|
||||
)
|
||||
assert self._dist(items, info) == 0
|
||||
|
||||
def test_comp_no_track_artists(self):
|
||||
# Some VA releases don't have track artists (incomplete metadata).
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("two", 2))
|
||||
items.append(_make_item("three", 3))
|
||||
info = AlbumInfo(
|
||||
artist="should be ignored",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=True,
|
||||
)
|
||||
info.tracks[0].artist = None
|
||||
info.tracks[1].artist = None
|
||||
info.tracks[2].artist = None
|
||||
assert self._dist(items, info) == 0
|
||||
|
||||
def test_comp_track_artists_do_not_match(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("two", 2, "someone else"))
|
||||
items.append(_make_item("three", 3))
|
||||
info = AlbumInfo(
|
||||
artist="some artist",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=True,
|
||||
)
|
||||
assert self._dist(items, info) != 0
|
||||
|
||||
def test_tracks_out_of_order(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("three", 2))
|
||||
items.append(_make_item("two", 3))
|
||||
info = AlbumInfo(
|
||||
artist="some artist",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=False,
|
||||
)
|
||||
dist = self._dist(items, info)
|
||||
assert 0 < dist < 0.2
|
||||
|
||||
def test_two_medium_release(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("two", 2))
|
||||
items.append(_make_item("three", 3))
|
||||
info = AlbumInfo(
|
||||
artist="some artist",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=False,
|
||||
)
|
||||
info.tracks[0].medium_index = 1
|
||||
info.tracks[1].medium_index = 2
|
||||
info.tracks[2].medium_index = 1
|
||||
dist = self._dist(items, info)
|
||||
assert dist == 0
|
||||
|
||||
def test_per_medium_track_numbers(self):
|
||||
items = []
|
||||
items.append(_make_item("one", 1))
|
||||
items.append(_make_item("two", 2))
|
||||
items.append(_make_item("three", 1))
|
||||
info = AlbumInfo(
|
||||
artist="some artist",
|
||||
album="some album",
|
||||
tracks=_make_trackinfo(),
|
||||
va=False,
|
||||
)
|
||||
info.tracks[0].medium_index = 1
|
||||
info.tracks[1].medium_index = 2
|
||||
info.tracks[2].medium_index = 1
|
||||
dist = self._dist(items, info)
|
||||
assert dist == 0
|
||||
|
||||
|
||||
class TestAssignment(ConfigMixin):
|
||||
A = "one"
|
||||
B = "two"
|
||||
|
|
@ -840,82 +444,6 @@ class ApplyCompilationTest(BeetsTestCase, ApplyTestUtil):
|
|||
assert self.items[1].comp
|
||||
|
||||
|
||||
class StringDistanceTest(unittest.TestCase):
|
||||
def test_equal_strings(self):
|
||||
dist = string_dist("Some String", "Some String")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_different_strings(self):
|
||||
dist = string_dist("Some String", "Totally Different")
|
||||
assert dist != 0.0
|
||||
|
||||
def test_punctuation_ignored(self):
|
||||
dist = string_dist("Some String", "Some.String!")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_case_ignored(self):
|
||||
dist = string_dist("Some String", "sOME sTring")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_leading_the_has_lower_weight(self):
|
||||
dist1 = string_dist("XXX Band Name", "Band Name")
|
||||
dist2 = string_dist("The Band Name", "Band Name")
|
||||
assert dist2 < dist1
|
||||
|
||||
def test_parens_have_lower_weight(self):
|
||||
dist1 = string_dist("One .Two.", "One")
|
||||
dist2 = string_dist("One (Two)", "One")
|
||||
assert dist2 < dist1
|
||||
|
||||
def test_brackets_have_lower_weight(self):
|
||||
dist1 = string_dist("One .Two.", "One")
|
||||
dist2 = string_dist("One [Two]", "One")
|
||||
assert dist2 < dist1
|
||||
|
||||
def test_ep_label_has_zero_weight(self):
|
||||
dist = string_dist("My Song (EP)", "My Song")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_featured_has_lower_weight(self):
|
||||
dist1 = string_dist("My Song blah Someone", "My Song")
|
||||
dist2 = string_dist("My Song feat Someone", "My Song")
|
||||
assert dist2 < dist1
|
||||
|
||||
def test_postfix_the(self):
|
||||
dist = string_dist("The Song Title", "Song Title, The")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_postfix_a(self):
|
||||
dist = string_dist("A Song Title", "Song Title, A")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_postfix_an(self):
|
||||
dist = string_dist("An Album Title", "Album Title, An")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_empty_strings(self):
|
||||
dist = string_dist("", "")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_solo_pattern(self):
|
||||
# Just make sure these don't crash.
|
||||
string_dist("The ", "")
|
||||
string_dist("(EP)", "(EP)")
|
||||
string_dist(", An", "")
|
||||
|
||||
def test_heuristic_does_not_harm_distance(self):
|
||||
dist = string_dist("Untitled", "[Untitled]")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_ampersand_expansion(self):
|
||||
dist = string_dist("And", "&")
|
||||
assert dist == 0.0
|
||||
|
||||
def test_accented_characters(self):
|
||||
dist = string_dist("\xe9\xe1\xf1", "ean")
|
||||
assert dist == 0.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"single_field,list_field",
|
||||
[
|
||||
|
|
|
|||
Loading…
Reference in a new issue