Use a decorator-based approach

This commit is contained in:
Šarūnas Nejus 2026-01-30 22:28:52 +00:00
parent 8e0b3f1323
commit cb6ad89ce6
3 changed files with 76 additions and 139 deletions

View file

@ -9,33 +9,26 @@ from __future__ import annotations
import abc
import re
from functools import cache, cached_property
from typing import (
TYPE_CHECKING,
Callable,
Generic,
Literal,
TypedDict,
TypeVar,
)
from contextlib import contextmanager, nullcontext
from functools import cache, cached_property, wraps
from typing import TYPE_CHECKING, Generic, Literal, TypedDict, TypeVar
import unidecode
from confuse import NotFoundError
from typing_extensions import NotRequired, ParamSpec
from typing_extensions import NotRequired
from beets import config, logging
from beets.util import cached_classproperty
from beets.util.id_extractors import extract_release_id
from .plugins import BeetsPlugin, find_plugins, notify_info_yielded, send
from .plugins import BeetsPlugin, find_plugins, notify_info_yielded
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Iterator, Sequence
from .autotag.hooks import AlbumInfo, Item, TrackInfo
P = ParamSpec("P")
R = TypeVar("R")
Ret = TypeVar("Ret")
# Global logger.
log = logging.getLogger("beets")
@ -46,52 +39,68 @@ def find_metadata_source_plugins() -> list[MetadataSourcePlugin]:
"""Return a list of all loaded metadata source plugins."""
# TODO: Make this an isinstance(MetadataSourcePlugin, ...) check in v3.0.0
# This should also allow us to remove the type: ignore comments below.
metadata_plugins = [p for p in find_plugins() if hasattr(p, "data_source")]
return [p for p in find_plugins() if hasattr(p, "data_source")] # type: ignore[misc]
if config["raise_on_error"].get(bool):
return metadata_plugins # type: ignore[return-value]
else:
return list(map(SafeProxy, metadata_plugins)) # type: ignore[arg-type]
@contextmanager
def handle_plugin_error(plugin: MetadataSourcePlugin, method_name: str):
"""Safely call a plugin method, catching and logging exceptions."""
try:
yield
except Exception as e:
log.error("Error in '{}.{}': {}", plugin.data_source, method_name, e)
log.debug("Exception details:", exc_info=True)
def _yield_from_plugins(
func: Callable[..., Iterable[Ret]],
) -> Callable[..., Iterator[Ret]]:
method_name = func.__name__
@wraps(func)
def wrapper(*args, **kwargs) -> Iterator[Ret]:
for plugin in find_metadata_source_plugins():
method = getattr(plugin, method_name)
with (
nullcontext()
if config["raise_on_error"]
else handle_plugin_error(plugin, method_name)
):
yield from filter(None, method(*args, **kwargs))
return wrapper
@notify_info_yielded("albuminfo_received")
def candidates(*args, **kwargs) -> Iterable[AlbumInfo]:
"""Return matching album candidates from all metadata source plugins."""
for plugin in find_metadata_source_plugins():
yield from plugin.candidates(*args, **kwargs)
@_yield_from_plugins
def candidates(*args, **kwargs) -> Iterator[AlbumInfo]:
yield from ()
@notify_info_yielded("trackinfo_received")
def item_candidates(*args, **kwargs) -> Iterable[TrackInfo]:
"""Return matching track candidates from all metadata source plugins."""
for plugin in find_metadata_source_plugins():
yield from plugin.item_candidates(*args, **kwargs)
@_yield_from_plugins
def item_candidates(*args, **kwargs) -> Iterator[TrackInfo]:
yield from ()
@notify_info_yielded("albuminfo_received")
@_yield_from_plugins
def albums_for_ids(*args, **kwargs) -> Iterator[AlbumInfo]:
yield from ()
@notify_info_yielded("trackinfo_received")
@_yield_from_plugins
def tracks_for_ids(*args, **kwargs) -> Iterator[TrackInfo]:
yield from ()
def album_for_id(_id: str) -> AlbumInfo | None:
"""Get AlbumInfo object for the given ID string.
A single ID can yield just a single album, so we return the first match.
"""
for plugin in find_metadata_source_plugins():
if info := plugin.album_for_id(_id):
send("albuminfo_received", info=info)
return info
return None
return next(albums_for_ids([_id]), None)
def track_for_id(_id: str) -> TrackInfo | None:
"""Get TrackInfo object for the given ID string.
A single ID can yield just a single track, so we return the first match.
"""
for plugin in find_metadata_source_plugins():
if info := plugin.track_for_id(_id):
send("trackinfo_received", info=info)
return info
return None
return next(tracks_for_ids([_id]), None)
@cache
@ -279,11 +288,11 @@ class SearchFilter(TypedDict):
album: NotRequired[str]
Res = TypeVar("Res", bound=IDResponse)
R = TypeVar("R", bound=IDResponse)
class SearchApiMetadataSourcePlugin(
Generic[Res], MetadataSourcePlugin, metaclass=abc.ABCMeta
Generic[R], MetadataSourcePlugin, metaclass=abc.ABCMeta
):
"""Helper class to implement a metadata source plugin with an API.
@ -308,7 +317,7 @@ class SearchApiMetadataSourcePlugin(
query_type: Literal["album", "track"],
filters: SearchFilter,
query_string: str = "",
) -> Sequence[Res]:
) -> Sequence[R]:
"""Perform a search on the API.
:param query_type: The type of query to perform.
@ -377,81 +386,3 @@ class SearchApiMetadataSourcePlugin(
query = unidecode.unidecode(query)
return query
# To have proper typing for the proxy class below, we need to
# trick mypy into thinking that SafeProxy is a subclass of
# MetadataSourcePlugin.
# https://stackoverflow.com/questions/71365594/how-to-make-a-proxy-object-with-typing-as-underlying-object-in-python
Proxied = TypeVar("Proxied", bound=MetadataSourcePlugin)
if TYPE_CHECKING:
base = MetadataSourcePlugin
else:
base = object
class SafeProxy(base):
"""A proxy class that forwards all attribute access to the wrapped
MetadataSourcePlugin instance.
We use this to catch and log exceptions from metadata source plugins
without crashing beets. E.g. on long running autotag operations.
"""
__plugin: MetadataSourcePlugin
def __init__(self, plugin: MetadataSourcePlugin):
self.__plugin = plugin
def __getattribute__(self, name):
if name in {
"_SafeProxy__plugin",
"_SafeProxy__handle_exception",
"candidates",
"item_candidates",
"album_for_id",
"track_for_id",
}:
return super().__getattribute__(name)
else:
return getattr(self.__plugin, name)
def __setattr__(self, name, value):
if name == "_SafeProxy__plugin":
super().__setattr__(name, value)
else:
self.__plugin.__setattr__(name, value)
def __handle_exception(self, func: Callable[P, R], e: Exception) -> None:
"""Helper function to log exceptions from metadata source plugins."""
log.error(
"Error in '{}.{}': {}",
self.__plugin.data_source,
func.__name__,
e,
)
log.debug("Exception details:", exc_info=True)
def album_for_id(self, *args, **kwargs):
try:
return self.__plugin.album_for_id(*args, **kwargs)
except Exception as e:
return self.__handle_exception(self.__plugin.album_for_id, e)
def track_for_id(self, *args, **kwargs):
try:
return self.__plugin.track_for_id(*args, **kwargs)
except Exception as e:
return self.__handle_exception(self.__plugin.track_for_id, e)
def candidates(self, *args, **kwargs):
try:
yield from self.__plugin.candidates(*args, **kwargs)
except Exception as e:
return self.__handle_exception(self.__plugin.candidates, e)
def item_candidates(self, *args, **kwargs):
try:
yield from self.__plugin.item_candidates(*args, **kwargs)
except Exception as e:
return self.__handle_exception(self.__plugin.item_candidates, e)

View file

@ -35,7 +35,7 @@ from beets.util import unique_list
from beets.util.deprecation import deprecate_for_maintainers, deprecate_for_user
if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Iterable, Iterator, Sequence
from confuse import ConfigView
@ -58,7 +58,6 @@ if TYPE_CHECKING:
P = ParamSpec("P")
Ret = TypeVar("Ret", bound=Any)
Listener = Callable[..., Any]
IterF = Callable[P, Iterable[Ret]]
PLUGIN_NAMESPACE = "beetsplug"
@ -548,7 +547,7 @@ def named_queries(model_cls: type[AnyModel]) -> dict[str, FieldQueryType]:
def notify_info_yielded(
event: EventType,
) -> Callable[[IterF[P, Ret]], IterF[P, Ret]]:
) -> Callable[[Callable[P, Iterable[Ret]]], Callable[P, Iterator[Ret]]]:
"""Makes a generator send the event 'event' every time it yields.
This decorator is supposed to decorate a generator, but any function
returning an iterable should work.
@ -556,9 +555,11 @@ def notify_info_yielded(
'send'.
"""
def decorator(func: IterF[P, Ret]) -> IterF[P, Ret]:
def decorator(
func: Callable[P, Iterable[Ret]],
) -> Callable[P, Iterator[Ret]]:
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterable[Ret]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterator[Ret]:
for v in func(*args, **kwargs):
send(event, info=v)
yield v

View file

@ -45,18 +45,23 @@ class TestMetadataPluginsException(PluginMixin):
self.unload_plugins()
@pytest.mark.parametrize(
"method_name,args",
"method_name,error_method_name,args",
[
("candidates", ()),
("item_candidates", ()),
("album_for_id", ("some_id",)),
("track_for_id", ("some_id",)),
("candidates", "candidates", ()),
("item_candidates", "item_candidates", ()),
("albums_for_ids", "albums_for_ids", (["some_id"],)),
("tracks_for_ids", "tracks_for_ids", (["some_id"],)),
# Currently, singular methods call plural ones internally and log
# errors from there
("album_for_id", "albums_for_ids", ("some_id",)),
("track_for_id", "tracks_for_ids", ("some_id",)),
],
)
def test_logging(
self,
caplog,
method_name,
error_method_name,
args,
):
self.config["raise_on_error"] = False
@ -72,7 +77,7 @@ class TestMetadataPluginsException(PluginMixin):
for msg in logs:
assert (
msg
== f"Error in 'ErrorMetadataMockPlugin.{method_name}': Mocked error"
== f"Error in 'ErrorMetadataMockPlugin.{error_method_name}': Mocked error" # noqa: E501
)
caplog.clear()