Merge branch 'serene_arc_autotag_typings'

This commit is contained in:
wisp3rwind 2023-02-19 11:58:59 +01:00
commit 5a2ce43fa8
4 changed files with 213 additions and 88 deletions

View file

@ -14,7 +14,9 @@
"""Facilities for automatically determining files' correct metadata.
"""
from typing import Mapping
from beets.library import Item
from beets import logging
from beets import config
@ -71,7 +73,7 @@ SPECIAL_FIELDS = {
# Additional utilities for the main interface.
def apply_item_metadata(item, track_info):
def apply_item_metadata(item: Item, track_info: TrackInfo):
"""Set an item's metadata from its matched TrackInfo object.
"""
item.artist = track_info.artist
@ -95,7 +97,7 @@ def apply_item_metadata(item, track_info):
# and track number). Perhaps these should be emptied?
def apply_metadata(album_info, mapping):
def apply_metadata(album_info: AlbumInfo, mapping: Mapping[Item, TrackInfo]):
"""Set the items' metadata to match an AlbumInfo object using a
mapping from Items to TrackInfo objects.
"""

View file

@ -14,13 +14,17 @@
"""Glue between metadata sources and the matching logic."""
from __future__ import annotations
from collections import namedtuple
from functools import total_ordering
import re
from typing import Dict, List, Tuple, Iterator, Union, NewType, Any, Optional,\
Iterable, Callable, TypeVar
from beets import logging
from beets import plugins
from beets import config
from beets.library import Item
from beets.util import as_string
from beets.autotag import mb
from jellyfish import levenshtein_distance
@ -29,6 +33,9 @@ from unidecode import unidecode
log = logging.getLogger('beets')
T = TypeVar('T')
# Classes used to represent candidate options.
class AttrDict(dict):
"""A dictionary that supports attribute ("dot") access, so `d.field`
@ -44,7 +51,7 @@ class AttrDict(dict):
def __setattr__(self, key, value):
self.__setitem__(key, value)
def __hash__(self):
def __hash__(self) -> int:
return id(self)
@ -62,17 +69,45 @@ class AlbumInfo(AttrDict):
The others are optional and may be None.
"""
def __init__(self, tracks, album=None, album_id=None, artist=None,
artist_id=None, asin=None, albumtype=None, va=False,
year=None, month=None, day=None, label=None, mediums=None,
artist_sort=None, releasegroup_id=None, catalognum=None,
script=None, language=None, country=None, style=None,
genre=None, albumstatus=None, media=None, albumdisambig=None,
releasegroupdisambig=None, artist_credit=None,
original_year=None, original_month=None,
original_day=None, data_source=None, data_url=None,
discogs_albumid=None, discogs_labelid=None,
discogs_artistid=None, **kwargs):
# TYPING: are all of these correct? I've assumed optional strings
def __init__(
self,
tracks: List['TrackInfo'],
album: Optional[str] = None,
album_id: Optional[str] = None,
artist: Optional[str] = None,
artist_id: Optional[str] = None,
asin: Optional[str] = None,
albumtype: Optional[str] = None,
va: bool = False,
year: Optional[str] = None,
month: Optional[str] = None,
day: Optional[str] = None,
label: Optional[str] = None,
mediums: Optional[int] = None,
artist_sort: Optional[str] = None,
releasegroup_id: Optional[str] = None,
catalognum: Optional[str] = None,
script: Optional[str] = None,
language: Optional[str] = None,
country: Optional[str] = None,
style: Optional[str] = None,
genre: Optional[str] = None,
albumstatus: Optional[str] = None,
media: Optional[str] = None,
albumdisambig: Optional[str] = None,
releasegroupdisambig: Optional[str] = None,
artist_credit: Optional[str] = None,
original_year: Optional[str] = None,
original_month: Optional[str] = None,
original_day: Optional[str] = None,
data_source: Optional[str] = None,
data_url: Optional[str] = None,
discogs_albumid: Optional[str] = None,
discogs_labelid: Optional[str] = None,
discogs_artistid: Optional[str] = None,
**kwargs,
):
self.album = album
self.album_id = album_id
self.artist = artist
@ -112,7 +147,7 @@ class AlbumInfo(AttrDict):
# Work around a bug in python-musicbrainz-ngs that causes some
# strings to be bytes rather than Unicode.
# https://github.com/alastair/python-musicbrainz-ngs/issues/85
def decode(self, codec='utf-8'):
def decode(self, codec: str = 'utf-8'):
"""Ensure that all string attributes on this object, and the
constituent `TrackInfo` objects, are decoded to Unicode.
"""
@ -129,7 +164,7 @@ class AlbumInfo(AttrDict):
for track in self.tracks:
track.decode(codec)
def copy(self):
def copy(self) -> 'AlbumInfo':
dupe = AlbumInfo([])
dupe.update(self)
dupe.tracks = [track.copy() for track in self.tracks]
@ -148,15 +183,38 @@ class TrackInfo(AttrDict):
are all 1-based.
"""
def __init__(self, title=None, track_id=None, release_track_id=None,
artist=None, artist_id=None, length=None, index=None,
medium=None, medium_index=None, medium_total=None,
artist_sort=None, disctitle=None, artist_credit=None,
data_source=None, data_url=None, media=None, lyricist=None,
composer=None, composer_sort=None, arranger=None,
track_alt=None, work=None, mb_workid=None,
work_disambig=None, bpm=None, initial_key=None, genre=None,
**kwargs):
# TYPING: are all of these correct? I've assumed optional strings
def __init__(
self,
title: Optional[str] = None,
track_id: Optional[str] = None,
release_track_id: Optional[str] = None,
artist: Optional[str] = None,
artist_id: Optional[str] = None,
length: Optional[str] = None,
index: Optional[int] = None,
medium: Optional[int] = None,
medium_index: Optional[int] = None,
medium_total: Optional[int] = None,
artist_sort: Optional[str] = None,
disctitle: Optional[str] = None,
artist_credit: Optional[str] = None,
data_source: Optional[str] = None,
data_url: Optional[str] = None,
media: Optional[str] = None,
lyricist: Optional[str] = None,
composer: Optional[str] = None,
composer_sort: Optional[str] = None,
arranger: Optional[str] = None,
track_alt: Optional[str] = None,
work: Optional[str] = None,
mb_workid: Optional[str] = None,
work_disambig: Optional[str] = None,
bpm: Optional[str] = None,
initial_key: Optional[str] = None,
genre: Optional[str] = None,
**kwargs,
):
self.title = title
self.track_id = track_id
self.release_track_id = release_track_id
@ -197,7 +255,7 @@ class TrackInfo(AttrDict):
if isinstance(value, bytes):
setattr(self, fld, value.decode(codec, 'ignore'))
def copy(self):
def copy(self) -> 'TrackInfo':
dupe = TrackInfo()
dupe.update(self)
return dupe
@ -223,7 +281,7 @@ SD_REPLACE = [
]
def _string_dist_basic(str1, str2):
def _string_dist_basic(str1: str, str2: str) -> float:
"""Basic edit distance between two strings, ignoring
non-alphanumeric characters and case. Comparisons are based on a
transliteration/lowering to ASCII characters. Normalized by string
@ -240,7 +298,7 @@ def _string_dist_basic(str1, str2):
return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2)))
def string_dist(str1, str2):
def string_dist(str1: str, str2: str) -> float:
"""Gives an "intuitive" edit distance between two strings. This is
an edit distance, normalized by the string length, with a number of
tweaks that reflect intuition about text.
@ -326,7 +384,7 @@ class Distance:
self._penalties = {}
@LazyClassProperty
def _weights(cls): # noqa: N805
def _weights(cls) -> Dict[str, float]: # noqa: N805
"""A dictionary from keys to floating-point weights.
"""
weights_view = config['match']['distance_weights']
@ -338,7 +396,7 @@ class Distance:
# Access the components and their aggregates.
@property
def distance(self):
def distance(self) -> float:
"""Return a weighted and normalized distance across all
penalties.
"""
@ -348,7 +406,7 @@ class Distance:
return 0.0
@property
def max_distance(self):
def max_distance(self) -> float:
"""Return the maximum distance penalty (normalization factor).
"""
dist_max = 0.0
@ -357,7 +415,7 @@ class Distance:
return dist_max
@property
def raw_distance(self):
def raw_distance(self) -> float:
"""Return the raw (denormalized) distance.
"""
dist_raw = 0.0
@ -365,7 +423,7 @@ class Distance:
dist_raw += sum(penalty) * self._weights[key]
return dist_raw
def items(self):
def items(self) -> List[Tuple[str, float]]:
"""Return a list of (key, dist) pairs, with `dist` being the
weighted distance, sorted from highest to lowest. Does not
include penalties with a zero value.
@ -383,32 +441,32 @@ class Distance:
key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0])
)
def __hash__(self):
def __hash__(self) -> int:
return id(self)
def __eq__(self, other):
def __eq__(self, other) -> bool:
return self.distance == other
# Behave like a float.
def __lt__(self, other):
def __lt__(self, other) -> bool:
return self.distance < other
def __float__(self):
def __float__(self) -> float:
return self.distance
def __sub__(self, other):
def __sub__(self, other) -> float:
return self.distance - other
def __rsub__(self, other):
def __rsub__(self, other) -> float:
return other - self.distance
def __str__(self):
def __str__(self) -> str:
return f"{self.distance:.2f}"
# Behave like a dict.
def __getitem__(self, key):
def __getitem__(self, key) -> float:
"""Returns the weighted distance for a named penalty.
"""
dist = sum(self._penalties[key]) * self._weights[key]
@ -417,16 +475,16 @@ class Distance:
return dist / dist_max
return 0.0
def __iter__(self):
def __iter__(self) -> Iterator[Tuple[str, float]]:
return iter(self.items())
def __len__(self):
def __len__(self) -> int:
return len(self.items())
def keys(self):
def keys(self) -> List[str]:
return [key for key, _ in self.items()]
def update(self, dist):
def update(self, dist: 'Distance'):
"""Adds all the distance penalties from `dist`.
"""
if not isinstance(dist, Distance):
@ -438,7 +496,7 @@ class Distance:
# Adding components.
def _eq(self, value1, value2):
def _eq(self, value1: T, value2: T) -> bool:
"""Returns True if `value1` is equal to `value2`. `value1` may
be a compiled regular expression, in which case it will be
matched against `value2`.
@ -447,7 +505,7 @@ class Distance:
return bool(value1.match(value2))
return value1 == value2
def add(self, key, dist):
def add(self, key: str, dist: float):
"""Adds a distance penalty. `key` must correspond with a
configured weight setting. `dist` must be a float between 0.0
and 1.0, and will be added to any existing distance penalties
@ -459,7 +517,12 @@ class Distance:
)
self._penalties.setdefault(key, []).append(dist)
def add_equality(self, key, value, options):
def add_equality(
self,
key: str,
value: Any,
options: Union[List[T, ...], Tuple[T, ...], T],
):
"""Adds a distance penalty of 1.0 if `value` doesn't match any
of the values in `options`. If an option is a compiled regular
expression, it will be considered equal if it matches against
@ -475,7 +538,7 @@ class Distance:
dist = 1.0
self.add(key, dist)
def add_expr(self, key, expr):
def add_expr(self, key: str, expr: bool):
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
or 0.0.
"""
@ -484,7 +547,7 @@ class Distance:
else:
self.add(key, 0.0)
def add_number(self, key, number1, number2):
def add_number(self, key: str, number1: int, number2: int):
"""Adds a distance penalty of 1.0 for each number of difference
between `number1` and `number2`, or 0.0 when there is no
difference. Use this when there is no upper limit on the
@ -497,7 +560,12 @@ class Distance:
else:
self.add(key, 0.0)
def add_priority(self, key, value, options):
def add_priority(
self,
key: str,
value: Any,
options: Union[List[T, ...], Tuple[T, ...], T],
):
"""Adds a distance penalty that corresponds to the position at
which `value` appears in `options`. A distance penalty of 0.0
for the first option, or 1.0 if there is no matching option. If
@ -515,7 +583,12 @@ class Distance:
dist = 1.0
self.add(key, dist)
def add_ratio(self, key, number1, number2):
def add_ratio(
self,
key: str,
number1: Union[int, float],
number2: Union[int, float],
):
"""Adds a distance penalty for `number1` as a ratio of `number2`.
`number1` is bound at 0 and `number2`.
"""
@ -526,7 +599,7 @@ class Distance:
dist = 0.0
self.add(key, dist)
def add_string(self, key, str1, str2):
def add_string(self, key: str, str1: str, str2: str):
"""Adds a distance penalty based on the edit distance between
`str1` and `str2`.
"""
@ -544,7 +617,7 @@ TrackMatch = namedtuple('TrackMatch', ['distance', 'info'])
# Aggregation of sources.
def album_for_mbid(release_id):
def album_for_mbid(release_id: str) -> Optional[AlbumInfo]:
"""Get an AlbumInfo object for a MusicBrainz release ID. Return None
if the ID is not found.
"""
@ -557,7 +630,7 @@ def album_for_mbid(release_id):
exc.log(log)
def track_for_mbid(recording_id):
def track_for_mbid(recording_id: str) -> Optional[TrackInfo]:
"""Get a TrackInfo object for a MusicBrainz recording ID. Return None
if the ID is not found.
"""
@ -570,7 +643,7 @@ def track_for_mbid(recording_id):
exc.log(log)
def albums_for_id(album_id):
def albums_for_id(album_id: str) -> Iterable[AlbumInfo]:
"""Get a list of albums for an ID."""
a = album_for_mbid(album_id)
if a:
@ -581,7 +654,7 @@ def albums_for_id(album_id):
yield a
def tracks_for_id(track_id):
def tracks_for_id(track_id: str) -> Iterable[TrackInfo]:
"""Get a list of tracks for an ID."""
t = track_for_mbid(track_id)
if t:
@ -592,7 +665,7 @@ def tracks_for_id(track_id):
yield t
def invoke_mb(call_func, *args):
def invoke_mb(call_func: Callable, *args):
try:
return call_func(*args)
except mb.MusicBrainzAPIError as exc:
@ -601,7 +674,13 @@ def invoke_mb(call_func, *args):
@plugins.notify_info_yielded('albuminfo_received')
def album_candidates(items, artist, album, va_likely, extra_tags):
def album_candidates(
items: List[Item],
artist: str,
album: str,
va_likely: bool,
extra_tags: Dict,
) -> Iterable[Tuple]:
"""Search for album matches. ``items`` is a list of Item objects
that make up the album. ``artist`` and ``album`` are the respective
names (strings), which may be derived from the item list or may be
@ -627,7 +706,7 @@ def album_candidates(items, artist, album, va_likely, extra_tags):
@plugins.notify_info_yielded('trackinfo_received')
def item_candidates(item, artist, title):
def item_candidates(item: Item, artist: str, title: str) -> Iterable[Tuple]:
"""Search for item matches. ``item`` is the Item to be matched.
``artist`` and ``title`` are strings and either reflect the item or
are specified by the user.

