Type MusicBrainzAPI properly

This commit is contained in:
Šarūnas Nejus 2025-12-25 22:20:44 +00:00
parent 55b9c1c145
commit 59b02bc49b
No known key found for this signature in database
4 changed files with 129 additions and 36 deletions

View file

@ -12,17 +12,20 @@ from __future__ import annotations
import operator
from dataclasses import dataclass, field
from functools import cached_property, singledispatchmethod
from functools import cached_property, singledispatchmethod, wraps
from itertools import groupby
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypedDict, TypeVar
from requests_ratelimiter import LimiterMixin
from typing_extensions import NotRequired, Unpack
from beets import config, logging
from .requests import RequestHandler, TimeoutAndRetrySession
if TYPE_CHECKING:
from collections.abc import Callable
from requests import Response
from .._typing import JSONDict
@ -34,11 +37,80 @@ class LimiterTimeoutSession(LimiterMixin, TimeoutAndRetrySession):
"""HTTP session that enforces rate limits."""
Entity = Literal[
"area",
"artist",
"collection",
"event",
"genre",
"instrument",
"label",
"place",
"recording",
"release",
"release-group",
"series",
"work",
"url",
]
class LookupKwargs(TypedDict, total=False):
includes: NotRequired[list[str]]
class PagingKwargs(TypedDict, total=False):
limit: NotRequired[int]
offset: NotRequired[int]
class SearchKwargs(PagingKwargs):
query: NotRequired[str]
class BrowseKwargs(LookupKwargs, PagingKwargs, total=False):
pass
class BrowseReleaseGroupsKwargs(BrowseKwargs, total=False):
artist: NotRequired[str]
collection: NotRequired[str]
release: NotRequired[str]
class BrowseRecordingsKwargs(BrowseReleaseGroupsKwargs, total=False):
work: NotRequired[str]
P = ParamSpec("P")
R = TypeVar("R")
def require_one_of(*keys: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
required = frozenset(keys)
def deco(func: Callable[P, R]) -> Callable[P, R]:
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# kwargs is a real dict at runtime; safe to inspect here
if not required & kwargs.keys():
required_str = ", ".join(sorted(required))
raise ValueError(
f"At least one of {required_str} filter is required"
)
return func(*args, **kwargs)
return wrapper
return deco
@dataclass
class MusicBrainzAPI(RequestHandler):
"""High-level interface to the MusicBrainz WS/2 API.
Responsibilities:
- Configure the API host and request rate from application configuration.
- Offer helpers to fetch common entity types and to run searches.
- Normalize MusicBrainz responses so relation lists are grouped by target
@ -85,10 +157,10 @@ class MusicBrainzAPI(RequestHandler):
kwargs["params"]["fmt"] = "json"
return super().request(*args, **kwargs)
def get_entity(
self, entity: str, includes: list[str] | None = None, **kwargs
def _get_resource(
self, resource: str, includes: list[str] | None = None, **kwargs
) -> JSONDict:
"""Retrieve and normalize data from the API entity endpoint.
"""Retrieve and normalize data from the API resource endpoint.
If requested, includes are appended to the request. The response is
passed through a normalizer that groups relation entries by their
@ -98,11 +170,22 @@ class MusicBrainzAPI(RequestHandler):
kwargs["inc"] = "+".join(includes)
return self._group_relations(
self.get_json(f"{self.api_root}/{entity}", params=kwargs)
self.get_json(f"{self.api_root}/{resource}", params=kwargs)
)
def search_entity(
self, entity: str, filters: dict[str, str], **kwargs
def _lookup(
self, entity: Entity, id_: str, **kwargs: Unpack[LookupKwargs]
) -> JSONDict:
return self._get_resource(f"{entity}/{id_}", **kwargs)
def _browse(self, entity: Entity, **kwargs) -> list[JSONDict]:
return self._get_resource(entity, **kwargs).get(f"{entity}s", [])
def search(
self,
entity: Entity,
filters: dict[str, str],
**kwargs: Unpack[SearchKwargs],
) -> list[JSONDict]:
"""Search for MusicBrainz entities matching the given filters.
@ -119,22 +202,41 @@ class MusicBrainzAPI(RequestHandler):
)
log.debug("Searching for MusicBrainz {}s with: {!r}", entity, query)
kwargs["query"] = query
return self.get_entity(entity, **kwargs)[f"{entity}s"]
return self._get_resource(entity, **kwargs)[f"{entity}s"]
def get_release(self, id_: str, **kwargs) -> JSONDict:
return self.get_entity(f"release/{id_}", **kwargs)
def get_release(self, id_: str, **kwargs: Unpack[LookupKwargs]) -> JSONDict:
"""Retrieve a release by its MusicBrainz ID."""
return self._lookup("release", id_, **kwargs)
def get_recording(self, id_: str, **kwargs) -> JSONDict:
return self.get_entity(f"recording/{id_}", **kwargs)
def get_recording(
self, id_: str, **kwargs: Unpack[LookupKwargs]
) -> JSONDict:
"""Retrieve a recording by its MusicBrainz ID."""
return self._lookup("recording", id_, **kwargs)
def get_work(self, id_: str, **kwargs) -> JSONDict:
return self.get_entity(f"work/{id_}", **kwargs)
def get_work(self, id_: str, **kwargs: Unpack[LookupKwargs]) -> JSONDict:
"""Retrieve a work by its MusicBrainz ID."""
return self._lookup("work", id_, **kwargs)
def browse_recordings(self, **kwargs) -> list[JSONDict]:
return self.get_entity("recording", **kwargs)["recordings"]
@require_one_of("artist", "collection", "release", "work")
def browse_recordings(
self, **kwargs: Unpack[BrowseRecordingsKwargs]
) -> list[JSONDict]:
"""Browse recordings related to the given entities.
def browse_release_groups(self, **kwargs) -> list[JSONDict]:
return self.get_entity("release-group", **kwargs)["release-groups"]
At least one of artist, collection, release, or work must be provided.
"""
return self._browse("recording", **kwargs)
@require_one_of("artist", "collection", "release")
def browse_release_groups(
self, **kwargs: Unpack[BrowseReleaseGroupsKwargs]
) -> list[JSONDict]:
"""Browse release groups related to the given entities.
At least one of artist, collection, or release must be provided.
"""
return self._get_resource("release-group", **kwargs)["release-groups"]
@singledispatchmethod
@classmethod

