diff --git a/beets/autotag/hooks.py b/beets/autotag/hooks.py index 30904ff29..0dcaa43ed 100644 --- a/beets/autotag/hooks.py +++ b/beets/autotag/hooks.py @@ -17,10 +17,13 @@ from collections import namedtuple from functools import total_ordering import re +from typing import Dict, List, Tuple, Iterator, Union, NewType, Any, Optional, \ + Iterable, Callable from beets import logging from beets import plugins from beets import config +from beets.library import Item from beets.util import as_string from beets.autotag import mb from jellyfish import levenshtein_distance @@ -31,8 +34,10 @@ log = logging.getLogger('beets') # The name of the type for patterns in re changed in Python 3.7. try: Pattern = re._pattern_type + Patterntype = NewType('Patterntype', re._pattern_type) except AttributeError: Pattern = re.Pattern + Patterntype = NewType('Patterntype', re.Pattern) # Classes used to represent candidate options. @@ -68,17 +73,45 @@ class AlbumInfo(AttrDict): The others are optional and may be None. """ - def __init__(self, tracks, album=None, album_id=None, artist=None, - artist_id=None, asin=None, albumtype=None, va=False, - year=None, month=None, day=None, label=None, mediums=None, - artist_sort=None, releasegroup_id=None, catalognum=None, - script=None, language=None, country=None, style=None, - genre=None, albumstatus=None, media=None, albumdisambig=None, - releasegroupdisambig=None, artist_credit=None, - original_year=None, original_month=None, - original_day=None, data_source=None, data_url=None, - discogs_albumid=None, discogs_labelid=None, - discogs_artistid=None, **kwargs): + # TYPING: are all of these correct? I've assumed optional strings + def __init__( + self, + tracks: List['TrackInfo'], + album: Optional[str] = None, + album_id: Optional[str] = None, + artist: Optional[str] = None, + artist_id: Optional[str] = None, + asin: Optional[str] = None, + albumtype: Optional[str] = None, + va: bool = False, + year: Optional[str] = None, + month: Optional[str] = None, + day: Optional[str] = None, + label: Optional[str] = None, + mediums: Optional[str] = None, + artist_sort: Optional[str] = None, + releasegroup_id: Optional[str] = None, + catalognum: Optional[str] = None, + script: Optional[str] = None, + language: Optional[str] = None, + country: Optional[str] = None, + style: Optional[str] = None, + genre: Optional[str] = None, + albumstatus: Optional[str] = None, + media: Optional[str] = None, + albumdisambig: Optional[str] = None, + releasegroupdisambig: Optional[str] = None, + artist_credit: Optional[str] = None, + original_year: Optional[str] = None, + original_month: Optional[str] = None, + original_day: Optional[str] = None, + data_source: Optional[str] = None, + data_url: Optional[str] = None, + discogs_albumid: Optional[str] = None, + discogs_labelid: Optional[str] = None, + discogs_artistid: Optional[str] = None, + **kwargs, + ): self.album = album self.album_id = album_id self.artist = artist @@ -118,7 +151,7 @@ class AlbumInfo(AttrDict): # Work around a bug in python-musicbrainz-ngs that causes some # strings to be bytes rather than Unicode. # https://github.com/alastair/python-musicbrainz-ngs/issues/85 - def decode(self, codec='utf-8'): + def decode(self, codec: str = 'utf-8'): """Ensure that all string attributes on this object, and the constituent `TrackInfo` objects, are decoded to Unicode. """ @@ -135,7 +168,7 @@ class AlbumInfo(AttrDict): for track in self.tracks: track.decode(codec) - def copy(self): + def copy(self) -> 'AlbumInfo': dupe = AlbumInfo([]) dupe.update(self) dupe.tracks = [track.copy() for track in self.tracks] @@ -154,15 +187,38 @@ class TrackInfo(AttrDict): are all 1-based. """ - def __init__(self, title=None, track_id=None, release_track_id=None, - artist=None, artist_id=None, length=None, index=None, - medium=None, medium_index=None, medium_total=None, - artist_sort=None, disctitle=None, artist_credit=None, - data_source=None, data_url=None, media=None, lyricist=None, - composer=None, composer_sort=None, arranger=None, - track_alt=None, work=None, mb_workid=None, - work_disambig=None, bpm=None, initial_key=None, genre=None, - **kwargs): + # TYPING: are all of these correct? I've assumed optional strings + def __init__( + self, + title: Optional[str] = None, + track_id: Optional[str] = None, + release_track_id: Optional[str] = None, + artist: Optional[str] = None, + artist_id: Optional[str] = None, + length: Optional[str] = None, + index: Optional[str] = None, + medium: Optional[str] = None, + medium_index: Optional[str] = None, + medium_total: Optional[str] = None, + artist_sort: Optional[str] = None, + disctitle: Optional[str] = None, + artist_credit: Optional[str] = None, + data_source: Optional[str] = None, + data_url: Optional[str] = None, + media: Optional[str] = None, + lyricist: Optional[str] = None, + composer: Optional[str] = None, + composer_sort: Optional[str] = None, + arranger: Optional[str] = None, + track_alt: Optional[str] = None, + work: Optional[str] = None, + mb_workid: Optional[str] = None, + work_disambig: Optional[str] = None, + bpm: Optional[str] = None, + initial_key: Optional[str] = None, + genre: Optional[str] = None, + **kwargs, + ): self.title = title self.track_id = track_id self.release_track_id = release_track_id @@ -203,7 +259,7 @@ class TrackInfo(AttrDict): if isinstance(value, bytes): setattr(self, fld, value.decode(codec, 'ignore')) - def copy(self): + def copy(self) -> 'TrackInfo': dupe = TrackInfo() dupe.update(self) return dupe @@ -229,7 +285,7 @@ SD_REPLACE = [ ] -def _string_dist_basic(str1, str2): +def _string_dist_basic(str1: str, str2: str) -> float: """Basic edit distance between two strings, ignoring non-alphanumeric characters and case. Comparisons are based on a transliteration/lowering to ASCII characters. Normalized by string @@ -246,7 +302,7 @@ def _string_dist_basic(str1, str2): return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2))) -def string_dist(str1, str2): +def string_dist(str1: str, str2: str) -> float: """Gives an "intuitive" edit distance between two strings. This is an edit distance, normalized by the string length, with a number of tweaks that reflect intuition about text. @@ -332,7 +388,7 @@ class Distance: self._penalties = {} @LazyClassProperty - def _weights(cls): # noqa: N805 + def _weights(cls) -> Dict: # noqa: N805 """A dictionary from keys to floating-point weights. """ weights_view = config['match']['distance_weights'] @@ -344,7 +400,7 @@ class Distance: # Access the components and their aggregates. @property - def distance(self): + def distance(self) -> float: """Return a weighted and normalized distance across all penalties. """ @@ -354,7 +410,7 @@ class Distance: return 0.0 @property - def max_distance(self): + def max_distance(self) -> float: """Return the maximum distance penalty (normalization factor). """ dist_max = 0.0 @@ -363,7 +419,7 @@ class Distance: return dist_max @property - def raw_distance(self): + def raw_distance(self) -> float: """Return the raw (denormalized) distance. """ dist_raw = 0.0 @@ -371,7 +427,7 @@ class Distance: dist_raw += sum(penalty) * self._weights[key] return dist_raw - def items(self): + def items(self) -> List[Tuple[str, float]]: """Return a list of (key, dist) pairs, with `dist` being the weighted distance, sorted from highest to lowest. Does not include penalties with a zero value. @@ -389,32 +445,32 @@ class Distance: key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0]) ) - def __hash__(self): + def __hash__(self) -> int: return id(self) - def __eq__(self, other): + def __eq__(self, other) -> bool: return self.distance == other # Behave like a float. - def __lt__(self, other): + def __lt__(self, other) -> bool: return self.distance < other - def __float__(self): + def __float__(self) -> float: return self.distance - def __sub__(self, other): + def __sub__(self, other) -> float: return self.distance - other - def __rsub__(self, other): + def __rsub__(self, other) -> float: return other - self.distance - def __str__(self): + def __str__(self) -> str: return f"{self.distance:.2f}" # Behave like a dict. - def __getitem__(self, key): + def __getitem__(self, key) -> float: """Returns the weighted distance for a named penalty. """ dist = sum(self._penalties[key]) * self._weights[key] @@ -423,16 +479,16 @@ class Distance: return dist / dist_max return 0.0 - def __iter__(self): + def __iter__(self) -> Iterator[Tuple[str, float]]: return iter(self.items()) - def __len__(self): + def __len__(self) -> int: return len(self.items()) - def keys(self): + def keys(self) -> List[str]: return [key for key, _ in self.items()] - def update(self, dist): + def update(self, dist: 'Distance'): """Adds all the distance penalties from `dist`. """ if not isinstance(dist, Distance): @@ -444,7 +500,7 @@ class Distance: # Adding components. - def _eq(self, value1, value2): + def _eq(self, value1: Union['Distance', Patterntype], value2) -> bool: """Returns True if `value1` is equal to `value2`. `value1` may be a compiled regular expression, in which case it will be matched against `value2`. @@ -453,7 +509,7 @@ class Distance: return bool(value1.match(value2)) return value1 == value2 - def add(self, key, dist): + def add(self, key: str, dist: float): """Adds a distance penalty. `key` must correspond with a configured weight setting. `dist` must be a float between 0.0 and 1.0, and will be added to any existing distance penalties @@ -465,7 +521,12 @@ class Distance: ) self._penalties.setdefault(key, []).append(dist) - def add_equality(self, key, value, options): + def add_equality( + self, + key: str, + value: Any, + options: Union[List, Tuple, Patterntype], + ): """Adds a distance penalty of 1.0 if `value` doesn't match any of the values in `options`. If an option is a compiled regular expression, it will be considered equal if it matches against @@ -481,7 +542,7 @@ class Distance: dist = 1.0 self.add(key, dist) - def add_expr(self, key, expr): + def add_expr(self, key: str, expr: bool): """Adds a distance penalty of 1.0 if `expr` evaluates to True, or 0.0. """ @@ -490,7 +551,7 @@ class Distance: else: self.add(key, 0.0) - def add_number(self, key, number1, number2): + def add_number(self, key: str, number1: float, number2: float): """Adds a distance penalty of 1.0 for each number of difference between `number1` and `number2`, or 0.0 when there is no difference. Use this when there is no upper limit on the @@ -503,7 +564,12 @@ class Distance: else: self.add(key, 0.0) - def add_priority(self, key, value, options): + def add_priority( + self, + key: str, + value: Any, + options: Union[List, Tuple, Patterntype], + ): """Adds a distance penalty that corresponds to the position at which `value` appears in `options`. A distance penalty of 0.0 for the first option, or 1.0 if there is no matching option. If @@ -521,7 +587,7 @@ class Distance: dist = 1.0 self.add(key, dist) - def add_ratio(self, key, number1, number2): + def add_ratio(self, key: str, number1: float, number2: float): """Adds a distance penalty for `number1` as a ratio of `number2`. `number1` is bound at 0 and `number2`. """ @@ -532,7 +598,7 @@ class Distance: dist = 0.0 self.add(key, dist) - def add_string(self, key, str1, str2): + def add_string(self, key: str, str1: str, str2: str): """Adds a distance penalty based on the edit distance between `str1` and `str2`. """ @@ -550,7 +616,7 @@ TrackMatch = namedtuple('TrackMatch', ['distance', 'info']) # Aggregation of sources. -def album_for_mbid(release_id): +def album_for_mbid(release_id: str) -> Optional[AlbumInfo]: """Get an AlbumInfo object for a MusicBrainz release ID. Return None if the ID is not found. """ @@ -563,7 +629,7 @@ def album_for_mbid(release_id): exc.log(log) -def track_for_mbid(recording_id): +def track_for_mbid(recording_id: str) -> Optional[TrackInfo]: """Get a TrackInfo object for a MusicBrainz recording ID. Return None if the ID is not found. """ @@ -576,7 +642,7 @@ def track_for_mbid(recording_id): exc.log(log) -def albums_for_id(album_id): +def albums_for_id(album_id: str) -> Iterable[Union[None, AlbumInfo]]: """Get a list of albums for an ID.""" a = album_for_mbid(album_id) if a: @@ -587,7 +653,7 @@ def albums_for_id(album_id): yield a -def tracks_for_id(track_id): +def tracks_for_id(track_id: str) -> Iterable[Union[None, TrackInfo]]: """Get a list of tracks for an ID.""" t = track_for_mbid(track_id) if t: @@ -598,7 +664,7 @@ def tracks_for_id(track_id): yield t -def invoke_mb(call_func, *args): +def invoke_mb(call_func: Callable, *args): try: return call_func(*args) except mb.MusicBrainzAPIError as exc: @@ -607,7 +673,13 @@ def invoke_mb(call_func, *args): @plugins.notify_info_yielded('albuminfo_received') -def album_candidates(items, artist, album, va_likely, extra_tags): +def album_candidates( + items: List[Item], + artist: str, + album: str, + va_likely: bool, + extra_tags: Dict, +) -> Iterable[Tuple]: """Search for album matches. ``items`` is a list of Item objects that make up the album. ``artist`` and ``album`` are the respective names (strings), which may be derived from the item list or may be @@ -633,7 +705,7 @@ def album_candidates(items, artist, album, va_likely, extra_tags): @plugins.notify_info_yielded('trackinfo_received') -def item_candidates(item, artist, title): +def item_candidates(item: Item, artist: str, title: str) -> Iterable[Tuple]: """Search for item matches. ``item`` is the Item to be matched. ``artist`` and ``title`` are strings and either reflect the item or are specified by the user.