mirror of
https://github.com/beetbox/beets.git
synced 2026-02-05 23:14:07 +01:00
Fix data source penalty application logic
The data_source penalty was not being calculated correctly because `_get_distance` was being called for **all** enabled metadata plugins which eventually meant that matches were being penalised needlessly. This commit refactors the distance calculation to: - Remove the plugin-based track_distance() and album_distance() methods that were applying penalties incorrectly - Calculate data_source penalties directly in track_distance() and distance() functions when sources don't match - Use a centralized get_penalty() function to retrieve plugin-specific penalty values via a registry with O(1) lookup - Change default data_source_penalty from 0.0 to 0.5 to ensure mismatches are penalized by default - Add data_source to get_most_common_tags() to determine the likely original source for comparison This ensures that tracks and albums from different data sources are properly penalized during matching, improving match quality and preventing cross-source matches.
This commit is contained in:
parent
96670cf971
commit
455d620ae0
10 changed files with 108 additions and 86 deletions
|
|
@ -409,7 +409,10 @@ def track_distance(
|
|||
dist.add_expr("medium", item.disc != track_info.medium)
|
||||
|
||||
# Plugins.
|
||||
dist.update(metadata_plugins.track_distance(item, track_info))
|
||||
if (original := item.get("data_source")) and (
|
||||
actual := track_info.data_source
|
||||
) != original:
|
||||
dist.add("data_source", metadata_plugins.get_penalty(actual))
|
||||
|
||||
return dist
|
||||
|
||||
|
|
@ -526,6 +529,9 @@ def distance(
|
|||
dist.add("unmatched_tracks", 1.0)
|
||||
|
||||
# Plugins.
|
||||
dist.update(metadata_plugins.album_distance(items, album_info, mapping))
|
||||
|
||||
if (
|
||||
likelies["data_source"]
|
||||
and (data_source := album_info.data_source) != likelies["data_source"]
|
||||
):
|
||||
dist.add("data_source", metadata_plugins.get_penalty(data_source))
|
||||
return dist
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from __future__ import annotations
|
|||
|
||||
import abc
|
||||
import re
|
||||
from functools import cache
|
||||
from functools import cache, cached_property
|
||||
from typing import TYPE_CHECKING, Generic, Literal, Sequence, TypedDict, TypeVar
|
||||
|
||||
import unidecode
|
||||
|
|
@ -23,9 +23,6 @@ from .plugins import BeetsPlugin, find_plugins, notify_info_yielded, send
|
|||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
from confuse import ConfigView
|
||||
|
||||
from .autotag import Distance
|
||||
from .autotag.hooks import AlbumInfo, Item, TrackInfo
|
||||
|
||||
|
||||
|
|
@ -76,48 +73,17 @@ def track_for_id(_id: str) -> TrackInfo | None:
|
|||
return None
|
||||
|
||||
|
||||
def track_distance(item: Item, info: TrackInfo) -> Distance:
|
||||
"""Returns the track distance for an item and trackinfo.
|
||||
|
||||
Returns a Distance object is populated by all metadata source plugins
|
||||
that implement the :py:meth:`MetadataSourcePlugin.track_distance` method.
|
||||
"""
|
||||
from beets.autotag.distance import Distance
|
||||
|
||||
dist = Distance()
|
||||
for plugin in find_metadata_source_plugins():
|
||||
dist.update(plugin.track_distance(item, info))
|
||||
return dist
|
||||
|
||||
|
||||
def album_distance(
|
||||
items: Sequence[Item],
|
||||
album_info: AlbumInfo,
|
||||
mapping: dict[Item, TrackInfo],
|
||||
) -> Distance:
|
||||
"""Returns the album distance calculated by plugins."""
|
||||
from beets.autotag.distance import Distance
|
||||
|
||||
dist = Distance()
|
||||
for plugin in find_metadata_source_plugins():
|
||||
dist.update(plugin.album_distance(items, album_info, mapping))
|
||||
return dist
|
||||
|
||||
|
||||
def _get_distance(
|
||||
config: ConfigView, data_source: str, info: AlbumInfo | TrackInfo
|
||||
) -> Distance:
|
||||
"""Returns the ``data_source`` weight and the maximum source weight
|
||||
for albums or individual tracks.
|
||||
"""
|
||||
from beets.autotag.distance import Distance
|
||||
|
||||
dist = Distance()
|
||||
if info.data_source == data_source:
|
||||
dist.add(
|
||||
"data_source", config["data_source_mismatch_penalty"].as_number()
|
||||
)
|
||||
return dist
|
||||
@cache
|
||||
def get_penalty(data_source: str | None) -> float:
|
||||
"""Get the penalty value for the given data source."""
|
||||
return next(
|
||||
(
|
||||
p.data_source_mismatch_penalty
|
||||
for p in find_metadata_source_plugins()
|
||||
if p.data_source == data_source
|
||||
),
|
||||
MetadataSourcePlugin.DEFAULT_DATA_SOURCE_MISMATCH_PENALTY,
|
||||
)
|
||||
|
||||
|
||||
class MetadataSourcePlugin(BeetsPlugin, metaclass=abc.ABCMeta):
|
||||
|
|
@ -128,12 +94,26 @@ class MetadataSourcePlugin(BeetsPlugin, metaclass=abc.ABCMeta):
|
|||
and tracks, and to retrieve album and track information by ID.
|
||||
"""
|
||||
|
||||
DEFAULT_DATA_SOURCE_MISMATCH_PENALTY = 0.5
|
||||
|
||||
@cached_classproperty
|
||||
def data_source(cls) -> str:
|
||||
"""The data source name for this plugin.
|
||||
|
||||
This is inferred from the plugin name.
|
||||
"""
|
||||
return cls.__name__.replace("Plugin", "") # type: ignore[attr-defined]
|
||||
|
||||
@cached_property
|
||||
def data_source_mismatch_penalty(self) -> float:
|
||||
return self.config["data_source_mismatch_penalty"].as_number()
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config.add(
|
||||
{
|
||||
"search_limit": 5,
|
||||
"data_source_mismatch_penalty": 0.0,
|
||||
"data_source_mismatch_penalty": self.DEFAULT_DATA_SOURCE_MISMATCH_PENALTY, # noqa: E501
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -207,35 +187,6 @@ class MetadataSourcePlugin(BeetsPlugin, metaclass=abc.ABCMeta):
|
|||
|
||||
return (self.track_for_id(id) for id in ids)
|
||||
|
||||
def album_distance(
|
||||
self,
|
||||
items: Sequence[Item],
|
||||
album_info: AlbumInfo,
|
||||
mapping: dict[Item, TrackInfo],
|
||||
) -> Distance:
|
||||
"""Calculate the distance for an album based on its items and album info."""
|
||||
return _get_distance(
|
||||
data_source=self.data_source, info=album_info, config=self.config
|
||||
)
|
||||
|
||||
def track_distance(
|
||||
self,
|
||||
item: Item,
|
||||
info: TrackInfo,
|
||||
) -> Distance:
|
||||
"""Calculate the distance for a track based on its item and track info."""
|
||||
return _get_distance(
|
||||
data_source=self.data_source, info=info, config=self.config
|
||||
)
|
||||
|
||||
@cached_classproperty
|
||||
def data_source(cls) -> str:
|
||||
"""The data source name for this plugin.
|
||||
|
||||
This is inferred from the plugin name.
|
||||
"""
|
||||
return cls.__name__.replace("Plugin", "") # type: ignore[attr-defined]
|
||||
|
||||
def _extract_id(self, url: str) -> str | None:
|
||||
"""Extract an ID from a URL for this metadata source plugin.
|
||||
|
||||
|
|
|
|||
|
|
@ -836,9 +836,10 @@ def get_most_common_tags(
|
|||
"country",
|
||||
"media",
|
||||
"albumdisambig",
|
||||
"data_source",
|
||||
]
|
||||
for field in fields:
|
||||
values = [item[field] for item in items if item]
|
||||
values = [item.get(field) for item in items if item]
|
||||
likelies[field], freq = plurality(values)
|
||||
consensus[field] = freq == len(values)
|
||||
|
||||
|
|
|
|||
|
|
@ -45,6 +45,10 @@ Bug fixes:
|
|||
an import of another :class:`beets.plugins.BeetsPlugin` class. :bug:`6033`
|
||||
- :doc:`/plugins/fromfilename`: Fix :bug:`5218`, improve the code (refactor
|
||||
regexps, allow for more cases, add some logging), add tests.
|
||||
- Metadata source plugins: Fixed data source penalty calculation that was
|
||||
incorrectly applied during import matching. The ``source_weight``
|
||||
configuration option has been renamed to ``data_source_mismatch_penalty`` to
|
||||
better reflect its purpose. :bug:`6066`
|
||||
|
||||
For packagers:
|
||||
|
||||
|
|
@ -75,6 +79,13 @@ For developers and plugin authors:
|
|||
- Typing improvements in ``beets/logging.py``: ``getLogger`` now returns
|
||||
``BeetsLogger`` when called with a name, or ``RootLogger`` when called without
|
||||
a name.
|
||||
- The ``track_distance()`` and ``album_distance()`` methods have been removed
|
||||
from ``MetadataSourcePlugin``. Distance calculation for data source mismatches
|
||||
is now handled automatically by the core matching logic. This change
|
||||
simplifies the plugin architecture and fixes incorrect penalty calculations.
|
||||
:bug:`6066`
|
||||
- Metadata source plugins are now registered globally when instantiated, which
|
||||
makes their handling slightly more efficient.
|
||||
|
||||
2.4.0 (September 13, 2025)
|
||||
--------------------------
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ Default
|
|||
.. code-block:: yaml
|
||||
|
||||
deezer:
|
||||
data_source_mismatch_penalty: 0.0
|
||||
data_source_mismatch_penalty: 0.5
|
||||
search_limit: 5
|
||||
search_query_ascii: no
|
||||
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ Default
|
|||
.. code-block:: yaml
|
||||
|
||||
discogs:
|
||||
data_source_mismatch_penalty: 0.0
|
||||
data_source_mismatch_penalty: 0.5
|
||||
search_limit: 5
|
||||
apikey: REDACTED
|
||||
apisecret: REDACTED
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ We provide several :ref:`autotagger_extensions` that fetch metadata from online
|
|||
databases. They share the following configuration options:
|
||||
|
||||
- **data_source_mismatch_penalty**: Penalty applied to matches during import.
|
||||
Default: ``0.0`` (no penalty).
|
||||
Any decimal number between 0 and 1. Default: ``0.5``.
|
||||
|
||||
Penalize this data source to prioritize others. For example, to prefer Discogs
|
||||
over MusicBrainz:
|
||||
|
|
@ -64,7 +64,7 @@ databases. They share the following configuration options:
|
|||
data_source_mismatch_penalty: 2.0
|
||||
|
||||
By default, all sources are equally preferred with each having
|
||||
``data_source_mismatch_penalty`` set to ``0.0``.
|
||||
``data_source_mismatch_penalty`` set to ``0.5``.
|
||||
|
||||
- **search_limit**: Maximum number of search results to consider. Default:
|
||||
``5``.
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ Default
|
|||
.. code-block:: yaml
|
||||
|
||||
musicbrainz:
|
||||
data_source_mismatch_penalty: 0.0
|
||||
data_source_mismatch_penalty: 0.5
|
||||
search_limit: 5
|
||||
host: musicbrainz.org
|
||||
https: no
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ Default
|
|||
.. code-block:: yaml
|
||||
|
||||
spotify:
|
||||
data_source_mismatch_penalty: 0.0
|
||||
data_source_mismatch_penalty: 0.5
|
||||
search_limit: 5
|
||||
mode: list
|
||||
region_filter:
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from beets.autotag.distance import (
|
|||
track_distance,
|
||||
)
|
||||
from beets.library import Item
|
||||
from beets.metadata_plugins import MetadataSourcePlugin, get_penalty
|
||||
from beets.test.helper import ConfigMixin
|
||||
|
||||
_p = pytest.param
|
||||
|
|
@ -297,3 +298,55 @@ class TestStringDistance:
|
|||
string_dist("The ", "")
|
||||
string_dist("(EP)", "(EP)")
|
||||
string_dist(", An", "")
|
||||
|
||||
|
||||
class TestDataSourceDistance:
|
||||
MATCH = 0.0
|
||||
MISMATCH = 0.125
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, monkeypatch, penalty, weight):
|
||||
monkeypatch.setitem(Distance._weights, "data_source", weight)
|
||||
get_penalty.cache_clear()
|
||||
|
||||
class TestMetadataSourcePlugin(MetadataSourcePlugin):
|
||||
def album_for_id(self, *args, **kwargs): ...
|
||||
def track_for_id(self, *args, **kwargs): ...
|
||||
def candidates(self, *args, **kwargs): ...
|
||||
def item_candidates(self, *args, **kwargs): ...
|
||||
|
||||
class OriginalPlugin(TestMetadataSourcePlugin):
|
||||
pass
|
||||
|
||||
class OtherPlugin(TestMetadataSourcePlugin):
|
||||
@property
|
||||
def data_source_mismatch_penalty(self):
|
||||
return penalty
|
||||
|
||||
monkeypatch.setattr(
|
||||
"beets.metadata_plugins.find_metadata_source_plugins",
|
||||
lambda: [OriginalPlugin(), OtherPlugin()],
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"item,info,penalty,weight,expected_distance",
|
||||
[
|
||||
_p("Original", "Original", 0.5, 1.0, MATCH, id="match"),
|
||||
_p("Original", "Other", 0.5, 1.0, MISMATCH, id="mismatch"),
|
||||
_p("Original", "unknown", 0.5, 1.0, MISMATCH, id="mismatch-unknown"), # noqa: E501
|
||||
_p("Original", None, 0.5, 1.0, MISMATCH, id="mismatch-no-info"),
|
||||
_p(None, "Other", 0.5, 1.0, MATCH, id="match-no-original"),
|
||||
_p("unknown", "unknown", 0.5, 1.0, MATCH, id="match-unknown"),
|
||||
_p("Original", "Other", 1.0, 1.0, 0.25, id="mismatch-max-penalty"),
|
||||
_p("Original", "Other", 0.5, 5.0, 0.3125, id="mismatch-high-weight"), # noqa: E501
|
||||
_p("Original", "Other", 0.0, 1.0, MATCH, id="match-no-penalty"),
|
||||
_p("Original", "Other", 0.5, 0.0, MATCH, id="match-no-weight"),
|
||||
],
|
||||
) # fmt: skip
|
||||
def test_distance(self, item, info, expected_distance):
|
||||
item = Item(data_source=item)
|
||||
info = TrackInfo(data_source=info, title="")
|
||||
|
||||
dist = track_distance(item, info)
|
||||
|
||||
assert dist.distance == expected_distance
|
||||
|
|
|
|||
Loading…
Reference in a new issue