Added typehints to the plugins file.

This commit is contained in:
Sebastian Mohr 2025-04-02 15:57:16 +02:00
parent 030fd1fcf5
commit 0cc0db313a

View file

@ -14,17 +14,37 @@
"""Support for beets plugins."""
from __future__ import annotations
import abc
import inspect
import re
import traceback
from collections import defaultdict
from functools import wraps
from typing import (
TYPE_CHECKING,
Callable,
Generic,
Sequence,
TypedDict,
TypeVar,
)
import mediafile
import beets
from beets import logging
from beets.autotag import Distance
if TYPE_CHECKING:
from collections.abc import Iterable
from beets.autotag import AlbumInfo, TrackInfo
from beets.dbcore import Query
from beets.library import Item
from beets.ui import Subcommand
PLUGIN_NAMESPACE = "beetsplug"
@ -145,55 +165,74 @@ class BeetsPlugin:
return wrapper
def queries(self):
def queries(self) -> dict[str, type[Query]]:
"""Return a dict mapping prefixes to Query subclasses."""
return {}
def track_distance(self, item, info):
def track_distance(
self,
item: Item,
info: TrackInfo,
) -> Distance:
"""Should return a Distance object to be added to the
distance for every track comparison.
"""
return beets.autotag.hooks.Distance()
return Distance()
def album_distance(self, items, album_info, mapping):
def album_distance(
self,
items: list[Item],
album_info: AlbumInfo,
mapping: dict[Item, TrackInfo],
) -> Distance:
"""Should return a Distance object to be added to the
distance for every album-level comparison.
"""
return beets.autotag.hooks.Distance()
return Distance()
def candidates(self, items, artist, album, va_likely, extra_tags=None):
def candidates(
self,
items: list[Item],
artist: str,
album: str,
va_likely: bool,
extra_tags=None,
) -> Sequence[AlbumInfo]:
"""Should return a sequence of AlbumInfo objects that match the
album whose items are provided.
"""
return ()
def item_candidates(self, item, artist, title):
def item_candidates(
self,
item: Item,
artist: str,
title: str,
) -> Sequence[TrackInfo]:
"""Should return a sequence of TrackInfo objects that match the
item provided.
"""
return ()
def album_for_id(self, album_id):
def album_for_id(self, album_id: str) -> AlbumInfo | None:
"""Return an AlbumInfo object or None if no matching release was
found.
"""
return None
def track_for_id(self, track_id):
def track_for_id(self, track_id: str) -> TrackInfo | None:
"""Return a TrackInfo object or None if no matching release was
found.
"""
return None
def add_media_field(self, name, descriptor):
def add_media_field(self, name: str, descriptor: mediafile.MediaField):
"""Add a field that is synchronized between media files and items.
When a media field is added ``item.write()`` will set the name
property of the item's MediaFile to ``item[name]`` and save the
changes. Similarly ``item.read()`` will set ``item[name]`` to
the value of the name property of the media file.
``descriptor`` must be an instance of ``mediafile.MediaField``.
"""
# Defer import to prevent circular dependency
from beets import library
@ -202,9 +241,9 @@ class BeetsPlugin:
library.Item._media_fields.add(name)
_raw_listeners = None
listeners = None
listeners: None | dict[str, list[Callable]] = None
def register_listener(self, event, func):
def register_listener(self, event: str, func: Callable):
"""Add a function as a listener for the specified event."""
wrapped_func = self._set_log_level_and_params(logging.WARNING, func)
@ -221,7 +260,7 @@ class BeetsPlugin:
album_template_fields = None
@classmethod
def template_func(cls, name):
def template_func(cls, name: str):
"""Decorator that registers a path template function. The
function will be invoked as ``%name{}`` from path format
strings.
@ -236,7 +275,7 @@ class BeetsPlugin:
return helper
@classmethod
def template_field(cls, name):
def template_field(cls, name: str):
"""Decorator that registers a path template field computation.
The value will be referenced as ``$name`` from path format
strings. The function must accept a single parameter, the Item
@ -255,7 +294,7 @@ class BeetsPlugin:
_classes = set()
def load_plugins(names=()):
def load_plugins(names: Sequence[str] = ()):
"""Imports the modules for a sequence of plugin names. Each name
must be the name of a Python module under the "beetsplug" namespace
package in sys.path; the module indicated should contain the
@ -293,7 +332,7 @@ def load_plugins(names=()):
_instances = {}
def find_plugins():
def find_plugins() -> list[BeetsPlugin]:
"""Returns a list of BeetsPlugin subclass instances from all
currently loaded beets plugins. Loads the default plugin set
first.
@ -316,7 +355,7 @@ def find_plugins():
# Communication with plugins.
def commands():
def commands() -> list[Subcommand]:
"""Returns a list of Subcommand objects from all loaded plugins."""
out = []
for plugin in find_plugins():
@ -324,11 +363,11 @@ def commands():
return out
def queries():
def queries() -> dict[str, type[Query]]:
"""Returns a dict mapping prefix strings to Query subclasses all loaded
plugins.
"""
out = {}
out: dict[str, type[Query]] = {}
for plugin in find_plugins():
out.update(plugin.queries())
return out
@ -361,7 +400,7 @@ def named_queries(model_cls):
return queries
def track_distance(item, info):
def track_distance(item: Item, info: TrackInfo) -> Distance:
"""Gets the track distance calculated by all loaded plugins.
Returns a Distance object.
"""
@ -373,7 +412,11 @@ def track_distance(item, info):
return dist
def album_distance(items, album_info, mapping):
def album_distance(
items: list[Item],
album_info: AlbumInfo,
mapping: dict[Item, TrackInfo],
) -> Distance:
"""Returns the album distance calculated by plugins."""
from beets.autotag.hooks import Distance
@ -383,7 +426,13 @@ def album_distance(items, album_info, mapping):
return dist
def candidates(items, artist, album, va_likely, extra_tags=None):
def candidates(
items: list[Item],
artist: str,
album: str,
va_likely: bool,
extra_tags=None,
) -> Iterable[AlbumInfo]:
"""Gets MusicBrainz candidates for an album from each plugin."""
for plugin in find_plugins():
yield from plugin.candidates(
@ -391,13 +440,13 @@ def candidates(items, artist, album, va_likely, extra_tags=None):
)
def item_candidates(item, artist, title):
def item_candidates(item: Item, artist: str, title: str) -> Iterable[TrackInfo]:
"""Gets MusicBrainz candidates for an item from the plugins."""
for plugin in find_plugins():
yield from plugin.item_candidates(item, artist, title)
def album_for_id(album_id):
def album_for_id(album_id: str) -> Iterable[AlbumInfo]:
"""Get AlbumInfo objects for a given ID string."""
for plugin in find_plugins():
album = plugin.album_for_id(album_id)
@ -405,7 +454,7 @@ def album_for_id(album_id):
yield album
def track_for_id(track_id):
def track_for_id(track_id: str) -> Iterable[TrackInfo]:
"""Get TrackInfo objects for a given ID string."""
for plugin in find_plugins():
track = plugin.track_for_id(track_id)
@ -443,7 +492,7 @@ def import_stages():
# New-style (lazy) plugin-provided fields.
def _check_conflicts_and_merge(plugin, plugin_funcs, funcs):
def _check_conflicts_and_merge(plugin: BeetsPlugin, plugin_funcs, funcs):
"""Check the provided template functions for conflicts and merge into funcs.
Raises a `PluginConflictError` if a plugin defines template functions
@ -598,11 +647,11 @@ def notify_info_yielded(event):
return decorator
def get_distance(config, data_source, info):
def get_distance(config, data_source, info) -> Distance:
"""Returns the ``data_source`` weight and the maximum source weight
for albums or individual tracks.
"""
dist = beets.autotag.Distance()
dist = Distance()
if info.data_source == data_source:
dist.add("source", config["source_weight"].as_number())
return dist
@ -638,7 +687,27 @@ def apply_item_changes(lib, item, move, pretend, write):
item.store()
class MetadataSourcePlugin(metaclass=abc.ABCMeta):
class Response(TypedDict):
"""A dictionary with the response of a plugin API call.
May be extended by plugins to include additional information, but id
is required.
"""
id: str
class RegexDict(TypedDict):
"""A dictionary with regex patterns as keys and match groups as values."""
pattern: str
match_group: int
R = TypeVar("R", bound=Response)
class MetadataSourcePlugin(Generic[R], metaclass=abc.ABCMeta):
def __init__(self):
super().__init__()
self.config.add({"source_weight": 0.5})
@ -664,19 +733,26 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta):
raise NotImplementedError
@abc.abstractmethod
def _search_api(self, query_type, filters, keywords=""):
def _search_api(self, query_type, filters, keywords="") -> Sequence[R]:
raise NotImplementedError
@abc.abstractmethod
def album_for_id(self, album_id):
def album_for_id(self, album_id: str) -> AlbumInfo | None:
raise NotImplementedError
@abc.abstractmethod
def track_for_id(self, track_id=None, track_data=None):
def track_for_id(
self, track_id: str | None = None, track_data: R | None = None
) -> TrackInfo | None:
raise NotImplementedError
@staticmethod
def get_artist(artists, id_key="id", name_key="name", join_key=None):
def get_artist(
artists,
id_key: str | int = "id",
name_key: str | int = "name",
join_key: str | int | None = None,
) -> tuple[str, str | None]:
"""Returns an artist string (all artists) and an artist_id (the main
artist) for a list of artist object dicts.
@ -691,18 +767,14 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta):
:type artists: list[dict] or list[list]
:param id_key: Key or index corresponding to the value of ``id`` for
the main/first artist. Defaults to 'id'.
:type id_key: str or int
:param name_key: Key or index corresponding to values of names
to concatenate for the artist string (containing all artists).
Defaults to 'name'.
:type name_key: str or int
:param join_key: Key or index corresponding to a field containing a
keyword to use for combining artists into a single string, for
example "Feat.", "Vs.", "And" or similar. The default is None
which keeps the default behaviour (comma-separated).
:type join_key: str or int
:return: Normalized artist string.
:rtype: str
"""
artist_id = None
artist_string = ""
@ -727,19 +799,15 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta):
return artist_string, artist_id
@staticmethod
def _get_id(url_type, id_, id_regex):
def _get_id(url_type: str, id_: str, id_regex: RegexDict) -> str | None:
"""Parse an ID from its URL if necessary.
:param url_type: Type of URL. Either 'album' or 'track'.
:type url_type: str
:param id_: Album/track ID or URL.
:type id_: str
:param id_regex: A dictionary containing a regular expression
extracting an ID from an URL (if it's not an ID already) in
'pattern' and the number of the match group in 'match_group'.
:type id_regex: dict
:return: Album/track ID.
:rtype: str
"""
log.debug("Extracting {} ID from '{}'", url_type, id_)
match = re.search(id_regex["pattern"].format(url_type), str(id_))
@ -749,21 +817,22 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta):
return id_
return None
def candidates(self, items, artist, album, va_likely, extra_tags=None):
def candidates(
self,
items: list[Item],
artist: str,
album: str,
va_likely: bool,
extra_tags=None,
) -> Sequence[AlbumInfo]:
"""Returns a list of AlbumInfo objects for Search API results
matching an ``album`` and ``artist`` (if not various).
:param items: List of items comprised by an album to be matched.
:type items: list[beets.library.Item]
:param artist: The artist of the album to be matched.
:type artist: str
:param album: The name of the album to be matched.
:type album: str
:param va_likely: True if the album to be matched likely has
Various Artists.
:type va_likely: bool
:return: Candidate AlbumInfo objects.
:rtype: list[beets.autotag.hooks.AlbumInfo]
"""
query_filters = {"album": album}
if not va_likely:
@ -772,23 +841,23 @@ class MetadataSourcePlugin(metaclass=abc.ABCMeta):
albums = [self.album_for_id(album_id=r["id"]) for r in results]
return [a for a in albums if a is not None]
def item_candidates(self, item, artist, title):
def item_candidates(
self, item: Item, artist: str, title: str
) -> Sequence[TrackInfo]:
"""Returns a list of TrackInfo objects for Search API results
matching ``title`` and ``artist``.
:param item: Singleton item to be matched.
:type item: beets.library.Item
:param artist: The artist of the track to be matched.
:type artist: str
:param title: The title of the track to be matched.
:type title: str
:return: Candidate TrackInfo objects.
:rtype: list[beets.autotag.hooks.TrackInfo]
"""
tracks = self._search_api(
track_responses = self._search_api(
query_type="track", keywords=title, filters={"artist": artist}
)
return [self.track_for_id(track_data=track) for track in tracks]
tracks = [self.track_for_id(track_data=r) for r in track_responses]
return [t for t in tracks if t is not None]
def album_distance(self, items, album_info, mapping):
return get_distance(