From 0cc0db313a88f5a57aac78c72f4fa4ee30c87b43 Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Wed, 2 Apr 2025 15:57:16 +0200 Subject: [PATCH 01/16] Added typehints to the plugins file. --- beets/plugins.py | 187 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 128 insertions(+), 59 deletions(-) diff --git a/beets/plugins.py b/beets/plugins.py index 299c41815..3906bb041 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -14,17 +14,37 @@ """Support for beets plugins.""" +from __future__ import annotations + import abc import inspect import re import traceback from collections import defaultdict from functools import wraps +from typing import ( + TYPE_CHECKING, + Callable, + Generic, + Sequence, + TypedDict, + TypeVar, +) import mediafile import beets from beets import logging +from beets.autotag import Distance + +if TYPE_CHECKING: + from collections.abc import Iterable + + from beets.autotag import AlbumInfo, TrackInfo + from beets.dbcore import Query + from beets.library import Item + from beets.ui import Subcommand + PLUGIN_NAMESPACE = "beetsplug" @@ -145,55 +165,74 @@ class BeetsPlugin: return wrapper - def queries(self): + def queries(self) -> dict[str, type[Query]]: """Return a dict mapping prefixes to Query subclasses.""" return {} - def track_distance(self, item, info): + def track_distance( + self, + item: Item, + info: TrackInfo, + ) -> Distance: """Should return a Distance object to be added to the distance for every track comparison. """ - return beets.autotag.hooks.Distance() + return Distance() - def album_distance(self, items, album_info, mapping): + def album_distance( + self, + items: list[Item], + album_info: AlbumInfo, + mapping: dict[Item, TrackInfo], + ) -> Distance: """Should return a Distance object to be added to the distance for every album-level comparison. """ - return beets.autotag.hooks.Distance() + return Distance() - def candidates(self, items, artist, album, va_likely, extra_tags=None): + def candidates( + self, + items: list[Item], + artist: str, + album: str, + va_likely: bool, + extra_tags=None, + ) -> Sequence[AlbumInfo]: """Should return a sequence of AlbumInfo objects that match the album whose items are provided. """ return () - def item_candidates(self, item, artist, title): + def item_candidates( + self, + item: Item, + artist: str, + title: str, + ) -> Sequence[TrackInfo]: """Should return a sequence of TrackInfo objects that match the item provided. """ return () - def album_for_id(self, album_id): + def album_for_id(self, album_id: str) -> AlbumInfo | None: """Return an AlbumInfo object or None if no matching release was found. """ return None - def track_for_id(self, track_id): + def track_for_id(self, track_id: str) -> TrackInfo | None: """Return a TrackInfo object or None if no matching release was found. """ return None - def add_media_field(self, name, descriptor): + def add_media_field(self, name: str, descriptor: mediafile.MediaField): """Add a field that is synchronized between media files and items. When a media field is added ``item.write()`` will set the name property of the item's MediaFile to ``item[name]`` and save the changes. Similarly ``item.read()`` will set ``item[name]`` to the value of the name property of the media file. - - ``descriptor`` must be an instance of ``mediafile.MediaField``. """ # Defer import to prevent circular dependency from beets import library @@ -202,9 +241,9 @@ class BeetsPlugin: library.Item._media_fields.add(name) _raw_listeners = None - listeners = None + listeners: None | dict[str, list[Callable]] = None - def register_listener(self, event, func): + def register_listener(self, event: str, func: Callable): """Add a function as a listener for the specified event.""" wrapped_func = self._set_log_level_and_params(logging.WARNING, func) @@ -221,7 +260,7 @@ class BeetsPlugin: album_template_fields = None @classmethod - def template_func(cls, name): + def template_func(cls, name: str): """Decorator that registers a path template function. The function will be invoked as ``%name{}`` from path format strings. @@ -236,7 +275,7 @@ class BeetsPlugin: return helper @classmethod - def template_field(cls, name): + def template_field(cls, name: str): """Decorator that registers a path template field computation. The value will be referenced as ``$name`` from path format strings. The function must accept a single parameter, the Item @@ -255,7 +294,7 @@ class BeetsPlugin: _classes = set() -def load_plugins(names=()): +def load_plugins(names: Sequence[str] = ()): """Imports the modules for a sequence of plugin names. Each name must be the name of a Python module under the "beetsplug" namespace package in sys.path; the module indicated should contain the @@ -293,7 +332,7 @@ def load_plugins(names=()): _instances = {} -def find_plugins(): +def find_plugins() -> list[BeetsPlugin]: """Returns a list of BeetsPlugin subclass instances from all currently loaded beets plugins. Loads the default plugin set first. @@ -316,7 +355,7 @@ def find_plugins(): # Communication with plugins. -def commands(): +def commands() -> list[Subcommand]: """Returns a list of Subcommand objects from all loaded plugins.""" out = [] for plugin in find_plugins(): @@ -324,11 +363,11 @@ def commands(): return out -def queries(): +def queries() -> dict[str, type[Query]]: """Returns a dict mapping prefix strings to Query subclasses all loaded plugins. """ - out = {} + out: dict[str, type[Query]] = {} for plugin in find_plugins(): out.update(plugin.queries()) return out @@ -361,7 +400,7 @@ def named_queries(model_cls): return queries -def track_distance(item, info): +def track_distance(item: Item, info: TrackInfo) -> Distance: """Gets the track distance calculated by all loaded plugins. Returns a Distance object. """ @@ -373,7 +412,11 @@ def track_distance(item, info): return dist -def album_distance(items, album_info, mapping): +def album_distance( + items: list[Item], + album_info: AlbumInfo, + mapping: dict[Item, TrackInfo], +) -> Distance: """Returns the album distance calculated by plugins.""" from beets.autotag.hooks import Distance @@ -383,7 +426,13 @@ def album_distance(items, album_info, mapping): return dist -def candidates(items, artist, album, va_likely, extra_tags=None): +def candidates( + items: list[Item], + artist: str, + album: str, + va_likely: bool, + extra_tags=None, +) -> Iterable[AlbumInfo]: """Gets MusicBrainz candidates for an album from each plugin.""" for plugin in find_plugins(): yield from plugin.candidates( @@ -391,13 +440,13 @@ def candidates(items, artist, album, va_likely, extra_tags=None): ) -def item_candidates(item, artist, title): +def item_candidates(item: Item, artist: str, title: str) -> Iterable[TrackInfo]: """Gets MusicBrainz candidates for an item from the plugins.""" for plugin in find_plugins(): yield from plugin.item_candidates(item, artist, title) -def album_for_id(album_id): +def album_for_id(album_id: str) -> Iterable[AlbumInfo]: """Get AlbumInfo objects for a given ID string.""" for plugin in find_plugins(): album = plugin.album_for_id(album_id) @@ -405,7 +454,7 @@ def album_for_id(album_id): yield album -def track_for_id(track_id): +def track_for_id(track_id: str) -> Iterable[TrackInfo]: """Get TrackInfo objects for a given ID string.""" for plugin in find_plugins(): track = plugin.track_for_id(track_id) @@ -443,7 +492,7 @@ def import_stages(): # New-style (lazy) plugin-provided fields. -def _check_conflicts_and_merge(plugin, plugin_funcs, funcs): +def _check_conflicts_and_merge(plugin: BeetsPlugin, plugin_funcs, funcs): """Check the provided template functions for conflicts and merge into funcs. Raises a `PluginConflictError` if a plugin defines template functions @@ -598,11 +647,11 @@ def notify_info_yielded(event): return decorator -def get_distance(config, data_source, info): +def get_distance(config, data_source, info) -> Distance: """Returns the ``data_source`` weight and the maximum source weight for albums or individual tracks. """ - dist = beets.autotag.Distance() + dist = Distance() if info.data_source == data_source: dist.add("source", config["source_weight"].as_number()) return dist @@ -638,7 +687,27 @@ def apply_item_changes(lib, item, move, pretend, write): item.store() -class MetadataSourcePlugin(metaclass=abc.ABCMeta): +class Response(TypedDict): + """A dictionary with the response of a plugin API call. + + May be extended by plugins to include additional information, but id + is required. + """ + + id: str + + +class RegexDict(TypedDict): + """A dictionary with regex patterns as keys and match groups as values.""" + + pattern: str + match_group: int + + +R = TypeVar("R", bound=Response) + + +class MetadataSourcePlugin(Generic[R], metaclass=abc.ABCMeta): def __init__(self): super().__init__() self.config.add({"source_weight": 0.5}) @@ -664,19 +733,26 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta): raise NotImplementedError @abc.abstractmethod - def _search_api(self, query_type, filters, keywords=""): + def _search_api(self, query_type, filters, keywords="") -> Sequence[R]: raise NotImplementedError @abc.abstractmethod - def album_for_id(self, album_id): + def album_for_id(self, album_id: str) -> AlbumInfo | None: raise NotImplementedError @abc.abstractmethod - def track_for_id(self, track_id=None, track_data=None): + def track_for_id( + self, track_id: str | None = None, track_data: R | None = None + ) -> TrackInfo | None: raise NotImplementedError @staticmethod - def get_artist(artists, id_key="id", name_key="name", join_key=None): + def get_artist( + artists, + id_key: str | int = "id", + name_key: str | int = "name", + join_key: str | int | None = None, + ) -> tuple[str, str | None]: """Returns an artist string (all artists) and an artist_id (the main artist) for a list of artist object dicts. @@ -691,18 +767,14 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta): :type artists: list[dict] or list[list] :param id_key: Key or index corresponding to the value of ``id`` for the main/first artist. Defaults to 'id'. - :type id_key: str or int :param name_key: Key or index corresponding to values of names to concatenate for the artist string (containing all artists). Defaults to 'name'. - :type name_key: str or int :param join_key: Key or index corresponding to a field containing a keyword to use for combining artists into a single string, for example "Feat.", "Vs.", "And" or similar. The default is None which keeps the default behaviour (comma-separated). - :type join_key: str or int :return: Normalized artist string. - :rtype: str """ artist_id = None artist_string = "" @@ -727,19 +799,15 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta): return artist_string, artist_id @staticmethod - def _get_id(url_type, id_, id_regex): + def _get_id(url_type: str, id_: str, id_regex: RegexDict) -> str | None: """Parse an ID from its URL if necessary. :param url_type: Type of URL. Either 'album' or 'track'. - :type url_type: str :param id_: Album/track ID or URL. - :type id_: str :param id_regex: A dictionary containing a regular expression extracting an ID from an URL (if it's not an ID already) in 'pattern' and the number of the match group in 'match_group'. - :type id_regex: dict :return: Album/track ID. - :rtype: str """ log.debug("Extracting {} ID from '{}'", url_type, id_) match = re.search(id_regex["pattern"].format(url_type), str(id_)) @@ -749,21 +817,22 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta): return id_ return None - def candidates(self, items, artist, album, va_likely, extra_tags=None): + def candidates( + self, + items: list[Item], + artist: str, + album: str, + va_likely: bool, + extra_tags=None, + ) -> Sequence[AlbumInfo]: """Returns a list of AlbumInfo objects for Search API results matching an ``album`` and ``artist`` (if not various). :param items: List of items comprised by an album to be matched. - :type items: list[beets.library.Item] :param artist: The artist of the album to be matched. - :type artist: str :param album: The name of the album to be matched. - :type album: str :param va_likely: True if the album to be matched likely has Various Artists. - :type va_likely: bool - :return: Candidate AlbumInfo objects. - :rtype: list[beets.autotag.hooks.AlbumInfo] """ query_filters = {"album": album} if not va_likely: @@ -772,23 +841,23 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta): albums = [self.album_for_id(album_id=r["id"]) for r in results] return [a for a in albums if a is not None] - def item_candidates(self, item, artist, title): + def item_candidates( + self, item: Item, artist: str, title: str + ) -> Sequence[TrackInfo]: """Returns a list of TrackInfo objects for Search API results matching ``title`` and ``artist``. :param item: Singleton item to be matched. - :type item: beets.library.Item :param artist: The artist of the track to be matched. - :type artist: str :param title: The title of the track to be matched. - :type title: str - :return: Candidate TrackInfo objects. - :rtype: list[beets.autotag.hooks.TrackInfo] """ - tracks = self._search_api( + track_responses = self._search_api( query_type="track", keywords=title, filters={"artist": artist} ) - return [self.track_for_id(track_data=track) for track in tracks] + + tracks = [self.track_for_id(track_data=r) for r in track_responses] + + return [t for t in tracks if t is not None] def album_distance(self, items, album_info, mapping): return get_distance( From 19b3330fc8a407bda69f481e10517fad8b7b44ab Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Wed, 2 Apr 2025 16:07:00 +0200 Subject: [PATCH 02/16] Fixed circular import of distance by consistently importing if whenever it is needed. --- beets/plugins.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/beets/plugins.py b/beets/plugins.py index 3906bb041..a5edb0ae5 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -35,12 +35,11 @@ import mediafile import beets from beets import logging -from beets.autotag import Distance if TYPE_CHECKING: from collections.abc import Iterable - from beets.autotag import AlbumInfo, TrackInfo + from beets.autotag import AlbumInfo, Distance, TrackInfo from beets.dbcore import Query from beets.library import Item from beets.ui import Subcommand @@ -177,6 +176,8 @@ class BeetsPlugin: """Should return a Distance object to be added to the distance for every track comparison. """ + from beets.autotag.hooks import Distance + return Distance() def album_distance( @@ -188,6 +189,8 @@ class BeetsPlugin: """Should return a Distance object to be added to the distance for every album-level comparison. """ + from beets.autotag.hooks import Distance + return Distance() def candidates( @@ -651,6 +654,8 @@ def get_distance(config, data_source, info) -> Distance: """Returns the ``data_source`` weight and the maximum source weight for albums or individual tracks. """ + from beets.autotag.hooks import Distance + dist = Distance() if info.data_source == data_source: dist.add("source", config["source_weight"].as_number()) From e6ce81891306ed2d7fc780285a6b8ec930ba5e08 Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Wed, 2 Apr 2025 16:28:41 +0200 Subject: [PATCH 03/16] Added some more typehints where missing. --- beets/plugins.py | 61 +++++++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/beets/plugins.py b/beets/plugins.py index a5edb0ae5..8fa9d4a3d 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -41,7 +41,7 @@ if TYPE_CHECKING: from beets.autotag import AlbumInfo, Distance, TrackInfo from beets.dbcore import Query - from beets.library import Item + from beets.library import Item, Library from beets.ui import Subcommand @@ -536,7 +536,7 @@ def event_handlers(): """Find all event handlers from plugins as a dictionary mapping event names to sequences of callables. """ - all_handlers = defaultdict(list) + all_handlers: dict[str, list[Callable]] = defaultdict(list) for plugin in find_plugins(): if plugin.listeners: for event, handlers in plugin.listeners.items(): @@ -650,7 +650,9 @@ def notify_info_yielded(event): return decorator -def get_distance(config, data_source, info) -> Distance: +def get_distance( + config, data_source: str, info: AlbumInfo | TrackInfo +) -> Distance: """Returns the ``data_source`` weight and the maximum source weight for albums or individual tracks. """ @@ -662,20 +664,17 @@ def get_distance(config, data_source, info) -> Distance: return dist -def apply_item_changes(lib, item, move, pretend, write): +def apply_item_changes( + lib: Library, item: Item, move: bool, pretend: bool, write: bool +): """Store, move, and write the item according to the arguments. :param lib: beets library. - :type lib: beets.library.Library :param item: Item whose changes to apply. - :type item: beets.library.Item :param move: Move the item if it's in the library. - :type move: bool :param pretend: Return without moving, writing, or storing the item's metadata. - :type pretend: bool :param write: Write the item's metadata to its media file. - :type write: bool """ if pretend: return @@ -703,7 +702,9 @@ class Response(TypedDict): class RegexDict(TypedDict): - """A dictionary with regex patterns as keys and match groups as values.""" + """A dictionary containing a regex pattern and the number of the + match group. + """ pattern: str match_group: int @@ -712,29 +713,36 @@ class RegexDict(TypedDict): R = TypeVar("R", bound=Response) -class MetadataSourcePlugin(Generic[R], metaclass=abc.ABCMeta): +class MetadataSourcePlugin(Generic[R], BeetsPlugin, metaclass=abc.ABCMeta): def __init__(self): super().__init__() self.config.add({"source_weight": 0.5}) - @abc.abstractproperty - def id_regex(self): + foo: str + + @property + @abc.abstractmethod + def id_regex(self) -> RegexDict: raise NotImplementedError - @abc.abstractproperty - def data_source(self): + @property + @abc.abstractmethod + def data_source(self) -> str: raise NotImplementedError - @abc.abstractproperty - def search_url(self): + @property + @abc.abstractmethod + def search_url(self) -> str: raise NotImplementedError - @abc.abstractproperty - def album_url(self): + @property + @abc.abstractmethod + def album_url(self) -> str: raise NotImplementedError - @abc.abstractproperty - def track_url(self): + @property + @abc.abstractmethod + def track_url(self) -> str: raise NotImplementedError @abc.abstractmethod @@ -864,12 +872,17 @@ class MetadataSourcePlugin(Generic[R], metaclass=abc.ABCMeta): return [t for t in tracks if t is not None] - def album_distance(self, items, album_info, mapping): + def album_distance( + self, + items: list[Item], + album_info: AlbumInfo, + mapping: dict[Item, TrackInfo], + ) -> Distance: return get_distance( data_source=self.data_source, info=album_info, config=self.config ) - def track_distance(self, item, track_info): + def track_distance(self, item: Item, info: TrackInfo) -> Distance: return get_distance( - data_source=self.data_source, info=track_info, config=self.config + data_source=self.data_source, info=info, config=self.config ) From 753bdd91061c112e03fca001f1171ba66cf76bd4 Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Tue, 8 Apr 2025 18:15:08 +0200 Subject: [PATCH 04/16] Added typehints for _instance, _classes and for class attributes. --- beets/plugins.py | 37 +++++++++++++++++++++++++++---------- docs/dev/plugins.rst | 4 ++-- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/beets/plugins.py b/beets/plugins.py index 8fa9d4a3d..a25fe6d8e 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -24,9 +24,11 @@ from collections import defaultdict from functools import wraps from typing import ( TYPE_CHECKING, + Any, Callable, Generic, Sequence, + Type, TypedDict, TypeVar, ) @@ -39,9 +41,11 @@ from beets import logging if TYPE_CHECKING: from collections.abc import Iterable + from confuse import ConfigView + from beets.autotag import AlbumInfo, Distance, TrackInfo from beets.dbcore import Query - from beets.library import Item, Library + from beets.library import Album, Item, Library from beets.ui import Subcommand @@ -79,6 +83,10 @@ class PluginLogFilter(logging.Filter): return True +# Typing for listeners +CallableWithAnyKwargs = Callable[[None, Any], None] + + # Managing the plugins themselves. @@ -88,16 +96,24 @@ class BeetsPlugin: the abstract methods defined here. """ + name: str + config: ConfigView + def __init__(self, name=None): """Perform one-time plugin setup.""" + self.name = name or self.__module__.split(".")[-1] self.config = beets.config[self.name] + + # Set class attributes if they are not already set + # for the type of plugin. if not self.template_funcs: self.template_funcs = {} if not self.template_fields: self.template_fields = {} if not self.album_template_fields: self.album_template_fields = {} + self.early_import_stages = [] self.import_stages = [] @@ -243,13 +259,14 @@ class BeetsPlugin: mediafile.MediaFile.add_field(name, descriptor) library.Item._media_fields.add(name) - _raw_listeners = None - listeners: None | dict[str, list[Callable]] = None + _raw_listeners: dict[str, list[Callable[[None, Any], None]]] | None = None + listeners: dict[str, list[Callable[[None, Any], None]]] | None = None - def register_listener(self, event: str, func: Callable): + def register_listener(self, event: str, func: CallableWithAnyKwargs): """Add a function as a listener for the specified event.""" wrapped_func = self._set_log_level_and_params(logging.WARNING, func) + cls = self.__class__ cls = self.__class__ if cls.listeners is None or cls._raw_listeners is None: cls._raw_listeners = defaultdict(list) @@ -258,9 +275,9 @@ class BeetsPlugin: cls._raw_listeners[event].append(func) cls.listeners[event].append(wrapped_func) - template_funcs = None - template_fields = None - album_template_fields = None + template_funcs: dict[str, Callable[[str, Any], str]] | None = None + template_fields: dict[str, Callable[[Item], str]] | None = None + album_template_fields: dict[str, Callable[[Album], str]] | None = None @classmethod def template_func(cls, name: str): @@ -294,7 +311,7 @@ class BeetsPlugin: return helper -_classes = set() +_classes: set[type[BeetsPlugin]] = set() def load_plugins(names: Sequence[str] = ()): @@ -332,7 +349,7 @@ def load_plugins(names: Sequence[str] = ()): ) -_instances = {} +_instances: dict[Type[BeetsPlugin], BeetsPlugin] = {} def find_plugins() -> list[BeetsPlugin]: @@ -544,7 +561,7 @@ def event_handlers(): return all_handlers -def send(event, **arguments): +def send(event: str, **arguments: Any): """Send an event to all assigned event listeners. `event` is the name of the event to send, all other named arguments diff --git a/docs/dev/plugins.rst b/docs/dev/plugins.rst index 96e69153d..577947707 100644 --- a/docs/dev/plugins.rst +++ b/docs/dev/plugins.rst @@ -367,7 +367,7 @@ Here's an example:: super().__init__() self.template_funcs['initial'] = _tmpl_initial - def _tmpl_initial(text): + def _tmpl_initial(text:str) -> str: if text: return text[0].upper() else: @@ -387,7 +387,7 @@ Here's an example that adds a ``$disc_and_track`` field:: super().__init__() self.template_fields['disc_and_track'] = _tmpl_disc_and_track - def _tmpl_disc_and_track(item): + def _tmpl_disc_and_track(item: Item) -> str: """Expand to the disc number and track number if this is a multi-disc release. Otherwise, just expands to the track number. From 559b87ee5698e83cc497582f6759a69d315a2e52 Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Tue, 8 Apr 2025 18:29:54 +0200 Subject: [PATCH 05/16] Updated changelog. Also edited my previous changelog entry to streamline typehint entries a bit. --- docs/changelog.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 88d87e32f..b5e8f8e57 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -91,10 +91,10 @@ Other changes: wrong (outdated) commit. Now the tag is created in the same workflow step right after committing the version update. :bug:`5539` -* Added some typehints: ImportSession and Pipeline have typehints now. Should - improve useability for new developers. * :doc:`/plugins/smartplaylist`: URL-encode additional item `fields` within generated EXTM3U playlists instead of JSON-encoding them. +* typehints: `./beets/importer.py` file now has improved typehints. +* typehints: `./beets/plugins.py` file now includes typehints. 2.2.0 (December 02, 2024) ------------------------- From 7e61027366675b57709c98e4de2a3034e4b78d2b Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Tue, 15 Apr 2025 11:49:24 +0200 Subject: [PATCH 06/16] Added suggestions from code review --- beets/plugins.py | 44 +++++++++++++++++++++++--------------------- docs/dev/plugins.rst | 2 +- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/beets/plugins.py b/beets/plugins.py index a25fe6d8e..7e251590a 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -83,8 +83,9 @@ class PluginLogFilter(logging.Filter): return True -# Typing for listeners -CallableWithAnyKwargs = Callable[[None, Any], None] +# Typing for listeners, we are not too sure about the +# kwargs and args here, depends on the plugin +Listener = Callable[..., None] # Managing the plugins themselves. @@ -99,7 +100,7 @@ class BeetsPlugin: name: str config: ConfigView - def __init__(self, name=None): + def __init__(self, name: str | None = None): """Perform one-time plugin setup.""" self.name = name or self.__module__.split(".")[-1] @@ -215,7 +216,7 @@ class BeetsPlugin: artist: str, album: str, va_likely: bool, - extra_tags=None, + extra_tags: dict[str, Any] | None = None, ) -> Sequence[AlbumInfo]: """Should return a sequence of AlbumInfo objects that match the album whose items are provided. @@ -259,15 +260,15 @@ class BeetsPlugin: mediafile.MediaFile.add_field(name, descriptor) library.Item._media_fields.add(name) - _raw_listeners: dict[str, list[Callable[[None, Any], None]]] | None = None - listeners: dict[str, list[Callable[[None, Any], None]]] | None = None + _raw_listeners: dict[str, list[Listener]] | None = None + listeners: dict[str, list[Listener]] | None = None - def register_listener(self, event: str, func: CallableWithAnyKwargs): + def register_listener(self, event: str, func: Listener): """Add a function as a listener for the specified event.""" wrapped_func = self._set_log_level_and_params(logging.WARNING, func) cls = self.__class__ - cls = self.__class__ + if cls.listeners is None or cls._raw_listeners is None: cls._raw_listeners = defaultdict(list) cls.listeners = defaultdict(list) @@ -451,7 +452,7 @@ def candidates( artist: str, album: str, va_likely: bool, - extra_tags=None, + extra_tags: dict[str, Any] | None = None, ) -> Iterable[AlbumInfo]: """Gets MusicBrainz candidates for an album from each plugin.""" for plugin in find_plugins(): @@ -510,9 +511,12 @@ def import_stages(): # New-style (lazy) plugin-provided fields. +F = TypeVar("F", Callable[[Item], str], Callable[[Album], str]) -def _check_conflicts_and_merge(plugin: BeetsPlugin, plugin_funcs, funcs): +def _check_conflicts_and_merge( + plugin: BeetsPlugin, plugin_funcs: dict[str, F] | None, funcs: dict[str, F] +): """Check the provided template functions for conflicts and merge into funcs. Raises a `PluginConflictError` if a plugin defines template functions @@ -528,19 +532,19 @@ def _check_conflicts_and_merge(plugin: BeetsPlugin, plugin_funcs, funcs): funcs.update(plugin_funcs) -def item_field_getters(): +def item_field_getters() -> dict[str, Callable[[Item], str]]: """Get a dictionary mapping field names to unary functions that compute the field's value. """ - funcs = {} + funcs: dict[str, Callable[[Item], str]] = {} for plugin in find_plugins(): _check_conflicts_and_merge(plugin, plugin.template_fields, funcs) return funcs -def album_field_getters(): +def album_field_getters() -> dict[str, Callable[[Album], str]]: """As above, for album fields.""" - funcs = {} + funcs: dict[str, Callable[[Album], str]] = {} for plugin in find_plugins(): _check_conflicts_and_merge(plugin, plugin.album_template_fields, funcs) return funcs @@ -549,11 +553,11 @@ def album_field_getters(): # Event dispatch. -def event_handlers(): +def event_handlers() -> dict[str, list[Listener]]: """Find all event handlers from plugins as a dictionary mapping event names to sequences of callables. """ - all_handlers: dict[str, list[Callable]] = defaultdict(list) + all_handlers: dict[str, list[Listener]] = defaultdict(list) for plugin in find_plugins(): if plugin.listeners: for event, handlers in plugin.listeners.items(): @@ -668,7 +672,7 @@ def notify_info_yielded(event): def get_distance( - config, data_source: str, info: AlbumInfo | TrackInfo + config: ConfigView, data_source: str, info: AlbumInfo | TrackInfo ) -> Distance: """Returns the ``data_source`` weight and the maximum source weight for albums or individual tracks. @@ -711,7 +715,7 @@ def apply_item_changes( class Response(TypedDict): """A dictionary with the response of a plugin API call. - May be extended by plugins to include additional information, but id + May be extended by plugins to include additional information, but `id` is required. """ @@ -735,8 +739,6 @@ class MetadataSourcePlugin(Generic[R], BeetsPlugin, metaclass=abc.ABCMeta): super().__init__() self.config.add({"source_weight": 0.5}) - foo: str - @property @abc.abstractmethod def id_regex(self) -> RegexDict: @@ -853,7 +855,7 @@ class MetadataSourcePlugin(Generic[R], BeetsPlugin, metaclass=abc.ABCMeta): artist: str, album: str, va_likely: bool, - extra_tags=None, + extra_tags: dict[str, Any] | None = None, ) -> Sequence[AlbumInfo]: """Returns a list of AlbumInfo objects for Search API results matching an ``album`` and ``artist`` (if not various). diff --git a/docs/dev/plugins.rst b/docs/dev/plugins.rst index 577947707..0ebff3231 100644 --- a/docs/dev/plugins.rst +++ b/docs/dev/plugins.rst @@ -367,7 +367,7 @@ Here's an example:: super().__init__() self.template_funcs['initial'] = _tmpl_initial - def _tmpl_initial(text:str) -> str: + def _tmpl_initial(text: str) -> str: if text: return text[0].upper() else: From 90254bb511c36acf15df0923c3b028cdab4b0462 Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Tue, 15 Apr 2025 11:57:52 +0200 Subject: [PATCH 07/16] Fixed lint error introduced by merging main. --- beets/plugins.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/beets/plugins.py b/beets/plugins.py index 823bebf41..bf7b6db31 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -28,7 +28,6 @@ from typing import ( Callable, Generic, Sequence, - Type, TypedDict, TypeVar, ) @@ -37,6 +36,7 @@ import mediafile import beets from beets import logging +from beets.library import Album, Item, Library if TYPE_CHECKING: from collections.abc import Iterable @@ -45,7 +45,6 @@ if TYPE_CHECKING: from beets.autotag import AlbumInfo, Distance, TrackInfo from beets.dbcore import Query - from beets.library import Album, Item, Library from beets.ui import Subcommand @@ -350,7 +349,6 @@ def load_plugins(names: Sequence[str] = ()): ) - _instances: dict[type[BeetsPlugin], BeetsPlugin] = {} @@ -468,7 +466,6 @@ def item_candidates(item: Item, artist: str, title: str) -> Iterable[TrackInfo]: yield from plugin.item_candidates(item, artist, title) - def album_for_id(_id: str) -> AlbumInfo | None: """Get AlbumInfo object for the given ID string. From 287c7228af7fe73c2356de5345c3884f49b51319 Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Tue, 15 Apr 2025 12:02:40 +0200 Subject: [PATCH 08/16] Fixed circular import issue introduced in last commit --- beets/plugins.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/beets/plugins.py b/beets/plugins.py index bf7b6db31..dbf3802e1 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -36,7 +36,6 @@ import mediafile import beets from beets import logging -from beets.library import Album, Item, Library if TYPE_CHECKING: from collections.abc import Iterable @@ -45,6 +44,7 @@ if TYPE_CHECKING: from beets.autotag import AlbumInfo, Distance, TrackInfo from beets.dbcore import Query + from beets.library import Album, Item, Library from beets.ui import Subcommand @@ -520,7 +520,11 @@ def import_stages(): # New-style (lazy) plugin-provided fields. -F = TypeVar("F", Callable[[Item], str], Callable[[Album], str]) + +if ( + TYPE_CHECKING +): # Needed because Item, Album circular introduce circular import + F = TypeVar("F", Callable[[Item], str], Callable[[Album], str]) def _check_conflicts_and_merge( From 62d28260c782f921538a7e1436d854a089a34657 Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Tue, 15 Apr 2025 18:04:40 +0200 Subject: [PATCH 09/16] small nits --- beets/plugins.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/beets/plugins.py b/beets/plugins.py index dbf3802e1..60fda6b7d 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -82,8 +82,6 @@ class PluginLogFilter(logging.Filter): return True -# Typing for listeners, we are not too sure about the -# kwargs and args here, depends on the plugin Listener = Callable[..., None] @@ -521,9 +519,8 @@ def import_stages(): # New-style (lazy) plugin-provided fields. -if ( - TYPE_CHECKING -): # Needed because Item, Album circular introduce circular import +if TYPE_CHECKING: + # Needed because Item, Album circular introduce circular import F = TypeVar("F", Callable[[Item], str], Callable[[Album], str]) @@ -578,7 +575,7 @@ def event_handlers() -> dict[str, list[Listener]]: return all_handlers -def send(event: str, **arguments: Any): +def send(event: str, **arguments: Any) -> list[Any]: """Send an event to all assigned event listeners. `event` is the name of the event to send, all other named arguments From b3c61d5c195031870860b5c46a19b91cc55041ca Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Thu, 17 Apr 2025 14:17:45 +0200 Subject: [PATCH 10/16] typed import_stages --- beets/plugins.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/beets/plugins.py b/beets/plugins.py index 60fda6b7d..6902fdd33 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from beets.autotag import AlbumInfo, Distance, TrackInfo from beets.dbcore import Query + from beets.importer import ImportSession, ImportTask from beets.library import Album, Item, Library from beets.ui import Subcommand @@ -96,6 +97,8 @@ class BeetsPlugin: name: str config: ConfigView + early_import_stages: list[Callable[[ImportSession, ImportTask], None]] + import_stages: list[Callable[[ImportSession, ImportTask], None]] def __init__(self, name: str | None = None): """Perform one-time plugin setup.""" @@ -584,7 +587,7 @@ def send(event: str, **arguments: Any) -> list[Any]: Return a list of non-None values returned from the handlers. """ log.debug("Sending event: {0}", event) - results = [] + results: list[Any] = [] for handler in event_handlers()[event]: result = handler(**arguments) if result is not None: From fef81af67ddbc3d2325983b91655b4033c134406 Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Fri, 18 Apr 2025 12:06:02 +0200 Subject: [PATCH 11/16] https://github.com/beetbox/beets/pull/5701#discussion_r2048644522 --- beets/plugins.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/beets/plugins.py b/beets/plugins.py index 6902fdd33..beb191c4c 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -522,9 +522,7 @@ def import_stages(): # New-style (lazy) plugin-provided fields. -if TYPE_CHECKING: - # Needed because Item, Album circular introduce circular import - F = TypeVar("F", Callable[[Item], str], Callable[[Album], str]) +F = TypeVar("F") def _check_conflicts_and_merge( From 39a5bdb0bd570e21f5b4cdde6080fea3e53e860d Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Fri, 18 Apr 2025 13:29:33 +0200 Subject: [PATCH 12/16] https://github.com/beetbox/beets/pull/5701#discussion_r2050488475 --- beets/plugins.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/beets/plugins.py b/beets/plugins.py index beb191c4c..aec3ca340 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -246,7 +246,9 @@ class BeetsPlugin: """ return None - def add_media_field(self, name: str, descriptor: mediafile.MediaField): + def add_media_field( + self, name: str, descriptor: mediafile.MediaField + ) -> None: """Add a field that is synchronized between media files and items. When a media field is added ``item.write()`` will set the name @@ -263,7 +265,7 @@ class BeetsPlugin: _raw_listeners: dict[str, list[Listener]] | None = None listeners: dict[str, list[Listener]] | None = None - def register_listener(self, event: str, func: Listener): + def register_listener(self, event: str, func: Listener) -> None: """Add a function as a listener for the specified event.""" wrapped_func = self._set_log_level_and_params(logging.WARNING, func) @@ -338,6 +340,7 @@ def load_plugins(names: Sequence[str] = ()): isinstance(obj, type) and issubclass(obj, BeetsPlugin) and obj != BeetsPlugin + and obj != MetadataSourcePlugin and obj not in _classes ): _classes.add(obj) From d7838b29c3beaa8db121235d6b25be76ec762cd0 Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Fri, 18 Apr 2025 16:03:39 +0200 Subject: [PATCH 13/16] https://github.com/beetbox/beets/pull/5701#discussion_r2050637901 --- beets/plugins.py | 49 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/beets/plugins.py b/beets/plugins.py index aec3ca340..c5d9c590a 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -19,6 +19,7 @@ from __future__ import annotations import abc import inspect import re +import sys import traceback from collections import defaultdict from functools import wraps @@ -48,6 +49,11 @@ if TYPE_CHECKING: from beets.library import Album, Item, Library from beets.ui import Subcommand +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + PLUGIN_NAMESPACE = "beetsplug" @@ -83,8 +89,16 @@ class PluginLogFilter(logging.Filter): return True +P = ParamSpec("P") +Ret = TypeVar("Ret", bound=Any) + Listener = Callable[..., None] +if TYPE_CHECKING: + ImportStageFunc = Callable[[ImportSession, ImportTask], None] + T = TypeVar("T", Album, Item, str) + TFunc = Callable[[T], str] + TFuncMap = dict[str, TFunc[T]] # Managing the plugins themselves. @@ -97,8 +111,8 @@ class BeetsPlugin: name: str config: ConfigView - early_import_stages: list[Callable[[ImportSession, ImportTask], None]] - import_stages: list[Callable[[ImportSession, ImportTask], None]] + early_import_stages: list[ImportStageFunc] + import_stages: list[ImportStageFunc] def __init__(self, name: str | None = None): """Perform one-time plugin setup.""" @@ -129,14 +143,17 @@ class BeetsPlugin: """ return () - def _set_stage_log_level(self, stages): + def _set_stage_log_level( + self, + stages: list[ImportStageFunc], + ) -> list[ImportStageFunc]: """Adjust all the stages in `stages` to WARNING logging level.""" return [ self._set_log_level_and_params(logging.WARNING, stage) for stage in stages ] - def get_early_import_stages(self): + def get_early_import_stages(self) -> list[ImportStageFunc]: """Return a list of functions that should be called as importer pipelines stages early in the pipeline. @@ -146,7 +163,7 @@ class BeetsPlugin: """ return self._set_stage_log_level(self.early_import_stages) - def get_import_stages(self): + def get_import_stages(self) -> list[ImportStageFunc]: """Return a list of functions that should be called as importer pipelines stages. @@ -156,7 +173,11 @@ class BeetsPlugin: """ return self._set_stage_log_level(self.import_stages) - def _set_log_level_and_params(self, base_log_level, func): + def _set_log_level_and_params( + self, + base_log_level: int, + func: Callable[P, Ret], + ) -> Callable[P, Ret]: """Wrap `func` to temporarily set this plugin's logger level to `base_log_level` + config options (and restore it to its previous value after the function returns). Also determines which params may not @@ -165,7 +186,7 @@ class BeetsPlugin: argspec = inspect.getfullargspec(func) @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Ret: assert self._log.level == logging.NOTSET verbosity = beets.config["verbose"].get(int) @@ -278,9 +299,9 @@ class BeetsPlugin: cls._raw_listeners[event].append(func) cls.listeners[event].append(wrapped_func) - template_funcs: dict[str, Callable[[str, Any], str]] | None = None - template_fields: dict[str, Callable[[Item], str]] | None = None - album_template_fields: dict[str, Callable[[Album], str]] | None = None + template_funcs: TFuncMap[str] | None = None + template_fields: TFuncMap[Item] | None = None + album_template_fields: TFuncMap[Album] | None = None @classmethod def template_func(cls, name: str): @@ -546,19 +567,19 @@ def _check_conflicts_and_merge( funcs.update(plugin_funcs) -def item_field_getters() -> dict[str, Callable[[Item], str]]: +def item_field_getters() -> TFuncMap[Item]: """Get a dictionary mapping field names to unary functions that compute the field's value. """ - funcs: dict[str, Callable[[Item], str]] = {} + funcs: TFuncMap[Item] = {} for plugin in find_plugins(): _check_conflicts_and_merge(plugin, plugin.template_fields, funcs) return funcs -def album_field_getters() -> dict[str, Callable[[Album], str]]: +def album_field_getters() -> TFuncMap[Album]: """As above, for album fields.""" - funcs: dict[str, Callable[[Album], str]] = {} + funcs: TFuncMap[Album] = {} for plugin in find_plugins(): _check_conflicts_and_merge(plugin, plugin.album_template_fields, funcs) return funcs From 2f57dd9e1c10bd18f4b85cf99744a2e38e7e17ac Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Fri, 18 Apr 2025 23:03:32 +0200 Subject: [PATCH 14/16] Added missing return types. --- beets/plugins.py | 77 ++++++++++++++++++++++++++++-------------------- pyproject.toml | 9 ++++++ 2 files changed, 54 insertions(+), 32 deletions(-) diff --git a/beets/plugins.py b/beets/plugins.py index c5d9c590a..2b848d3cf 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -137,7 +137,7 @@ class BeetsPlugin: if not any(isinstance(f, PluginLogFilter) for f in self._log.filters): self._log.addFilter(PluginLogFilter(self)) - def commands(self): + def commands(self) -> Sequence[Subcommand]: """Should return a list of beets.ui.Subcommand objects for commands that should be added to beets' CLI. """ @@ -304,13 +304,13 @@ class BeetsPlugin: album_template_fields: TFuncMap[Album] | None = None @classmethod - def template_func(cls, name: str): + def template_func(cls, name: str) -> Callable[[TFunc[str]], TFunc[str]]: """Decorator that registers a path template function. The function will be invoked as ``%name{}`` from path format strings. """ - def helper(func): + def helper(func: TFunc[str]) -> TFunc[str]: if cls.template_funcs is None: cls.template_funcs = {} cls.template_funcs[name] = func @@ -319,14 +319,14 @@ class BeetsPlugin: return helper @classmethod - def template_field(cls, name: str): + def template_field(cls, name: str) -> Callable[[TFunc[Item]], TFunc[Item]]: """Decorator that registers a path template field computation. The value will be referenced as ``$name`` from path format strings. The function must accept a single parameter, the Item being formatted. """ - def helper(func): + def helper(func: TFunc[Item]) -> TFunc[Item]: if cls.template_fields is None: cls.template_fields = {} cls.template_fields[name] = func @@ -338,7 +338,7 @@ class BeetsPlugin: _classes: set[type[BeetsPlugin]] = set() -def load_plugins(names: Sequence[str] = ()): +def load_plugins(names: Sequence[str] = ()) -> None: """Imports the modules for a sequence of plugin names. Each name must be the name of a Python module under the "beetsplug" namespace package in sys.path; the module indicated should contain the @@ -402,7 +402,7 @@ def find_plugins() -> list[BeetsPlugin]: def commands() -> list[Subcommand]: """Returns a list of Subcommand objects from all loaded plugins.""" - out = [] + out: list[Subcommand] = [] for plugin in find_plugins(): out += plugin.commands() return out @@ -418,10 +418,10 @@ def queries() -> dict[str, type[Query]]: return out -def types(model_cls): +def types(model_cls: type[T]) -> dict[str, type[Any]]: # Gives us `item_types` and `album_types` attr_name = f"{model_cls.__name__.lower()}_types" - types = {} + types: dict[str, type[Any]] = {} for plugin in find_plugins(): plugin_types = getattr(plugin, attr_name, {}) for field in plugin_types: @@ -435,10 +435,10 @@ def types(model_cls): return types -def named_queries(model_cls): +def named_queries(model_cls: type[T]) -> dict[str, Query]: # Gather `item_queries` and `album_queries` from the plugins. attr_name = f"{model_cls.__name__.lower()}_queries" - queries = {} + queries: dict[str, Query] = {} for plugin in find_plugins(): plugin_queries = getattr(plugin, attr_name, {}) queries.update(plugin_queries) @@ -517,28 +517,28 @@ def track_for_id(_id: str) -> TrackInfo | None: return None -def template_funcs(): +def template_funcs() -> TFuncMap[str]: """Get all the template functions declared by plugins as a dictionary. """ - funcs = {} + funcs: TFuncMap[str] = {} for plugin in find_plugins(): if plugin.template_funcs: funcs.update(plugin.template_funcs) return funcs -def early_import_stages(): +def early_import_stages() -> list[ImportStageFunc]: """Get a list of early import stage functions defined by plugins.""" - stages = [] + stages: list[ImportStageFunc] = [] for plugin in find_plugins(): stages += plugin.get_early_import_stages() return stages -def import_stages(): +def import_stages() -> list[ImportStageFunc]: """Get a list of import stage functions defined by plugins.""" - stages = [] + stages: list[ImportStageFunc] = [] for plugin in find_plugins(): stages += plugin.get_import_stages() return stages @@ -551,7 +551,7 @@ F = TypeVar("F") def _check_conflicts_and_merge( plugin: BeetsPlugin, plugin_funcs: dict[str, F] | None, funcs: dict[str, F] -): +) -> None: """Check the provided template functions for conflicts and merge into funcs. Raises a `PluginConflictError` if a plugin defines template functions @@ -617,7 +617,7 @@ def send(event: str, **arguments: Any) -> list[Any]: return results -def feat_tokens(for_artist=True): +def feat_tokens(for_artist: bool = True) -> str: """Return a regular expression that matches phrases like "featuring" that separate a main artist or a song title from secondary artists. The `for_artist` option determines whether the regex should be @@ -631,14 +631,16 @@ def feat_tokens(for_artist=True): ) -def sanitize_choices(choices, choices_all): +def sanitize_choices( + choices: Sequence[str], choices_all: Sequence[str] +) -> list[str]: """Clean up a stringlist configuration attribute: keep only choices elements present in choices_all, remove duplicate elements, expand '*' wildcard while keeping original stringlist order. """ - seen = set() - others = [x for x in choices_all if x not in choices] - res = [] + seen: set[str] = set() + others: list[str] = [x for x in choices_all if x not in choices] + res: list[str] = [] for s in choices: if s not in seen: if s in list(choices_all): @@ -649,7 +651,9 @@ def sanitize_choices(choices, choices_all): return res -def sanitize_pairs(pairs, pairs_all): +def sanitize_pairs( + pairs: Sequence[tuple[str, str]], pairs_all: Sequence[tuple[str, str]] +) -> list[tuple[str, str]]: """Clean up a single-element mapping configuration attribute as returned by Confuse's `Pairs` template: keep only two-element tuples present in pairs_all, remove duplicate elements, expand ('str', '*') and ('*', '*') @@ -665,10 +669,10 @@ def sanitize_pairs(pairs, pairs_all): ... ) [('foo', 'baz'), ('foo', 'bar'), ('key', 'value'), ('foo', 'foobar')] """ - pairs_all = list(pairs_all) - seen = set() + pairs_all: list[tuple[str, str]] = list(pairs_all) + seen: set[tuple[str, str]] = set() others = [x for x in pairs_all if x not in pairs] - res = [] + res: list[tuple[str, str]] = [] for k, values in pairs: for v in values.split(): x = (k, v) @@ -687,7 +691,9 @@ def sanitize_pairs(pairs, pairs_all): return res -def notify_info_yielded(event): +def notify_info_yielded( + event: str, +) -> Callable[[Callable[P, Iterable[Ret]]], Callable[P, Iterable[Ret]]]: """Makes a generator send the event 'event' every time it yields. This decorator is supposed to decorate a generator, but any function returning an iterable should work. @@ -695,8 +701,10 @@ def notify_info_yielded(event): 'send'. """ - def decorator(generator): - def decorated(*args, **kwargs): + def decorator( + generator: Callable[P, Iterable[Ret]], + ) -> Callable[P, Iterable[Ret]]: + def decorated(*args: P.args, **kwargs: P.kwargs) -> Iterable[Ret]: for v in generator(*args, **kwargs): send(event, info=v) yield v @@ -722,7 +730,7 @@ def get_distance( def apply_item_changes( lib: Library, item: Item, move: bool, pretend: bool, write: bool -): +) -> None: """Store, move, and write the item according to the arguments. :param lib: beets library. @@ -800,7 +808,12 @@ class MetadataSourcePlugin(Generic[R], BeetsPlugin, metaclass=abc.ABCMeta): raise NotImplementedError @abc.abstractmethod - def _search_api(self, query_type, filters, keywords="") -> Sequence[R]: + def _search_api( + self, + query_type: str, + filters: dict[str, str] | None, + keywords: str = "", + ) -> Sequence[R]: raise NotImplementedError @abc.abstractmethod diff --git a/pyproject.toml b/pyproject.toml index d985c54ea..6b705f68c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -281,3 +281,12 @@ ignore-variadic-names = true [tool.ruff.lint.pep8-naming] classmethod-decorators = ["cached_classproperty"] extend-ignore-names = ["assert*", "cached_classproperty"] + +# Temporary, until we decide on a mypy +# config for all files. +[[tool.mypy.overrides]] +module = "beets.plugins" +disallow_untyped_decorators = true +disallow_any_generics = true +check_untyped_defs = true +allow_redefinition = true From 557cb1a0197b9d99d411a1d72fcda390da3cca9a Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Sat, 19 Apr 2025 13:00:12 +0200 Subject: [PATCH 15/16] minor additions --- beets/plugins.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/beets/plugins.py b/beets/plugins.py index 2b848d3cf..5ed8fe612 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -22,6 +22,7 @@ import re import sys import traceback from collections import defaultdict +from collections.abc import Iterable from functools import wraps from typing import ( TYPE_CHECKING, @@ -39,12 +40,11 @@ import beets from beets import logging if TYPE_CHECKING: - from collections.abc import Iterable - from confuse import ConfigView from beets.autotag import AlbumInfo, Distance, TrackInfo from beets.dbcore import Query + from beets.dbcore.db import FieldQueryType, SQLiteType from beets.importer import ImportSession, ImportTask from beets.library import Album, Item, Library from beets.ui import Subcommand @@ -193,7 +193,7 @@ class BeetsPlugin: log_level = max(logging.DEBUG, base_log_level - 10 * verbosity) self._log.setLevel(log_level) if argspec.varkw is None: - kwargs = {k: v for k, v in kwargs.items() if k in argspec.args} + kwargs = {k: v for k, v in kwargs.items() if k in argspec.args} # type: ignore[assignment] try: return func(*args, **kwargs) @@ -418,10 +418,14 @@ def queries() -> dict[str, type[Query]]: return out -def types(model_cls: type[T]) -> dict[str, type[Any]]: +if TYPE_CHECKING: + AnyModel = TypeVar("AnyModel", Album, Item) + + +def types(model_cls: type[AnyModel]) -> dict[str, type[SQLiteType]]: # Gives us `item_types` and `album_types` attr_name = f"{model_cls.__name__.lower()}_types" - types: dict[str, type[Any]] = {} + types: dict[str, type[SQLiteType]] = {} for plugin in find_plugins(): plugin_types = getattr(plugin, attr_name, {}) for field in plugin_types: @@ -435,10 +439,10 @@ def types(model_cls: type[T]) -> dict[str, type[Any]]: return types -def named_queries(model_cls: type[T]) -> dict[str, Query]: +def named_queries(model_cls: type[AnyModel]) -> dict[str, FieldQueryType]: # Gather `item_queries` and `album_queries` from the plugins. attr_name = f"{model_cls.__name__.lower()}_queries" - queries: dict[str, Query] = {} + queries: dict[str, FieldQueryType] = {} for plugin in find_plugins(): plugin_queries = getattr(plugin, attr_name, {}) queries.update(plugin_queries) @@ -639,7 +643,7 @@ def sanitize_choices( wildcard while keeping original stringlist order. """ seen: set[str] = set() - others: list[str] = [x for x in choices_all if x not in choices] + others = [x for x in choices_all if x not in choices] res: list[str] = [] for s in choices: if s not in seen: @@ -691,9 +695,12 @@ def sanitize_pairs( return res +IterF = Callable[P, Iterable[Ret]] + + def notify_info_yielded( event: str, -) -> Callable[[Callable[P, Iterable[Ret]]], Callable[P, Iterable[Ret]]]: +) -> Callable[[IterF[P, Ret]], IterF[P, Ret]]: """Makes a generator send the event 'event' every time it yields. This decorator is supposed to decorate a generator, but any function returning an iterable should work. @@ -702,8 +709,8 @@ def notify_info_yielded( """ def decorator( - generator: Callable[P, Iterable[Ret]], - ) -> Callable[P, Iterable[Ret]]: + generator: IterF[P, Ret], + ) -> IterF[P, Ret]: def decorated(*args: P.args, **kwargs: P.kwargs) -> Iterable[Ret]: for v in generator(*args, **kwargs): send(event, info=v) From 6594bd7f245b39989c8e0661b65b266e25621c91 Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Sat, 19 Apr 2025 16:19:07 +0200 Subject: [PATCH 16/16] Moved all typehint that need a typechecking guard to the top. see https://github.com/beetbox/beets/pull/5701#discussion_r2051486012 --- beets/plugins.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/beets/plugins.py b/beets/plugins.py index 5ed8fe612..d33458825 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -39,6 +39,12 @@ import mediafile import beets from beets import logging +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + + if TYPE_CHECKING: from confuse import ConfigView @@ -49,10 +55,14 @@ if TYPE_CHECKING: from beets.library import Album, Item, Library from beets.ui import Subcommand -if sys.version_info >= (3, 10): - from typing import ParamSpec -else: - from typing_extensions import ParamSpec + # TYPE_CHECKING guard is needed for any derived type + # which uses an import from `beets.library` and `beets.imported` + ImportStageFunc = Callable[[ImportSession, ImportTask], None] + T = TypeVar("T", Album, Item, str) + TFunc = Callable[[T], str] + TFuncMap = dict[str, TFunc[T]] + + AnyModel = TypeVar("AnyModel", Album, Item) PLUGIN_NAMESPACE = "beetsplug" @@ -64,6 +74,11 @@ LASTFM_KEY = "2dc3914abf35f0d9c92d97d8f8e42b43" log = logging.getLogger("beets") +P = ParamSpec("P") +Ret = TypeVar("Ret", bound=Any) +Listener = Callable[..., None] + + class PluginConflictError(Exception): """Indicates that the services provided by one plugin conflict with those of another. @@ -89,17 +104,6 @@ class PluginLogFilter(logging.Filter): return True -P = ParamSpec("P") -Ret = TypeVar("Ret", bound=Any) - -Listener = Callable[..., None] - -if TYPE_CHECKING: - ImportStageFunc = Callable[[ImportSession, ImportTask], None] - T = TypeVar("T", Album, Item, str) - TFunc = Callable[[T], str] - TFuncMap = dict[str, TFunc[T]] - # Managing the plugins themselves. @@ -418,10 +422,6 @@ def queries() -> dict[str, type[Query]]: return out -if TYPE_CHECKING: - AnyModel = TypeVar("AnyModel", Album, Item) - - def types(model_cls: type[AnyModel]) -> dict[str, type[SQLiteType]]: # Gives us `item_types` and `album_types` attr_name = f"{model_cls.__name__.lower()}_types"