Refactor HTTP request handling with RequestHandler base class

Introduce a new RequestHandler base class to introduce a shared session,
centralize HTTP request management and error handling across plugins.

Key changes:
- Add RequestHandler base class with a shared/cached session
- Convert TimeoutSession to use SingletonMeta for proper resource
  management
- Create LyricsRequestHandler subclass with lyrics-specific error
  handling
- Update MusicBrainzAPI to inherit from RequestHandler
This commit is contained in:
Šarūnas Nejus 2025-10-20 21:54:26 +01:00
parent 69dc06dff7
commit 71b20ae6f6
No known key found for this signature in database
4 changed files with 175 additions and 45 deletions

View file

@ -1,38 +1,149 @@
from __future__ import annotations
import atexit import atexit
import threading
from contextlib import contextmanager
from functools import cached_property
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar
import requests import requests
from beets import __version__ from beets import __version__
if TYPE_CHECKING:
class HTTPNotFoundError(requests.exceptions.HTTPError): from collections.abc import Iterator
pass
class CaptchaError(requests.exceptions.HTTPError): class BeetsHTTPError(requests.exceptions.HTTPError):
pass STATUS: ClassVar[HTTPStatus]
def __init__(self, *args, **kwargs) -> None:
super().__init__(
f"HTTP Error: {self.STATUS.value} {self.STATUS.phrase}",
*args,
**kwargs,
)
class TimeoutSession(requests.Session): class HTTPNotFoundError(BeetsHTTPError):
STATUS = HTTPStatus.NOT_FOUND
class Closeable(Protocol):
"""Protocol for objects that have a close method."""
def close(self) -> None: ...
C = TypeVar("C", bound=Closeable)
class SingletonMeta(type, Generic[C]):
"""Metaclass ensuring a single shared instance per class.
Creates one instance per class type on first instantiation, reusing it
for all subsequent calls. Automatically registers cleanup on program exit
for proper resource management.
"""
_instances: ClassVar[dict[type[Any], Any]] = {}
_lock: ClassVar[threading.Lock] = threading.Lock()
def __call__(cls, *args: Any, **kwargs: Any) -> C:
if cls not in cls._instances:
with cls._lock:
if cls not in SingletonMeta._instances:
instance = super().__call__(*args, **kwargs)
SingletonMeta._instances[cls] = instance
atexit.register(instance.close)
return SingletonMeta._instances[cls]
class TimeoutSession(requests.Session, metaclass=SingletonMeta):
"""HTTP session with automatic timeout and status checking.
Extends requests.Session to provide sensible defaults for beets HTTP
requests: automatic timeout enforcement, status code validation, and
proper user agent identification.
"""
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.headers["User-Agent"] = f"beets/{__version__} https://beets.io/" self.headers["User-Agent"] = f"beets/{__version__} https://beets.io/"
@atexit.register
def close_session():
"""Close the requests session on shut down."""
self.close()
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
"""Wrap the request method to raise an exception on HTTP errors.""" """Execute HTTP request with automatic timeout and status validation.
Ensures all requests have a timeout (defaults to 10 seconds) and raises
an exception for HTTP error status codes.
"""
kwargs.setdefault("timeout", 10) kwargs.setdefault("timeout", 10)
r = super().request(*args, **kwargs) r = super().request(*args, **kwargs)
if r.status_code == HTTPStatus.NOT_FOUND:
raise HTTPNotFoundError("HTTP Error: Not Found", response=r)
if 300 <= r.status_code < 400:
raise CaptchaError("Captcha is required", response=r)
r.raise_for_status() r.raise_for_status()
return r return r
class RequestHandler:
"""Manages HTTP requests with custom error handling and session management.
Provides a reusable interface for making HTTP requests with automatic
conversion of standard HTTP errors to beets-specific exceptions. Supports
custom session types and error mappings that can be overridden by
subclasses.
"""
session_type: ClassVar[type[TimeoutSession]] = TimeoutSession
explicit_http_errors: ClassVar[list[type[BeetsHTTPError]]] = [
HTTPNotFoundError
]
@cached_property
def session(self) -> Any:
"""Lazily initialize and cache the HTTP session."""
return self.session_type()
def status_to_error(
self, code: int
) -> type[requests.exceptions.HTTPError] | None:
"""Map HTTP status codes to beets-specific exception types.
Searches the configured explicit HTTP errors for a matching status code.
Returns None if no specific error type is registered for the given code.
"""
return next(
(e for e in self.explicit_http_errors if e.STATUS == code), None
)
@contextmanager
def handle_http_error(self) -> Iterator[None]:
"""Convert standard HTTP errors to beets-specific exceptions.
Wraps operations that may raise HTTPError, automatically translating
recognized status codes into their corresponding beets exception types.
Unrecognized errors are re-raised unchanged.
"""
try:
yield
except requests.exceptions.HTTPError as e:
if beets_error := self.status_to_error(e.response.status_code):
raise beets_error(response=e.response) from e
raise
def request(self, method: str, *args, **kwargs) -> requests.Response:
"""Perform HTTP request using the session with automatic error handling.
Delegates to the underlying session method while converting recognized
HTTP errors to beets-specific exceptions through the error handler.
"""
with self.handle_http_error():
return getattr(self.session, method)(*args, **kwargs)
def get(self, *args, **kwargs) -> requests.Response:
"""Perform HTTP GET request with automatic error handling."""
return self.request("get", *args, **kwargs)
def fetch_json(self, *args, **kwargs):
"""Fetch and parse JSON data from an HTTP endpoint."""
return self.get(*args, **kwargs).json()

