diff --git a/beetsplug/_utils/requests.py b/beetsplug/_utils/requests.py index a9a1af372..b45efd780 100644 --- a/beetsplug/_utils/requests.py +++ b/beetsplug/_utils/requests.py @@ -1,38 +1,149 @@ +from __future__ import annotations + import atexit +import threading +from contextlib import contextmanager +from functools import cached_property from http import HTTPStatus +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar import requests from beets import __version__ - -class HTTPNotFoundError(requests.exceptions.HTTPError): - pass +if TYPE_CHECKING: + from collections.abc import Iterator -class CaptchaError(requests.exceptions.HTTPError): - pass +class BeetsHTTPError(requests.exceptions.HTTPError): + STATUS: ClassVar[HTTPStatus] + + def __init__(self, *args, **kwargs) -> None: + super().__init__( + f"HTTP Error: {self.STATUS.value} {self.STATUS.phrase}", + *args, + **kwargs, + ) -class TimeoutSession(requests.Session): +class HTTPNotFoundError(BeetsHTTPError): + STATUS = HTTPStatus.NOT_FOUND + + +class Closeable(Protocol): + """Protocol for objects that have a close method.""" + + def close(self) -> None: ... + + +C = TypeVar("C", bound=Closeable) + + +class SingletonMeta(type, Generic[C]): + """Metaclass ensuring a single shared instance per class. + + Creates one instance per class type on first instantiation, reusing it + for all subsequent calls. Automatically registers cleanup on program exit + for proper resource management. + """ + + _instances: ClassVar[dict[type[Any], Any]] = {} + _lock: ClassVar[threading.Lock] = threading.Lock() + + def __call__(cls, *args: Any, **kwargs: Any) -> C: + if cls not in cls._instances: + with cls._lock: + if cls not in SingletonMeta._instances: + instance = super().__call__(*args, **kwargs) + SingletonMeta._instances[cls] = instance + atexit.register(instance.close) + return SingletonMeta._instances[cls] + + +class TimeoutSession(requests.Session, metaclass=SingletonMeta): + """HTTP session with automatic timeout and status checking. + + Extends requests.Session to provide sensible defaults for beets HTTP + requests: automatic timeout enforcement, status code validation, and + proper user agent identification. + """ + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.headers["User-Agent"] = f"beets/{__version__} https://beets.io/" - @atexit.register - def close_session(): - """Close the requests session on shut down.""" - self.close() - def request(self, *args, **kwargs): - """Wrap the request method to raise an exception on HTTP errors.""" + """Execute HTTP request with automatic timeout and status validation. + + Ensures all requests have a timeout (defaults to 10 seconds) and raises + an exception for HTTP error status codes. + """ kwargs.setdefault("timeout", 10) r = super().request(*args, **kwargs) - if r.status_code == HTTPStatus.NOT_FOUND: - raise HTTPNotFoundError("HTTP Error: Not Found", response=r) - if 300 <= r.status_code < 400: - raise CaptchaError("Captcha is required", response=r) - r.raise_for_status() return r + + +class RequestHandler: + """Manages HTTP requests with custom error handling and session management. + + Provides a reusable interface for making HTTP requests with automatic + conversion of standard HTTP errors to beets-specific exceptions. Supports + custom session types and error mappings that can be overridden by + subclasses. + """ + + session_type: ClassVar[type[TimeoutSession]] = TimeoutSession + explicit_http_errors: ClassVar[list[type[BeetsHTTPError]]] = [ + HTTPNotFoundError + ] + + @cached_property + def session(self) -> Any: + """Lazily initialize and cache the HTTP session.""" + return self.session_type() + + def status_to_error( + self, code: int + ) -> type[requests.exceptions.HTTPError] | None: + """Map HTTP status codes to beets-specific exception types. + + Searches the configured explicit HTTP errors for a matching status code. + Returns None if no specific error type is registered for the given code. + """ + return next( + (e for e in self.explicit_http_errors if e.STATUS == code), None + ) + + @contextmanager + def handle_http_error(self) -> Iterator[None]: + """Convert standard HTTP errors to beets-specific exceptions. + + Wraps operations that may raise HTTPError, automatically translating + recognized status codes into their corresponding beets exception types. + Unrecognized errors are re-raised unchanged. + """ + try: + yield + except requests.exceptions.HTTPError as e: + if beets_error := self.status_to_error(e.response.status_code): + raise beets_error(response=e.response) from e + raise + + def request(self, *args, **kwargs) -> requests.Response: + """Perform HTTP request using the session with automatic error handling. + + Delegates to the underlying session method while converting recognized + HTTP errors to beets-specific exceptions through the error handler. + """ + with self.handle_http_error(): + return self.session.request(*args, **kwargs) + + def get(self, *args, **kwargs) -> requests.Response: + """Perform HTTP GET request with automatic error handling.""" + return self.request("get", *args, **kwargs) + + def get_json(self, *args, **kwargs): + """Fetch and parse JSON data from an HTTP endpoint.""" + return self.get(*args, **kwargs).json() diff --git a/beetsplug/lyrics.py b/beetsplug/lyrics.py index 8b28a6179..d6e14c175 100644 --- a/beetsplug/lyrics.py +++ b/beetsplug/lyrics.py @@ -34,16 +34,17 @@ import requests from bs4 import BeautifulSoup from unidecode import unidecode -import beets from beets import plugins, ui from beets.autotag.distance import string_dist from beets.util.config import sanitize_choices -from ._utils.requests import CaptchaError, HTTPNotFoundError, TimeoutSession +from ._utils.requests import HTTPNotFoundError, RequestHandler if TYPE_CHECKING: from collections.abc import Iterable, Iterator + import confuse + from beets.importer import ImportTask from beets.library import Item, Library from beets.logging import BeetsLogger as Logger @@ -59,7 +60,9 @@ if TYPE_CHECKING: INSTRUMENTAL_LYRICS = "[Instrumental]" -r_session = TimeoutSession() +class CaptchaError(requests.exceptions.HTTPError): + def __init__(self, *args, **kwargs) -> None: + super().__init__("Captcha is required", *args, **kwargs) # Utilities. @@ -155,9 +158,18 @@ def slug(text: str) -> str: return re.sub(r"\W+", "-", unidecode(text).lower().strip()).strip("-") -class RequestHandler: +class LyricsRequestHandler(RequestHandler): _log: Logger + def status_to_error(self, code: int) -> type[requests.HTTPError] | None: + if err := super().status_to_error(code): + return err + + if 300 <= code < 400: + return CaptchaError + + return None + def debug(self, message: str, *args) -> None: """Log a debug message with the class name.""" self._log.debug(f"{self.__class__.__name__}: {message}", *args) @@ -177,7 +189,7 @@ class RequestHandler: return f"{url}?{urlencode(params)}" - def fetch_text( + def get_text( self, url: str, params: JSONDict | None = None, **kwargs ) -> str: """Return text / HTML data from the given URL. @@ -187,21 +199,21 @@ class RequestHandler: """ url = self.format_url(url, params) self.debug("Fetching HTML from {}", url) - r = r_session.get(url, **kwargs) + r = self.get(url, **kwargs) r.encoding = None return r.text - def fetch_json(self, url: str, params: JSONDict | None = None, **kwargs): + def get_json(self, url: str, params: JSONDict | None = None, **kwargs): """Return JSON data from the given URL.""" url = self.format_url(url, params) self.debug("Fetching JSON from {}", url) - return r_session.get(url, **kwargs).json() + return super().get_json(url, **kwargs) def post_json(self, url: str, params: JSONDict | None = None, **kwargs): """Send POST request and return JSON response.""" url = self.format_url(url, params) self.debug("Posting JSON to {}", url) - return r_session.post(url, **kwargs).json() + return self.request("post", url, **kwargs).json() @contextmanager def handle_request(self) -> Iterator[None]: @@ -220,8 +232,10 @@ class BackendClass(type): return cls.__name__.lower() -class Backend(RequestHandler, metaclass=BackendClass): - def __init__(self, config, log): +class Backend(LyricsRequestHandler, metaclass=BackendClass): + config: confuse.Subview + + def __init__(self, config: confuse.Subview, log: Logger) -> None: self._log = log self.config = config @@ -325,10 +339,10 @@ class LRCLib(Backend): if album: get_params["album_name"] = album - yield self.fetch_json(self.SEARCH_URL, params=base_params) + yield self.get_json(self.SEARCH_URL, params=base_params) with suppress(HTTPNotFoundError): - yield [self.fetch_json(self.GET_URL, params=get_params)] + yield [self.get_json(self.GET_URL, params=get_params)] @classmethod def pick_best_match(cls, lyrics: Iterable[LRCLyrics]) -> LRCLyrics | None: @@ -376,7 +390,7 @@ class MusiXmatch(Backend): def fetch(self, artist: str, title: str, *_) -> tuple[str, str] | None: url = self.build_url(artist, title) - html = self.fetch_text(url) + html = self.get_text(url) if "We detected that your IP is blocked" in html: self.warn("Failed: Blocked IP address") return None @@ -501,7 +515,7 @@ class SearchBackend(SoupMixin, Backend): def fetch(self, artist: str, title: str, *_) -> tuple[str, str] | None: """Fetch lyrics for the given artist and title.""" for result in self.get_results(artist, title): - if (html := self.fetch_text(result.url)) and ( + if (html := self.get_text(result.url)) and ( lyrics := self.scrape(html) ): return lyrics, result.url @@ -531,7 +545,7 @@ class Genius(SearchBackend): return {"Authorization": f"Bearer {self.config['genius_api_key']}"} def search(self, artist: str, title: str) -> Iterable[SearchResult]: - search_data: GeniusAPI.Search = self.fetch_json( + search_data: GeniusAPI.Search = self.get_json( self.SEARCH_URL, params={"q": f"{artist} {title}"}, headers=self.headers, @@ -560,7 +574,7 @@ class Tekstowo(SearchBackend): return self.SEARCH_URL.format(quote_plus(unidecode(artistitle))) def search(self, artist: str, title: str) -> Iterable[SearchResult]: - if html := self.fetch_text(self.build_url(title, artist)): + if html := self.get_text(self.build_url(title, artist)): soup = self.get_soup(html) for tag in soup.select("div[class=flex-group] > a[title*=' - ']"): artist, title = str(tag["title"]).split(" - ", 1) @@ -626,12 +640,12 @@ class Google(SearchBackend): html = Html.remove_ads(super().pre_process_html(html)) return Html.remove_formatting(Html.merge_paragraphs(html)) - def fetch_text(self, *args, **kwargs) -> str: + def get_text(self, *args, **kwargs) -> str: """Handle an error so that we can continue with the next URL.""" kwargs.setdefault("allow_redirects", False) with self.handle_request(): try: - return super().fetch_text(*args, **kwargs) + return super().get_text(*args, **kwargs) except CaptchaError: self.ignored_domains.add(urlparse(args[0]).netloc) raise @@ -687,7 +701,7 @@ class Google(SearchBackend): "excludeTerms": ", ".join(self.EXCLUDE_PAGES), } - data: GoogleCustomSearchAPI.Response = self.fetch_json( + data: GoogleCustomSearchAPI.Response = self.get_json( self.SEARCH_URL, params=params ) for item in data.get("items", []): @@ -712,7 +726,7 @@ class Google(SearchBackend): @dataclass -class Translator(RequestHandler): +class Translator(LyricsRequestHandler): TRANSLATE_URL = "https://api.cognitive.microsofttranslator.com/translate" LINE_PARTS_RE = re.compile(r"^(\[\d\d:\d\d.\d\d\]|) *(.*)$") SEPARATOR = " | " @@ -922,7 +936,7 @@ class RestFiles: ui.print_(textwrap.dedent(text)) -class LyricsPlugin(RequestHandler, plugins.BeetsPlugin): +class LyricsPlugin(LyricsRequestHandler, plugins.BeetsPlugin): BACKEND_BY_NAME = { b.name: b for b in [LRCLib, Google, Genius, Tekstowo, MusiXmatch] } diff --git a/beetsplug/musicbrainz.py b/beetsplug/musicbrainz.py index e777a5d18..91d829dcc 100644 --- a/beetsplug/musicbrainz.py +++ b/beetsplug/musicbrainz.py @@ -35,7 +35,7 @@ from beets.metadata_plugins import MetadataSourcePlugin from beets.util.deprecation import deprecate_for_user from beets.util.id_extractors import extract_release_id -from ._utils.requests import HTTPNotFoundError, TimeoutSession +from ._utils.requests import HTTPNotFoundError, RequestHandler, TimeoutSession if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -61,10 +61,6 @@ FIELDS_TO_MB_KEYS = { } -class LimiterTimeoutSession(LimiterMixin, TimeoutSession): - pass - - RELEASE_INCLUDES = [ "artists", "media", @@ -103,32 +99,39 @@ BROWSE_CHUNKSIZE = 100 BROWSE_MAXTRACKS = 500 +class LimiterTimeoutSession(LimiterMixin, TimeoutSession): + pass + + @dataclass -class MusicBrainzAPI: +class MusicBrainzAPI(RequestHandler): + session_type = LimiterTimeoutSession + api_host: str rate_limit: float @cached_property def session(self) -> LimiterTimeoutSession: - return LimiterTimeoutSession(per_second=self.rate_limit) + return self.session_type(per_second=self.rate_limit) - def _get(self, entity: str, **kwargs) -> JSONDict: - return self.session.get( - f"{self.api_host}/ws/2/{entity}", params={**kwargs, "fmt": "json"} - ).json() - - def get_release(self, id_: str) -> JSONDict: + def get_entity(self, entity: str, **kwargs) -> JSONDict: return self._group_relations( - self._get(f"release/{id_}", inc=" ".join(RELEASE_INCLUDES)) + self.get_json( + f"{self.api_host}/ws/2/{entity}", + params={**kwargs, "fmt": "json"}, + ) ) + def get_release(self, id_: str) -> JSONDict: + return self.get_entity(f"release/{id_}", inc=" ".join(RELEASE_INCLUDES)) + def get_recording(self, id_: str) -> JSONDict: - return self._get(f"recording/{id_}", inc=" ".join(TRACK_INCLUDES)) + return self.get_entity(f"recording/{id_}", inc=" ".join(TRACK_INCLUDES)) def browse_recordings(self, **kwargs) -> list[JSONDict]: kwargs.setdefault("limit", BROWSE_CHUNKSIZE) kwargs.setdefault("inc", BROWSE_INCLUDES) - return self._get("recording", **kwargs)["recordings"] + return self.get_entity("recording", **kwargs)["recordings"] @singledispatchmethod @classmethod @@ -202,7 +205,7 @@ def _preferred_alias( if ( alias["locale"] == locale and alias.get("primary") - and alias.get("type", "").lower() not in ignored_alias_types + and (alias.get("type") or "").lower() not in ignored_alias_types ): matches.append(alias) @@ -852,7 +855,7 @@ class MusicBrainzPlugin(MetadataSourcePlugin): self._log.debug( "Searching for MusicBrainz {}s with: {!r}", query_type, query ) - return self.api._get( + return self.api.get_entity( query_type, query=query, limit=self.config["search_limit"].get() )[f"{query_type}s"] diff --git a/test/plugins/test_musicbrainz.py b/test/plugins/test_musicbrainz.py index a81c85c4d..0a3155430 100644 --- a/test/plugins/test_musicbrainz.py +++ b/test/plugins/test_musicbrainz.py @@ -1052,7 +1052,7 @@ class TestMusicBrainzPlugin(PluginMixin): def test_item_candidates(self, monkeypatch, mb): monkeypatch.setattr( - "beetsplug.musicbrainz.MusicBrainzAPI._get", + "beetsplug.musicbrainz.MusicBrainzAPI.get_json", lambda *_, **__: {"recordings": [self.RECORDING]}, ) @@ -1063,7 +1063,7 @@ class TestMusicBrainzPlugin(PluginMixin): def test_candidates(self, monkeypatch, mb): monkeypatch.setattr( - "beetsplug.musicbrainz.MusicBrainzAPI._get", + "beetsplug.musicbrainz.MusicBrainzAPI.get_json", lambda *_, **__: {"releases": [{"id": self.mbid}]}, ) monkeypatch.setattr(