diff --git a/beetsplug/_utils/requests.py b/beetsplug/_utils/requests.py index b8ac541e9..313ed13b4 100644 --- a/beetsplug/_utils/requests.py +++ b/beetsplug/_utils/requests.py @@ -67,7 +67,7 @@ class TimeoutAndRetrySession(requests.Session, metaclass=SingletonMeta): * default beets User-Agent header * default request timeout - * automatic retries on transient connection errors + * automatic retries on transient connection or server errors * raises exceptions for HTTP error status codes """ @@ -75,7 +75,18 @@ class TimeoutAndRetrySession(requests.Session, metaclass=SingletonMeta): super().__init__(*args, **kwargs) self.headers["User-Agent"] = f"beets/{__version__} https://beets.io/" - retry = Retry(connect=2, total=2, backoff_factor=1) + retry = Retry( + connect=2, + total=2, + backoff_factor=1, + # Retry on server errors + status_forcelist=[ + HTTPStatus.INTERNAL_SERVER_ERROR, + HTTPStatus.BAD_GATEWAY, + HTTPStatus.SERVICE_UNAVAILABLE, + HTTPStatus.GATEWAY_TIMEOUT, + ], + ) adapter = HTTPAdapter(max_retries=retry) self.mount("https://", adapter) self.mount("http://", adapter) diff --git a/test/plugins/utils/test_request_handler.py b/test/plugins/utils/test_request_handler.py index c17a9387b..6887283dc 100644 --- a/test/plugins/utils/test_request_handler.py +++ b/test/plugins/utils/test_request_handler.py @@ -48,11 +48,20 @@ class TestRequestHandlerRetry: assert response.status_code == HTTPStatus.OK @pytest.mark.parametrize( - "last_response", [ConnectionResetError], ids=["conn_error"] + "last_response", + [ + ConnectionResetError, + HTTPResponse( + body=io.BytesIO(b"Server Error"), + status=HTTPStatus.INTERNAL_SERVER_ERROR, + preload_content=False, + ), + ], + ids=["conn_error", "server_error"], ) def test_retry_exhaustion(self, request_handler): """Verify that the handler raises an error after exhausting retries.""" with pytest.raises( - requests.exceptions.ConnectionError, match="Max retries exceeded" + requests.exceptions.RequestException, match="Max retries exceeded" ): request_handler.get("http://example.com/api")