diff --git a/beets/plugins.py b/beets/plugins.py index 2ca98649e..d33458825 100644 --- a/beets/plugins.py +++ b/beets/plugins.py @@ -19,18 +19,50 @@ from __future__ import annotations import abc import inspect import re +import sys import traceback from collections import defaultdict +from collections.abc import Iterable from functools import wraps -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Sequence, + TypedDict, + TypeVar, +) import mediafile import beets from beets import logging +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + + if TYPE_CHECKING: - from beets.autotag.hooks import AlbumInfo, TrackInfo + from confuse import ConfigView + + from beets.autotag import AlbumInfo, Distance, TrackInfo + from beets.dbcore import Query + from beets.dbcore.db import FieldQueryType, SQLiteType + from beets.importer import ImportSession, ImportTask + from beets.library import Album, Item, Library + from beets.ui import Subcommand + + # TYPE_CHECKING guard is needed for any derived type + # which uses an import from `beets.library` and `beets.imported` + ImportStageFunc = Callable[[ImportSession, ImportTask], None] + T = TypeVar("T", Album, Item, str) + TFunc = Callable[[T], str] + TFuncMap = dict[str, TFunc[T]] + + AnyModel = TypeVar("AnyModel", Album, Item) PLUGIN_NAMESPACE = "beetsplug" @@ -42,6 +74,11 @@ LASTFM_KEY = "2dc3914abf35f0d9c92d97d8f8e42b43" log = logging.getLogger("beets") +P = ParamSpec("P") +Ret = TypeVar("Ret", bound=Any) +Listener = Callable[..., None] + + class PluginConflictError(Exception): """Indicates that the services provided by one plugin conflict with those of another. @@ -76,16 +113,26 @@ class BeetsPlugin: the abstract methods defined here. """ - def __init__(self, name=None): + name: str + config: ConfigView + early_import_stages: list[ImportStageFunc] + import_stages: list[ImportStageFunc] + + def __init__(self, name: str | None = None): """Perform one-time plugin setup.""" + self.name = name or self.__module__.split(".")[-1] self.config = beets.config[self.name] + + # Set class attributes if they are not already set + # for the type of plugin. if not self.template_funcs: self.template_funcs = {} if not self.template_fields: self.template_fields = {} if not self.album_template_fields: self.album_template_fields = {} + self.early_import_stages = [] self.import_stages = [] @@ -94,20 +141,23 @@ class BeetsPlugin: if not any(isinstance(f, PluginLogFilter) for f in self._log.filters): self._log.addFilter(PluginLogFilter(self)) - def commands(self): + def commands(self) -> Sequence[Subcommand]: """Should return a list of beets.ui.Subcommand objects for commands that should be added to beets' CLI. """ return () - def _set_stage_log_level(self, stages): + def _set_stage_log_level( + self, + stages: list[ImportStageFunc], + ) -> list[ImportStageFunc]: """Adjust all the stages in `stages` to WARNING logging level.""" return [ self._set_log_level_and_params(logging.WARNING, stage) for stage in stages ] - def get_early_import_stages(self): + def get_early_import_stages(self) -> list[ImportStageFunc]: """Return a list of functions that should be called as importer pipelines stages early in the pipeline. @@ -117,7 +167,7 @@ class BeetsPlugin: """ return self._set_stage_log_level(self.early_import_stages) - def get_import_stages(self): + def get_import_stages(self) -> list[ImportStageFunc]: """Return a list of functions that should be called as importer pipelines stages. @@ -127,7 +177,11 @@ class BeetsPlugin: """ return self._set_stage_log_level(self.import_stages) - def _set_log_level_and_params(self, base_log_level, func): + def _set_log_level_and_params( + self, + base_log_level: int, + func: Callable[P, Ret], + ) -> Callable[P, Ret]: """Wrap `func` to temporarily set this plugin's logger level to `base_log_level` + config options (and restore it to its previous value after the function returns). Also determines which params may not @@ -136,14 +190,14 @@ class BeetsPlugin: argspec = inspect.getfullargspec(func) @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Ret: assert self._log.level == logging.NOTSET verbosity = beets.config["verbose"].get(int) log_level = max(logging.DEBUG, base_log_level - 10 * verbosity) self._log.setLevel(log_level) if argspec.varkw is None: - kwargs = {k: v for k, v in kwargs.items() if k in argspec.args} + kwargs = {k: v for k, v in kwargs.items() if k in argspec.args} # type: ignore[assignment] try: return func(*args, **kwargs) @@ -152,55 +206,80 @@ class BeetsPlugin: return wrapper - def queries(self): + def queries(self) -> dict[str, type[Query]]: """Return a dict mapping prefixes to Query subclasses.""" return {} - def track_distance(self, item, info): + def track_distance( + self, + item: Item, + info: TrackInfo, + ) -> Distance: """Should return a Distance object to be added to the distance for every track comparison. """ - return beets.autotag.hooks.Distance() + from beets.autotag.hooks import Distance - def album_distance(self, items, album_info, mapping): + return Distance() + + def album_distance( + self, + items: list[Item], + album_info: AlbumInfo, + mapping: dict[Item, TrackInfo], + ) -> Distance: """Should return a Distance object to be added to the distance for every album-level comparison. """ - return beets.autotag.hooks.Distance() + from beets.autotag.hooks import Distance - def candidates(self, items, artist, album, va_likely, extra_tags=None): + return Distance() + + def candidates( + self, + items: list[Item], + artist: str, + album: str, + va_likely: bool, + extra_tags: dict[str, Any] | None = None, + ) -> Sequence[AlbumInfo]: """Should return a sequence of AlbumInfo objects that match the album whose items are provided. """ return () - def item_candidates(self, item, artist, title): + def item_candidates( + self, + item: Item, + artist: str, + title: str, + ) -> Sequence[TrackInfo]: """Should return a sequence of TrackInfo objects that match the item provided. """ return () - def album_for_id(self, album_id): + def album_for_id(self, album_id: str) -> AlbumInfo | None: """Return an AlbumInfo object or None if no matching release was found. """ return None - def track_for_id(self, track_id): + def track_for_id(self, track_id: str) -> TrackInfo | None: """Return a TrackInfo object or None if no matching release was found. """ return None - def add_media_field(self, name, descriptor): + def add_media_field( + self, name: str, descriptor: mediafile.MediaField + ) -> None: """Add a field that is synchronized between media files and items. When a media field is added ``item.write()`` will set the name property of the item's MediaFile to ``item[name]`` and save the changes. Similarly ``item.read()`` will set ``item[name]`` to the value of the name property of the media file. - - ``descriptor`` must be an instance of ``mediafile.MediaField``. """ # Defer import to prevent circular dependency from beets import library @@ -208,14 +287,15 @@ class BeetsPlugin: mediafile.MediaFile.add_field(name, descriptor) library.Item._media_fields.add(name) - _raw_listeners = None - listeners = None + _raw_listeners: dict[str, list[Listener]] | None = None + listeners: dict[str, list[Listener]] | None = None - def register_listener(self, event, func): + def register_listener(self, event: str, func: Listener) -> None: """Add a function as a listener for the specified event.""" wrapped_func = self._set_log_level_and_params(logging.WARNING, func) cls = self.__class__ + if cls.listeners is None or cls._raw_listeners is None: cls._raw_listeners = defaultdict(list) cls.listeners = defaultdict(list) @@ -223,18 +303,18 @@ class BeetsPlugin: cls._raw_listeners[event].append(func) cls.listeners[event].append(wrapped_func) - template_funcs = None - template_fields = None - album_template_fields = None + template_funcs: TFuncMap[str] | None = None + template_fields: TFuncMap[Item] | None = None + album_template_fields: TFuncMap[Album] | None = None @classmethod - def template_func(cls, name): + def template_func(cls, name: str) -> Callable[[TFunc[str]], TFunc[str]]: """Decorator that registers a path template function. The function will be invoked as ``%name{}`` from path format strings. """ - def helper(func): + def helper(func: TFunc[str]) -> TFunc[str]: if cls.template_funcs is None: cls.template_funcs = {} cls.template_funcs[name] = func @@ -243,14 +323,14 @@ class BeetsPlugin: return helper @classmethod - def template_field(cls, name): + def template_field(cls, name: str) -> Callable[[TFunc[Item]], TFunc[Item]]: """Decorator that registers a path template field computation. The value will be referenced as ``$name`` from path format strings. The function must accept a single parameter, the Item being formatted. """ - def helper(func): + def helper(func: TFunc[Item]) -> TFunc[Item]: if cls.template_fields is None: cls.template_fields = {} cls.template_fields[name] = func @@ -259,10 +339,10 @@ class BeetsPlugin: return helper -_classes = set() +_classes: set[type[BeetsPlugin]] = set() -def load_plugins(names=()): +def load_plugins(names: Sequence[str] = ()) -> None: """Imports the modules for a sequence of plugin names. Each name must be the name of a Python module under the "beetsplug" namespace package in sys.path; the module indicated should contain the @@ -285,6 +365,7 @@ def load_plugins(names=()): isinstance(obj, type) and issubclass(obj, BeetsPlugin) and obj != BeetsPlugin + and obj != MetadataSourcePlugin and obj not in _classes ): _classes.add(obj) @@ -300,7 +381,7 @@ def load_plugins(names=()): _instances: dict[type[BeetsPlugin], BeetsPlugin] = {} -def find_plugins(): +def find_plugins() -> list[BeetsPlugin]: """Returns a list of BeetsPlugin subclass instances from all currently loaded beets plugins. Loads the default plugin set first. @@ -323,28 +404,28 @@ def find_plugins(): # Communication with plugins. -def commands(): +def commands() -> list[Subcommand]: """Returns a list of Subcommand objects from all loaded plugins.""" - out = [] + out: list[Subcommand] = [] for plugin in find_plugins(): out += plugin.commands() return out -def queries(): +def queries() -> dict[str, type[Query]]: """Returns a dict mapping prefix strings to Query subclasses all loaded plugins. """ - out = {} + out: dict[str, type[Query]] = {} for plugin in find_plugins(): out.update(plugin.queries()) return out -def types(model_cls): +def types(model_cls: type[AnyModel]) -> dict[str, type[SQLiteType]]: # Gives us `item_types` and `album_types` attr_name = f"{model_cls.__name__.lower()}_types" - types = {} + types: dict[str, type[SQLiteType]] = {} for plugin in find_plugins(): plugin_types = getattr(plugin, attr_name, {}) for field in plugin_types: @@ -358,17 +439,17 @@ def types(model_cls): return types -def named_queries(model_cls): +def named_queries(model_cls: type[AnyModel]) -> dict[str, FieldQueryType]: # Gather `item_queries` and `album_queries` from the plugins. attr_name = f"{model_cls.__name__.lower()}_queries" - queries = {} + queries: dict[str, FieldQueryType] = {} for plugin in find_plugins(): plugin_queries = getattr(plugin, attr_name, {}) queries.update(plugin_queries) return queries -def track_distance(item, info): +def track_distance(item: Item, info: TrackInfo) -> Distance: """Gets the track distance calculated by all loaded plugins. Returns a Distance object. """ @@ -380,7 +461,11 @@ def track_distance(item, info): return dist -def album_distance(items, album_info, mapping): +def album_distance( + items: list[Item], + album_info: AlbumInfo, + mapping: dict[Item, TrackInfo], +) -> Distance: """Returns the album distance calculated by plugins.""" from beets.autotag.hooks import Distance @@ -390,7 +475,13 @@ def album_distance(items, album_info, mapping): return dist -def candidates(items, artist, album, va_likely, extra_tags=None): +def candidates( + items: list[Item], + artist: str, + album: str, + va_likely: bool, + extra_tags: dict[str, Any] | None = None, +) -> Iterable[AlbumInfo]: """Gets MusicBrainz candidates for an album from each plugin.""" for plugin in find_plugins(): yield from plugin.candidates( @@ -398,7 +489,7 @@ def candidates(items, artist, album, va_likely, extra_tags=None): ) -def item_candidates(item, artist, title): +def item_candidates(item: Item, artist: str, title: str) -> Iterable[TrackInfo]: """Gets MusicBrainz candidates for an item from the plugins.""" for plugin in find_plugins(): yield from plugin.item_candidates(item, artist, title) @@ -430,28 +521,28 @@ def track_for_id(_id: str) -> TrackInfo | None: return None -def template_funcs(): +def template_funcs() -> TFuncMap[str]: """Get all the template functions declared by plugins as a dictionary. """ - funcs = {} + funcs: TFuncMap[str] = {} for plugin in find_plugins(): if plugin.template_funcs: funcs.update(plugin.template_funcs) return funcs -def early_import_stages(): +def early_import_stages() -> list[ImportStageFunc]: """Get a list of early import stage functions defined by plugins.""" - stages = [] + stages: list[ImportStageFunc] = [] for plugin in find_plugins(): stages += plugin.get_early_import_stages() return stages -def import_stages(): +def import_stages() -> list[ImportStageFunc]: """Get a list of import stage functions defined by plugins.""" - stages = [] + stages: list[ImportStageFunc] = [] for plugin in find_plugins(): stages += plugin.get_import_stages() return stages @@ -459,8 +550,12 @@ def import_stages(): # New-style (lazy) plugin-provided fields. +F = TypeVar("F") -def _check_conflicts_and_merge(plugin, plugin_funcs, funcs): + +def _check_conflicts_and_merge( + plugin: BeetsPlugin, plugin_funcs: dict[str, F] | None, funcs: dict[str, F] +) -> None: """Check the provided template functions for conflicts and merge into funcs. Raises a `PluginConflictError` if a plugin defines template functions @@ -476,19 +571,19 @@ def _check_conflicts_and_merge(plugin, plugin_funcs, funcs): funcs.update(plugin_funcs) -def item_field_getters(): +def item_field_getters() -> TFuncMap[Item]: """Get a dictionary mapping field names to unary functions that compute the field's value. """ - funcs = {} + funcs: TFuncMap[Item] = {} for plugin in find_plugins(): _check_conflicts_and_merge(plugin, plugin.template_fields, funcs) return funcs -def album_field_getters(): +def album_field_getters() -> TFuncMap[Album]: """As above, for album fields.""" - funcs = {} + funcs: TFuncMap[Album] = {} for plugin in find_plugins(): _check_conflicts_and_merge(plugin, plugin.album_template_fields, funcs) return funcs @@ -497,11 +592,11 @@ def album_field_getters(): # Event dispatch. -def event_handlers(): +def event_handlers() -> dict[str, list[Listener]]: """Find all event handlers from plugins as a dictionary mapping event names to sequences of callables. """ - all_handlers = defaultdict(list) + all_handlers: dict[str, list[Listener]] = defaultdict(list) for plugin in find_plugins(): if plugin.listeners: for event, handlers in plugin.listeners.items(): @@ -509,7 +604,7 @@ def event_handlers(): return all_handlers -def send(event, **arguments): +def send(event: str, **arguments: Any) -> list[Any]: """Send an event to all assigned event listeners. `event` is the name of the event to send, all other named arguments @@ -518,7 +613,7 @@ def send(event, **arguments): Return a list of non-None values returned from the handlers. """ log.debug("Sending event: {0}", event) - results = [] + results: list[Any] = [] for handler in event_handlers()[event]: result = handler(**arguments) if result is not None: @@ -526,7 +621,7 @@ def send(event, **arguments): return results -def feat_tokens(for_artist=True): +def feat_tokens(for_artist: bool = True) -> str: """Return a regular expression that matches phrases like "featuring" that separate a main artist or a song title from secondary artists. The `for_artist` option determines whether the regex should be @@ -540,14 +635,16 @@ def feat_tokens(for_artist=True): ) -def sanitize_choices(choices, choices_all): +def sanitize_choices( + choices: Sequence[str], choices_all: Sequence[str] +) -> list[str]: """Clean up a stringlist configuration attribute: keep only choices elements present in choices_all, remove duplicate elements, expand '*' wildcard while keeping original stringlist order. """ - seen = set() + seen: set[str] = set() others = [x for x in choices_all if x not in choices] - res = [] + res: list[str] = [] for s in choices: if s not in seen: if s in list(choices_all): @@ -558,7 +655,9 @@ def sanitize_choices(choices, choices_all): return res -def sanitize_pairs(pairs, pairs_all): +def sanitize_pairs( + pairs: Sequence[tuple[str, str]], pairs_all: Sequence[tuple[str, str]] +) -> list[tuple[str, str]]: """Clean up a single-element mapping configuration attribute as returned by Confuse's `Pairs` template: keep only two-element tuples present in pairs_all, remove duplicate elements, expand ('str', '*') and ('*', '*') @@ -574,10 +673,10 @@ def sanitize_pairs(pairs, pairs_all): ... ) [('foo', 'baz'), ('foo', 'bar'), ('key', 'value'), ('foo', 'foobar')] """ - pairs_all = list(pairs_all) - seen = set() + pairs_all: list[tuple[str, str]] = list(pairs_all) + seen: set[tuple[str, str]] = set() others = [x for x in pairs_all if x not in pairs] - res = [] + res: list[tuple[str, str]] = [] for k, values in pairs: for v in values.split(): x = (k, v) @@ -596,7 +695,12 @@ def sanitize_pairs(pairs, pairs_all): return res -def notify_info_yielded(event): +IterF = Callable[P, Iterable[Ret]] + + +def notify_info_yielded( + event: str, +) -> Callable[[IterF[P, Ret]], IterF[P, Ret]]: """Makes a generator send the event 'event' every time it yields. This decorator is supposed to decorate a generator, but any function returning an iterable should work. @@ -604,8 +708,10 @@ def notify_info_yielded(event): 'send'. """ - def decorator(generator): - def decorated(*args, **kwargs): + def decorator( + generator: IterF[P, Ret], + ) -> IterF[P, Ret]: + def decorated(*args: P.args, **kwargs: P.kwargs) -> Iterable[Ret]: for v in generator(*args, **kwargs): send(event, info=v) yield v @@ -615,30 +721,31 @@ def notify_info_yielded(event): return decorator -def get_distance(config, data_source, info): +def get_distance( + config: ConfigView, data_source: str, info: AlbumInfo | TrackInfo +) -> Distance: """Returns the ``data_source`` weight and the maximum source weight for albums or individual tracks. """ - dist = beets.autotag.Distance() + from beets.autotag.hooks import Distance + + dist = Distance() if info.data_source == data_source: dist.add("source", config["source_weight"].as_number()) return dist -def apply_item_changes(lib, item, move, pretend, write): +def apply_item_changes( + lib: Library, item: Item, move: bool, pretend: bool, write: bool +) -> None: """Store, move, and write the item according to the arguments. :param lib: beets library. - :type lib: beets.library.Library :param item: Item whose changes to apply. - :type item: beets.library.Item :param move: Move the item if it's in the library. - :type move: bool :param pretend: Return without moving, writing, or storing the item's metadata. - :type pretend: bool :param write: Write the item's metadata to its media file. - :type write: bool """ if pretend: return @@ -655,45 +762,84 @@ def apply_item_changes(lib, item, move, pretend, write): item.store() -class MetadataSourcePlugin(metaclass=abc.ABCMeta): +class Response(TypedDict): + """A dictionary with the response of a plugin API call. + + May be extended by plugins to include additional information, but `id` + is required. + """ + + id: str + + +class RegexDict(TypedDict): + """A dictionary containing a regex pattern and the number of the + match group. + """ + + pattern: str + match_group: int + + +R = TypeVar("R", bound=Response) + + +class MetadataSourcePlugin(Generic[R], BeetsPlugin, metaclass=abc.ABCMeta): def __init__(self): super().__init__() self.config.add({"source_weight": 0.5}) - @abc.abstractproperty - def id_regex(self): + @property + @abc.abstractmethod + def id_regex(self) -> RegexDict: raise NotImplementedError - @abc.abstractproperty - def data_source(self): + @property + @abc.abstractmethod + def data_source(self) -> str: raise NotImplementedError - @abc.abstractproperty - def search_url(self): + @property + @abc.abstractmethod + def search_url(self) -> str: raise NotImplementedError - @abc.abstractproperty - def album_url(self): + @property + @abc.abstractmethod + def album_url(self) -> str: raise NotImplementedError - @abc.abstractproperty - def track_url(self): + @property + @abc.abstractmethod + def track_url(self) -> str: raise NotImplementedError @abc.abstractmethod - def _search_api(self, query_type, filters, keywords=""): + def _search_api( + self, + query_type: str, + filters: dict[str, str] | None, + keywords: str = "", + ) -> Sequence[R]: raise NotImplementedError @abc.abstractmethod - def album_for_id(self, album_id): + def album_for_id(self, album_id: str) -> AlbumInfo | None: raise NotImplementedError @abc.abstractmethod - def track_for_id(self, track_id=None, track_data=None): + def track_for_id( + self, track_id: str | None = None, track_data: R | None = None + ) -> TrackInfo | None: raise NotImplementedError @staticmethod - def get_artist(artists, id_key="id", name_key="name", join_key=None): + def get_artist( + artists, + id_key: str | int = "id", + name_key: str | int = "name", + join_key: str | int | None = None, + ) -> tuple[str, str | None]: """Returns an artist string (all artists) and an artist_id (the main artist) for a list of artist object dicts. @@ -708,18 +854,14 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta): :type artists: list[dict] or list[list] :param id_key: Key or index corresponding to the value of ``id`` for the main/first artist. Defaults to 'id'. - :type id_key: str or int :param name_key: Key or index corresponding to values of names to concatenate for the artist string (containing all artists). Defaults to 'name'. - :type name_key: str or int :param join_key: Key or index corresponding to a field containing a keyword to use for combining artists into a single string, for example "Feat.", "Vs.", "And" or similar. The default is None which keeps the default behaviour (comma-separated). - :type join_key: str or int :return: Normalized artist string. - :rtype: str """ artist_id = None artist_string = "" @@ -744,19 +886,15 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta): return artist_string, artist_id @staticmethod - def _get_id(url_type, id_, id_regex): + def _get_id(url_type: str, id_: str, id_regex: RegexDict) -> str | None: """Parse an ID from its URL if necessary. :param url_type: Type of URL. Either 'album' or 'track'. - :type url_type: str :param id_: Album/track ID or URL. - :type id_: str :param id_regex: A dictionary containing a regular expression extracting an ID from an URL (if it's not an ID already) in 'pattern' and the number of the match group in 'match_group'. - :type id_regex: dict :return: Album/track ID. - :rtype: str """ log.debug("Extracting {} ID from '{}'", url_type, id_) match = re.search(id_regex["pattern"].format(url_type), str(id_)) @@ -766,21 +904,22 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta): return id_ return None - def candidates(self, items, artist, album, va_likely, extra_tags=None): + def candidates( + self, + items: list[Item], + artist: str, + album: str, + va_likely: bool, + extra_tags: dict[str, Any] | None = None, + ) -> Sequence[AlbumInfo]: """Returns a list of AlbumInfo objects for Search API results matching an ``album`` and ``artist`` (if not various). :param items: List of items comprised by an album to be matched. - :type items: list[beets.library.Item] :param artist: The artist of the album to be matched. - :type artist: str :param album: The name of the album to be matched. - :type album: str :param va_likely: True if the album to be matched likely has Various Artists. - :type va_likely: bool - :return: Candidate AlbumInfo objects. - :rtype: list[beets.autotag.hooks.AlbumInfo] """ query_filters = {"album": album} if not va_likely: @@ -789,30 +928,35 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta): albums = [self.album_for_id(album_id=r["id"]) for r in results] return [a for a in albums if a is not None] - def item_candidates(self, item, artist, title): + def item_candidates( + self, item: Item, artist: str, title: str + ) -> Sequence[TrackInfo]: """Returns a list of TrackInfo objects for Search API results matching ``title`` and ``artist``. :param item: Singleton item to be matched. - :type item: beets.library.Item :param artist: The artist of the track to be matched. - :type artist: str :param title: The title of the track to be matched. - :type title: str - :return: Candidate TrackInfo objects. - :rtype: list[beets.autotag.hooks.TrackInfo] """ - tracks = self._search_api( + track_responses = self._search_api( query_type="track", keywords=title, filters={"artist": artist} ) - return [self.track_for_id(track_data=track) for track in tracks] - def album_distance(self, items, album_info, mapping): + tracks = [self.track_for_id(track_data=r) for r in track_responses] + + return [t for t in tracks if t is not None] + + def album_distance( + self, + items: list[Item], + album_info: AlbumInfo, + mapping: dict[Item, TrackInfo], + ) -> Distance: return get_distance( data_source=self.data_source, info=album_info, config=self.config ) - def track_distance(self, item, track_info): + def track_distance(self, item: Item, info: TrackInfo) -> Distance: return get_distance( - data_source=self.data_source, info=track_info, config=self.config + data_source=self.data_source, info=info, config=self.config ) diff --git a/docs/changelog.rst b/docs/changelog.rst index 77d237e6d..e52f329b0 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -96,10 +96,10 @@ Other changes: wrong (outdated) commit. Now the tag is created in the same workflow step right after committing the version update. :bug:`5539` -* Added some typehints: ImportSession and Pipeline have typehints now. Should - improve useability for new developers. * :doc:`/plugins/smartplaylist`: URL-encode additional item `fields` within generated EXTM3U playlists instead of JSON-encoding them. +* typehints: `./beets/importer.py` file now has improved typehints. +* typehints: `./beets/plugins.py` file now includes typehints. * :doc:`plugins/ftintitle`: Optimize the plugin by avoiding unnecessary writes to the database. * Database models are now serializable with pickle. diff --git a/docs/dev/plugins.rst b/docs/dev/plugins.rst index 96e69153d..0ebff3231 100644 --- a/docs/dev/plugins.rst +++ b/docs/dev/plugins.rst @@ -367,7 +367,7 @@ Here's an example:: super().__init__() self.template_funcs['initial'] = _tmpl_initial - def _tmpl_initial(text): + def _tmpl_initial(text: str) -> str: if text: return text[0].upper() else: @@ -387,7 +387,7 @@ Here's an example that adds a ``$disc_and_track`` field:: super().__init__() self.template_fields['disc_and_track'] = _tmpl_disc_and_track - def _tmpl_disc_and_track(item): + def _tmpl_disc_and_track(item: Item) -> str: """Expand to the disc number and track number if this is a multi-disc release. Otherwise, just expands to the track number. diff --git a/pyproject.toml b/pyproject.toml index d985c54ea..6b705f68c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -281,3 +281,12 @@ ignore-variadic-names = true [tool.ruff.lint.pep8-naming] classmethod-decorators = ["cached_classproperty"] extend-ignore-names = ["assert*", "cached_classproperty"] + +# Temporary, until we decide on a mypy +# config for all files. +[[tool.mypy.overrides]] +module = "beets.plugins" +disallow_untyped_decorators = true +disallow_any_generics = true +check_untyped_defs = true +allow_redefinition = true