mirror of
https://github.com/beetbox/beets.git
synced 2026-01-13 19:52:48 +01:00
Type MusicBrainzAPI properly
This commit is contained in:
parent
55b9c1c145
commit
59b02bc49b
4 changed files with 129 additions and 36 deletions
|
|
@ -12,17 +12,20 @@ from __future__ import annotations
|
|||
|
||||
import operator
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, singledispatchmethod
|
||||
from functools import cached_property, singledispatchmethod, wraps
|
||||
from itertools import groupby
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypedDict, TypeVar
|
||||
|
||||
from requests_ratelimiter import LimiterMixin
|
||||
from typing_extensions import NotRequired, Unpack
|
||||
|
||||
from beets import config, logging
|
||||
|
||||
from .requests import RequestHandler, TimeoutAndRetrySession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from requests import Response
|
||||
|
||||
from .._typing import JSONDict
|
||||
|
|
@ -34,11 +37,80 @@ class LimiterTimeoutSession(LimiterMixin, TimeoutAndRetrySession):
|
|||
"""HTTP session that enforces rate limits."""
|
||||
|
||||
|
||||
Entity = Literal[
|
||||
"area",
|
||||
"artist",
|
||||
"collection",
|
||||
"event",
|
||||
"genre",
|
||||
"instrument",
|
||||
"label",
|
||||
"place",
|
||||
"recording",
|
||||
"release",
|
||||
"release-group",
|
||||
"series",
|
||||
"work",
|
||||
"url",
|
||||
]
|
||||
|
||||
|
||||
class LookupKwargs(TypedDict, total=False):
|
||||
includes: NotRequired[list[str]]
|
||||
|
||||
|
||||
class PagingKwargs(TypedDict, total=False):
|
||||
limit: NotRequired[int]
|
||||
offset: NotRequired[int]
|
||||
|
||||
|
||||
class SearchKwargs(PagingKwargs):
|
||||
query: NotRequired[str]
|
||||
|
||||
|
||||
class BrowseKwargs(LookupKwargs, PagingKwargs, total=False):
|
||||
pass
|
||||
|
||||
|
||||
class BrowseReleaseGroupsKwargs(BrowseKwargs, total=False):
|
||||
artist: NotRequired[str]
|
||||
collection: NotRequired[str]
|
||||
release: NotRequired[str]
|
||||
|
||||
|
||||
class BrowseRecordingsKwargs(BrowseReleaseGroupsKwargs, total=False):
|
||||
work: NotRequired[str]
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def require_one_of(*keys: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
required = frozenset(keys)
|
||||
|
||||
def deco(func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# kwargs is a real dict at runtime; safe to inspect here
|
||||
if not required & kwargs.keys():
|
||||
required_str = ", ".join(sorted(required))
|
||||
raise ValueError(
|
||||
f"At least one of {required_str} filter is required"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
@dataclass
|
||||
class MusicBrainzAPI(RequestHandler):
|
||||
"""High-level interface to the MusicBrainz WS/2 API.
|
||||
|
||||
Responsibilities:
|
||||
|
||||
- Configure the API host and request rate from application configuration.
|
||||
- Offer helpers to fetch common entity types and to run searches.
|
||||
- Normalize MusicBrainz responses so relation lists are grouped by target
|
||||
|
|
@ -85,10 +157,10 @@ class MusicBrainzAPI(RequestHandler):
|
|||
kwargs["params"]["fmt"] = "json"
|
||||
return super().request(*args, **kwargs)
|
||||
|
||||
def get_entity(
|
||||
self, entity: str, includes: list[str] | None = None, **kwargs
|
||||
def _get_resource(
|
||||
self, resource: str, includes: list[str] | None = None, **kwargs
|
||||
) -> JSONDict:
|
||||
"""Retrieve and normalize data from the API entity endpoint.
|
||||
"""Retrieve and normalize data from the API resource endpoint.
|
||||
|
||||
If requested, includes are appended to the request. The response is
|
||||
passed through a normalizer that groups relation entries by their
|
||||
|
|
@ -98,11 +170,22 @@ class MusicBrainzAPI(RequestHandler):
|
|||
kwargs["inc"] = "+".join(includes)
|
||||
|
||||
return self._group_relations(
|
||||
self.get_json(f"{self.api_root}/{entity}", params=kwargs)
|
||||
self.get_json(f"{self.api_root}/{resource}", params=kwargs)
|
||||
)
|
||||
|
||||
def search_entity(
|
||||
self, entity: str, filters: dict[str, str], **kwargs
|
||||
def _lookup(
|
||||
self, entity: Entity, id_: str, **kwargs: Unpack[LookupKwargs]
|
||||
) -> JSONDict:
|
||||
return self._get_resource(f"{entity}/{id_}", **kwargs)
|
||||
|
||||
def _browse(self, entity: Entity, **kwargs) -> list[JSONDict]:
|
||||
return self._get_resource(entity, **kwargs).get(f"{entity}s", [])
|
||||
|
||||
def search(
|
||||
self,
|
||||
entity: Entity,
|
||||
filters: dict[str, str],
|
||||
**kwargs: Unpack[SearchKwargs],
|
||||
) -> list[JSONDict]:
|
||||
"""Search for MusicBrainz entities matching the given filters.
|
||||
|
||||
|
|
@ -119,22 +202,41 @@ class MusicBrainzAPI(RequestHandler):
|
|||
)
|
||||
log.debug("Searching for MusicBrainz {}s with: {!r}", entity, query)
|
||||
kwargs["query"] = query
|
||||
return self.get_entity(entity, **kwargs)[f"{entity}s"]
|
||||
return self._get_resource(entity, **kwargs)[f"{entity}s"]
|
||||
|
||||
def get_release(self, id_: str, **kwargs) -> JSONDict:
|
||||
return self.get_entity(f"release/{id_}", **kwargs)
|
||||
def get_release(self, id_: str, **kwargs: Unpack[LookupKwargs]) -> JSONDict:
|
||||
"""Retrieve a release by its MusicBrainz ID."""
|
||||
return self._lookup("release", id_, **kwargs)
|
||||
|
||||
def get_recording(self, id_: str, **kwargs) -> JSONDict:
|
||||
return self.get_entity(f"recording/{id_}", **kwargs)
|
||||
def get_recording(
|
||||
self, id_: str, **kwargs: Unpack[LookupKwargs]
|
||||
) -> JSONDict:
|
||||
"""Retrieve a recording by its MusicBrainz ID."""
|
||||
return self._lookup("recording", id_, **kwargs)
|
||||
|
||||
def get_work(self, id_: str, **kwargs) -> JSONDict:
|
||||
return self.get_entity(f"work/{id_}", **kwargs)
|
||||
def get_work(self, id_: str, **kwargs: Unpack[LookupKwargs]) -> JSONDict:
|
||||
"""Retrieve a work by its MusicBrainz ID."""
|
||||
return self._lookup("work", id_, **kwargs)
|
||||
|
||||
def browse_recordings(self, **kwargs) -> list[JSONDict]:
|
||||
return self.get_entity("recording", **kwargs)["recordings"]
|
||||
@require_one_of("artist", "collection", "release", "work")
|
||||
def browse_recordings(
|
||||
self, **kwargs: Unpack[BrowseRecordingsKwargs]
|
||||
) -> list[JSONDict]:
|
||||
"""Browse recordings related to the given entities.
|
||||
|
||||
def browse_release_groups(self, **kwargs) -> list[JSONDict]:
|
||||
return self.get_entity("release-group", **kwargs)["release-groups"]
|
||||
At least one of artist, collection, release, or work must be provided.
|
||||
"""
|
||||
return self._browse("recording", **kwargs)
|
||||
|
||||
@require_one_of("artist", "collection", "release")
|
||||
def browse_release_groups(
|
||||
self, **kwargs: Unpack[BrowseReleaseGroupsKwargs]
|
||||
) -> list[JSONDict]:
|
||||
"""Browse release groups related to the given entities.
|
||||
|
||||
At least one of artist, collection, or release must be provided.
|
||||
"""
|
||||
return self._get_resource("release-group", **kwargs)["release-groups"]
|
||||
|
||||
@singledispatchmethod
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ class ListenBrainzPlugin(MusicBrainzAPIMixin, BeetsPlugin):
|
|||
|
||||
def get_mb_recording_id(self, track) -> str | None:
|
||||
"""Returns the MusicBrainz recording ID for a track."""
|
||||
results = self.mb_api.search_entity(
|
||||
results = self.mb_api.search(
|
||||
"recording",
|
||||
{
|
||||
"": track["track_metadata"].get("track_name"),
|
||||
|
|
|
|||
|
|
@ -58,15 +58,12 @@ class MusicBrainzUserAPI(MusicBrainzAPI):
|
|||
|
||||
auth: HTTPDigestAuth = field(init=False)
|
||||
|
||||
@cached_property
|
||||
def user(self) -> str:
|
||||
return config["musicbrainz"]["user"].as_str()
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
config["musicbrainz"]["pass"].redact = True
|
||||
self.auth = HTTPDigestAuth(
|
||||
self.user, config["musicbrainz"]["pass"].as_str()
|
||||
config["musicbrainz"]["user"].as_str(),
|
||||
config["musicbrainz"]["pass"].as_str(),
|
||||
)
|
||||
|
||||
def request(self, *args, **kwargs) -> Response:
|
||||
|
|
@ -76,15 +73,9 @@ class MusicBrainzUserAPI(MusicBrainzAPI):
|
|||
kwargs["auth"] = self.auth
|
||||
return super().request(*args, **kwargs)
|
||||
|
||||
def get_collections(self) -> list[JSONDict]:
|
||||
"""Get all collections for the authenticated user.
|
||||
|
||||
Note that both URL parameters must be included to retrieve private
|
||||
collections.
|
||||
"""
|
||||
return self.get_entity(
|
||||
"collection", editor=self.user, includes=["user-collections"]
|
||||
).get("collections", [])
|
||||
def browse_collections(self) -> list[JSONDict]:
|
||||
"""Get all collections for the authenticated user."""
|
||||
return self._browse("collection")
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -183,7 +174,7 @@ class MusicBrainzCollectionPlugin(BeetsPlugin):
|
|||
|
||||
@cached_property
|
||||
def collection(self) -> MBCollection:
|
||||
if not (collections := self.mb_api.get_collections()):
|
||||
if not (collections := self.mb_api.browse_collections()):
|
||||
raise ui.UserError("no collections exist for user")
|
||||
|
||||
# Get all release collection IDs, avoiding event collections
|
||||
|
|
|
|||
|
|
@ -751,7 +751,7 @@ class MusicBrainzPlugin(MusicBrainzAPIMixin, MetadataSourcePlugin):
|
|||
using the provided criteria. Handles API errors by converting them into
|
||||
MusicBrainzAPIError exceptions with contextual information.
|
||||
"""
|
||||
return self.mb_api.search_entity(
|
||||
return self.mb_api.search(
|
||||
query_type, filters, limit=self.config["search_limit"].get()
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue