Upgrade musicbrainz and discogs to SearchApiMetadataSourcePlugin

And centralise common search functionality inside the parent class
This commit is contained in:
Šarūnas Nejus 2025-11-11 05:20:20 +00:00
parent 3a72d85c5e
commit 36a6ee8efd
No known key found for this signature in database
7 changed files with 146 additions and 242 deletions

View file

@ -10,11 +10,17 @@ from __future__ import annotations
import abc import abc
import re import re
from functools import cache, cached_property from functools import cache, cached_property
from typing import TYPE_CHECKING, Generic, Literal, TypedDict, TypeVar from typing import (
TYPE_CHECKING,
Generic,
Literal,
NamedTuple,
TypedDict,
TypeVar,
)
import unidecode import unidecode
from confuse import NotFoundError from confuse import NotFoundError
from typing_extensions import NotRequired
from beets.util import cached_classproperty from beets.util import cached_classproperty
from beets.util.id_extractors import extract_release_id from beets.util.id_extractors import extract_release_id
@ -26,6 +32,8 @@ if TYPE_CHECKING:
from .autotag.hooks import AlbumInfo, Item, TrackInfo from .autotag.hooks import AlbumInfo, Item, TrackInfo
QueryType = Literal["album", "track"]
@cache @cache
def find_metadata_source_plugins() -> list[MetadataSourcePlugin]: def find_metadata_source_plugins() -> list[MetadataSourcePlugin]:
@ -169,7 +177,7 @@ class MetadataSourcePlugin(BeetsPlugin, metaclass=abc.ABCMeta):
""" """
raise NotImplementedError raise NotImplementedError
def albums_for_ids(self, ids: Sequence[str]) -> Iterable[AlbumInfo | None]: def albums_for_ids(self, ids: Iterable[str]) -> Iterable[AlbumInfo | None]:
"""Batch lookup of album metadata for a list of album IDs. """Batch lookup of album metadata for a list of album IDs.
Given a list of album identifiers, yields corresponding AlbumInfo objects. Given a list of album identifiers, yields corresponding AlbumInfo objects.
@ -180,7 +188,7 @@ class MetadataSourcePlugin(BeetsPlugin, metaclass=abc.ABCMeta):
return (self.album_for_id(id) for id in ids) return (self.album_for_id(id) for id in ids)
def tracks_for_ids(self, ids: Sequence[str]) -> Iterable[TrackInfo | None]: def tracks_for_ids(self, ids: Iterable[str]) -> Iterable[TrackInfo | None]:
"""Batch lookup of track metadata for a list of track IDs. """Batch lookup of track metadata for a list of track IDs.
Given a list of track identifiers, yields corresponding TrackInfo objects. Given a list of track identifiers, yields corresponding TrackInfo objects.
@ -254,14 +262,15 @@ class IDResponse(TypedDict):
id: str id: str
class SearchFilter(TypedDict):
artist: NotRequired[str]
album: NotRequired[str]
R = TypeVar("R", bound=IDResponse) R = TypeVar("R", bound=IDResponse)
class SearchParams(NamedTuple):
query_type: QueryType
query: str
filters: dict[str, str]
class SearchApiMetadataSourcePlugin( class SearchApiMetadataSourcePlugin(
Generic[R], MetadataSourcePlugin, metaclass=abc.ABCMeta Generic[R], MetadataSourcePlugin, metaclass=abc.ABCMeta
): ):
@ -282,12 +291,26 @@ class SearchApiMetadataSourcePlugin(
} }
) )
@abc.abstractmethod def get_search_filters(
def _search_api(
self, self,
query_type: Literal["album", "track"], query_type: QueryType,
filters: SearchFilter, items: Sequence[Item],
query_string: str = "", artist: str,
name: str,
va_likely: bool,
) -> tuple[str, dict[str, str]]:
query = f'album:"{name}"' if query_type == "album" else name
if query_type == "track" or not va_likely:
query += f' artist:"{artist}"'
return query, {}
@abc.abstractmethod
def get_search_response(self, params: SearchParams) -> Sequence[R]:
raise NotImplementedError
def _search_api(
self, query_type: QueryType, query: str, filters: dict[str, str]
) -> Sequence[R]: ) -> Sequence[R]:
"""Perform a search on the API. """Perform a search on the API.
@ -297,7 +320,28 @@ class SearchApiMetadataSourcePlugin(
Should return a list of identifiers for the requested type (album or track). Should return a list of identifiers for the requested type (album or track).
""" """
raise NotImplementedError if self.config["search_query_ascii"].get():
query = unidecode.unidecode(query)
filters["limit"] = str(self.config["search_limit"].get())
params = SearchParams(query_type, query, filters)
self._log.debug("Searching for '{}' with {}", query, filters)
try:
response_data = self.get_search_response(params)
except Exception:
self._log.error("Error fetching data", exc_info=True)
return ()
self._log.debug("Found {} result(s)", len(response_data))
return response_data
def _get_candidates(
self, query_type: QueryType, *args, **kwargs
) -> Sequence[R]:
return self._search_api(
query_type, *self.get_search_filters(query_type, *args, **kwargs)
)
def candidates( def candidates(
self, self,
@ -306,54 +350,11 @@ class SearchApiMetadataSourcePlugin(
album: str, album: str,
va_likely: bool, va_likely: bool,
) -> Iterable[AlbumInfo]: ) -> Iterable[AlbumInfo]:
query_filters: SearchFilter = {} results = self._get_candidates("album", items, artist, album, va_likely)
if album: return filter(None, self.albums_for_ids(r["id"] for r in results))
query_filters["album"] = album
if not va_likely:
query_filters["artist"] = artist
results = self._search_api("album", query_filters)
if not results:
return []
return filter(
None, self.albums_for_ids([result["id"] for result in results])
)
def item_candidates( def item_candidates(
self, item: Item, artist: str, title: str self, item: Item, artist: str, title: str
) -> Iterable[TrackInfo]: ) -> Iterable[TrackInfo]:
results = self._search_api( results = self._get_candidates("track", [item], artist, title, False)
"track", {"artist": artist}, query_string=title return filter(None, self.tracks_for_ids(r["id"] for r in results))
)
if not results:
return []
return filter(
None,
self.tracks_for_ids([result["id"] for result in results if result]),
)
def _construct_search_query(
self, filters: SearchFilter, query_string: str
) -> str:
"""Construct a query string with the specified filters and keywords to
be provided to the spotify (or similar) search API.
The returned format was initially designed for spotify's search API but
we found is also useful with other APIs that support similar query structures.
see `spotify <https://developer.spotify.com/documentation/web-api/reference/search>`_
and `deezer <https://developers.deezer.com/api/search>`_.
:param filters: Field filters to apply.
:param query_string: Query keywords to use.
:return: Query string to be provided to the search API.
"""
components = [query_string, *(f"{k}:'{v}'" for k, v in filters.items())]
query = " ".join(filter(None, components))
if self.config["search_query_ascii"].get():
query = unidecode.unidecode(query)
return query

View file

@ -18,23 +18,20 @@ from __future__ import annotations
import collections import collections
import time import time
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING
import requests import requests
from beets import ui from beets import ui
from beets.autotag import AlbumInfo, TrackInfo from beets.autotag import AlbumInfo, TrackInfo
from beets.dbcore import types from beets.dbcore import types
from beets.metadata_plugins import ( from beets.metadata_plugins import IDResponse, SearchApiMetadataSourcePlugin
IDResponse,
SearchApiMetadataSourcePlugin,
SearchFilter,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence from collections.abc import Sequence
from beets.library import Item, Library from beets.library import Item, Library
from beets.metadata_plugins import SearchParams
from ._typing import JSONDict from ._typing import JSONDict
@ -220,58 +217,12 @@ class DeezerPlugin(SearchApiMetadataSourcePlugin[IDResponse]):
deezer_updated=time.time(), deezer_updated=time.time(),
) )
def _search_api( def get_search_response(self, params: SearchParams) -> list[IDResponse]:
self, return requests.get(
query_type: Literal[ f"{self.search_url}{params.query_type}",
"album", params={**params.filters, "q": params.query},
"track", timeout=10,
"artist", ).json()["data"]
"history",
"playlist",
"podcast",
"radio",
"user",
],
filters: SearchFilter,
query_string: str = "",
) -> Sequence[IDResponse]:
"""Query the Deezer Search API for the specified ``query_string``, applying
the provided ``filters``.
:param filters: Field filters to apply.
:param query_string: Additional query to include in the search.
:return: JSON data for the class:`Response <Response>` object or None
if no search results are returned.
"""
query = self._construct_search_query(
query_string=query_string, filters=filters
)
self._log.debug("Searching {.data_source} for '{}'", self, query)
try:
response = requests.get(
f"{self.search_url}{query_type}",
params={
"q": query,
"limit": self.config["search_limit"].get(),
},
timeout=10,
)
response.raise_for_status()
except requests.exceptions.RequestException as e:
self._log.error(
"Error fetching data from {.data_source} API\n Error: {}",
self,
e,
)
return ()
response_data: Sequence[IDResponse] = response.json().get("data", [])
self._log.debug(
"Found {} result(s) from {.data_source} for '{}'",
len(response_data),
self,
query,
)
return response_data
def deezerupdate(self, items: Sequence[Item], write: bool): def deezerupdate(self, items: Sequence[Item], write: bool):
"""Obtain rank information from Deezer.""" """Obtain rank information from Deezer."""

View file

@ -40,12 +40,13 @@ import beets.ui
from beets import config from beets import config
from beets.autotag.distance import string_dist from beets.autotag.distance import string_dist
from beets.autotag.hooks import AlbumInfo, TrackInfo from beets.autotag.hooks import AlbumInfo, TrackInfo
from beets.metadata_plugins import MetadataSourcePlugin from beets.metadata_plugins import IDResponse, SearchApiMetadataSourcePlugin
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Iterator, Sequence
from beets.library import Item from beets.library import Item
from beets.metadata_plugins import QueryType, SearchParams
USER_AGENT = f"beets/{beets.__version__} +https://beets.io/" USER_AGENT = f"beets/{beets.__version__} +https://beets.io/"
API_KEY = "rAzVUQYRaoFjeBjyWuWZ" API_KEY = "rAzVUQYRaoFjeBjyWuWZ"
@ -121,7 +122,7 @@ class IntermediateTrackInfo(TrackInfo):
super().__init__(**kwargs) super().__init__(**kwargs)
class DiscogsPlugin(MetadataSourcePlugin): class DiscogsPlugin(SearchApiMetadataSourcePlugin[IDResponse]):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.config.add( self.config.add(
@ -211,11 +212,6 @@ class DiscogsPlugin(MetadataSourcePlugin):
return token, secret return token, secret
def candidates(
self, items: Sequence[Item], artist: str, album: str, va_likely: bool
) -> Iterable[AlbumInfo]:
return self.get_albums(f"{artist} {album}" if va_likely else album)
def get_track_from_album( def get_track_from_album(
self, album_info: AlbumInfo, compare: Callable[[TrackInfo], float] self, album_info: AlbumInfo, compare: Callable[[TrackInfo], float]
) -> TrackInfo | None: ) -> TrackInfo | None:
@ -232,21 +228,19 @@ class DiscogsPlugin(MetadataSourcePlugin):
def item_candidates( def item_candidates(
self, item: Item, artist: str, title: str self, item: Item, artist: str, title: str
) -> Iterable[TrackInfo]: ) -> Iterator[TrackInfo]:
albums = self.candidates([item], artist, title, False) albums = self.candidates([item], artist, title, False)
def compare_func(track_info: TrackInfo) -> float: def compare_func(track_info: TrackInfo) -> float:
return string_dist(track_info.title, title) return string_dist(track_info.title, title)
tracks = (self.get_track_from_album(a, compare_func) for a in albums) tracks = (self.get_track_from_album(a, compare_func) for a in albums)
return list(filter(None, tracks)) return filter(None, tracks)
def album_for_id(self, album_id: str) -> AlbumInfo | None: def album_for_id(self, album_id: str) -> AlbumInfo | None:
"""Fetches an album by its Discogs ID and returns an AlbumInfo object """Fetches an album by its Discogs ID and returns an AlbumInfo object
or None if the album is not found. or None if the album is not found.
""" """
self._log.debug("Searching for release {}", album_id)
discogs_id = self._extract_id(album_id) discogs_id = self._extract_id(album_id)
if not discogs_id: if not discogs_id:
@ -280,29 +274,25 @@ class DiscogsPlugin(MetadataSourcePlugin):
return None return None
def get_albums(self, query: str) -> Iterable[AlbumInfo]: def get_search_filters(
"""Returns a list of AlbumInfo objects for a discogs search query.""" self,
# Strip non-word characters from query. Things like "!" and "-" can query_type: QueryType,
# cause a query to return no results, even if they match the artist or items: Sequence[Item],
# album title. Use `re.UNICODE` flag to avoid stripping non-english artist: str,
# word characters. name: str,
query = re.sub(r"(?u)\W+", " ", query) va_likely: bool,
# Strip medium information from query, Things like "CD1" and "disk 1" ) -> tuple[str, dict[str, str]]:
# can also negate an otherwise positive result. if va_likely:
query = re.sub(r"(?i)\b(CD|disc|vinyl)\s*\d+", "", query) artist = items[0].artist
try: return f"{artist} - {name}", {"type": "release"}
results = self.discogs_client.search(query, type="release")
results.per_page = self.config["search_limit"].get() def get_search_response(self, params: SearchParams) -> Sequence[IDResponse]:
releases = results.page(1) """Returns a list of AlbumInfo objects for a discogs search query."""
except CONNECTION_ERRORS: limit = params.filters.pop("limit")
self._log.debug( results = self.discogs_client.search(params.query, **params.filters)
"Communication error while searching for {0!r}", results.per_page = limit
query, return [r.data for r in results.page(1)]
exc_info=True,
)
return []
return filter(None, map(self.get_album_info, releases))
@cache @cache
def get_master_year(self, master_id: str) -> int | None: def get_master_year(self, master_id: str) -> int | None:

View file

@ -30,14 +30,14 @@ from confuse.exceptions import NotFoundError
import beets import beets
import beets.autotag.hooks import beets.autotag.hooks
from beets import config, plugins, util from beets import config, plugins, util
from beets.metadata_plugins import MetadataSourcePlugin from beets.metadata_plugins import IDResponse, SearchApiMetadataSourcePlugin
from beets.util.id_extractors import extract_release_id from beets.util.id_extractors import extract_release_id
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Iterable, Sequence from collections.abc import Sequence
from typing import Literal
from beets.library import Item from beets.library import Item
from beets.metadata_plugins import QueryType, SearchParams
from ._typing import JSONDict from ._typing import JSONDict
@ -369,7 +369,7 @@ def _merge_pseudo_and_actual_album(
return merged return merged
class MusicBrainzPlugin(MetadataSourcePlugin): class MusicBrainzPlugin(SearchApiMetadataSourcePlugin[IDResponse]):
def __init__(self): def __init__(self):
"""Set up the python-musicbrainz-ngs module according to settings """Set up the python-musicbrainz-ngs module according to settings
from the beets configuration. This should be called at startup. from the beets configuration. This should be called at startup.
@ -798,52 +798,27 @@ class MusicBrainzPlugin(MetadataSourcePlugin):
return criteria return criteria
def _search_api( def get_search_filters(
self,
query_type: Literal["recording", "release"],
filters: dict[str, str],
) -> list[JSONDict]:
"""Perform MusicBrainz API search and return results.
Execute a search against the MusicBrainz API for recordings or releases
using the provided criteria. Handles API errors by converting them into
MusicBrainzAPIError exceptions with contextual information.
"""
filters = {
k: _v for k, v in filters.items() if (_v := v.lower().strip())
}
self._log.debug(
"Searching for MusicBrainz {}s with: {!r}", query_type, filters
)
try:
method = getattr(musicbrainzngs, f"search_{query_type}s")
res = method(limit=self.config["search_limit"].get(), **filters)
except musicbrainzngs.MusicBrainzError as exc:
raise MusicBrainzAPIError(
exc, f"{query_type} search", filters, traceback.format_exc()
)
return res[f"{query_type}-list"]
def candidates(
self, self,
query_type: QueryType,
items: Sequence[Item], items: Sequence[Item],
artist: str, artist: str,
album: str, name: str,
va_likely: bool, va_likely: bool,
) -> Iterable[beets.autotag.hooks.AlbumInfo]: ) -> tuple[str, dict[str, str]]:
criteria = self.get_album_criteria(items, artist, album, va_likely) if query_type == "album":
release_ids = (r["id"] for r in self._search_api("release", criteria)) criteria = self.get_album_criteria(items, artist, name, va_likely)
else:
criteria = {"artist": artist, "recording": name, "alias": name}
yield from filter(None, map(self.album_for_id, release_ids)) return "", {
k: _v for k, v in criteria.items() if (_v := v.lower().strip())
}
def item_candidates( def get_search_response(self, params: SearchParams) -> Sequence[IDResponse]:
self, item: Item, artist: str, title: str mb_entity = "release" if params.query_type == "album" else "recording"
) -> Iterable[beets.autotag.hooks.TrackInfo]: method = getattr(musicbrainzngs, f"search_{mb_entity}s")
criteria = {"artist": artist, "recording": title, "alias": title} return method(**params.filters)[f"{mb_entity}-list"]
yield from filter(
None, map(self.track_info, self._search_api("recording", criteria))
)
def album_for_id( def album_for_id(
self, album_id: str self, album_id: str

View file

@ -39,7 +39,7 @@ from beets.library import Library
from beets.metadata_plugins import ( from beets.metadata_plugins import (
IDResponse, IDResponse,
SearchApiMetadataSourcePlugin, SearchApiMetadataSourcePlugin,
SearchFilter, SearchParams,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -447,11 +447,8 @@ class SpotifyPlugin(
track.medium_total = medium_total track.medium_total = medium_total
return track return track
def _search_api( def get_search_response(
self, self, params: SearchParams
query_type: Literal["album", "track"],
filters: SearchFilter,
query_string: str = "",
) -> Sequence[SearchResponseAlbums | SearchResponseTracks]: ) -> Sequence[SearchResponseAlbums | SearchResponseTracks]:
"""Query the Spotify Search API for the specified ``query_string``, """Query the Spotify Search API for the specified ``query_string``,
applying the provided ``filters``. applying the provided ``filters``.
@ -460,34 +457,27 @@ class SpotifyPlugin(
'artist', 'playlist', and 'track'. 'artist', 'playlist', and 'track'.
:param filters: Field filters to apply. :param filters: Field filters to apply.
:param query_string: Additional query to include in the search. :param query_string: Additional query to include in the search.
""" """
query = self._construct_search_query( response = requests.get(
filters=filters, query_string=query_string self.search_url,
headers={"Authorization": f"Bearer {self.access_token}"},
params={
**params.filters,
"q": params.query,
"type": params.query_type,
},
timeout=10,
) )
self._log.debug("Searching {.data_source} for '{}'", self, query)
try: try:
response = self._handle_response( response.raise_for_status()
"get", except requests.exceptions.HTTPError:
self.search_url, if response.status_code == 401:
params={ self._authenticate()
"q": query, return self.get_search_response(params)
"type": query_type,
"limit": self.config["search_limit"].get(), raise
},
) return response.json().get(f"{params.query_type}s", {}).get("items", [])
except APIError as e:
self._log.debug("Spotify API error: {}", e)
return ()
response_data = response.get(f"{query_type}s", {}).get("items", [])
self._log.debug(
"Found {} result(s) from {.data_source} for '{}'",
len(response_data),
self,
query,
)
return response_data
def commands(self) -> list[ui.Subcommand]: def commands(self) -> list[ui.Subcommand]:
# autotagger import command # autotagger import command
@ -600,22 +590,14 @@ class SpotifyPlugin(
query_string = item["title"] query_string = item["title"]
# Query the Web API for each track, look for the items' JSON data # Query the Web API for each track, look for the items' JSON data
query_filters: SearchFilter = {} query = query_string
if artist: if artist:
query_filters["artist"] = artist query += f" artist:'{artist}'"
if album: if album:
query_filters["album"] = album query += f" album:'{album}'"
response_data_tracks = self._search_api( response_data_tracks = self._search_api("track", query, {})
query_type="track",
query_string=query_string,
filters=query_filters,
)
if not response_data_tracks: if not response_data_tracks:
query = self._construct_search_query(
query_string=query_string, filters=query_filters
)
failures.append(query) failures.append(query)
continue continue

View file

@ -990,7 +990,7 @@ class TestMusicBrainzPlugin(PluginMixin):
plugin = "musicbrainz" plugin = "musicbrainz"
mbid = "d2a6f856-b553-40a0-ac54-a321e8e2da99" mbid = "d2a6f856-b553-40a0-ac54-a321e8e2da99"
RECORDING = {"title": "foo", "id": "bar", "length": 42} RECORDING = {"title": "foo", "id": mbid, "length": 42}
@pytest.fixture @pytest.fixture
def plugin_config(self): def plugin_config(self):
@ -1035,6 +1035,10 @@ class TestMusicBrainzPlugin(PluginMixin):
"musicbrainzngs.search_recordings", "musicbrainzngs.search_recordings",
lambda *_, **__: {"recording-list": [self.RECORDING]}, lambda *_, **__: {"recording-list": [self.RECORDING]},
) )
monkeypatch.setattr(
"musicbrainzngs.get_recording_by_id",
lambda *_, **__: {"recording": self.RECORDING},
)
candidates = list(mb.item_candidates(Item(), "hello", "there")) candidates = list(mb.item_candidates(Item(), "hello", "there"))

View file

@ -81,6 +81,7 @@ class SpotifyPluginTest(PluginTestCase):
params = _params(responses.calls[0].request.url) params = _params(responses.calls[0].request.url)
query = params["q"][0] query = params["q"][0]
print(query)
assert "duifhjslkef" in query assert "duifhjslkef" in query
assert "artist:'ujydfsuihse'" in query assert "artist:'ujydfsuihse'" in query
assert "album:'lkajsdflakjsd'" in query assert "album:'lkajsdflakjsd'" in query