diff --git a/beets/plugins.py b/beets/plugins.py index aec3ca340..c5d9c590a 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -19,6 +19,7 @@ from __future__ import annotations import abc import inspect import re +import sys import traceback from collections import defaultdict from functools import wraps @@ -48,6 +49,11 @@ if TYPE_CHECKING: from beets.library import Album, Item, Library from beets.ui import Subcommand +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + PLUGIN_NAMESPACE = "beetsplug" @@ -83,8 +89,16 @@ class PluginLogFilter(logging.Filter): return True +P = ParamSpec("P") +Ret = TypeVar("Ret", bound=Any) + Listener = Callable[..., None] +if TYPE_CHECKING: + ImportStageFunc = Callable[[ImportSession, ImportTask], None] + T = TypeVar("T", Album, Item, str) + TFunc = Callable[[T], str] + TFuncMap = dict[str, TFunc[T]] # Managing the plugins themselves. @@ -97,8 +111,8 @@ class BeetsPlugin: name: str config: ConfigView - early_import_stages: list[Callable[[ImportSession, ImportTask], None]] - import_stages: list[Callable[[ImportSession, ImportTask], None]] + early_import_stages: list[ImportStageFunc] + import_stages: list[ImportStageFunc] def __init__(self, name: str | None = None): """Perform one-time plugin setup.""" @@ -129,14 +143,17 @@ class BeetsPlugin: """ return () - def _set_stage_log_level(self, stages): + def _set_stage_log_level( + self, + stages: list[ImportStageFunc], + ) -> list[ImportStageFunc]: """Adjust all the stages in `stages` to WARNING logging level.""" return [ self._set_log_level_and_params(logging.WARNING, stage) for stage in stages ] - def get_early_import_stages(self): + def get_early_import_stages(self) -> list[ImportStageFunc]: """Return a list of functions that should be called as importer pipelines stages early in the pipeline. @@ -146,7 +163,7 @@ class BeetsPlugin: """ return self._set_stage_log_level(self.early_import_stages) - def get_import_stages(self): + def get_import_stages(self) -> list[ImportStageFunc]: """Return a list of functions that should be called as importer pipelines stages. @@ -156,7 +173,11 @@ class BeetsPlugin: """ return self._set_stage_log_level(self.import_stages) - def _set_log_level_and_params(self, base_log_level, func): + def _set_log_level_and_params( + self, + base_log_level: int, + func: Callable[P, Ret], + ) -> Callable[P, Ret]: """Wrap `func` to temporarily set this plugin's logger level to `base_log_level` + config options (and restore it to its previous value after the function returns). Also determines which params may not @@ -165,7 +186,7 @@ class BeetsPlugin: argspec = inspect.getfullargspec(func) @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Ret: assert self._log.level == logging.NOTSET verbosity = beets.config["verbose"].get(int) @@ -278,9 +299,9 @@ class BeetsPlugin: cls._raw_listeners[event].append(func) cls.listeners[event].append(wrapped_func) - template_funcs: dict[str, Callable[[str, Any], str]] | None = None - template_fields: dict[str, Callable[[Item], str]] | None = None - album_template_fields: dict[str, Callable[[Album], str]] | None = None + template_funcs: TFuncMap[str] | None = None + template_fields: TFuncMap[Item] | None = None + album_template_fields: TFuncMap[Album] | None = None @classmethod def template_func(cls, name: str): @@ -546,19 +567,19 @@ def _check_conflicts_and_merge( funcs.update(plugin_funcs) -def item_field_getters() -> dict[str, Callable[[Item], str]]: +def item_field_getters() -> TFuncMap[Item]: """Get a dictionary mapping field names to unary functions that compute the field's value. """ - funcs: dict[str, Callable[[Item], str]] = {} + funcs: TFuncMap[Item] = {} for plugin in find_plugins(): _check_conflicts_and_merge(plugin, plugin.template_fields, funcs) return funcs -def album_field_getters() -> dict[str, Callable[[Album], str]]: +def album_field_getters() -> TFuncMap[Album]: """As above, for album fields.""" - funcs: dict[str, Callable[[Album], str]] = {} + funcs: TFuncMap[Album] = {} for plugin in find_plugins(): _check_conflicts_and_merge(plugin, plugin.album_template_fields, funcs) return funcs