Sebastian Mohr 2025-04-18 16:03:39 +02:00
parent 39a5bdb0bd
commit d7838b29c3

View file

@ -19,6 +19,7 @@ from __future__ import annotations
import abc
import inspect
import re
import sys
import traceback
from collections import defaultdict
from functools import wraps
@ -48,6 +49,11 @@ if TYPE_CHECKING:
from beets.library import Album, Item, Library
from beets.ui import Subcommand
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
PLUGIN_NAMESPACE = "beetsplug"
@ -83,8 +89,16 @@ class PluginLogFilter(logging.Filter):
return True
P = ParamSpec("P")
Ret = TypeVar("Ret", bound=Any)
Listener = Callable[..., None]
if TYPE_CHECKING:
ImportStageFunc = Callable[[ImportSession, ImportTask], None]
T = TypeVar("T", Album, Item, str)
TFunc = Callable[[T], str]
TFuncMap = dict[str, TFunc[T]]
# Managing the plugins themselves.
@ -97,8 +111,8 @@ class BeetsPlugin:
name: str
config: ConfigView
early_import_stages: list[Callable[[ImportSession, ImportTask], None]]
import_stages: list[Callable[[ImportSession, ImportTask], None]]
early_import_stages: list[ImportStageFunc]
import_stages: list[ImportStageFunc]
def __init__(self, name: str | None = None):
"""Perform one-time plugin setup."""
@ -129,14 +143,17 @@ class BeetsPlugin:
"""
return ()
def _set_stage_log_level(self, stages):
def _set_stage_log_level(
self,
stages: list[ImportStageFunc],
) -> list[ImportStageFunc]:
"""Adjust all the stages in `stages` to WARNING logging level."""
return [
self._set_log_level_and_params(logging.WARNING, stage)
for stage in stages
]
def get_early_import_stages(self):
def get_early_import_stages(self) -> list[ImportStageFunc]:
"""Return a list of functions that should be called as importer
pipelines stages early in the pipeline.
@ -146,7 +163,7 @@ class BeetsPlugin:
"""
return self._set_stage_log_level(self.early_import_stages)
def get_import_stages(self):
def get_import_stages(self) -> list[ImportStageFunc]:
"""Return a list of functions that should be called as importer
pipelines stages.
@ -156,7 +173,11 @@ class BeetsPlugin:
"""
return self._set_stage_log_level(self.import_stages)
def _set_log_level_and_params(self, base_log_level, func):
def _set_log_level_and_params(
self,
base_log_level: int,
func: Callable[P, Ret],
) -> Callable[P, Ret]:
"""Wrap `func` to temporarily set this plugin's logger level to
`base_log_level` + config options (and restore it to its previous
value after the function returns). Also determines which params may not
@ -165,7 +186,7 @@ class BeetsPlugin:
argspec = inspect.getfullargspec(func)
@wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Ret:
assert self._log.level == logging.NOTSET
verbosity = beets.config["verbose"].get(int)
@ -278,9 +299,9 @@ class BeetsPlugin:
cls._raw_listeners[event].append(func)
cls.listeners[event].append(wrapped_func)
template_funcs: dict[str, Callable[[str, Any], str]] | None = None
template_fields: dict[str, Callable[[Item], str]] | None = None
album_template_fields: dict[str, Callable[[Album], str]] | None = None
template_funcs: TFuncMap[str] | None = None
template_fields: TFuncMap[Item] | None = None
album_template_fields: TFuncMap[Album] | None = None
@classmethod
def template_func(cls, name: str):
@ -546,19 +567,19 @@ def _check_conflicts_and_merge(
funcs.update(plugin_funcs)
def item_field_getters() -> dict[str, Callable[[Item], str]]:
def item_field_getters() -> TFuncMap[Item]:
"""Get a dictionary mapping field names to unary functions that
compute the field's value.
"""
funcs: dict[str, Callable[[Item], str]] = {}
funcs: TFuncMap[Item] = {}
for plugin in find_plugins():
_check_conflicts_and_merge(plugin, plugin.template_fields, funcs)
return funcs
def album_field_getters() -> dict[str, Callable[[Album], str]]:
def album_field_getters() -> TFuncMap[Album]:
"""As above, for album fields."""
funcs: dict[str, Callable[[Album], str]] = {}
funcs: TFuncMap[Album] = {}
for plugin in find_plugins():
_check_conflicts_and_merge(plugin, plugin.album_template_fields, funcs)
return funcs