mirror of
https://github.com/beetbox/beets.git
synced 2025-12-22 00:23:33 +01:00
Refactor HTTP request handling with RequestHandler base class
Introduce a new RequestHandler base class to introduce a shared session, centralize HTTP request management and error handling across plugins. Key changes: - Add RequestHandler base class with a shared/cached session - Convert TimeoutSession to use SingletonMeta for proper resource management - Create LyricsRequestHandler subclass with lyrics-specific error handling - Update MusicBrainzAPI to inherit from RequestHandler
This commit is contained in:
parent
041d4b8036
commit
72f7d6ebe3
4 changed files with 187 additions and 59 deletions
|
|
@ -1,38 +1,149 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from functools import cached_property
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar
|
||||
|
||||
import requests
|
||||
|
||||
from beets import __version__
|
||||
|
||||
|
||||
class HTTPNotFoundError(requests.exceptions.HTTPError):
|
||||
pass
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
|
||||
class CaptchaError(requests.exceptions.HTTPError):
|
||||
pass
|
||||
class BeetsHTTPError(requests.exceptions.HTTPError):
|
||||
STATUS: ClassVar[HTTPStatus]
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(
|
||||
f"HTTP Error: {self.STATUS.value} {self.STATUS.phrase}",
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class TimeoutSession(requests.Session):
|
||||
class HTTPNotFoundError(BeetsHTTPError):
|
||||
STATUS = HTTPStatus.NOT_FOUND
|
||||
|
||||
|
||||
class Closeable(Protocol):
|
||||
"""Protocol for objects that have a close method."""
|
||||
|
||||
def close(self) -> None: ...
|
||||
|
||||
|
||||
C = TypeVar("C", bound=Closeable)
|
||||
|
||||
|
||||
class SingletonMeta(type, Generic[C]):
|
||||
"""Metaclass ensuring a single shared instance per class.
|
||||
|
||||
Creates one instance per class type on first instantiation, reusing it
|
||||
for all subsequent calls. Automatically registers cleanup on program exit
|
||||
for proper resource management.
|
||||
"""
|
||||
|
||||
_instances: ClassVar[dict[type[Any], Any]] = {}
|
||||
_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> C:
|
||||
if cls not in cls._instances:
|
||||
with cls._lock:
|
||||
if cls not in SingletonMeta._instances:
|
||||
instance = super().__call__(*args, **kwargs)
|
||||
SingletonMeta._instances[cls] = instance
|
||||
atexit.register(instance.close)
|
||||
return SingletonMeta._instances[cls]
|
||||
|
||||
|
||||
class TimeoutSession(requests.Session, metaclass=SingletonMeta):
|
||||
"""HTTP session with automatic timeout and status checking.
|
||||
|
||||
Extends requests.Session to provide sensible defaults for beets HTTP
|
||||
requests: automatic timeout enforcement, status code validation, and
|
||||
proper user agent identification.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.headers["User-Agent"] = f"beets/{__version__} https://beets.io/"
|
||||
|
||||
@atexit.register
|
||||
def close_session():
|
||||
"""Close the requests session on shut down."""
|
||||
self.close()
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
"""Wrap the request method to raise an exception on HTTP errors."""
|
||||
"""Execute HTTP request with automatic timeout and status validation.
|
||||
|
||||
Ensures all requests have a timeout (defaults to 10 seconds) and raises
|
||||
an exception for HTTP error status codes.
|
||||
"""
|
||||
kwargs.setdefault("timeout", 10)
|
||||
r = super().request(*args, **kwargs)
|
||||
if r.status_code == HTTPStatus.NOT_FOUND:
|
||||
raise HTTPNotFoundError("HTTP Error: Not Found", response=r)
|
||||
if 300 <= r.status_code < 400:
|
||||
raise CaptchaError("Captcha is required", response=r)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
return r
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
"""Manages HTTP requests with custom error handling and session management.
|
||||
|
||||
Provides a reusable interface for making HTTP requests with automatic
|
||||
conversion of standard HTTP errors to beets-specific exceptions. Supports
|
||||
custom session types and error mappings that can be overridden by
|
||||
subclasses.
|
||||
"""
|
||||
|
||||
session_type: ClassVar[type[TimeoutSession]] = TimeoutSession
|
||||
explicit_http_errors: ClassVar[list[type[BeetsHTTPError]]] = [
|
||||
HTTPNotFoundError
|
||||
]
|
||||
|
||||
@cached_property
|
||||
def session(self) -> Any:
|
||||
"""Lazily initialize and cache the HTTP session."""
|
||||
return self.session_type()
|
||||
|
||||
def status_to_error(
|
||||
self, code: int
|
||||
) -> type[requests.exceptions.HTTPError] | None:
|
||||
"""Map HTTP status codes to beets-specific exception types.
|
||||
|
||||
Searches the configured explicit HTTP errors for a matching status code.
|
||||
Returns None if no specific error type is registered for the given code.
|
||||
"""
|
||||
return next(
|
||||
(e for e in self.explicit_http_errors if e.STATUS == code), None
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def handle_http_error(self) -> Iterator[None]:
|
||||
"""Convert standard HTTP errors to beets-specific exceptions.
|
||||
|
||||
Wraps operations that may raise HTTPError, automatically translating
|
||||
recognized status codes into their corresponding beets exception types.
|
||||
Unrecognized errors are re-raised unchanged.
|
||||
"""
|
||||
try:
|
||||
yield
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if beets_error := self.status_to_error(e.response.status_code):
|
||||
raise beets_error(response=e.response) from e
|
||||
raise
|
||||
|
||||
def request(self, *args, **kwargs) -> requests.Response:
|
||||
"""Perform HTTP request using the session with automatic error handling.
|
||||
|
||||
Delegates to the underlying session method while converting recognized
|
||||
HTTP errors to beets-specific exceptions through the error handler.
|
||||
"""
|
||||
with self.handle_http_error():
|
||||
return self.session.request(*args, **kwargs)
|
||||
|
||||
def get(self, *args, **kwargs) -> requests.Response:
|
||||
"""Perform HTTP GET request with automatic error handling."""
|
||||
return self.request("get", *args, **kwargs)
|
||||
|
||||
def get_json(self, *args, **kwargs):
|
||||
"""Fetch and parse JSON data from an HTTP endpoint."""
|
||||
return self.get(*args, **kwargs).json()
|
||||
|
|
|
|||
|
|
@ -34,16 +34,17 @@ import requests
|
|||
from bs4 import BeautifulSoup
|
||||
from unidecode import unidecode
|
||||
|
||||
import beets
|
||||
from beets import plugins, ui
|
||||
from beets.autotag.distance import string_dist
|
||||
from beets.util.config import sanitize_choices
|
||||
|
||||
from ._utils.requests import CaptchaError, HTTPNotFoundError, TimeoutSession
|
||||
from ._utils.requests import HTTPNotFoundError, RequestHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator
|
||||
|
||||
import confuse
|
||||
|
||||
from beets.importer import ImportTask
|
||||
from beets.library import Item, Library
|
||||
from beets.logging import BeetsLogger as Logger
|
||||
|
|
@ -59,7 +60,9 @@ if TYPE_CHECKING:
|
|||
INSTRUMENTAL_LYRICS = "[Instrumental]"
|
||||
|
||||
|
||||
r_session = TimeoutSession()
|
||||
class CaptchaError(requests.exceptions.HTTPError):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__("Captcha is required", *args, **kwargs)
|
||||
|
||||
|
||||
# Utilities.
|
||||
|
|
@ -155,9 +158,18 @@ def slug(text: str) -> str:
|
|||
return re.sub(r"\W+", "-", unidecode(text).lower().strip()).strip("-")
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
class LyricsRequestHandler(RequestHandler):
|
||||
_log: Logger
|
||||
|
||||
def status_to_error(self, code: int) -> type[requests.HTTPError] | None:
|
||||
if err := super().status_to_error(code):
|
||||
return err
|
||||
|
||||
if 300 <= code < 400:
|
||||
return CaptchaError
|
||||
|
||||
return None
|
||||
|
||||
def debug(self, message: str, *args) -> None:
|
||||
"""Log a debug message with the class name."""
|
||||
self._log.debug(f"{self.__class__.__name__}: {message}", *args)
|
||||
|
|
@ -177,7 +189,7 @@ class RequestHandler:
|
|||
|
||||
return f"{url}?{urlencode(params)}"
|
||||
|
||||
def fetch_text(
|
||||
def get_text(
|
||||
self, url: str, params: JSONDict | None = None, **kwargs
|
||||
) -> str:
|
||||
"""Return text / HTML data from the given URL.
|
||||
|
|
@ -187,21 +199,21 @@ class RequestHandler:
|
|||
"""
|
||||
url = self.format_url(url, params)
|
||||
self.debug("Fetching HTML from {}", url)
|
||||
r = r_session.get(url, **kwargs)
|
||||
r = self.get(url, **kwargs)
|
||||
r.encoding = None
|
||||
return r.text
|
||||
|
||||
def fetch_json(self, url: str, params: JSONDict | None = None, **kwargs):
|
||||
def get_json(self, url: str, params: JSONDict | None = None, **kwargs):
|
||||
"""Return JSON data from the given URL."""
|
||||
url = self.format_url(url, params)
|
||||
self.debug("Fetching JSON from {}", url)
|
||||
return r_session.get(url, **kwargs).json()
|
||||
return super().get_json(url, **kwargs)
|
||||
|
||||
def post_json(self, url: str, params: JSONDict | None = None, **kwargs):
|
||||
"""Send POST request and return JSON response."""
|
||||
url = self.format_url(url, params)
|
||||
self.debug("Posting JSON to {}", url)
|
||||
return r_session.post(url, **kwargs).json()
|
||||
return self.request("post", url, **kwargs).json()
|
||||
|
||||
@contextmanager
|
||||
def handle_request(self) -> Iterator[None]:
|
||||
|
|
@ -220,8 +232,10 @@ class BackendClass(type):
|
|||
return cls.__name__.lower()
|
||||
|
||||
|
||||
class Backend(RequestHandler, metaclass=BackendClass):
|
||||
def __init__(self, config, log):
|
||||
class Backend(LyricsRequestHandler, metaclass=BackendClass):
|
||||
config: confuse.Subview
|
||||
|
||||
def __init__(self, config: confuse.Subview, log: Logger) -> None:
|
||||
self._log = log
|
||||
self.config = config
|
||||
|
||||
|
|
@ -325,10 +339,10 @@ class LRCLib(Backend):
|
|||
if album:
|
||||
get_params["album_name"] = album
|
||||
|
||||
yield self.fetch_json(self.SEARCH_URL, params=base_params)
|
||||
yield self.get_json(self.SEARCH_URL, params=base_params)
|
||||
|
||||
with suppress(HTTPNotFoundError):
|
||||
yield [self.fetch_json(self.GET_URL, params=get_params)]
|
||||
yield [self.get_json(self.GET_URL, params=get_params)]
|
||||
|
||||
@classmethod
|
||||
def pick_best_match(cls, lyrics: Iterable[LRCLyrics]) -> LRCLyrics | None:
|
||||
|
|
@ -376,7 +390,7 @@ class MusiXmatch(Backend):
|
|||
def fetch(self, artist: str, title: str, *_) -> tuple[str, str] | None:
|
||||
url = self.build_url(artist, title)
|
||||
|
||||
html = self.fetch_text(url)
|
||||
html = self.get_text(url)
|
||||
if "We detected that your IP is blocked" in html:
|
||||
self.warn("Failed: Blocked IP address")
|
||||
return None
|
||||
|
|
@ -501,7 +515,7 @@ class SearchBackend(SoupMixin, Backend):
|
|||
def fetch(self, artist: str, title: str, *_) -> tuple[str, str] | None:
|
||||
"""Fetch lyrics for the given artist and title."""
|
||||
for result in self.get_results(artist, title):
|
||||
if (html := self.fetch_text(result.url)) and (
|
||||
if (html := self.get_text(result.url)) and (
|
||||
lyrics := self.scrape(html)
|
||||
):
|
||||
return lyrics, result.url
|
||||
|
|
@ -531,7 +545,7 @@ class Genius(SearchBackend):
|
|||
return {"Authorization": f"Bearer {self.config['genius_api_key']}"}
|
||||
|
||||
def search(self, artist: str, title: str) -> Iterable[SearchResult]:
|
||||
search_data: GeniusAPI.Search = self.fetch_json(
|
||||
search_data: GeniusAPI.Search = self.get_json(
|
||||
self.SEARCH_URL,
|
||||
params={"q": f"{artist} {title}"},
|
||||
headers=self.headers,
|
||||
|
|
@ -560,7 +574,7 @@ class Tekstowo(SearchBackend):
|
|||
return self.SEARCH_URL.format(quote_plus(unidecode(artistitle)))
|
||||
|
||||
def search(self, artist: str, title: str) -> Iterable[SearchResult]:
|
||||
if html := self.fetch_text(self.build_url(title, artist)):
|
||||
if html := self.get_text(self.build_url(title, artist)):
|
||||
soup = self.get_soup(html)
|
||||
for tag in soup.select("div[class=flex-group] > a[title*=' - ']"):
|
||||
artist, title = str(tag["title"]).split(" - ", 1)
|
||||
|
|
@ -626,12 +640,12 @@ class Google(SearchBackend):
|
|||
html = Html.remove_ads(super().pre_process_html(html))
|
||||
return Html.remove_formatting(Html.merge_paragraphs(html))
|
||||
|
||||
def fetch_text(self, *args, **kwargs) -> str:
|
||||
def get_text(self, *args, **kwargs) -> str:
|
||||
"""Handle an error so that we can continue with the next URL."""
|
||||
kwargs.setdefault("allow_redirects", False)
|
||||
with self.handle_request():
|
||||
try:
|
||||
return super().fetch_text(*args, **kwargs)
|
||||
return super().get_text(*args, **kwargs)
|
||||
except CaptchaError:
|
||||
self.ignored_domains.add(urlparse(args[0]).netloc)
|
||||
raise
|
||||
|
|
@ -687,7 +701,7 @@ class Google(SearchBackend):
|
|||
"excludeTerms": ", ".join(self.EXCLUDE_PAGES),
|
||||
}
|
||||
|
||||
data: GoogleCustomSearchAPI.Response = self.fetch_json(
|
||||
data: GoogleCustomSearchAPI.Response = self.get_json(
|
||||
self.SEARCH_URL, params=params
|
||||
)
|
||||
for item in data.get("items", []):
|
||||
|
|
@ -712,7 +726,7 @@ class Google(SearchBackend):
|
|||
|
||||
|
||||
@dataclass
|
||||
class Translator(RequestHandler):
|
||||
class Translator(LyricsRequestHandler):
|
||||
TRANSLATE_URL = "https://api.cognitive.microsofttranslator.com/translate"
|
||||
LINE_PARTS_RE = re.compile(r"^(\[\d\d:\d\d.\d\d\]|) *(.*)$")
|
||||
SEPARATOR = " | "
|
||||
|
|
@ -922,7 +936,7 @@ class RestFiles:
|
|||
ui.print_(textwrap.dedent(text))
|
||||
|
||||
|
||||
class LyricsPlugin(RequestHandler, plugins.BeetsPlugin):
|
||||
class LyricsPlugin(LyricsRequestHandler, plugins.BeetsPlugin):
|
||||
BACKEND_BY_NAME = {
|
||||
b.name: b for b in [LRCLib, Google, Genius, Tekstowo, MusiXmatch]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ from beets.metadata_plugins import MetadataSourcePlugin
|
|||
from beets.util.deprecation import deprecate_for_user
|
||||
from beets.util.id_extractors import extract_release_id
|
||||
|
||||
from ._utils.requests import HTTPNotFoundError, TimeoutSession
|
||||
from ._utils.requests import HTTPNotFoundError, RequestHandler, TimeoutSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
|
@ -61,10 +61,6 @@ FIELDS_TO_MB_KEYS = {
|
|||
}
|
||||
|
||||
|
||||
class LimiterTimeoutSession(LimiterMixin, TimeoutSession):
|
||||
pass
|
||||
|
||||
|
||||
RELEASE_INCLUDES = [
|
||||
"artists",
|
||||
"media",
|
||||
|
|
@ -103,32 +99,39 @@ BROWSE_CHUNKSIZE = 100
|
|||
BROWSE_MAXTRACKS = 500
|
||||
|
||||
|
||||
class LimiterTimeoutSession(LimiterMixin, TimeoutSession):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MusicBrainzAPI:
|
||||
class MusicBrainzAPI(RequestHandler):
|
||||
session_type = LimiterTimeoutSession
|
||||
|
||||
api_host: str
|
||||
rate_limit: float
|
||||
|
||||
@cached_property
|
||||
def session(self) -> LimiterTimeoutSession:
|
||||
return LimiterTimeoutSession(per_second=self.rate_limit)
|
||||
return self.session_type(per_second=self.rate_limit)
|
||||
|
||||
def _get(self, entity: str, **kwargs) -> JSONDict:
|
||||
return self.session.get(
|
||||
f"{self.api_host}/ws/2/{entity}", params={**kwargs, "fmt": "json"}
|
||||
).json()
|
||||
|
||||
def get_release(self, id_: str) -> JSONDict:
|
||||
def get_entity(self, entity: str, **kwargs) -> JSONDict:
|
||||
return self._group_relations(
|
||||
self._get(f"release/{id_}", inc=" ".join(RELEASE_INCLUDES))
|
||||
self.get_json(
|
||||
f"{self.api_host}/ws/2/{entity}",
|
||||
params={**kwargs, "fmt": "json"},
|
||||
)
|
||||
)
|
||||
|
||||
def get_release(self, id_: str) -> JSONDict:
|
||||
return self.get_entity(f"release/{id_}", inc=" ".join(RELEASE_INCLUDES))
|
||||
|
||||
def get_recording(self, id_: str) -> JSONDict:
|
||||
return self._get(f"recording/{id_}", inc=" ".join(TRACK_INCLUDES))
|
||||
return self.get_entity(f"recording/{id_}", inc=" ".join(TRACK_INCLUDES))
|
||||
|
||||
def browse_recordings(self, **kwargs) -> list[JSONDict]:
|
||||
kwargs.setdefault("limit", BROWSE_CHUNKSIZE)
|
||||
kwargs.setdefault("inc", BROWSE_INCLUDES)
|
||||
return self._get("recording", **kwargs)["recordings"]
|
||||
return self.get_entity("recording", **kwargs)["recordings"]
|
||||
|
||||
@singledispatchmethod
|
||||
@classmethod
|
||||
|
|
@ -202,7 +205,7 @@ def _preferred_alias(
|
|||
if (
|
||||
alias["locale"] == locale
|
||||
and alias.get("primary")
|
||||
and alias.get("type", "").lower() not in ignored_alias_types
|
||||
and (alias.get("type") or "").lower() not in ignored_alias_types
|
||||
):
|
||||
matches.append(alias)
|
||||
|
||||
|
|
@ -852,7 +855,7 @@ class MusicBrainzPlugin(MetadataSourcePlugin):
|
|||
self._log.debug(
|
||||
"Searching for MusicBrainz {}s with: {!r}", query_type, query
|
||||
)
|
||||
return self.api._get(
|
||||
return self.api.get_entity(
|
||||
query_type, query=query, limit=self.config["search_limit"].get()
|
||||
)[f"{query_type}s"]
|
||||
|
||||
|
|
|
|||
|
|
@ -1052,7 +1052,7 @@ class TestMusicBrainzPlugin(PluginMixin):
|
|||
|
||||
def test_item_candidates(self, monkeypatch, mb):
|
||||
monkeypatch.setattr(
|
||||
"beetsplug.musicbrainz.MusicBrainzAPI._get",
|
||||
"beetsplug.musicbrainz.MusicBrainzAPI.get_json",
|
||||
lambda *_, **__: {"recordings": [self.RECORDING]},
|
||||
)
|
||||
|
||||
|
|
@ -1063,7 +1063,7 @@ class TestMusicBrainzPlugin(PluginMixin):
|
|||
|
||||
def test_candidates(self, monkeypatch, mb):
|
||||
monkeypatch.setattr(
|
||||
"beetsplug.musicbrainz.MusicBrainzAPI._get",
|
||||
"beetsplug.musicbrainz.MusicBrainzAPI.get_json",
|
||||
lambda *_, **__: {"releases": [{"id": self.mbid}]},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
|
|
|
|||
Loading…
Reference in a new issue