View file

@ -26,7 +26,7 @@ from functools import cached_property, partial, total_ordering
from html import unescape from html import unescape
from itertools import groupby from itertools import groupby
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Iterable, Iterator, NamedTuple from typing import TYPE_CHECKING, NamedTuple
from urllib.parse import quote, quote_plus, urlencode, urlparse from urllib.parse import quote, quote_plus, urlencode, urlparse
import langdetect import langdetect
@ -34,14 +34,17 @@ import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from unidecode import unidecode from unidecode import unidecode
import beets
from beets import plugins, ui from beets import plugins, ui
from beets.autotag.distance import string_dist from beets.autotag.distance import string_dist
from beets.util.config import sanitize_choices from beets.util.config import sanitize_choices
from ._utils.requests import CaptchaError, HTTPNotFoundError, TimeoutSession from ._utils.requests import HTTPNotFoundError, RequestHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
import confuse
from beets.importer import ImportTask from beets.importer import ImportTask
from beets.library import Item, Library from beets.library import Item, Library
from beets.logging import BeetsLogger as Logger from beets.logging import BeetsLogger as Logger
@ -57,7 +60,9 @@ if TYPE_CHECKING:
INSTRUMENTAL_LYRICS = "[Instrumental]" INSTRUMENTAL_LYRICS = "[Instrumental]"
r_session = TimeoutSession() class CaptchaError(requests.exceptions.HTTPError):
def __init__(self, *args, **kwargs) -> None:
super().__init__("Captcha is required", *args, **kwargs)
# Utilities. # Utilities.
@ -153,9 +158,18 @@ def slug(text: str) -> str:
return re.sub(r"\W+", "-", unidecode(text).lower().strip()).strip("-") return re.sub(r"\W+", "-", unidecode(text).lower().strip()).strip("-")
class RequestHandler: class LyricsRequestHandler(RequestHandler):
_log: Logger _log: Logger
def status_to_error(self, code: int) -> type[requests.HTTPError] | None:
if err := super().status_to_error(code):
return err
if 300 <= code < 400:
return CaptchaError
return None
def debug(self, message: str, *args) -> None: def debug(self, message: str, *args) -> None:
"""Log a debug message with the class name.""" """Log a debug message with the class name."""
self._log.debug(f"{self.__class__.__name__}: {message}", *args) self._log.debug(f"{self.__class__.__name__}: {message}", *args)
@ -185,7 +199,7 @@ class RequestHandler:
""" """
url = self.format_url(url, params) url = self.format_url(url, params)
self.debug("Fetching HTML from {}", url) self.debug("Fetching HTML from {}", url)
r = r_session.get(url, **kwargs) r = self.get(url, **kwargs)
r.encoding = None r.encoding = None
return r.text return r.text
@ -193,13 +207,13 @@ class RequestHandler:
"""Return JSON data from the given URL.""" """Return JSON data from the given URL."""
url = self.format_url(url, params) url = self.format_url(url, params)
self.debug("Fetching JSON from {}", url) self.debug("Fetching JSON from {}", url)
return r_session.get(url, **kwargs).json() return super().fetch_json(url, **kwargs)
def post_json(self, url: str, params: JSONDict | None = None, **kwargs): def post_json(self, url: str, params: JSONDict | None = None, **kwargs):
"""Send POST request and return JSON response.""" """Send POST request and return JSON response."""
url = self.format_url(url, params) url = self.format_url(url, params)
self.debug("Posting JSON to {}", url) self.debug("Posting JSON to {}", url)
return r_session.post(url, **kwargs).json() return self.request("post", url, **kwargs).json()
@contextmanager @contextmanager
def handle_request(self) -> Iterator[None]: def handle_request(self) -> Iterator[None]:
@ -218,8 +232,10 @@ class BackendClass(type):
return cls.__name__.lower() return cls.__name__.lower()
class Backend(RequestHandler, metaclass=BackendClass): class Backend(LyricsRequestHandler, metaclass=BackendClass):
def __init__(self, config, log): config: confuse.Subview
def __init__(self, config: confuse.Subview, log: Logger) -> None:
self._log = log self._log = log
self.config = config self.config = config
@ -710,7 +726,7 @@ class Google(SearchBackend):
@dataclass @dataclass
class Translator(RequestHandler): class Translator(LyricsRequestHandler):
TRANSLATE_URL = "https://api.cognitive.microsofttranslator.com/translate" TRANSLATE_URL = "https://api.cognitive.microsofttranslator.com/translate"
LINE_PARTS_RE = re.compile(r"^(\[\d\d:\d\d.\d\d\]|) *(.*)$") LINE_PARTS_RE = re.compile(r"^(\[\d\d:\d\d.\d\d\]|) *(.*)$")
SEPARATOR = " | " SEPARATOR = " | "
@ -918,7 +934,7 @@ class RestFiles:
ui.print_(textwrap.dedent(text)) ui.print_(textwrap.dedent(text))
class LyricsPlugin(RequestHandler, plugins.BeetsPlugin): class LyricsPlugin(LyricsRequestHandler, plugins.BeetsPlugin):
BACKEND_BY_NAME = { BACKEND_BY_NAME = {
b.name: b for b in [LRCLib, Google, Genius, Tekstowo, MusiXmatch] b.name: b for b in [LRCLib, Google, Genius, Tekstowo, MusiXmatch]
} }

