From 0cc0db313a88f5a57aac78c72f4fa4ee30c87b43 Mon Sep 17 00:00:00 2001 From: Sebastian Mohr Date: Wed, 2 Apr 2025 15:57:16 +0200 Subject: [PATCH] Added typehints to the plugins file. --- beets/plugins.py | 187 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 128 insertions(+), 59 deletions(-) diff --git a/beets/plugins.py b/beets/plugins.py index 299c41815..3906bb041 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -14,17 +14,37 @@ """Support for beets plugins.""" +from __future__ import annotations + import abc import inspect import re import traceback from collections import defaultdict from functools import wraps +from typing import ( + TYPE_CHECKING, + Callable, + Generic, + Sequence, + TypedDict, + TypeVar, +) import mediafile import beets from beets import logging +from beets.autotag import Distance + +if TYPE_CHECKING: + from collections.abc import Iterable + + from beets.autotag import AlbumInfo, TrackInfo + from beets.dbcore import Query + from beets.library import Item + from beets.ui import Subcommand + PLUGIN_NAMESPACE = "beetsplug" @@ -145,55 +165,74 @@ class BeetsPlugin: return wrapper - def queries(self): + def queries(self) -> dict[str, type[Query]]: """Return a dict mapping prefixes to Query subclasses.""" return {} - def track_distance(self, item, info): + def track_distance( + self, + item: Item, + info: TrackInfo, + ) -> Distance: """Should return a Distance object to be added to the distance for every track comparison. """ - return beets.autotag.hooks.Distance() + return Distance() - def album_distance(self, items, album_info, mapping): + def album_distance( + self, + items: list[Item], + album_info: AlbumInfo, + mapping: dict[Item, TrackInfo], + ) -> Distance: """Should return a Distance object to be added to the distance for every album-level comparison. """ - return beets.autotag.hooks.Distance() + return Distance() - def candidates(self, items, artist, album, va_likely, extra_tags=None): + def candidates( + self, + items: list[Item], + artist: str, + album: str, + va_likely: bool, + extra_tags=None, + ) -> Sequence[AlbumInfo]: """Should return a sequence of AlbumInfo objects that match the album whose items are provided. """ return () - def item_candidates(self, item, artist, title): + def item_candidates( + self, + item: Item, + artist: str, + title: str, + ) -> Sequence[TrackInfo]: """Should return a sequence of TrackInfo objects that match the item provided. """ return () - def album_for_id(self, album_id): + def album_for_id(self, album_id: str) -> AlbumInfo | None: """Return an AlbumInfo object or None if no matching release was found. """ return None - def track_for_id(self, track_id): + def track_for_id(self, track_id: str) -> TrackInfo | None: """Return a TrackInfo object or None if no matching release was found. """ return None - def add_media_field(self, name, descriptor): + def add_media_field(self, name: str, descriptor: mediafile.MediaField): """Add a field that is synchronized between media files and items. When a media field is added ``item.write()`` will set the name property of the item's MediaFile to ``item[name]`` and save the changes. Similarly ``item.read()`` will set ``item[name]`` to the value of the name property of the media file. - - ``descriptor`` must be an instance of ``mediafile.MediaField``. """ # Defer import to prevent circular dependency from beets import library @@ -202,9 +241,9 @@ class BeetsPlugin: library.Item._media_fields.add(name) _raw_listeners = None - listeners = None + listeners: None | dict[str, list[Callable]] = None - def register_listener(self, event, func): + def register_listener(self, event: str, func: Callable): """Add a function as a listener for the specified event.""" wrapped_func = self._set_log_level_and_params(logging.WARNING, func) @@ -221,7 +260,7 @@ class BeetsPlugin: album_template_fields = None @classmethod - def template_func(cls, name): + def template_func(cls, name: str): """Decorator that registers a path template function. The function will be invoked as ``%name{}`` from path format strings. @@ -236,7 +275,7 @@ class BeetsPlugin: return helper @classmethod - def template_field(cls, name): + def template_field(cls, name: str): """Decorator that registers a path template field computation. The value will be referenced as ``$name`` from path format strings. The function must accept a single parameter, the Item @@ -255,7 +294,7 @@ class BeetsPlugin: _classes = set() -def load_plugins(names=()): +def load_plugins(names: Sequence[str] = ()): """Imports the modules for a sequence of plugin names. Each name must be the name of a Python module under the "beetsplug" namespace package in sys.path; the module indicated should contain the @@ -293,7 +332,7 @@ def load_plugins(names=()): _instances = {} -def find_plugins(): +def find_plugins() -> list[BeetsPlugin]: """Returns a list of BeetsPlugin subclass instances from all currently loaded beets plugins. Loads the default plugin set first. @@ -316,7 +355,7 @@ def find_plugins(): # Communication with plugins. -def commands(): +def commands() -> list[Subcommand]: """Returns a list of Subcommand objects from all loaded plugins.""" out = [] for plugin in find_plugins(): @@ -324,11 +363,11 @@ def commands(): return out -def queries(): +def queries() -> dict[str, type[Query]]: """Returns a dict mapping prefix strings to Query subclasses all loaded plugins. """ - out = {} + out: dict[str, type[Query]] = {} for plugin in find_plugins(): out.update(plugin.queries()) return out @@ -361,7 +400,7 @@ def named_queries(model_cls): return queries -def track_distance(item, info): +def track_distance(item: Item, info: TrackInfo) -> Distance: """Gets the track distance calculated by all loaded plugins. Returns a Distance object. """ @@ -373,7 +412,11 @@ def track_distance(item, info): return dist -def album_distance(items, album_info, mapping): +def album_distance( + items: list[Item], + album_info: AlbumInfo, + mapping: dict[Item, TrackInfo], +) -> Distance: """Returns the album distance calculated by plugins.""" from beets.autotag.hooks import Distance @@ -383,7 +426,13 @@ def album_distance(items, album_info, mapping): return dist -def candidates(items, artist, album, va_likely, extra_tags=None): +def candidates( + items: list[Item], + artist: str, + album: str, + va_likely: bool, + extra_tags=None, +) -> Iterable[AlbumInfo]: """Gets MusicBrainz candidates for an album from each plugin.""" for plugin in find_plugins(): yield from plugin.candidates( @@ -391,13 +440,13 @@ def candidates(items, artist, album, va_likely, extra_tags=None): ) -def item_candidates(item, artist, title): +def item_candidates(item: Item, artist: str, title: str) -> Iterable[TrackInfo]: """Gets MusicBrainz candidates for an item from the plugins.""" for plugin in find_plugins(): yield from plugin.item_candidates(item, artist, title) -def album_for_id(album_id): +def album_for_id(album_id: str) -> Iterable[AlbumInfo]: """Get AlbumInfo objects for a given ID string.""" for plugin in find_plugins(): album = plugin.album_for_id(album_id) @@ -405,7 +454,7 @@ def album_for_id(album_id): yield album -def track_for_id(track_id): +def track_for_id(track_id: str) -> Iterable[TrackInfo]: """Get TrackInfo objects for a given ID string.""" for plugin in find_plugins(): track = plugin.track_for_id(track_id) @@ -443,7 +492,7 @@ def import_stages(): # New-style (lazy) plugin-provided fields. -def _check_conflicts_and_merge(plugin, plugin_funcs, funcs): +def _check_conflicts_and_merge(plugin: BeetsPlugin, plugin_funcs, funcs): """Check the provided template functions for conflicts and merge into funcs. Raises a `PluginConflictError` if a plugin defines template functions @@ -598,11 +647,11 @@ def notify_info_yielded(event): return decorator -def get_distance(config, data_source, info): +def get_distance(config, data_source, info) -> Distance: """Returns the ``data_source`` weight and the maximum source weight for albums or individual tracks. """ - dist = beets.autotag.Distance() + dist = Distance() if info.data_source == data_source: dist.add("source", config["source_weight"].as_number()) return dist @@ -638,7 +687,27 @@ def apply_item_changes(lib, item, move, pretend, write): item.store() -class MetadataSourcePlugin(metaclass=abc.ABCMeta): +class Response(TypedDict): + """A dictionary with the response of a plugin API call. + + May be extended by plugins to include additional information, but id + is required. + """ + + id: str + + +class RegexDict(TypedDict): + """A dictionary with regex patterns as keys and match groups as values.""" + + pattern: str + match_group: int + + +R = TypeVar("R", bound=Response) + + +class MetadataSourcePlugin(Generic[R], metaclass=abc.ABCMeta): def __init__(self): super().__init__() self.config.add({"source_weight": 0.5}) @@ -664,19 +733,26 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta): raise NotImplementedError @abc.abstractmethod - def _search_api(self, query_type, filters, keywords=""): + def _search_api(self, query_type, filters, keywords="") -> Sequence[R]: raise NotImplementedError @abc.abstractmethod - def album_for_id(self, album_id): + def album_for_id(self, album_id: str) -> AlbumInfo | None: raise NotImplementedError @abc.abstractmethod - def track_for_id(self, track_id=None, track_data=None): + def track_for_id( + self, track_id: str | None = None, track_data: R | None = None + ) -> TrackInfo | None: raise NotImplementedError @staticmethod - def get_artist(artists, id_key="id", name_key="name", join_key=None): + def get_artist( + artists, + id_key: str | int = "id", + name_key: str | int = "name", + join_key: str | int | None = None, + ) -> tuple[str, str | None]: """Returns an artist string (all artists) and an artist_id (the main artist) for a list of artist object dicts. @@ -691,18 +767,14 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta): :type artists: list[dict] or list[list] :param id_key: Key or index corresponding to the value of ``id`` for the main/first artist. Defaults to 'id'. - :type id_key: str or int :param name_key: Key or index corresponding to values of names to concatenate for the artist string (containing all artists). Defaults to 'name'. - :type name_key: str or int :param join_key: Key or index corresponding to a field containing a keyword to use for combining artists into a single string, for example "Feat.", "Vs.", "And" or similar. The default is None which keeps the default behaviour (comma-separated). - :type join_key: str or int :return: Normalized artist string. - :rtype: str """ artist_id = None artist_string = "" @@ -727,19 +799,15 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta): return artist_string, artist_id @staticmethod - def _get_id(url_type, id_, id_regex): + def _get_id(url_type: str, id_: str, id_regex: RegexDict) -> str | None: """Parse an ID from its URL if necessary. :param url_type: Type of URL. Either 'album' or 'track'. - :type url_type: str :param id_: Album/track ID or URL. - :type id_: str :param id_regex: A dictionary containing a regular expression extracting an ID from an URL (if it's not an ID already) in 'pattern' and the number of the match group in 'match_group'. - :type id_regex: dict :return: Album/track ID. - :rtype: str """ log.debug("Extracting {} ID from '{}'", url_type, id_) match = re.search(id_regex["pattern"].format(url_type), str(id_)) @@ -749,21 +817,22 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta): return id_ return None - def candidates(self, items, artist, album, va_likely, extra_tags=None): + def candidates( + self, + items: list[Item], + artist: str, + album: str, + va_likely: bool, + extra_tags=None, + ) -> Sequence[AlbumInfo]: """Returns a list of AlbumInfo objects for Search API results matching an ``album`` and ``artist`` (if not various). :param items: List of items comprised by an album to be matched. - :type items: list[beets.library.Item] :param artist: The artist of the album to be matched. - :type artist: str :param album: The name of the album to be matched. - :type album: str :param va_likely: True if the album to be matched likely has Various Artists. - :type va_likely: bool - :return: Candidate AlbumInfo objects. - :rtype: list[beets.autotag.hooks.AlbumInfo] """ query_filters = {"album": album} if not va_likely: @@ -772,23 +841,23 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta): albums = [self.album_for_id(album_id=r["id"]) for r in results] return [a for a in albums if a is not None] - def item_candidates(self, item, artist, title): + def item_candidates( + self, item: Item, artist: str, title: str + ) -> Sequence[TrackInfo]: """Returns a list of TrackInfo objects for Search API results matching ``title`` and ``artist``. :param item: Singleton item to be matched. - :type item: beets.library.Item :param artist: The artist of the track to be matched. - :type artist: str :param title: The title of the track to be matched. - :type title: str - :return: Candidate TrackInfo objects. - :rtype: list[beets.autotag.hooks.TrackInfo] """ - tracks = self._search_api( + track_responses = self._search_api( query_type="track", keywords=title, filters={"artist": artist} ) - return [self.track_for_id(track_data=track) for track in tracks] + + tracks = [self.track_for_id(track_data=r) for r in track_responses] + + return [t for t in tracks if t is not None] def album_distance(self, items, album_info, mapping): return get_distance(