Add retries for connection errors

This commit is contained in:
Šarūnas Nejus 2025-12-19 20:36:19 +00:00
parent 9dad040977
commit d1aa45a008
No known key found for this signature in database
3 changed files with 88 additions and 15 deletions

View file

@ -8,6 +8,8 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from beets import __version__
@ -60,18 +62,24 @@ class SingletonMeta(type, Generic[C]):
return SingletonMeta._instances[cls]
class TimeoutSession(requests.Session, metaclass=SingletonMeta):
"""HTTP session with automatic timeout and status checking.
class TimeoutAndRetrySession(requests.Session, metaclass=SingletonMeta):
"""HTTP session with sensible defaults.
Extends requests.Session to provide sensible defaults for beets HTTP
requests: automatic timeout enforcement, status code validation, and
proper user agent identification.
* default beets User-Agent header
* default request timeout
* automatic retries on transient connection errors
* raises exceptions for HTTP error status codes
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.headers["User-Agent"] = f"beets/{__version__} https://beets.io/"
retry = Retry(connect=2, total=2, backoff_factor=1)
adapter = HTTPAdapter(max_retries=retry)
self.mount("https://", adapter)
self.mount("http://", adapter)
def request(self, *args, **kwargs):
"""Execute HTTP request with automatic timeout and status validation.
@ -106,15 +114,21 @@ class RequestHandler:
Feel free to define common methods that are used in multiple plugins.
"""
session_type: ClassVar[type[TimeoutSession]] = TimeoutSession
explicit_http_errors: ClassVar[list[type[BeetsHTTPError]]] = [
HTTPNotFoundError
]
def create_session(self) -> TimeoutAndRetrySession:
"""Create a new HTTP session instance.
Can be overridden by subclasses to provide custom session types.
"""
return TimeoutAndRetrySession()
@cached_property
def session(self) -> Any:
def session(self) -> TimeoutAndRetrySession:
"""Lazily initialize and cache the HTTP session."""
return self.session_type()
return self.create_session()
def status_to_error(
self, code: int

View file

@ -35,7 +35,11 @@ from beets.metadata_plugins import MetadataSourcePlugin
from beets.util.deprecation import deprecate_for_user
from beets.util.id_extractors import extract_release_id
from ._utils.requests import HTTPNotFoundError, RequestHandler, TimeoutSession
from ._utils.requests import (
HTTPNotFoundError,
RequestHandler,
TimeoutAndRetrySession,
)
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
@ -99,20 +103,17 @@ BROWSE_CHUNKSIZE = 100
BROWSE_MAXTRACKS = 500
class LimiterTimeoutSession(LimiterMixin, TimeoutSession):
class LimiterTimeoutSession(LimiterMixin, TimeoutAndRetrySession):
pass
@dataclass
class MusicBrainzAPI(RequestHandler):
session_type = LimiterTimeoutSession
api_host: str
rate_limit: float
@cached_property
def session(self) -> LimiterTimeoutSession:
return self.session_type(per_second=self.rate_limit)
def create_session(self) -> LimiterTimeoutSession:
return LimiterTimeoutSession(per_second=self.rate_limit)
def get_entity(self, entity: str, **kwargs) -> JSONDict:
return self._group_relations(

View file

@ -0,0 +1,58 @@
import io
from http import HTTPStatus
from unittest.mock import Mock
from urllib.error import URLError
import pytest
import requests
from urllib3 import HTTPResponse
from urllib3.exceptions import NewConnectionError
from beetsplug._utils.requests import RequestHandler
class TestRequestHandlerRetry:
@pytest.fixture(autouse=True)
def patch_connection(self, monkeypatch, last_response):
monkeypatch.setattr(
"urllib3.connectionpool.HTTPConnectionPool._make_request",
Mock(
side_effect=[
NewConnectionError(None, "Connection failed"),
URLError("bad"),
last_response,
]
),
)
@pytest.fixture
def request_handler(self):
return RequestHandler()
@pytest.mark.parametrize(
"last_response",
[
HTTPResponse(
body=io.BytesIO(b"success"),
status=HTTPStatus.OK,
preload_content=False,
),
],
ids=["success"],
)
def test_retry_on_connection_error(self, request_handler):
"""Verify that the handler retries on connection errors."""
response = request_handler.get("http://example.com/api")
assert response.text == "success"
assert response.status_code == HTTPStatus.OK
@pytest.mark.parametrize(
"last_response", [ConnectionResetError], ids=["conn_error"]
)
def test_retry_exhaustion(self, request_handler):
"""Verify that the handler raises an error after exhausting retries."""
with pytest.raises(
requests.exceptions.ConnectionError, match="Max retries exceeded"
):
request_handler.get("http://example.com/api")