Added typehints to the plugins file. (#5701)

Added some typehints to the `plugins.py` file. There are some typehints
missing but I think this is a good first step to get a better DevEx.
This commit is contained in:
Šarūnas Nejus 2025-04-19 20:20:30 +01:00 committed by GitHub
commit c5bfbde175
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 280 additions and 127 deletions

View file

@ -19,18 +19,50 @@ from __future__ import annotations
import abc import abc
import inspect import inspect
import re import re
import sys
import traceback import traceback
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Sequence,
TypedDict,
TypeVar,
)
import mediafile import mediafile
import beets import beets
from beets import logging from beets import logging
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
if TYPE_CHECKING: if TYPE_CHECKING:
from beets.autotag.hooks import AlbumInfo, TrackInfo from confuse import ConfigView
from beets.autotag import AlbumInfo, Distance, TrackInfo
from beets.dbcore import Query
from beets.dbcore.db import FieldQueryType, SQLiteType
from beets.importer import ImportSession, ImportTask
from beets.library import Album, Item, Library
from beets.ui import Subcommand
# TYPE_CHECKING guard is needed for any derived type
# which uses an import from `beets.library` and `beets.imported`
ImportStageFunc = Callable[[ImportSession, ImportTask], None]
T = TypeVar("T", Album, Item, str)
TFunc = Callable[[T], str]
TFuncMap = dict[str, TFunc[T]]
AnyModel = TypeVar("AnyModel", Album, Item)
PLUGIN_NAMESPACE = "beetsplug" PLUGIN_NAMESPACE = "beetsplug"
@ -42,6 +74,11 @@ LASTFM_KEY = "2dc3914abf35f0d9c92d97d8f8e42b43"
log = logging.getLogger("beets") log = logging.getLogger("beets")
P = ParamSpec("P")
Ret = TypeVar("Ret", bound=Any)
Listener = Callable[..., None]
class PluginConflictError(Exception): class PluginConflictError(Exception):
"""Indicates that the services provided by one plugin conflict with """Indicates that the services provided by one plugin conflict with
those of another. those of another.
@ -76,16 +113,26 @@ class BeetsPlugin:
the abstract methods defined here. the abstract methods defined here.
""" """
def __init__(self, name=None): name: str
config: ConfigView
early_import_stages: list[ImportStageFunc]
import_stages: list[ImportStageFunc]
def __init__(self, name: str | None = None):
"""Perform one-time plugin setup.""" """Perform one-time plugin setup."""
self.name = name or self.__module__.split(".")[-1] self.name = name or self.__module__.split(".")[-1]
self.config = beets.config[self.name] self.config = beets.config[self.name]
# Set class attributes if they are not already set
# for the type of plugin.
if not self.template_funcs: if not self.template_funcs:
self.template_funcs = {} self.template_funcs = {}
if not self.template_fields: if not self.template_fields:
self.template_fields = {} self.template_fields = {}
if not self.album_template_fields: if not self.album_template_fields:
self.album_template_fields = {} self.album_template_fields = {}
self.early_import_stages = [] self.early_import_stages = []
self.import_stages = [] self.import_stages = []
@ -94,20 +141,23 @@ class BeetsPlugin:
if not any(isinstance(f, PluginLogFilter) for f in self._log.filters): if not any(isinstance(f, PluginLogFilter) for f in self._log.filters):
self._log.addFilter(PluginLogFilter(self)) self._log.addFilter(PluginLogFilter(self))
def commands(self): def commands(self) -> Sequence[Subcommand]:
"""Should return a list of beets.ui.Subcommand objects for """Should return a list of beets.ui.Subcommand objects for
commands that should be added to beets' CLI. commands that should be added to beets' CLI.
""" """
return () return ()
def _set_stage_log_level(self, stages): def _set_stage_log_level(
self,
stages: list[ImportStageFunc],
) -> list[ImportStageFunc]:
"""Adjust all the stages in `stages` to WARNING logging level.""" """Adjust all the stages in `stages` to WARNING logging level."""
return [ return [
self._set_log_level_and_params(logging.WARNING, stage) self._set_log_level_and_params(logging.WARNING, stage)
for stage in stages for stage in stages
] ]
def get_early_import_stages(self): def get_early_import_stages(self) -> list[ImportStageFunc]:
"""Return a list of functions that should be called as importer """Return a list of functions that should be called as importer
pipelines stages early in the pipeline. pipelines stages early in the pipeline.
@ -117,7 +167,7 @@ class BeetsPlugin:
""" """
return self._set_stage_log_level(self.early_import_stages) return self._set_stage_log_level(self.early_import_stages)
def get_import_stages(self): def get_import_stages(self) -> list[ImportStageFunc]:
"""Return a list of functions that should be called as importer """Return a list of functions that should be called as importer
pipelines stages. pipelines stages.
@ -127,7 +177,11 @@ class BeetsPlugin:
""" """
return self._set_stage_log_level(self.import_stages) return self._set_stage_log_level(self.import_stages)
def _set_log_level_and_params(self, base_log_level, func): def _set_log_level_and_params(
self,
base_log_level: int,
func: Callable[P, Ret],
) -> Callable[P, Ret]:
"""Wrap `func` to temporarily set this plugin's logger level to """Wrap `func` to temporarily set this plugin's logger level to
`base_log_level` + config options (and restore it to its previous `base_log_level` + config options (and restore it to its previous
value after the function returns). Also determines which params may not value after the function returns). Also determines which params may not
@ -136,14 +190,14 @@ class BeetsPlugin:
argspec = inspect.getfullargspec(func) argspec = inspect.getfullargspec(func)
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args: P.args, **kwargs: P.kwargs) -> Ret:
assert self._log.level == logging.NOTSET assert self._log.level == logging.NOTSET
verbosity = beets.config["verbose"].get(int) verbosity = beets.config["verbose"].get(int)
log_level = max(logging.DEBUG, base_log_level - 10 * verbosity) log_level = max(logging.DEBUG, base_log_level - 10 * verbosity)
self._log.setLevel(log_level) self._log.setLevel(log_level)
if argspec.varkw is None: if argspec.varkw is None:
kwargs = {k: v for k, v in kwargs.items() if k in argspec.args} kwargs = {k: v for k, v in kwargs.items() if k in argspec.args} # type: ignore[assignment]
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
@ -152,55 +206,80 @@ class BeetsPlugin:
return wrapper return wrapper
def queries(self): def queries(self) -> dict[str, type[Query]]:
"""Return a dict mapping prefixes to Query subclasses.""" """Return a dict mapping prefixes to Query subclasses."""
return {} return {}
def track_distance(self, item, info): def track_distance(
self,
item: Item,
info: TrackInfo,
) -> Distance:
"""Should return a Distance object to be added to the """Should return a Distance object to be added to the
distance for every track comparison. distance for every track comparison.
""" """
return beets.autotag.hooks.Distance() from beets.autotag.hooks import Distance
def album_distance(self, items, album_info, mapping): return Distance()
def album_distance(
self,
items: list[Item],
album_info: AlbumInfo,
mapping: dict[Item, TrackInfo],
) -> Distance:
"""Should return a Distance object to be added to the """Should return a Distance object to be added to the
distance for every album-level comparison. distance for every album-level comparison.
""" """
return beets.autotag.hooks.Distance() from beets.autotag.hooks import Distance
def candidates(self, items, artist, album, va_likely, extra_tags=None): return Distance()
def candidates(
self,
items: list[Item],
artist: str,
album: str,
va_likely: bool,
extra_tags: dict[str, Any] | None = None,
) -> Sequence[AlbumInfo]:
"""Should return a sequence of AlbumInfo objects that match the """Should return a sequence of AlbumInfo objects that match the
album whose items are provided. album whose items are provided.
""" """
return () return ()
def item_candidates(self, item, artist, title): def item_candidates(
self,
item: Item,
artist: str,
title: str,
) -> Sequence[TrackInfo]:
"""Should return a sequence of TrackInfo objects that match the """Should return a sequence of TrackInfo objects that match the
item provided. item provided.
""" """
return () return ()
def album_for_id(self, album_id): def album_for_id(self, album_id: str) -> AlbumInfo | None:
"""Return an AlbumInfo object or None if no matching release was """Return an AlbumInfo object or None if no matching release was
found. found.
""" """
return None return None
def track_for_id(self, track_id): def track_for_id(self, track_id: str) -> TrackInfo | None:
"""Return a TrackInfo object or None if no matching release was """Return a TrackInfo object or None if no matching release was
found. found.
""" """
return None return None
def add_media_field(self, name, descriptor): def add_media_field(
self, name: str, descriptor: mediafile.MediaField
) -> None:
"""Add a field that is synchronized between media files and items. """Add a field that is synchronized between media files and items.
When a media field is added ``item.write()`` will set the name When a media field is added ``item.write()`` will set the name
property of the item's MediaFile to ``item[name]`` and save the property of the item's MediaFile to ``item[name]`` and save the
changes. Similarly ``item.read()`` will set ``item[name]`` to changes. Similarly ``item.read()`` will set ``item[name]`` to
the value of the name property of the media file. the value of the name property of the media file.
``descriptor`` must be an instance of ``mediafile.MediaField``.
""" """
# Defer import to prevent circular dependency # Defer import to prevent circular dependency
from beets import library from beets import library
@ -208,14 +287,15 @@ class BeetsPlugin:
mediafile.MediaFile.add_field(name, descriptor) mediafile.MediaFile.add_field(name, descriptor)
library.Item._media_fields.add(name) library.Item._media_fields.add(name)
_raw_listeners = None _raw_listeners: dict[str, list[Listener]] | None = None
listeners = None listeners: dict[str, list[Listener]] | None = None
def register_listener(self, event, func): def register_listener(self, event: str, func: Listener) -> None:
"""Add a function as a listener for the specified event.""" """Add a function as a listener for the specified event."""
wrapped_func = self._set_log_level_and_params(logging.WARNING, func) wrapped_func = self._set_log_level_and_params(logging.WARNING, func)
cls = self.__class__ cls = self.__class__
if cls.listeners is None or cls._raw_listeners is None: if cls.listeners is None or cls._raw_listeners is None:
cls._raw_listeners = defaultdict(list) cls._raw_listeners = defaultdict(list)
cls.listeners = defaultdict(list) cls.listeners = defaultdict(list)
@ -223,18 +303,18 @@ class BeetsPlugin:
cls._raw_listeners[event].append(func) cls._raw_listeners[event].append(func)
cls.listeners[event].append(wrapped_func) cls.listeners[event].append(wrapped_func)
template_funcs = None template_funcs: TFuncMap[str] | None = None
template_fields = None template_fields: TFuncMap[Item] | None = None
album_template_fields = None album_template_fields: TFuncMap[Album] | None = None
@classmethod @classmethod
def template_func(cls, name): def template_func(cls, name: str) -> Callable[[TFunc[str]], TFunc[str]]:
"""Decorator that registers a path template function. The """Decorator that registers a path template function. The
function will be invoked as ``%name{}`` from path format function will be invoked as ``%name{}`` from path format
strings. strings.
""" """
def helper(func): def helper(func: TFunc[str]) -> TFunc[str]:
if cls.template_funcs is None: if cls.template_funcs is None:
cls.template_funcs = {} cls.template_funcs = {}
cls.template_funcs[name] = func cls.template_funcs[name] = func
@ -243,14 +323,14 @@ class BeetsPlugin:
return helper return helper
@classmethod @classmethod
def template_field(cls, name): def template_field(cls, name: str) -> Callable[[TFunc[Item]], TFunc[Item]]:
"""Decorator that registers a path template field computation. """Decorator that registers a path template field computation.
The value will be referenced as ``$name`` from path format The value will be referenced as ``$name`` from path format
strings. The function must accept a single parameter, the Item strings. The function must accept a single parameter, the Item
being formatted. being formatted.
""" """
def helper(func): def helper(func: TFunc[Item]) -> TFunc[Item]:
if cls.template_fields is None: if cls.template_fields is None:
cls.template_fields = {} cls.template_fields = {}
cls.template_fields[name] = func cls.template_fields[name] = func
@ -259,10 +339,10 @@ class BeetsPlugin:
return helper return helper
_classes = set() _classes: set[type[BeetsPlugin]] = set()
def load_plugins(names=()): def load_plugins(names: Sequence[str] = ()) -> None:
"""Imports the modules for a sequence of plugin names. Each name """Imports the modules for a sequence of plugin names. Each name
must be the name of a Python module under the "beetsplug" namespace must be the name of a Python module under the "beetsplug" namespace
package in sys.path; the module indicated should contain the package in sys.path; the module indicated should contain the
@ -285,6 +365,7 @@ def load_plugins(names=()):
isinstance(obj, type) isinstance(obj, type)
and issubclass(obj, BeetsPlugin) and issubclass(obj, BeetsPlugin)
and obj != BeetsPlugin and obj != BeetsPlugin
and obj != MetadataSourcePlugin
and obj not in _classes and obj not in _classes
): ):
_classes.add(obj) _classes.add(obj)
@ -300,7 +381,7 @@ def load_plugins(names=()):
_instances: dict[type[BeetsPlugin], BeetsPlugin] = {} _instances: dict[type[BeetsPlugin], BeetsPlugin] = {}
def find_plugins(): def find_plugins() -> list[BeetsPlugin]:
"""Returns a list of BeetsPlugin subclass instances from all """Returns a list of BeetsPlugin subclass instances from all
currently loaded beets plugins. Loads the default plugin set currently loaded beets plugins. Loads the default plugin set
first. first.
@ -323,28 +404,28 @@ def find_plugins():
# Communication with plugins. # Communication with plugins.
def commands(): def commands() -> list[Subcommand]:
"""Returns a list of Subcommand objects from all loaded plugins.""" """Returns a list of Subcommand objects from all loaded plugins."""
out = [] out: list[Subcommand] = []
for plugin in find_plugins(): for plugin in find_plugins():
out += plugin.commands() out += plugin.commands()
return out return out
def queries(): def queries() -> dict[str, type[Query]]:
"""Returns a dict mapping prefix strings to Query subclasses all loaded """Returns a dict mapping prefix strings to Query subclasses all loaded
plugins. plugins.
""" """
out = {} out: dict[str, type[Query]] = {}
for plugin in find_plugins(): for plugin in find_plugins():
out.update(plugin.queries()) out.update(plugin.queries())
return out return out
def types(model_cls): def types(model_cls: type[AnyModel]) -> dict[str, type[SQLiteType]]:
# Gives us `item_types` and `album_types` # Gives us `item_types` and `album_types`
attr_name = f"{model_cls.__name__.lower()}_types" attr_name = f"{model_cls.__name__.lower()}_types"
types = {} types: dict[str, type[SQLiteType]] = {}
for plugin in find_plugins(): for plugin in find_plugins():
plugin_types = getattr(plugin, attr_name, {}) plugin_types = getattr(plugin, attr_name, {})
for field in plugin_types: for field in plugin_types:
@ -358,17 +439,17 @@ def types(model_cls):
return types return types
def named_queries(model_cls): def named_queries(model_cls: type[AnyModel]) -> dict[str, FieldQueryType]:
# Gather `item_queries` and `album_queries` from the plugins. # Gather `item_queries` and `album_queries` from the plugins.
attr_name = f"{model_cls.__name__.lower()}_queries" attr_name = f"{model_cls.__name__.lower()}_queries"
queries = {} queries: dict[str, FieldQueryType] = {}
for plugin in find_plugins(): for plugin in find_plugins():
plugin_queries = getattr(plugin, attr_name, {}) plugin_queries = getattr(plugin, attr_name, {})
queries.update(plugin_queries) queries.update(plugin_queries)
return queries return queries
def track_distance(item, info): def track_distance(item: Item, info: TrackInfo) -> Distance:
"""Gets the track distance calculated by all loaded plugins. """Gets the track distance calculated by all loaded plugins.
Returns a Distance object. Returns a Distance object.
""" """
@ -380,7 +461,11 @@ def track_distance(item, info):
return dist return dist
def album_distance(items, album_info, mapping): def album_distance(
items: list[Item],
album_info: AlbumInfo,
mapping: dict[Item, TrackInfo],
) -> Distance:
"""Returns the album distance calculated by plugins.""" """Returns the album distance calculated by plugins."""
from beets.autotag.hooks import Distance from beets.autotag.hooks import Distance
@ -390,7 +475,13 @@ def album_distance(items, album_info, mapping):
return dist return dist
def candidates(items, artist, album, va_likely, extra_tags=None): def candidates(
items: list[Item],
artist: str,
album: str,
va_likely: bool,
extra_tags: dict[str, Any] | None = None,
) -> Iterable[AlbumInfo]:
"""Gets MusicBrainz candidates for an album from each plugin.""" """Gets MusicBrainz candidates for an album from each plugin."""
for plugin in find_plugins(): for plugin in find_plugins():
yield from plugin.candidates( yield from plugin.candidates(
@ -398,7 +489,7 @@ def candidates(items, artist, album, va_likely, extra_tags=None):
) )
def item_candidates(item, artist, title): def item_candidates(item: Item, artist: str, title: str) -> Iterable[TrackInfo]:
"""Gets MusicBrainz candidates for an item from the plugins.""" """Gets MusicBrainz candidates for an item from the plugins."""
for plugin in find_plugins(): for plugin in find_plugins():
yield from plugin.item_candidates(item, artist, title) yield from plugin.item_candidates(item, artist, title)
@ -430,28 +521,28 @@ def track_for_id(_id: str) -> TrackInfo | None:
return None return None
def template_funcs(): def template_funcs() -> TFuncMap[str]:
"""Get all the template functions declared by plugins as a """Get all the template functions declared by plugins as a
dictionary. dictionary.
""" """
funcs = {} funcs: TFuncMap[str] = {}
for plugin in find_plugins(): for plugin in find_plugins():
if plugin.template_funcs: if plugin.template_funcs:
funcs.update(plugin.template_funcs) funcs.update(plugin.template_funcs)
return funcs return funcs
def early_import_stages(): def early_import_stages() -> list[ImportStageFunc]:
"""Get a list of early import stage functions defined by plugins.""" """Get a list of early import stage functions defined by plugins."""
stages = [] stages: list[ImportStageFunc] = []
for plugin in find_plugins(): for plugin in find_plugins():
stages += plugin.get_early_import_stages() stages += plugin.get_early_import_stages()
return stages return stages
def import_stages(): def import_stages() -> list[ImportStageFunc]:
"""Get a list of import stage functions defined by plugins.""" """Get a list of import stage functions defined by plugins."""
stages = [] stages: list[ImportStageFunc] = []
for plugin in find_plugins(): for plugin in find_plugins():
stages += plugin.get_import_stages() stages += plugin.get_import_stages()
return stages return stages
@ -459,8 +550,12 @@ def import_stages():
# New-style (lazy) plugin-provided fields. # New-style (lazy) plugin-provided fields.
F = TypeVar("F")
def _check_conflicts_and_merge(plugin, plugin_funcs, funcs):
def _check_conflicts_and_merge(
plugin: BeetsPlugin, plugin_funcs: dict[str, F] | None, funcs: dict[str, F]
) -> None:
"""Check the provided template functions for conflicts and merge into funcs. """Check the provided template functions for conflicts and merge into funcs.
Raises a `PluginConflictError` if a plugin defines template functions Raises a `PluginConflictError` if a plugin defines template functions
@ -476,19 +571,19 @@ def _check_conflicts_and_merge(plugin, plugin_funcs, funcs):
funcs.update(plugin_funcs) funcs.update(plugin_funcs)
def item_field_getters(): def item_field_getters() -> TFuncMap[Item]:
"""Get a dictionary mapping field names to unary functions that """Get a dictionary mapping field names to unary functions that
compute the field's value. compute the field's value.
""" """
funcs = {} funcs: TFuncMap[Item] = {}
for plugin in find_plugins(): for plugin in find_plugins():
_check_conflicts_and_merge(plugin, plugin.template_fields, funcs) _check_conflicts_and_merge(plugin, plugin.template_fields, funcs)
return funcs return funcs
def album_field_getters(): def album_field_getters() -> TFuncMap[Album]:
"""As above, for album fields.""" """As above, for album fields."""
funcs = {} funcs: TFuncMap[Album] = {}
for plugin in find_plugins(): for plugin in find_plugins():
_check_conflicts_and_merge(plugin, plugin.album_template_fields, funcs) _check_conflicts_and_merge(plugin, plugin.album_template_fields, funcs)
return funcs return funcs
@ -497,11 +592,11 @@ def album_field_getters():
# Event dispatch. # Event dispatch.
def event_handlers(): def event_handlers() -> dict[str, list[Listener]]:
"""Find all event handlers from plugins as a dictionary mapping """Find all event handlers from plugins as a dictionary mapping
event names to sequences of callables. event names to sequences of callables.
""" """
all_handlers = defaultdict(list) all_handlers: dict[str, list[Listener]] = defaultdict(list)
for plugin in find_plugins(): for plugin in find_plugins():
if plugin.listeners: if plugin.listeners:
for event, handlers in plugin.listeners.items(): for event, handlers in plugin.listeners.items():
@ -509,7 +604,7 @@ def event_handlers():
return all_handlers return all_handlers
def send(event, **arguments): def send(event: str, **arguments: Any) -> list[Any]:
"""Send an event to all assigned event listeners. """Send an event to all assigned event listeners.
`event` is the name of the event to send, all other named arguments `event` is the name of the event to send, all other named arguments
@ -518,7 +613,7 @@ def send(event, **arguments):
Return a list of non-None values returned from the handlers. Return a list of non-None values returned from the handlers.
""" """
log.debug("Sending event: {0}", event) log.debug("Sending event: {0}", event)
results = [] results: list[Any] = []
for handler in event_handlers()[event]: for handler in event_handlers()[event]:
result = handler(**arguments) result = handler(**arguments)
if result is not None: if result is not None:
@ -526,7 +621,7 @@ def send(event, **arguments):
return results return results
def feat_tokens(for_artist=True): def feat_tokens(for_artist: bool = True) -> str:
"""Return a regular expression that matches phrases like "featuring" """Return a regular expression that matches phrases like "featuring"
that separate a main artist or a song title from secondary artists. that separate a main artist or a song title from secondary artists.
The `for_artist` option determines whether the regex should be The `for_artist` option determines whether the regex should be
@ -540,14 +635,16 @@ def feat_tokens(for_artist=True):
) )
def sanitize_choices(choices, choices_all): def sanitize_choices(
choices: Sequence[str], choices_all: Sequence[str]
) -> list[str]:
"""Clean up a stringlist configuration attribute: keep only choices """Clean up a stringlist configuration attribute: keep only choices
elements present in choices_all, remove duplicate elements, expand '*' elements present in choices_all, remove duplicate elements, expand '*'
wildcard while keeping original stringlist order. wildcard while keeping original stringlist order.
""" """
seen = set() seen: set[str] = set()
others = [x for x in choices_all if x not in choices] others = [x for x in choices_all if x not in choices]
res = [] res: list[str] = []
for s in choices: for s in choices:
if s not in seen: if s not in seen:
if s in list(choices_all): if s in list(choices_all):
@ -558,7 +655,9 @@ def sanitize_choices(choices, choices_all):
return res return res
def sanitize_pairs(pairs, pairs_all): def sanitize_pairs(
pairs: Sequence[tuple[str, str]], pairs_all: Sequence[tuple[str, str]]
) -> list[tuple[str, str]]:
"""Clean up a single-element mapping configuration attribute as returned """Clean up a single-element mapping configuration attribute as returned
by Confuse's `Pairs` template: keep only two-element tuples present in by Confuse's `Pairs` template: keep only two-element tuples present in
pairs_all, remove duplicate elements, expand ('str', '*') and ('*', '*') pairs_all, remove duplicate elements, expand ('str', '*') and ('*', '*')
@ -574,10 +673,10 @@ def sanitize_pairs(pairs, pairs_all):
... ) ... )
[('foo', 'baz'), ('foo', 'bar'), ('key', 'value'), ('foo', 'foobar')] [('foo', 'baz'), ('foo', 'bar'), ('key', 'value'), ('foo', 'foobar')]
""" """
pairs_all = list(pairs_all) pairs_all: list[tuple[str, str]] = list(pairs_all)
seen = set() seen: set[tuple[str, str]] = set()
others = [x for x in pairs_all if x not in pairs] others = [x for x in pairs_all if x not in pairs]
res = [] res: list[tuple[str, str]] = []
for k, values in pairs: for k, values in pairs:
for v in values.split(): for v in values.split():
x = (k, v) x = (k, v)
@ -596,7 +695,12 @@ def sanitize_pairs(pairs, pairs_all):
return res return res
def notify_info_yielded(event): IterF = Callable[P, Iterable[Ret]]
def notify_info_yielded(
event: str,
) -> Callable[[IterF[P, Ret]], IterF[P, Ret]]:
"""Makes a generator send the event 'event' every time it yields. """Makes a generator send the event 'event' every time it yields.
This decorator is supposed to decorate a generator, but any function This decorator is supposed to decorate a generator, but any function
returning an iterable should work. returning an iterable should work.
@ -604,8 +708,10 @@ def notify_info_yielded(event):
'send'. 'send'.
""" """
def decorator(generator): def decorator(
def decorated(*args, **kwargs): generator: IterF[P, Ret],
) -> IterF[P, Ret]:
def decorated(*args: P.args, **kwargs: P.kwargs) -> Iterable[Ret]:
for v in generator(*args, **kwargs): for v in generator(*args, **kwargs):
send(event, info=v) send(event, info=v)
yield v yield v
@ -615,30 +721,31 @@ def notify_info_yielded(event):
return decorator return decorator
def get_distance(config, data_source, info): def get_distance(
config: ConfigView, data_source: str, info: AlbumInfo | TrackInfo
) -> Distance:
"""Returns the ``data_source`` weight and the maximum source weight """Returns the ``data_source`` weight and the maximum source weight
for albums or individual tracks. for albums or individual tracks.
""" """
dist = beets.autotag.Distance() from beets.autotag.hooks import Distance
dist = Distance()
if info.data_source == data_source: if info.data_source == data_source:
dist.add("source", config["source_weight"].as_number()) dist.add("source", config["source_weight"].as_number())
return dist return dist
def apply_item_changes(lib, item, move, pretend, write): def apply_item_changes(
lib: Library, item: Item, move: bool, pretend: bool, write: bool
) -> None:
"""Store, move, and write the item according to the arguments. """Store, move, and write the item according to the arguments.
:param lib: beets library. :param lib: beets library.
:type lib: beets.library.Library
:param item: Item whose changes to apply. :param item: Item whose changes to apply.
:type item: beets.library.Item
:param move: Move the item if it's in the library. :param move: Move the item if it's in the library.
:type move: bool
:param pretend: Return without moving, writing, or storing the item's :param pretend: Return without moving, writing, or storing the item's
metadata. metadata.
:type pretend: bool
:param write: Write the item's metadata to its media file. :param write: Write the item's metadata to its media file.
:type write: bool
""" """
if pretend: if pretend:
return return
@ -655,45 +762,84 @@ def apply_item_changes(lib, item, move, pretend, write):
item.store() item.store()
class MetadataSourcePlugin(metaclass=abc.ABCMeta): class Response(TypedDict):
"""A dictionary with the response of a plugin API call.
May be extended by plugins to include additional information, but `id`
is required.
"""
id: str
class RegexDict(TypedDict):
"""A dictionary containing a regex pattern and the number of the
match group.
"""
pattern: str
match_group: int
R = TypeVar("R", bound=Response)
class MetadataSourcePlugin(Generic[R], BeetsPlugin, metaclass=abc.ABCMeta):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.config.add({"source_weight": 0.5}) self.config.add({"source_weight": 0.5})
@abc.abstractproperty @property
def id_regex(self): @abc.abstractmethod
def id_regex(self) -> RegexDict:
raise NotImplementedError raise NotImplementedError
@abc.abstractproperty @property
def data_source(self): @abc.abstractmethod
def data_source(self) -> str:
raise NotImplementedError raise NotImplementedError
@abc.abstractproperty @property
def search_url(self): @abc.abstractmethod
def search_url(self) -> str:
raise NotImplementedError raise NotImplementedError
@abc.abstractproperty @property
def album_url(self): @abc.abstractmethod
def album_url(self) -> str:
raise NotImplementedError raise NotImplementedError
@abc.abstractproperty @property
def track_url(self): @abc.abstractmethod
def track_url(self) -> str:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def _search_api(self, query_type, filters, keywords=""): def _search_api(
self,
query_type: str,
filters: dict[str, str] | None,
keywords: str = "",
) -> Sequence[R]:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def album_for_id(self, album_id): def album_for_id(self, album_id: str) -> AlbumInfo | None:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def track_for_id(self, track_id=None, track_data=None): def track_for_id(
self, track_id: str | None = None, track_data: R | None = None
) -> TrackInfo | None:
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
def get_artist(artists, id_key="id", name_key="name", join_key=None): def get_artist(
artists,
id_key: str | int = "id",
name_key: str | int = "name",
join_key: str | int | None = None,
) -> tuple[str, str | None]:
"""Returns an artist string (all artists) and an artist_id (the main """Returns an artist string (all artists) and an artist_id (the main
artist) for a list of artist object dicts. artist) for a list of artist object dicts.
@ -708,18 +854,14 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta):
:type artists: list[dict] or list[list] :type artists: list[dict] or list[list]
:param id_key: Key or index corresponding to the value of ``id`` for :param id_key: Key or index corresponding to the value of ``id`` for
the main/first artist. Defaults to 'id'. the main/first artist. Defaults to 'id'.
:type id_key: str or int
:param name_key: Key or index corresponding to values of names :param name_key: Key or index corresponding to values of names
to concatenate for the artist string (containing all artists). to concatenate for the artist string (containing all artists).
Defaults to 'name'. Defaults to 'name'.
:type name_key: str or int
:param join_key: Key or index corresponding to a field containing a :param join_key: Key or index corresponding to a field containing a
keyword to use for combining artists into a single string, for keyword to use for combining artists into a single string, for
example "Feat.", "Vs.", "And" or similar. The default is None example "Feat.", "Vs.", "And" or similar. The default is None
which keeps the default behaviour (comma-separated). which keeps the default behaviour (comma-separated).
:type join_key: str or int
:return: Normalized artist string. :return: Normalized artist string.
:rtype: str
""" """
artist_id = None artist_id = None
artist_string = "" artist_string = ""
@ -744,19 +886,15 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta):
return artist_string, artist_id return artist_string, artist_id
@staticmethod @staticmethod
def _get_id(url_type, id_, id_regex): def _get_id(url_type: str, id_: str, id_regex: RegexDict) -> str | None:
"""Parse an ID from its URL if necessary. """Parse an ID from its URL if necessary.
:param url_type: Type of URL. Either 'album' or 'track'. :param url_type: Type of URL. Either 'album' or 'track'.
:type url_type: str
:param id_: Album/track ID or URL. :param id_: Album/track ID or URL.
:type id_: str
:param id_regex: A dictionary containing a regular expression :param id_regex: A dictionary containing a regular expression
extracting an ID from an URL (if it's not an ID already) in extracting an ID from an URL (if it's not an ID already) in
'pattern' and the number of the match group in 'match_group'. 'pattern' and the number of the match group in 'match_group'.
:type id_regex: dict
:return: Album/track ID. :return: Album/track ID.
:rtype: str
""" """
log.debug("Extracting {} ID from '{}'", url_type, id_) log.debug("Extracting {} ID from '{}'", url_type, id_)
match = re.search(id_regex["pattern"].format(url_type), str(id_)) match = re.search(id_regex["pattern"].format(url_type), str(id_))
@ -766,21 +904,22 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta):
return id_ return id_
return None return None
def candidates(self, items, artist, album, va_likely, extra_tags=None): def candidates(
self,
items: list[Item],
artist: str,
album: str,
va_likely: bool,
extra_tags: dict[str, Any] | None = None,
) -> Sequence[AlbumInfo]:
"""Returns a list of AlbumInfo objects for Search API results """Returns a list of AlbumInfo objects for Search API results
matching an ``album`` and ``artist`` (if not various). matching an ``album`` and ``artist`` (if not various).
:param items: List of items comprised by an album to be matched. :param items: List of items comprised by an album to be matched.
:type items: list[beets.library.Item]
:param artist: The artist of the album to be matched. :param artist: The artist of the album to be matched.
:type artist: str
:param album: The name of the album to be matched. :param album: The name of the album to be matched.
:type album: str
:param va_likely: True if the album to be matched likely has :param va_likely: True if the album to be matched likely has
Various Artists. Various Artists.
:type va_likely: bool
:return: Candidate AlbumInfo objects.
:rtype: list[beets.autotag.hooks.AlbumInfo]
""" """
query_filters = {"album": album} query_filters = {"album": album}
if not va_likely: if not va_likely:
@ -789,30 +928,35 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta):
albums = [self.album_for_id(album_id=r["id"]) for r in results] albums = [self.album_for_id(album_id=r["id"]) for r in results]
return [a for a in albums if a is not None] return [a for a in albums if a is not None]
def item_candidates(self, item, artist, title): def item_candidates(
self, item: Item, artist: str, title: str
) -> Sequence[TrackInfo]:
"""Returns a list of TrackInfo objects for Search API results """Returns a list of TrackInfo objects for Search API results
matching ``title`` and ``artist``. matching ``title`` and ``artist``.
:param item: Singleton item to be matched. :param item: Singleton item to be matched.
:type item: beets.library.Item
:param artist: The artist of the track to be matched. :param artist: The artist of the track to be matched.
:type artist: str
:param title: The title of the track to be matched. :param title: The title of the track to be matched.
:type title: str
:return: Candidate TrackInfo objects.
:rtype: list[beets.autotag.hooks.TrackInfo]
""" """
tracks = self._search_api( track_responses = self._search_api(
query_type="track", keywords=title, filters={"artist": artist} query_type="track", keywords=title, filters={"artist": artist}
) )
return [self.track_for_id(track_data=track) for track in tracks]
def album_distance(self, items, album_info, mapping): tracks = [self.track_for_id(track_data=r) for r in track_responses]
return [t for t in tracks if t is not None]
def album_distance(
self,
items: list[Item],
album_info: AlbumInfo,
mapping: dict[Item, TrackInfo],
) -> Distance:
return get_distance( return get_distance(
data_source=self.data_source, info=album_info, config=self.config data_source=self.data_source, info=album_info, config=self.config
) )
def track_distance(self, item, track_info): def track_distance(self, item: Item, info: TrackInfo) -> Distance:
return get_distance( return get_distance(
data_source=self.data_source, info=track_info, config=self.config data_source=self.data_source, info=info, config=self.config
) )

