diff --git a/beets/plugins.py b/beets/plugins.py index 8fa9d4a3d..a25fe6d8e 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -24,9 +24,11 @@ from collections import defaultdict from functools import wraps from typing import ( TYPE_CHECKING, + Any, Callable, Generic, Sequence, + Type, TypedDict, TypeVar, ) @@ -39,9 +41,11 @@ from beets import logging if TYPE_CHECKING: from collections.abc import Iterable + from confuse import ConfigView + from beets.autotag import AlbumInfo, Distance, TrackInfo from beets.dbcore import Query - from beets.library import Item, Library + from beets.library import Album, Item, Library from beets.ui import Subcommand @@ -79,6 +83,10 @@ class PluginLogFilter(logging.Filter): return True +# Typing for listeners +CallableWithAnyKwargs = Callable[[None, Any], None] + + # Managing the plugins themselves. @@ -88,16 +96,24 @@ class BeetsPlugin: the abstract methods defined here. """ + name: str + config: ConfigView + def __init__(self, name=None): """Perform one-time plugin setup.""" + self.name = name or self.__module__.split(".")[-1] self.config = beets.config[self.name] + + # Set class attributes if they are not already set + # for the type of plugin. if not self.template_funcs: self.template_funcs = {} if not self.template_fields: self.template_fields = {} if not self.album_template_fields: self.album_template_fields = {} + self.early_import_stages = [] self.import_stages = [] @@ -243,13 +259,14 @@ class BeetsPlugin: mediafile.MediaFile.add_field(name, descriptor) library.Item._media_fields.add(name) - _raw_listeners = None - listeners: None | dict[str, list[Callable]] = None + _raw_listeners: dict[str, list[Callable[[None, Any], None]]] | None = None + listeners: dict[str, list[Callable[[None, Any], None]]] | None = None - def register_listener(self, event: str, func: Callable): + def register_listener(self, event: str, func: CallableWithAnyKwargs): """Add a function as a listener for the specified event.""" wrapped_func = self._set_log_level_and_params(logging.WARNING, func) + cls = self.__class__ cls = self.__class__ if cls.listeners is None or cls._raw_listeners is None: cls._raw_listeners = defaultdict(list) @@ -258,9 +275,9 @@ class BeetsPlugin: cls._raw_listeners[event].append(func) cls.listeners[event].append(wrapped_func) - template_funcs = None - template_fields = None - album_template_fields = None + template_funcs: dict[str, Callable[[str, Any], str]] | None = None + template_fields: dict[str, Callable[[Item], str]] | None = None + album_template_fields: dict[str, Callable[[Album], str]] | None = None @classmethod def template_func(cls, name: str): @@ -294,7 +311,7 @@ class BeetsPlugin: return helper -_classes = set() +_classes: set[type[BeetsPlugin]] = set() def load_plugins(names: Sequence[str] = ()): @@ -332,7 +349,7 @@ def load_plugins(names: Sequence[str] = ()): ) -_instances = {} +_instances: dict[Type[BeetsPlugin], BeetsPlugin] = {} def find_plugins() -> list[BeetsPlugin]: @@ -544,7 +561,7 @@ def event_handlers(): return all_handlers -def send(event, **arguments): +def send(event: str, **arguments: Any): """Send an event to all assigned event listeners. `event` is the name of the event to send, all other named arguments diff --git a/docs/dev/plugins.rst b/docs/dev/plugins.rst index 96e69153d..577947707 100644 --- a/docs/dev/plugins.rst +++ b/docs/dev/plugins.rst @@ -367,7 +367,7 @@ Here's an example:: super().__init__() self.template_funcs['initial'] = _tmpl_initial - def _tmpl_initial(text): + def _tmpl_initial(text:str) -> str: if text: return text[0].upper() else: @@ -387,7 +387,7 @@ Here's an example that adds a ``$disc_and_track`` field:: super().__init__() self.template_fields['disc_and_track'] = _tmpl_disc_and_track - def _tmpl_disc_and_track(item): + def _tmpl_disc_and_track(item: Item) -> str: """Expand to the disc number and track number if this is a multi-disc release. Otherwise, just expands to the track number.