Upgrade musicbrainz and discogs to SearchApiMetadataSourcePlugin

And centralise common search functionality inside the parent class
This commit is contained in:
Šarūnas Nejus 2025-11-11 05:20:20 +00:00
parent 3a72d85c5e
commit 36a6ee8efd
No known key found for this signature in database
7 changed files with 146 additions and 242 deletions

View file

@ -10,11 +10,17 @@ from __future__ import annotations
import abc
import re
from functools import cache, cached_property
from typing import TYPE_CHECKING, Generic, Literal, TypedDict, TypeVar
from typing import (
TYPE_CHECKING,
Generic,
Literal,
NamedTuple,
TypedDict,
TypeVar,
)
import unidecode
from confuse import NotFoundError
from typing_extensions import NotRequired
from beets.util import cached_classproperty
from beets.util.id_extractors import extract_release_id
@ -26,6 +32,8 @@ if TYPE_CHECKING:
from .autotag.hooks import AlbumInfo, Item, TrackInfo
QueryType = Literal["album", "track"]
@cache
def find_metadata_source_plugins() -> list[MetadataSourcePlugin]:
@ -169,7 +177,7 @@ class MetadataSourcePlugin(BeetsPlugin, metaclass=abc.ABCMeta):
"""
raise NotImplementedError
def albums_for_ids(self, ids: Sequence[str]) -> Iterable[AlbumInfo | None]:
def albums_for_ids(self, ids: Iterable[str]) -> Iterable[AlbumInfo | None]:
"""Batch lookup of album metadata for a list of album IDs.
Given a list of album identifiers, yields corresponding AlbumInfo objects.
@ -180,7 +188,7 @@ class MetadataSourcePlugin(BeetsPlugin, metaclass=abc.ABCMeta):
return (self.album_for_id(id) for id in ids)
def tracks_for_ids(self, ids: Sequence[str]) -> Iterable[TrackInfo | None]:
def tracks_for_ids(self, ids: Iterable[str]) -> Iterable[TrackInfo | None]:
"""Batch lookup of track metadata for a list of track IDs.
Given a list of track identifiers, yields corresponding TrackInfo objects.
@ -254,14 +262,15 @@ class IDResponse(TypedDict):
id: str
class SearchFilter(TypedDict):
artist: NotRequired[str]
album: NotRequired[str]
R = TypeVar("R", bound=IDResponse)
class SearchParams(NamedTuple):
query_type: QueryType
query: str
filters: dict[str, str]
class SearchApiMetadataSourcePlugin(
Generic[R], MetadataSourcePlugin, metaclass=abc.ABCMeta
):
@ -282,12 +291,26 @@ class SearchApiMetadataSourcePlugin(
}
)
@abc.abstractmethod
def _search_api(
def get_search_filters(
self,
query_type: Literal["album", "track"],
filters: SearchFilter,
query_string: str = "",
query_type: QueryType,
items: Sequence[Item],
artist: str,
name: str,
va_likely: bool,
) -> tuple[str, dict[str, str]]:
query = f'album:"{name}"' if query_type == "album" else name
if query_type == "track" or not va_likely:
query += f' artist:"{artist}"'
return query, {}
@abc.abstractmethod
def get_search_response(self, params: SearchParams) -> Sequence[R]:
raise NotImplementedError
def _search_api(
self, query_type: QueryType, query: str, filters: dict[str, str]
) -> Sequence[R]:
"""Perform a search on the API.
@ -297,7 +320,28 @@ class SearchApiMetadataSourcePlugin(
Should return a list of identifiers for the requested type (album or track).
"""
raise NotImplementedError
if self.config["search_query_ascii"].get():
query = unidecode.unidecode(query)
filters["limit"] = str(self.config["search_limit"].get())
params = SearchParams(query_type, query, filters)
self._log.debug("Searching for '{}' with {}", query, filters)
try:
response_data = self.get_search_response(params)
except Exception:
self._log.error("Error fetching data", exc_info=True)
return ()
self._log.debug("Found {} result(s)", len(response_data))
return response_data
def _get_candidates(
self, query_type: QueryType, *args, **kwargs
) -> Sequence[R]:
return self._search_api(
query_type, *self.get_search_filters(query_type, *args, **kwargs)
)
def candidates(
self,
@ -306,54 +350,11 @@ class SearchApiMetadataSourcePlugin(
album: str,
va_likely: bool,
) -> Iterable[AlbumInfo]:
query_filters: SearchFilter = {}
if album:
query_filters["album"] = album
if not va_likely:
query_filters["artist"] = artist
results = self._search_api("album", query_filters)
if not results:
return []
return filter(
None, self.albums_for_ids([result["id"] for result in results])
)
results = self._get_candidates("album", items, artist, album, va_likely)
return filter(None, self.albums_for_ids(r["id"] for r in results))
def item_candidates(
self, item: Item, artist: str, title: str
) -> Iterable[TrackInfo]:
results = self._search_api(
"track", {"artist": artist}, query_string=title
)
if not results:
return []
return filter(
None,
self.tracks_for_ids([result["id"] for result in results if result]),
)
def _construct_search_query(
self, filters: SearchFilter, query_string: str
) -> str:
"""Construct a query string with the specified filters and keywords to
be provided to the spotify (or similar) search API.
The returned format was initially designed for spotify's search API but
we found is also useful with other APIs that support similar query structures.
see `spotify <https://developer.spotify.com/documentation/web-api/reference/search>`_
and `deezer <https://developers.deezer.com/api/search>`_.
:param filters: Field filters to apply.
:param query_string: Query keywords to use.
:return: Query string to be provided to the search API.
"""
components = [query_string, *(f"{k}:'{v}'" for k, v in filters.items())]
query = " ".join(filter(None, components))
if self.config["search_query_ascii"].get():
query = unidecode.unidecode(query)
return query
results = self._get_candidates("track", [item], artist, title, False)
return filter(None, self.tracks_for_ids(r["id"] for r in results))

View file

@ -18,23 +18,20 @@ from __future__ import annotations
import collections
import time
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING
import requests
from beets import ui
from beets.autotag import AlbumInfo, TrackInfo
from beets.dbcore import types
from beets.metadata_plugins import (
IDResponse,
SearchApiMetadataSourcePlugin,
SearchFilter,
)
from beets.metadata_plugins import IDResponse, SearchApiMetadataSourcePlugin
if TYPE_CHECKING:
from collections.abc import Sequence
from beets.library import Item, Library
from beets.metadata_plugins import SearchParams
from ._typing import JSONDict
@ -220,58 +217,12 @@ class DeezerPlugin(SearchApiMetadataSourcePlugin[IDResponse]):
deezer_updated=time.time(),
)
def _search_api(
self,
query_type: Literal[
"album",
"track",
"artist",
"history",
"playlist",
"podcast",
"radio",
"user",
],
filters: SearchFilter,
query_string: str = "",
) -> Sequence[IDResponse]:
"""Query the Deezer Search API for the specified ``query_string``, applying
the provided ``filters``.
:param filters: Field filters to apply.
:param query_string: Additional query to include in the search.
:return: JSON data for the class:`Response <Response>` object or None
if no search results are returned.
"""
query = self._construct_search_query(
query_string=query_string, filters=filters
)
self._log.debug("Searching {.data_source} for '{}'", self, query)
try:
response = requests.get(
f"{self.search_url}{query_type}",
params={
"q": query,
"limit": self.config["search_limit"].get(),
},
timeout=10,
)
response.raise_for_status()
except requests.exceptions.RequestException as e:
self._log.error(
"Error fetching data from {.data_source} API\n Error: {}",
self,
e,
)
return ()
response_data: Sequence[IDResponse] = response.json().get("data", [])
self._log.debug(
"Found {} result(s) from {.data_source} for '{}'",
len(response_data),
self,
query,
)
return response_data
def get_search_response(self, params: SearchParams) -> list[IDResponse]:
return requests.get(
f"{self.search_url}{params.query_type}",
params={**params.filters, "q": params.query},
timeout=10,
).json()["data"]
def deezerupdate(self, items: Sequence[Item], write: bool):
"""Obtain rank information from Deezer."""

View file

@ -40,12 +40,13 @@ import beets.ui
from beets import config
from beets.autotag.distance import string_dist
from beets.autotag.hooks import AlbumInfo, TrackInfo
from beets.metadata_plugins import MetadataSourcePlugin
from beets.metadata_plugins import IDResponse, SearchApiMetadataSourcePlugin
if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Iterable, Iterator, Sequence
from beets.library import Item
from beets.metadata_plugins import QueryType, SearchParams
USER_AGENT = f"beets/{beets.__version__} +https://beets.io/"
API_KEY = "rAzVUQYRaoFjeBjyWuWZ"
@ -121,7 +122,7 @@ class IntermediateTrackInfo(TrackInfo):
super().__init__(**kwargs)
class DiscogsPlugin(MetadataSourcePlugin):
class DiscogsPlugin(SearchApiMetadataSourcePlugin[IDResponse]):
def __init__(self):
super().__init__()
self.config.add(
@ -211,11 +212,6 @@ class DiscogsPlugin(MetadataSourcePlugin):
return token, secret
def candidates(
self, items: Sequence[Item], artist: str, album: str, va_likely: bool
) -> Iterable[AlbumInfo]:
return self.get_albums(f"{artist} {album}" if va_likely else album)
def get_track_from_album(
self, album_info: AlbumInfo, compare: Callable[[TrackInfo], float]
) -> TrackInfo | None:
@ -232,21 +228,19 @@ class DiscogsPlugin(MetadataSourcePlugin):
def item_candidates(
self, item: Item, artist: str, title: str
) -> Iterable[TrackInfo]:
) -> Iterator[TrackInfo]:
albums = self.candidates([item], artist, title, False)
def compare_func(track_info: TrackInfo) -> float:
return string_dist(track_info.title, title)
tracks = (self.get_track_from_album(a, compare_func) for a in albums)
return list(filter(None, tracks))
return filter(None, tracks)
def album_for_id(self, album_id: str) -> AlbumInfo | None:
"""Fetches an album by its Discogs ID and returns an AlbumInfo object
or None if the album is not found.
"""
self._log.debug("Searching for release {}", album_id)
discogs_id = self._extract_id(album_id)
if not discogs_id:
@ -280,29 +274,25 @@ class DiscogsPlugin(MetadataSourcePlugin):
return None
def get_albums(self, query: str) -> Iterable[AlbumInfo]:
"""Returns a list of AlbumInfo objects for a discogs search query."""
# Strip non-word characters from query. Things like "!" and "-" can
# cause a query to return no results, even if they match the artist or
# album title. Use `re.UNICODE` flag to avoid stripping non-english
# word characters.
query = re.sub(r"(?u)\W+", " ", query)
# Strip medium information from query, Things like "CD1" and "disk 1"
# can also negate an otherwise positive result.
query = re.sub(r"(?i)\b(CD|disc|vinyl)\s*\d+", "", query)
def get_search_filters(
self,
query_type: QueryType,
items: Sequence[Item],
artist: str,
name: str,
va_likely: bool,
) -> tuple[str, dict[str, str]]:
if va_likely:
artist = items[0].artist
try:
results = self.discogs_client.search(query, type="release")
results.per_page = self.config["search_limit"].get()
releases = results.page(1)
except CONNECTION_ERRORS:
self._log.debug(
"Communication error while searching for {0!r}",
query,
exc_info=True,
)
return []
return filter(None, map(self.get_album_info, releases))
return f"{artist} - {name}", {"type": "release"}
def get_search_response(self, params: SearchParams) -> Sequence[IDResponse]:
"""Returns a list of AlbumInfo objects for a discogs search query."""
limit = params.filters.pop("limit")
results = self.discogs_client.search(params.query, **params.filters)
results.per_page = limit
return [r.data for r in results.page(1)]
@cache
def get_master_year(self, master_id: str) -> int | None:

View file

@ -30,14 +30,14 @@ from confuse.exceptions import NotFoundError
import beets
import beets.autotag.hooks
from beets import config, plugins, util
from beets.metadata_plugins import MetadataSourcePlugin
from beets.metadata_plugins import IDResponse, SearchApiMetadataSourcePlugin
from beets.util.id_extractors import extract_release_id
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from typing import Literal
from collections.abc import Sequence
from beets.library import Item
from beets.metadata_plugins import QueryType, SearchParams
from ._typing import JSONDict
@ -369,7 +369,7 @@ def _merge_pseudo_and_actual_album(
return merged
class MusicBrainzPlugin(MetadataSourcePlugin):
class MusicBrainzPlugin(SearchApiMetadataSourcePlugin[IDResponse]):
def __init__(self):
"""Set up the python-musicbrainz-ngs module according to settings
from the beets configuration. This should be called at startup.
@ -798,52 +798,27 @@ class MusicBrainzPlugin(MetadataSourcePlugin):
return criteria
def _search_api(
self,
query_type: Literal["recording", "release"],
filters: dict[str, str],
) -> list[JSONDict]:
"""Perform MusicBrainz API search and return results.
Execute a search against the MusicBrainz API for recordings or releases
using the provided criteria. Handles API errors by converting them into
MusicBrainzAPIError exceptions with contextual information.
"""
filters = {
k: _v for k, v in filters.items() if (_v := v.lower().strip())
}
self._log.debug(
"Searching for MusicBrainz {}s with: {!r}", query_type, filters
)
try:
method = getattr(musicbrainzngs, f"search_{query_type}s")
res = method(limit=self.config["search_limit"].get(), **filters)
except musicbrainzngs.MusicBrainzError as exc:
raise MusicBrainzAPIError(
exc, f"{query_type} search", filters, traceback.format_exc()
)
return res[f"{query_type}-list"]
def candidates(
def get_search_filters(
self,
query_type: QueryType,
items: Sequence[Item],
artist: str,
album: str,
name: str,
va_likely: bool,
) -> Iterable[beets.autotag.hooks.AlbumInfo]:
criteria = self.get_album_criteria(items, artist, album, va_likely)
release_ids = (r["id"] for r in self._search_api("release", criteria))
) -> tuple[str, dict[str, str]]:
if query_type == "album":
criteria = self.get_album_criteria(items, artist, name, va_likely)
else:
criteria = {"artist": artist, "recording": name, "alias": name}
yield from filter(None, map(self.album_for_id, release_ids))
return "", {
k: _v for k, v in criteria.items() if (_v := v.lower().strip())
}
def item_candidates(
self, item: Item, artist: str, title: str
) -> Iterable[beets.autotag.hooks.TrackInfo]:
criteria = {"artist": artist, "recording": title, "alias": title}
yield from filter(
None, map(self.track_info, self._search_api("recording", criteria))
)
def get_search_response(self, params: SearchParams) -> Sequence[IDResponse]:
mb_entity = "release" if params.query_type == "album" else "recording"
method = getattr(musicbrainzngs, f"search_{mb_entity}s")
return method(**params.filters)[f"{mb_entity}-list"]
def album_for_id(
self, album_id: str

View file

@ -39,7 +39,7 @@ from beets.library import Library
from beets.metadata_plugins import (
IDResponse,
SearchApiMetadataSourcePlugin,
SearchFilter,
SearchParams,
)
if TYPE_CHECKING:
@ -447,11 +447,8 @@ class SpotifyPlugin(
track.medium_total = medium_total
return track
def _search_api(
self,
query_type: Literal["album", "track"],
filters: SearchFilter,
query_string: str = "",
def get_search_response(
self, params: SearchParams
) -> Sequence[SearchResponseAlbums | SearchResponseTracks]:
"""Query the Spotify Search API for the specified ``query_string``,
applying the provided ``filters``.
@ -460,34 +457,27 @@ class SpotifyPlugin(
'artist', 'playlist', and 'track'.
:param filters: Field filters to apply.
:param query_string: Additional query to include in the search.
"""
query = self._construct_search_query(
filters=filters, query_string=query_string
response = requests.get(
self.search_url,
headers={"Authorization": f"Bearer {self.access_token}"},
params={
**params.filters,
"q": params.query,
"type": params.query_type,
},
timeout=10,
)
self._log.debug("Searching {.data_source} for '{}'", self, query)
try:
response = self._handle_response(
"get",
self.search_url,
params={
"q": query,
"type": query_type,
"limit": self.config["search_limit"].get(),
},
)
except APIError as e:
self._log.debug("Spotify API error: {}", e)
return ()
response_data = response.get(f"{query_type}s", {}).get("items", [])
self._log.debug(
"Found {} result(s) from {.data_source} for '{}'",
len(response_data),
self,
query,
)
return response_data
response.raise_for_status()
except requests.exceptions.HTTPError:
if response.status_code == 401:
self._authenticate()
return self.get_search_response(params)
raise
return response.json().get(f"{params.query_type}s", {}).get("items", [])
def commands(self) -> list[ui.Subcommand]:
# autotagger import command
@ -600,22 +590,14 @@ class SpotifyPlugin(
query_string = item["title"]
# Query the Web API for each track, look for the items' JSON data
query_filters: SearchFilter = {}
query = query_string
if artist:
query_filters["artist"] = artist
query += f" artist:'{artist}'"
if album:
query_filters["album"] = album
query += f" album:'{album}'"
response_data_tracks = self._search_api(
query_type="track",
query_string=query_string,
filters=query_filters,
)
response_data_tracks = self._search_api("track", query, {})
if not response_data_tracks:
query = self._construct_search_query(
query_string=query_string, filters=query_filters
)
failures.append(query)
continue

View file

@ -990,7 +990,7 @@ class TestMusicBrainzPlugin(PluginMixin):
plugin = "musicbrainz"
mbid = "d2a6f856-b553-40a0-ac54-a321e8e2da99"
RECORDING = {"title": "foo", "id": "bar", "length": 42}
RECORDING = {"title": "foo", "id": mbid, "length": 42}
@pytest.fixture
def plugin_config(self):
@ -1035,6 +1035,10 @@ class TestMusicBrainzPlugin(PluginMixin):
"musicbrainzngs.search_recordings",
lambda *_, **__: {"recording-list": [self.RECORDING]},
)
monkeypatch.setattr(
"musicbrainzngs.get_recording_by_id",
lambda *_, **__: {"recording": self.RECORDING},
)
candidates = list(mb.item_candidates(Item(), "hello", "there"))

View file

@ -81,6 +81,7 @@ class SpotifyPluginTest(PluginTestCase):
params = _params(responses.calls[0].request.url)
query = params["q"][0]
print(query)
assert "duifhjslkef" in query
assert "artist:'ujydfsuihse'" in query
assert "album:'lkajsdflakjsd'" in query