View file

@ -96,10 +96,10 @@ Other changes:
wrong (outdated) commit. Now the tag is created in the same workflow step wrong (outdated) commit. Now the tag is created in the same workflow step
right after committing the version update. right after committing the version update.
:bug:`5539` :bug:`5539`
* Added some typehints: ImportSession and Pipeline have typehints now. Should
improve useability for new developers.
* :doc:`/plugins/smartplaylist`: URL-encode additional item `fields` within generated * :doc:`/plugins/smartplaylist`: URL-encode additional item `fields` within generated
EXTM3U playlists instead of JSON-encoding them. EXTM3U playlists instead of JSON-encoding them.
* typehints: `./beets/importer.py` file now has improved typehints.
* typehints: `./beets/plugins.py` file now includes typehints.
* :doc:`plugins/ftintitle`: Optimize the plugin by avoiding unnecessary writes * :doc:`plugins/ftintitle`: Optimize the plugin by avoiding unnecessary writes
to the database. to the database.
* Database models are now serializable with pickle. * Database models are now serializable with pickle.

View file

@ -367,7 +367,7 @@ Here's an example::
super().__init__() super().__init__()
self.template_funcs['initial'] = _tmpl_initial self.template_funcs['initial'] = _tmpl_initial
def _tmpl_initial(text): def _tmpl_initial(text: str) -> str:
if text: if text:
return text[0].upper() return text[0].upper()
else: else:
@ -387,7 +387,7 @@ Here's an example that adds a ``$disc_and_track`` field::
super().__init__() super().__init__()
self.template_fields['disc_and_track'] = _tmpl_disc_and_track self.template_fields['disc_and_track'] = _tmpl_disc_and_track
def _tmpl_disc_and_track(item): def _tmpl_disc_and_track(item: Item) -> str:
"""Expand to the disc number and track number if this is a """Expand to the disc number and track number if this is a
multi-disc release. Otherwise, just expands to the track multi-disc release. Otherwise, just expands to the track
number. number.

View file

@ -281,3 +281,12 @@ ignore-variadic-names = true
[tool.ruff.lint.pep8-naming] [tool.ruff.lint.pep8-naming]
classmethod-decorators = ["cached_classproperty"] classmethod-decorators = ["cached_classproperty"]
extend-ignore-names = ["assert*", "cached_classproperty"] extend-ignore-names = ["assert*", "cached_classproperty"]
# Temporary, until we decide on a mypy
# config for all files.
[[tool.mypy.overrides]]
module = "beets.plugins"
disallow_untyped_decorators = true
disallow_any_generics = true
check_untyped_defs = true
allow_redefinition = true