diff --git a/beets/plugins.py b/beets/plugins.py index 2b848d3cf..5ed8fe612 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -22,6 +22,7 @@ import re import sys import traceback from collections import defaultdict +from collections.abc import Iterable from functools import wraps from typing import ( TYPE_CHECKING, @@ -39,12 +40,11 @@ import beets from beets import logging if TYPE_CHECKING: - from collections.abc import Iterable - from confuse import ConfigView from beets.autotag import AlbumInfo, Distance, TrackInfo from beets.dbcore import Query + from beets.dbcore.db import FieldQueryType, SQLiteType from beets.importer import ImportSession, ImportTask from beets.library import Album, Item, Library from beets.ui import Subcommand @@ -193,7 +193,7 @@ class BeetsPlugin: log_level = max(logging.DEBUG, base_log_level - 10 * verbosity) self._log.setLevel(log_level) if argspec.varkw is None: - kwargs = {k: v for k, v in kwargs.items() if k in argspec.args} + kwargs = {k: v for k, v in kwargs.items() if k in argspec.args} # type: ignore[assignment] try: return func(*args, **kwargs) @@ -418,10 +418,14 @@ def queries() -> dict[str, type[Query]]: return out -def types(model_cls: type[T]) -> dict[str, type[Any]]: +if TYPE_CHECKING: + AnyModel = TypeVar("AnyModel", Album, Item) + + +def types(model_cls: type[AnyModel]) -> dict[str, type[SQLiteType]]: # Gives us `item_types` and `album_types` attr_name = f"{model_cls.__name__.lower()}_types" - types: dict[str, type[Any]] = {} + types: dict[str, type[SQLiteType]] = {} for plugin in find_plugins(): plugin_types = getattr(plugin, attr_name, {}) for field in plugin_types: @@ -435,10 +439,10 @@ def types(model_cls: type[T]) -> dict[str, type[Any]]: return types -def named_queries(model_cls: type[T]) -> dict[str, Query]: +def named_queries(model_cls: type[AnyModel]) -> dict[str, FieldQueryType]: # Gather `item_queries` and `album_queries` from the plugins. attr_name = f"{model_cls.__name__.lower()}_queries" - queries: dict[str, Query] = {} + queries: dict[str, FieldQueryType] = {} for plugin in find_plugins(): plugin_queries = getattr(plugin, attr_name, {}) queries.update(plugin_queries) @@ -639,7 +643,7 @@ def sanitize_choices( wildcard while keeping original stringlist order. """ seen: set[str] = set() - others: list[str] = [x for x in choices_all if x not in choices] + others = [x for x in choices_all if x not in choices] res: list[str] = [] for s in choices: if s not in seen: @@ -691,9 +695,12 @@ def sanitize_pairs( return res +IterF = Callable[P, Iterable[Ret]] + + def notify_info_yielded( event: str, -) -> Callable[[Callable[P, Iterable[Ret]]], Callable[P, Iterable[Ret]]]: +) -> Callable[[IterF[P, Ret]], IterF[P, Ret]]: """Makes a generator send the event 'event' every time it yields. This decorator is supposed to decorate a generator, but any function returning an iterable should work. @@ -702,8 +709,8 @@ def notify_info_yielded( """ def decorator( - generator: Callable[P, Iterable[Ret]], - ) -> Callable[P, Iterable[Ret]]: + generator: IterF[P, Ret], + ) -> IterF[P, Ret]: def decorated(*args: P.args, **kwargs: P.kwargs) -> Iterable[Ret]: for v in generator(*args, **kwargs): send(event, info=v)