diff --git a/beets/plugins.py b/beets/plugins.py index 3e04ccdfc..b78058d36 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -23,7 +23,7 @@ import traceback from collections import defaultdict from functools import wraps from types import GenericAlias -from typing import TYPE_CHECKING, Any, Callable, Sequence, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar import mediafile from typing_extensions import ParamSpec @@ -32,17 +32,14 @@ import beets from beets import logging if TYPE_CHECKING: - from beets.event_types import EventType - - -if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Callable, Iterable, Sequence from confuse import ConfigView from beets.dbcore import Query from beets.dbcore.db import FieldQueryType from beets.dbcore.types import Type + from beets.event_types import EventType from beets.importer import ImportSession, ImportTask from beets.library import Album, Item, Library from beets.ui import Subcommand @@ -58,7 +55,7 @@ if TYPE_CHECKING: P = ParamSpec("P") Ret = TypeVar("Ret", bound=Any) - Listener = Callable[..., None] + Listener = Callable[..., Any] IterF = Callable[P, Iterable[Ret]] @@ -105,6 +102,14 @@ class BeetsPlugin(metaclass=abc.ABCMeta): the abstract methods defined here. """ + _raw_listeners: ClassVar[dict[EventType, list[Listener]]] = defaultdict( + list + ) + listeners: ClassVar[dict[EventType, list[Listener]]] = defaultdict(list) + template_funcs: TFuncMap[str] | None = None + template_fields: TFuncMap[Item] | None = None + album_template_fields: TFuncMap[Album] | None = None + name: str config: ConfigView early_import_stages: list[ImportStageFunc] @@ -218,25 +223,13 @@ class BeetsPlugin(metaclass=abc.ABCMeta): mediafile.MediaFile.add_field(name, descriptor) library.Item._media_fields.add(name) - _raw_listeners: dict[str, list[Listener]] | None = None - listeners: dict[str, list[Listener]] | None = None - - def register_listener(self, event: "EventType", func: Listener): + def register_listener(self, event: EventType, func: Listener) -> None: """Add a function as a listener for the specified event.""" - wrapped_func = self._set_log_level_and_params(logging.WARNING, func) - - cls = self.__class__ - - if cls.listeners is None or cls._raw_listeners is None: - cls._raw_listeners = defaultdict(list) - cls.listeners = defaultdict(list) - if func not in cls._raw_listeners[event]: - cls._raw_listeners[event].append(func) - cls.listeners[event].append(wrapped_func) - - template_funcs: TFuncMap[str] | None = None - template_fields: TFuncMap[Item] | None = None - album_template_fields: TFuncMap[Album] | None = None + if func not in self._raw_listeners[event]: + self._raw_listeners[event].append(func) + self.listeners[event].append( + self._set_log_level_and_params(logging.WARNING, func) + ) @classmethod def template_func(cls, name: str) -> Callable[[TFunc[str]], TFunc[str]]: @@ -383,7 +376,9 @@ def named_queries(model_cls: type[AnyModel]) -> dict[str, FieldQueryType]: } -def notify_info_yielded(event: str) -> Callable[[IterF[P, Ret]], IterF[P, Ret]]: +def notify_info_yielded( + event: EventType, +) -> Callable[[IterF[P, Ret]], IterF[P, Ret]]: """Makes a generator send the event 'event' every time it yields. This decorator is supposed to decorate a generator, but any function returning an iterable should work. @@ -474,19 +469,7 @@ def album_field_getters() -> TFuncMap[Album]: # Event dispatch. -def event_handlers() -> dict[str, list[Listener]]: - """Find all event handlers from plugins as a dictionary mapping - event names to sequences of callables. - """ - all_handlers: dict[str, list[Listener]] = defaultdict(list) - for plugin in find_plugins(): - if plugin.listeners: - for event, handlers in plugin.listeners.items(): - all_handlers[event] += handlers - return all_handlers - - -def send(event: str, **arguments: Any) -> list[Any]: +def send(event: EventType, **arguments: Any) -> list[Any]: """Send an event to all assigned event listeners. `event` is the name of the event to send, all other named arguments @@ -495,12 +478,11 @@ def send(event: str, **arguments: Any) -> list[Any]: Return a list of non-None values returned from the handlers. """ log.debug("Sending event: {0}", event) - results: list[Any] = [] - for handler in event_handlers()[event]: - result = handler(**arguments) - if result is not None: - results.append(result) - return results + return [ + r + for handler in BeetsPlugin.listeners[event] + if (r := handler(**arguments)) is not None + ] def feat_tokens(for_artist: bool = True) -> str: diff --git a/beets/test/helper.py b/beets/test/helper.py index eb024a7aa..767ff41fe 100644 --- a/beets/test/helper.py +++ b/beets/test/helper.py @@ -498,8 +498,8 @@ class PluginMixin(ConfigMixin): def unload_plugins(self) -> None: """Unload all plugins and remove them from the configuration.""" # FIXME this should eventually be handled by a plugin manager - for plugin_class in beets.plugins._instances: - plugin_class.listeners = None + beets.plugins.BeetsPlugin.listeners.clear() + beets.plugins.BeetsPlugin._raw_listeners.clear() self.config["plugins"] = [] beets.plugins._classes = set() beets.plugins._instances = {} diff --git a/test/test_plugins.py b/test/test_plugins.py index 95378fc7b..413d87bb4 100644 --- a/test/test_plugins.py +++ b/test/test_plugins.py @@ -243,15 +243,7 @@ class ListenersTest(PluginLoaderTestCase): d.register_listener("cli_exit", d2.dummy) assert DummyPlugin._raw_listeners["cli_exit"] == [d.dummy, d2.dummy] - @patch("beets.plugins.find_plugins") - @patch("inspect.getfullargspec") - def test_events_called(self, mock_gfa, mock_find_plugins): - mock_gfa.return_value = Mock( - args=(), - varargs="args", - varkw="kwargs", - ) - + def test_events_called(self): class DummyPlugin(plugins.BeetsPlugin): def __init__(self): super().__init__() @@ -261,7 +253,6 @@ class ListenersTest(PluginLoaderTestCase): self.register_listener("event_bar", self.bar) d = DummyPlugin() - mock_find_plugins.return_value = (d,) plugins.send("event") d.foo.assert_has_calls([]) @@ -271,8 +262,7 @@ class ListenersTest(PluginLoaderTestCase): d.foo.assert_called_once_with(var="tagada") d.bar.assert_has_calls([]) - @patch("beets.plugins.find_plugins") - def test_listener_params(self, mock_find_plugins): + def test_listener_params(self): class DummyPlugin(plugins.BeetsPlugin): def __init__(self): super().__init__() @@ -316,8 +306,7 @@ class ListenersTest(PluginLoaderTestCase): def dummy9(self, **kwargs): assert kwargs == {"foo": 5} - d = DummyPlugin() - mock_find_plugins.return_value = (d,) + DummyPlugin() plugins.send("event1", foo=5) plugins.send("event2", foo=5)