View file

@ -19,14 +19,18 @@ releases and tracks.
import datetime
import re
from typing import List, Dict, Tuple, Iterable, Union, Optional
from munkres import Munkres
from collections import namedtuple
from beets import logging
from beets import plugins
from beets import config
from beets.library import Item
from beets.util import plurality
from beets.autotag import hooks
from beets.autotag import hooks, TrackInfo, Distance, AlbumInfo, TrackMatch, \
AlbumMatch
from beets.util.enumeration import OrderedEnum
# Artist signals that indicate "various artists". These are used at the
@ -60,7 +64,7 @@ Proposal = namedtuple('Proposal', ('candidates', 'recommendation'))
# Primary matching functionality.
def current_metadata(items):
def current_metadata(items: List[Item]) -> Tuple[Dict, Dict]:
"""Extract the likely current metadata for an album given a list of its
items. Return two dictionaries:
- The most common value for each field.
@ -85,7 +89,10 @@ def current_metadata(items):
return likelies, consensus
def assign_items(items, tracks):
def assign_items(
items: List[Item],
tracks: List[TrackInfo],
) -> Tuple[Dict, List[Item], List[TrackInfo]]:
"""Given a list of Items and a list of TrackInfo objects, find the
best mapping between them. Returns a mapping from Items to TrackInfo
objects, a set of extra Items, and a set of extra TrackInfo
@ -114,14 +121,18 @@ def assign_items(items, tracks):
return mapping, extra_items, extra_tracks
def track_index_changed(item, track_info):
def track_index_changed(item: Item, track_info: TrackInfo) -> bool:
"""Returns True if the item and track info index is different. Tolerates
per disc and per release numbering.
"""
return item.track not in (track_info.medium_index, track_info.index)
def track_distance(item, track_info, incl_artist=False):
def track_distance(
item: Item,
track_info: TrackInfo,
incl_artist: bool = False,
) -> Distance:
"""Determines the significance of a track metadata change. Returns a
Distance object. `incl_artist` indicates that a distance component should
be included for the track artist (i.e., for various-artist releases).
@ -157,7 +168,11 @@ def track_distance(item, track_info, incl_artist=False):
return dist
def distance(items, album_info, mapping):
def distance(
items: Iterable[Item],
album_info: AlbumInfo,
mapping: Dict[Item, TrackInfo],
) -> Distance:
"""Determines how "significant" an album metadata change would be.
Returns a Distance object. `album_info` is an AlbumInfo object
reflecting the album to be compared. `items` is a sequence of all
@ -263,7 +278,7 @@ def distance(items, album_info, mapping):
return dist
def match_by_id(items):
def match_by_id(items: Iterable[Item]):
"""If the items are tagged with a MusicBrainz album ID, returns an
AlbumInfo object for the corresponding album. Otherwise, returns
None.
@ -287,7 +302,9 @@ def match_by_id(items):
return hooks.album_for_mbid(first)
def _recommendation(results):
def _recommendation(
results: List[Union[AlbumMatch, TrackMatch]],
) -> Recommendation:
"""Given a sorted list of AlbumMatch or TrackMatch objects, return a
recommendation based on the results' distances.
@ -338,12 +355,12 @@ def _recommendation(results):
return rec
def _sort_candidates(candidates):
def _sort_candidates(candidates) -> Iterable:
"""Sort candidates by distance."""
return sorted(candidates, key=lambda match: match.distance)
def _add_candidate(items, results, info):
def _add_candidate(items: Iterable[Item], results: Dict, info: AlbumInfo):
"""Given a candidate AlbumInfo object, attempt to add the candidate
to the output dictionary of AlbumMatch objects. This involves
checking the track count, ordering the items, checking for
@ -386,8 +403,12 @@ def _add_candidate(items, results, info):
extra_items, extra_tracks)
def tag_album(items, search_artist=None, search_album=None,
search_ids=[]):
def tag_album(
items,
search_artist: Optional[str] = None,
search_album: Optional[str] = None,
search_ids: List = [],
) -> Tuple[str, str, Proposal]:
"""Return a tuple of the current artist name, the current album
name, and a `Proposal` containing `AlbumMatch` candidates.
@ -472,8 +493,12 @@ def tag_album(items, search_artist=None, search_album=None,
return cur_artist, cur_album, Proposal(candidates, rec)
def tag_item(item, search_artist=None, search_title=None,
search_ids=[]):
def tag_item(
item,
search_artist: Optional[str] = None,
search_title: Optional[str] = None,
search_ids: List = [],
) -> Proposal:
"""Find metadata for a single track. Return a `Proposal` consisting
of `TrackMatch` objects.

View file

@ -14,6 +14,8 @@
"""Searches for albums in the MusicBrainz database.
"""
from __future__ import annotations
from typing import List, Tuple, Dict, Optional, Iterator
import musicbrainzngs
import re
@ -82,11 +84,11 @@ if 'genres' in musicbrainzngs.VALID_INCLUDES['recording']:
RELEASE_INCLUDES += ['genres']
def track_url(trackid):
def track_url(trackid: str) -> str:
return urljoin(BASE_URL, 'recording/' + trackid)
def album_url(albumid):
def album_url(albumid: str) -> str:
return urljoin(BASE_URL, 'release/' + albumid)
@ -106,7 +108,7 @@ def configure():
)
def _preferred_alias(aliases):
def _preferred_alias(aliases: List):
"""Given an list of alias structures for an artist credit, select
and return the user's preferred alias alias or None if no matching
alias is found.
@ -138,7 +140,7 @@ def _preferred_alias(aliases):
return matches[0]
def _preferred_release_event(release):
def _preferred_release_event(release: Dict) -> Tuple[str, str]:
"""Given a release, select and return the user's preferred release
event as a tuple of (country, release_date). Fall back to the
default release event if a preferred event is not found.
@ -156,7 +158,7 @@ def _preferred_release_event(release):
return release.get('country'), release.get('date')
def _flatten_artist_credit(credit):
def _flatten_artist_credit(credit: List[Dict]) -> Tuple[str, str, str]:
"""Given a list representing an ``artist-credit`` block, flatten the
data into a triple of joined artist name strings: canonical, sort, and
credit.
@ -215,8 +217,13 @@ def _get_related_artist_names(relations, relation_type):
return ', '.join(related_artists)
def track_info(recording, index=None, medium=None, medium_index=None,
medium_total=None):
def track_info(
recording: Dict,
index: Optional[int] = None,
medium: Optional[int] = None,
medium_index: Optional[int] = None,
medium_total: Optional[int] = None,
) -> beets.autotag.hooks.TrackInfo:
"""Translates a MusicBrainz recording result dictionary into a beets
``TrackInfo`` object. Three parameters are optional and are used
only for tracks that appear on releases (non-singletons): ``index``,
@ -303,7 +310,11 @@ def track_info(recording, index=None, medium=None, medium_index=None,
return info
def _set_date_str(info, date_str, original=False):
def _set_date_str(
info: beets.autotag.hooks.AlbumInfo,
date_str: str,
original: bool = False,
):
"""Given a (possibly partial) YYYY-MM-DD string and an AlbumInfo
object, set the object's release date fields appropriately. If
`original`, then set the original_year, etc., fields.
@ -323,7 +334,7 @@ def _set_date_str(info, date_str, original=False):
setattr(info, key, date_num)
def album_info(release):
def album_info(release: Dict) -> beets.autotag.hooks.AlbumInfo:
"""Takes a MusicBrainz release result dictionary and returns a beets
AlbumInfo object containing the interesting data about that release.
"""
@ -502,7 +513,12 @@ def album_info(release):
return info
def match_album(artist, album, tracks=None, extra_tags=None):
def match_album(
artist: str,
album: str,
tracks: Optional[int] = None,
extra_tags: Dict = None,
) -> Iterator[beets.autotag.hooks.AlbumInfo]:
"""Searches for a single album ("release" in MusicBrainz parlance)
and returns an iterator over AlbumInfo objects. May raise a
MusicBrainzAPIError.
@ -549,7 +565,10 @@ def match_album(artist, album, tracks=None, extra_tags=None):
yield albuminfo
def match_track(artist, title):
def match_track(
artist: str,
title: str,
) -> Iterator[beets.autotag.hooks.TrackInfo]:
"""Searches for a single track and returns an iterable of TrackInfo
objects. May raise a MusicBrainzAPIError.
"""
@ -571,7 +590,7 @@ def match_track(artist, title):
yield track_info(recording)
def _parse_id(s):
def _parse_id(s: str) -> Optional[str]:
"""Search for a MusicBrainz ID in the given string and return it. If
no ID can be found, return None.
"""
@ -581,7 +600,7 @@ def _parse_id(s):
return match.group()
def album_for_id(releaseid):
def album_for_id(releaseid: str) -> Optional[beets.autotag.hooks.AlbumInfo]:
"""Fetches an album by its MusicBrainz ID and returns an AlbumInfo
object or None if the album is not found. May raise a
MusicBrainzAPIError.
@ -603,7 +622,7 @@ def album_for_id(releaseid):
return album_info(res['release'])
def track_for_id(releaseid):
def track_for_id(releaseid: str) -> Optional[beets.autotag.hooks.TrackInfo]:
"""Fetches a track by its MusicBrainz ID. Returns a TrackInfo object
or None if no track is found. May raise a MusicBrainzAPIError.
"""