View file

@ -132,7 +132,7 @@ class ListenBrainzPlugin(MusicBrainzAPIMixin, BeetsPlugin):
def get_mb_recording_id(self, track) -> str | None:
"""Returns the MusicBrainz recording ID for a track."""
results = self.mb_api.search_entity(
results = self.mb_api.search(
"recording",
{
"": track["track_metadata"].get("track_name"),

View file

@ -58,15 +58,12 @@ class MusicBrainzUserAPI(MusicBrainzAPI):
auth: HTTPDigestAuth = field(init=False)
@cached_property
def user(self) -> str:
return config["musicbrainz"]["user"].as_str()
def __post_init__(self) -> None:
super().__post_init__()
config["musicbrainz"]["pass"].redact = True
self.auth = HTTPDigestAuth(
self.user, config["musicbrainz"]["pass"].as_str()
config["musicbrainz"]["user"].as_str(),
config["musicbrainz"]["pass"].as_str(),
)
def request(self, *args, **kwargs) -> Response:
@ -76,15 +73,9 @@ class MusicBrainzUserAPI(MusicBrainzAPI):
kwargs["auth"] = self.auth
return super().request(*args, **kwargs)
def get_collections(self) -> list[JSONDict]:
"""Get all collections for the authenticated user.
Note that both URL parameters must be included to retrieve private
collections.
"""
return self.get_entity(
"collection", editor=self.user, includes=["user-collections"]
).get("collections", [])
def browse_collections(self) -> list[JSONDict]:
"""Get all collections for the authenticated user."""
return self._browse("collection")
@dataclass
@ -183,7 +174,7 @@ class MusicBrainzCollectionPlugin(BeetsPlugin):
@cached_property
def collection(self) -> MBCollection:
if not (collections := self.mb_api.get_collections()):
if not (collections := self.mb_api.browse_collections()):
raise ui.UserError("no collections exist for user")
# Get all release collection IDs, avoiding event collections

View file

@ -751,7 +751,7 @@ class MusicBrainzPlugin(MusicBrainzAPIMixin, MetadataSourcePlugin):
using the provided criteria. Handles API errors by converting them into
MusicBrainzAPIError exceptions with contextual information.
"""
return self.mb_api.search_entity(
return self.mb_api.search(
query_type, filters, limit=self.config["search_limit"].get()
)