From cb6ad89ce65662a6daa9499aa9a8d0c6cf241836 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Fri, 30 Jan 2026 22:28:52 +0000 Subject: [PATCH] Use a decorator-based approach --- beets/metadata_plugins.py | 187 +++++++++++----------------------- beets/plugins.py | 11 +- test/test_metadata_plugins.py | 17 ++-- 3 files changed, 76 insertions(+), 139 deletions(-) diff --git a/beets/metadata_plugins.py b/beets/metadata_plugins.py index 8c3b438b0..7c08d72a3 100644 --- a/beets/metadata_plugins.py +++ b/beets/metadata_plugins.py @@ -9,33 +9,26 @@ from __future__ import annotations import abc import re -from functools import cache, cached_property -from typing import ( - TYPE_CHECKING, - Callable, - Generic, - Literal, - TypedDict, - TypeVar, -) +from contextlib import contextmanager, nullcontext +from functools import cache, cached_property, wraps +from typing import TYPE_CHECKING, Generic, Literal, TypedDict, TypeVar import unidecode from confuse import NotFoundError -from typing_extensions import NotRequired, ParamSpec +from typing_extensions import NotRequired from beets import config, logging from beets.util import cached_classproperty from beets.util.id_extractors import extract_release_id -from .plugins import BeetsPlugin, find_plugins, notify_info_yielded, send +from .plugins import BeetsPlugin, find_plugins, notify_info_yielded if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + from collections.abc import Callable, Iterable, Iterator, Sequence from .autotag.hooks import AlbumInfo, Item, TrackInfo - P = ParamSpec("P") - R = TypeVar("R") + Ret = TypeVar("Ret") # Global logger. log = logging.getLogger("beets") @@ -46,52 +39,68 @@ def find_metadata_source_plugins() -> list[MetadataSourcePlugin]: """Return a list of all loaded metadata source plugins.""" # TODO: Make this an isinstance(MetadataSourcePlugin, ...) check in v3.0.0 # This should also allow us to remove the type: ignore comments below. - metadata_plugins = [p for p in find_plugins() if hasattr(p, "data_source")] + return [p for p in find_plugins() if hasattr(p, "data_source")] # type: ignore[misc] - if config["raise_on_error"].get(bool): - return metadata_plugins # type: ignore[return-value] - else: - return list(map(SafeProxy, metadata_plugins)) # type: ignore[arg-type] + +@contextmanager +def handle_plugin_error(plugin: MetadataSourcePlugin, method_name: str): + """Safely call a plugin method, catching and logging exceptions.""" + try: + yield + except Exception as e: + log.error("Error in '{}.{}': {}", plugin.data_source, method_name, e) + log.debug("Exception details:", exc_info=True) + + +def _yield_from_plugins( + func: Callable[..., Iterable[Ret]], +) -> Callable[..., Iterator[Ret]]: + method_name = func.__name__ + + @wraps(func) + def wrapper(*args, **kwargs) -> Iterator[Ret]: + for plugin in find_metadata_source_plugins(): + method = getattr(plugin, method_name) + with ( + nullcontext() + if config["raise_on_error"] + else handle_plugin_error(plugin, method_name) + ): + yield from filter(None, method(*args, **kwargs)) + + return wrapper @notify_info_yielded("albuminfo_received") -def candidates(*args, **kwargs) -> Iterable[AlbumInfo]: - """Return matching album candidates from all metadata source plugins.""" - for plugin in find_metadata_source_plugins(): - yield from plugin.candidates(*args, **kwargs) +@_yield_from_plugins +def candidates(*args, **kwargs) -> Iterator[AlbumInfo]: + yield from () @notify_info_yielded("trackinfo_received") -def item_candidates(*args, **kwargs) -> Iterable[TrackInfo]: - """Return matching track candidates from all metadata source plugins.""" - for plugin in find_metadata_source_plugins(): - yield from plugin.item_candidates(*args, **kwargs) +@_yield_from_plugins +def item_candidates(*args, **kwargs) -> Iterator[TrackInfo]: + yield from () + + +@notify_info_yielded("albuminfo_received") +@_yield_from_plugins +def albums_for_ids(*args, **kwargs) -> Iterator[AlbumInfo]: + yield from () + + +@notify_info_yielded("trackinfo_received") +@_yield_from_plugins +def tracks_for_ids(*args, **kwargs) -> Iterator[TrackInfo]: + yield from () def album_for_id(_id: str) -> AlbumInfo | None: - """Get AlbumInfo object for the given ID string. - - A single ID can yield just a single album, so we return the first match. - """ - for plugin in find_metadata_source_plugins(): - if info := plugin.album_for_id(_id): - send("albuminfo_received", info=info) - return info - - return None + return next(albums_for_ids([_id]), None) def track_for_id(_id: str) -> TrackInfo | None: - """Get TrackInfo object for the given ID string. - - A single ID can yield just a single track, so we return the first match. - """ - for plugin in find_metadata_source_plugins(): - if info := plugin.track_for_id(_id): - send("trackinfo_received", info=info) - return info - - return None + return next(tracks_for_ids([_id]), None) @cache @@ -279,11 +288,11 @@ class SearchFilter(TypedDict): album: NotRequired[str] -Res = TypeVar("Res", bound=IDResponse) +R = TypeVar("R", bound=IDResponse) class SearchApiMetadataSourcePlugin( - Generic[Res], MetadataSourcePlugin, metaclass=abc.ABCMeta + Generic[R], MetadataSourcePlugin, metaclass=abc.ABCMeta ): """Helper class to implement a metadata source plugin with an API. @@ -308,7 +317,7 @@ class SearchApiMetadataSourcePlugin( query_type: Literal["album", "track"], filters: SearchFilter, query_string: str = "", - ) -> Sequence[Res]: + ) -> Sequence[R]: """Perform a search on the API. :param query_type: The type of query to perform. @@ -377,81 +386,3 @@ class SearchApiMetadataSourcePlugin( query = unidecode.unidecode(query) return query - - -# To have proper typing for the proxy class below, we need to -# trick mypy into thinking that SafeProxy is a subclass of -# MetadataSourcePlugin. -# https://stackoverflow.com/questions/71365594/how-to-make-a-proxy-object-with-typing-as-underlying-object-in-python -Proxied = TypeVar("Proxied", bound=MetadataSourcePlugin) -if TYPE_CHECKING: - base = MetadataSourcePlugin -else: - base = object - - -class SafeProxy(base): - """A proxy class that forwards all attribute access to the wrapped - MetadataSourcePlugin instance. - - We use this to catch and log exceptions from metadata source plugins - without crashing beets. E.g. on long running autotag operations. - """ - - __plugin: MetadataSourcePlugin - - def __init__(self, plugin: MetadataSourcePlugin): - self.__plugin = plugin - - def __getattribute__(self, name): - if name in { - "_SafeProxy__plugin", - "_SafeProxy__handle_exception", - "candidates", - "item_candidates", - "album_for_id", - "track_for_id", - }: - return super().__getattribute__(name) - else: - return getattr(self.__plugin, name) - - def __setattr__(self, name, value): - if name == "_SafeProxy__plugin": - super().__setattr__(name, value) - else: - self.__plugin.__setattr__(name, value) - - def __handle_exception(self, func: Callable[P, R], e: Exception) -> None: - """Helper function to log exceptions from metadata source plugins.""" - log.error( - "Error in '{}.{}': {}", - self.__plugin.data_source, - func.__name__, - e, - ) - log.debug("Exception details:", exc_info=True) - - def album_for_id(self, *args, **kwargs): - try: - return self.__plugin.album_for_id(*args, **kwargs) - except Exception as e: - return self.__handle_exception(self.__plugin.album_for_id, e) - - def track_for_id(self, *args, **kwargs): - try: - return self.__plugin.track_for_id(*args, **kwargs) - except Exception as e: - return self.__handle_exception(self.__plugin.track_for_id, e) - - def candidates(self, *args, **kwargs): - try: - yield from self.__plugin.candidates(*args, **kwargs) - except Exception as e: - return self.__handle_exception(self.__plugin.candidates, e) - - def item_candidates(self, *args, **kwargs): - try: - yield from self.__plugin.item_candidates(*args, **kwargs) - except Exception as e: - return self.__handle_exception(self.__plugin.item_candidates, e) diff --git a/beets/plugins.py b/beets/plugins.py index ec3f999c4..01d9d3327 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -35,7 +35,7 @@ from beets.util import unique_list from beets.util.deprecation import deprecate_for_maintainers, deprecate_for_user if TYPE_CHECKING: - from collections.abc import Callable, Iterable, Sequence + from collections.abc import Callable, Iterable, Iterator, Sequence from confuse import ConfigView @@ -58,7 +58,6 @@ if TYPE_CHECKING: P = ParamSpec("P") Ret = TypeVar("Ret", bound=Any) Listener = Callable[..., Any] - IterF = Callable[P, Iterable[Ret]] PLUGIN_NAMESPACE = "beetsplug" @@ -548,7 +547,7 @@ def named_queries(model_cls: type[AnyModel]) -> dict[str, FieldQueryType]: def notify_info_yielded( event: EventType, -) -> Callable[[IterF[P, Ret]], IterF[P, Ret]]: +) -> Callable[[Callable[P, Iterable[Ret]]], Callable[P, Iterator[Ret]]]: """Makes a generator send the event 'event' every time it yields. This decorator is supposed to decorate a generator, but any function returning an iterable should work. @@ -556,9 +555,11 @@ def notify_info_yielded( 'send'. """ - def decorator(func: IterF[P, Ret]) -> IterF[P, Ret]: + def decorator( + func: Callable[P, Iterable[Ret]], + ) -> Callable[P, Iterator[Ret]]: @wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterable[Ret]: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterator[Ret]: for v in func(*args, **kwargs): send(event, info=v) yield v diff --git a/test/test_metadata_plugins.py b/test/test_metadata_plugins.py index edf66adcd..d34185330 100644 --- a/test/test_metadata_plugins.py +++ b/test/test_metadata_plugins.py @@ -45,18 +45,23 @@ class TestMetadataPluginsException(PluginMixin): self.unload_plugins() @pytest.mark.parametrize( - "method_name,args", + "method_name,error_method_name,args", [ - ("candidates", ()), - ("item_candidates", ()), - ("album_for_id", ("some_id",)), - ("track_for_id", ("some_id",)), + ("candidates", "candidates", ()), + ("item_candidates", "item_candidates", ()), + ("albums_for_ids", "albums_for_ids", (["some_id"],)), + ("tracks_for_ids", "tracks_for_ids", (["some_id"],)), + # Currently, singular methods call plural ones internally and log + # errors from there + ("album_for_id", "albums_for_ids", ("some_id",)), + ("track_for_id", "tracks_for_ids", ("some_id",)), ], ) def test_logging( self, caplog, method_name, + error_method_name, args, ): self.config["raise_on_error"] = False @@ -72,7 +77,7 @@ class TestMetadataPluginsException(PluginMixin): for msg in logs: assert ( msg - == f"Error in 'ErrorMetadataMockPlugin.{method_name}': Mocked error" + == f"Error in 'ErrorMetadataMockPlugin.{error_method_name}': Mocked error" # noqa: E501 ) caplog.clear()