Improve typing for template fields and funcs

This commit is contained in:
asardaes 2025-11-15 16:31:20 +01:00 committed by Alexis Sardá
parent 9c37f94171
commit 2eff2d25f5

View file

@ -151,9 +151,9 @@ class BeetsPlugin(metaclass=abc.ABCMeta):
list list
) )
listeners: ClassVar[dict[EventType, list[Listener]]] = defaultdict(list) listeners: ClassVar[dict[EventType, list[Listener]]] = defaultdict(list)
template_funcs: TFuncMap[str] | None = None template_funcs: ClassVar[TFuncMap[str]] = {}
template_fields: TFuncMap[Item] | None = None template_fields: ClassVar[TFuncMap[Item]] = {}
album_template_fields: TFuncMap[Album] | None = None album_template_fields: ClassVar[TFuncMap[Album]] = {}
name: str name: str
config: ConfigView config: ConfigView
@ -222,11 +222,11 @@ class BeetsPlugin(metaclass=abc.ABCMeta):
# Set class attributes if they are not already set # Set class attributes if they are not already set
# for the type of plugin. # for the type of plugin.
if not self.template_funcs: if not self.template_funcs:
self.template_funcs = {} self.template_funcs = {} # type: ignore[misc]
if not self.template_fields: if not self.template_fields:
self.template_fields = {} self.template_fields = {} # type: ignore[misc]
if not self.album_template_fields: if not self.album_template_fields:
self.album_template_fields = {} self.album_template_fields = {} # type: ignore[misc]
self.early_import_stages = [] self.early_import_stages = []
self.import_stages = [] self.import_stages = []
@ -368,8 +368,6 @@ class BeetsPlugin(metaclass=abc.ABCMeta):
""" """
def helper(func: TFunc[str]) -> TFunc[str]: def helper(func: TFunc[str]) -> TFunc[str]:
if cls.template_funcs is None:
cls.template_funcs = {}
cls.template_funcs[name] = func cls.template_funcs[name] = func
return func return func
@ -384,8 +382,6 @@ class BeetsPlugin(metaclass=abc.ABCMeta):
""" """
def helper(func: TFunc[Item]) -> TFunc[Item]: def helper(func: TFunc[Item]) -> TFunc[Item]:
if cls.template_fields is None:
cls.template_fields = {}
cls.template_fields[name] = func cls.template_fields[name] = func
return func return func
@ -565,7 +561,6 @@ def template_funcs() -> TFuncMap[str]:
""" """
funcs: TFuncMap[str] = {} funcs: TFuncMap[str] = {}
for plugin in find_plugins(): for plugin in find_plugins():
if plugin.template_funcs:
funcs.update(plugin.template_funcs) funcs.update(plugin.template_funcs)
return funcs return funcs
@ -592,14 +587,13 @@ F = TypeVar("F")
def _check_conflicts_and_merge( def _check_conflicts_and_merge(
plugin: BeetsPlugin, plugin_funcs: dict[str, F] | None, funcs: dict[str, F] plugin: BeetsPlugin, plugin_funcs: dict[str, F], funcs: dict[str, F]
) -> None: ) -> None:
"""Check the provided template functions for conflicts and merge into funcs. """Check the provided template functions for conflicts and merge into funcs.
Raises a `PluginConflictError` if a plugin defines template functions Raises a `PluginConflictError` if a plugin defines template functions
for fields that another plugin has already defined template functions for. for fields that another plugin has already defined template functions for.
""" """
if plugin_funcs:
if not plugin_funcs.keys().isdisjoint(funcs.keys()): if not plugin_funcs.keys().isdisjoint(funcs.keys()):
conflicted_fields = ", ".join(plugin_funcs.keys() & funcs.keys()) conflicted_fields = ", ".join(plugin_funcs.keys() & funcs.keys())
raise PluginConflictError( raise PluginConflictError(