View file

@ -34,6 +34,7 @@ from beets.metadata_plugins import MetadataSourcePlugin
from beets.util.id_extractors import extract_release_id from beets.util.id_extractors import extract_release_id
from ._utils.requests import HTTPNotFoundError, TimeoutSession from ._utils.requests import HTTPNotFoundError, TimeoutSession
from .lyrics import RequestHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Literal from typing import Literal
@ -58,10 +59,6 @@ FIELDS_TO_MB_KEYS = {
} }
class LimiterTimeoutSession(LimiterMixin, TimeoutSession):
pass
RELEASE_INCLUDES = [ RELEASE_INCLUDES = [
"artists", "artists",
"media", "media",
@ -99,30 +96,36 @@ BROWSE_CHUNKSIZE = 100
BROWSE_MAXTRACKS = 500 BROWSE_MAXTRACKS = 500
class LimiterTimeoutSession(LimiterMixin, TimeoutSession):
pass
@dataclass @dataclass
class MusicBrainzAPI: class MusicBrainzAPI(RequestHandler):
session_type = LimiterTimeoutSession
api_host: str api_host: str
rate_limit: float rate_limit: float
@cached_property @cached_property
def session(self) -> LimiterTimeoutSession: def session(self) -> LimiterTimeoutSession:
return LimiterTimeoutSession(per_second=self.rate_limit) return self.session_type(per_second=self.rate_limit)
def _get(self, entity: str, **kwargs) -> JSONDict: def get_entity(self, entity: str, **kwargs) -> JSONDict:
return self.session.get( return self.fetch_json(
f"{self.api_host}/ws/2/{entity}", params={**kwargs, "fmt": "json"} f"{self.api_host}/ws/2/{entity}", params={**kwargs, "fmt": "json"}
).json() )
def get_release(self, id_: str) -> JSONDict: def get_release(self, id_: str) -> JSONDict:
return self._get(f"release/{id_}", inc=" ".join(RELEASE_INCLUDES)) return self.get_entity(f"release/{id_}", inc=" ".join(RELEASE_INCLUDES))
def get_recording(self, id_: str) -> JSONDict: def get_recording(self, id_: str) -> JSONDict:
return self._get(f"recording/{id_}", inc=" ".join(TRACK_INCLUDES)) return self.get_entity(f"recording/{id_}", inc=" ".join(TRACK_INCLUDES))
def browse_recordings(self, **kwargs) -> list[JSONDict]: def browse_recordings(self, **kwargs) -> list[JSONDict]:
kwargs.setdefault("limit", BROWSE_CHUNKSIZE) kwargs.setdefault("limit", BROWSE_CHUNKSIZE)
kwargs.setdefault("inc", BROWSE_INCLUDES) kwargs.setdefault("inc", BROWSE_INCLUDES)
return self._get("recording", **kwargs)["recordings"] return self.get_entity("recording", **kwargs)["recordings"]
def _preferred_alias(aliases: list[JSONDict]): def _preferred_alias(aliases: list[JSONDict]):
@ -149,7 +152,7 @@ def _preferred_alias(aliases: list[JSONDict]):
if ( if (
alias["locale"] == locale alias["locale"] == locale
and alias.get("primary") and alias.get("primary")
and alias.get("type", "").lower() not in ignored_alias_types and (alias.get("type") or "").lower() not in ignored_alias_types
): ):
matches.append(alias) matches.append(alias)
@ -790,7 +793,7 @@ class MusicBrainzPlugin(MetadataSourcePlugin):
self._log.debug( self._log.debug(
"Searching for MusicBrainz {}s with: {!r}", query_type, query "Searching for MusicBrainz {}s with: {!r}", query_type, query
) )
return self.api._get( return self.api.get_entity(
query_type, query=query, limit=self.config["search_limit"].get() query_type, query=query, limit=self.config["search_limit"].get()
)[f"{query_type}s"] )[f"{query_type}s"]

View file

@ -1031,7 +1031,7 @@ class TestMusicBrainzPlugin(PluginMixin):
def test_item_candidates(self, monkeypatch, mb): def test_item_candidates(self, monkeypatch, mb):
monkeypatch.setattr( monkeypatch.setattr(
"beetsplug.musicbrainz.MusicBrainzAPI._get", "beetsplug.musicbrainz.MusicBrainzAPI.fetch_json",
lambda *_, **__: {"recordings": [self.RECORDING]}, lambda *_, **__: {"recordings": [self.RECORDING]},
) )
@ -1042,7 +1042,7 @@ class TestMusicBrainzPlugin(PluginMixin):
def test_candidates(self, monkeypatch, mb): def test_candidates(self, monkeypatch, mb):
monkeypatch.setattr( monkeypatch.setattr(
"beetsplug.musicbrainz.MusicBrainzAPI._get", "beetsplug.musicbrainz.MusicBrainzAPI.fetch_json",
lambda *_, **__: {"releases": [{"id": self.mbid}]}, lambda *_, **__: {"releases": [{"id": self.mbid}]},
) )
monkeypatch.setattr( monkeypatch.setattr(