From 59b02bc49b60ae41a040b51fc4bf783804f876b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Thu, 25 Dec 2025 22:20:44 +0000 Subject: [PATCH] Type MusicBrainzAPI properly --- beetsplug/_utils/musicbrainz.py | 140 +++++++++++++++++++++++++++----- beetsplug/listenbrainz.py | 2 +- beetsplug/mbcollection.py | 21 ++--- beetsplug/musicbrainz.py | 2 +- 4 files changed, 129 insertions(+), 36 deletions(-) diff --git a/beetsplug/_utils/musicbrainz.py b/beetsplug/_utils/musicbrainz.py index 47a2550f0..2fc821df9 100644 --- a/beetsplug/_utils/musicbrainz.py +++ b/beetsplug/_utils/musicbrainz.py @@ -12,17 +12,20 @@ from __future__ import annotations import operator from dataclasses import dataclass, field -from functools import cached_property, singledispatchmethod +from functools import cached_property, singledispatchmethod, wraps from itertools import groupby -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypedDict, TypeVar from requests_ratelimiter import LimiterMixin +from typing_extensions import NotRequired, Unpack from beets import config, logging from .requests import RequestHandler, TimeoutAndRetrySession if TYPE_CHECKING: + from collections.abc import Callable + from requests import Response from .._typing import JSONDict @@ -34,11 +37,80 @@ class LimiterTimeoutSession(LimiterMixin, TimeoutAndRetrySession): """HTTP session that enforces rate limits.""" +Entity = Literal[ + "area", + "artist", + "collection", + "event", + "genre", + "instrument", + "label", + "place", + "recording", + "release", + "release-group", + "series", + "work", + "url", +] + + +class LookupKwargs(TypedDict, total=False): + includes: NotRequired[list[str]] + + +class PagingKwargs(TypedDict, total=False): + limit: NotRequired[int] + offset: NotRequired[int] + + +class SearchKwargs(PagingKwargs): + query: NotRequired[str] + + +class BrowseKwargs(LookupKwargs, PagingKwargs, total=False): + pass + + +class BrowseReleaseGroupsKwargs(BrowseKwargs, total=False): + artist: NotRequired[str] + collection: NotRequired[str] + release: NotRequired[str] + + +class BrowseRecordingsKwargs(BrowseReleaseGroupsKwargs, total=False): + work: NotRequired[str] + + +P = ParamSpec("P") +R = TypeVar("R") + + +def require_one_of(*keys: str) -> Callable[[Callable[P, R]], Callable[P, R]]: + required = frozenset(keys) + + def deco(func: Callable[P, R]) -> Callable[P, R]: + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + # kwargs is a real dict at runtime; safe to inspect here + if not required & kwargs.keys(): + required_str = ", ".join(sorted(required)) + raise ValueError( + f"At least one of {required_str} filter is required" + ) + return func(*args, **kwargs) + + return wrapper + + return deco + + @dataclass class MusicBrainzAPI(RequestHandler): """High-level interface to the MusicBrainz WS/2 API. Responsibilities: + - Configure the API host and request rate from application configuration. - Offer helpers to fetch common entity types and to run searches. - Normalize MusicBrainz responses so relation lists are grouped by target @@ -85,10 +157,10 @@ class MusicBrainzAPI(RequestHandler): kwargs["params"]["fmt"] = "json" return super().request(*args, **kwargs) - def get_entity( - self, entity: str, includes: list[str] | None = None, **kwargs + def _get_resource( + self, resource: str, includes: list[str] | None = None, **kwargs ) -> JSONDict: - """Retrieve and normalize data from the API entity endpoint. + """Retrieve and normalize data from the API resource endpoint. If requested, includes are appended to the request. The response is passed through a normalizer that groups relation entries by their @@ -98,11 +170,22 @@ class MusicBrainzAPI(RequestHandler): kwargs["inc"] = "+".join(includes) return self._group_relations( - self.get_json(f"{self.api_root}/{entity}", params=kwargs) + self.get_json(f"{self.api_root}/{resource}", params=kwargs) ) - def search_entity( - self, entity: str, filters: dict[str, str], **kwargs + def _lookup( + self, entity: Entity, id_: str, **kwargs: Unpack[LookupKwargs] + ) -> JSONDict: + return self._get_resource(f"{entity}/{id_}", **kwargs) + + def _browse(self, entity: Entity, **kwargs) -> list[JSONDict]: + return self._get_resource(entity, **kwargs).get(f"{entity}s", []) + + def search( + self, + entity: Entity, + filters: dict[str, str], + **kwargs: Unpack[SearchKwargs], ) -> list[JSONDict]: """Search for MusicBrainz entities matching the given filters. @@ -119,22 +202,41 @@ class MusicBrainzAPI(RequestHandler): ) log.debug("Searching for MusicBrainz {}s with: {!r}", entity, query) kwargs["query"] = query - return self.get_entity(entity, **kwargs)[f"{entity}s"] + return self._get_resource(entity, **kwargs)[f"{entity}s"] - def get_release(self, id_: str, **kwargs) -> JSONDict: - return self.get_entity(f"release/{id_}", **kwargs) + def get_release(self, id_: str, **kwargs: Unpack[LookupKwargs]) -> JSONDict: + """Retrieve a release by its MusicBrainz ID.""" + return self._lookup("release", id_, **kwargs) - def get_recording(self, id_: str, **kwargs) -> JSONDict: - return self.get_entity(f"recording/{id_}", **kwargs) + def get_recording( + self, id_: str, **kwargs: Unpack[LookupKwargs] + ) -> JSONDict: + """Retrieve a recording by its MusicBrainz ID.""" + return self._lookup("recording", id_, **kwargs) - def get_work(self, id_: str, **kwargs) -> JSONDict: - return self.get_entity(f"work/{id_}", **kwargs) + def get_work(self, id_: str, **kwargs: Unpack[LookupKwargs]) -> JSONDict: + """Retrieve a work by its MusicBrainz ID.""" + return self._lookup("work", id_, **kwargs) - def browse_recordings(self, **kwargs) -> list[JSONDict]: - return self.get_entity("recording", **kwargs)["recordings"] + @require_one_of("artist", "collection", "release", "work") + def browse_recordings( + self, **kwargs: Unpack[BrowseRecordingsKwargs] + ) -> list[JSONDict]: + """Browse recordings related to the given entities. - def browse_release_groups(self, **kwargs) -> list[JSONDict]: - return self.get_entity("release-group", **kwargs)["release-groups"] + At least one of artist, collection, release, or work must be provided. + """ + return self._browse("recording", **kwargs) + + @require_one_of("artist", "collection", "release") + def browse_release_groups( + self, **kwargs: Unpack[BrowseReleaseGroupsKwargs] + ) -> list[JSONDict]: + """Browse release groups related to the given entities. + + At least one of artist, collection, or release must be provided. + """ + return self._get_resource("release-group", **kwargs)["release-groups"] @singledispatchmethod @classmethod diff --git a/beetsplug/listenbrainz.py b/beetsplug/listenbrainz.py index d054a00cc..fa73bd6b8 100644 --- a/beetsplug/listenbrainz.py +++ b/beetsplug/listenbrainz.py @@ -132,7 +132,7 @@ class ListenBrainzPlugin(MusicBrainzAPIMixin, BeetsPlugin): def get_mb_recording_id(self, track) -> str | None: """Returns the MusicBrainz recording ID for a track.""" - results = self.mb_api.search_entity( + results = self.mb_api.search( "recording", { "": track["track_metadata"].get("track_name"), diff --git a/beetsplug/mbcollection.py b/beetsplug/mbcollection.py index 25f16228a..f89670dd3 100644 --- a/beetsplug/mbcollection.py +++ b/beetsplug/mbcollection.py @@ -58,15 +58,12 @@ class MusicBrainzUserAPI(MusicBrainzAPI): auth: HTTPDigestAuth = field(init=False) - @cached_property - def user(self) -> str: - return config["musicbrainz"]["user"].as_str() - def __post_init__(self) -> None: super().__post_init__() config["musicbrainz"]["pass"].redact = True self.auth = HTTPDigestAuth( - self.user, config["musicbrainz"]["pass"].as_str() + config["musicbrainz"]["user"].as_str(), + config["musicbrainz"]["pass"].as_str(), ) def request(self, *args, **kwargs) -> Response: @@ -76,15 +73,9 @@ class MusicBrainzUserAPI(MusicBrainzAPI): kwargs["auth"] = self.auth return super().request(*args, **kwargs) - def get_collections(self) -> list[JSONDict]: - """Get all collections for the authenticated user. - - Note that both URL parameters must be included to retrieve private - collections. - """ - return self.get_entity( - "collection", editor=self.user, includes=["user-collections"] - ).get("collections", []) + def browse_collections(self) -> list[JSONDict]: + """Get all collections for the authenticated user.""" + return self._browse("collection") @dataclass @@ -183,7 +174,7 @@ class MusicBrainzCollectionPlugin(BeetsPlugin): @cached_property def collection(self) -> MBCollection: - if not (collections := self.mb_api.get_collections()): + if not (collections := self.mb_api.browse_collections()): raise ui.UserError("no collections exist for user") # Get all release collection IDs, avoiding event collections diff --git a/beetsplug/musicbrainz.py b/beetsplug/musicbrainz.py index 990f21351..3e194c067 100644 --- a/beetsplug/musicbrainz.py +++ b/beetsplug/musicbrainz.py @@ -751,7 +751,7 @@ class MusicBrainzPlugin(MusicBrainzAPIMixin, MetadataSourcePlugin): using the provided criteria. Handles API errors by converting them into MusicBrainzAPIError exceptions with contextual information. """ - return self.mb_api.search_entity( + return self.mb_api.search( query_type, filters, limit=self.config["search_limit"].get() )