Add typing for module

This commit is contained in:
Serene-Arc 2022-12-15 20:30:24 +10:00
parent 4606ff20ce
commit 5044d13d8d

View file

@ -17,10 +17,13 @@
from collections import namedtuple
from functools import total_ordering
import re
from typing import Dict, List, Tuple, Iterator, Union, NewType, Any, Optional, \
Iterable, Callable
from beets import logging
from beets import plugins
from beets import config
from beets.library import Item
from beets.util import as_string
from beets.autotag import mb
from jellyfish import levenshtein_distance
@ -31,8 +34,10 @@ log = logging.getLogger('beets')
# The name of the type for patterns in re changed in Python 3.7.
try:
Pattern = re._pattern_type
Patterntype = NewType('Patterntype', re._pattern_type)
except AttributeError:
Pattern = re.Pattern
Patterntype = NewType('Patterntype', re.Pattern)
# Classes used to represent candidate options.
@ -68,17 +73,45 @@ class AlbumInfo(AttrDict):
The others are optional and may be None.
"""
def __init__(self, tracks, album=None, album_id=None, artist=None,
artist_id=None, asin=None, albumtype=None, va=False,
year=None, month=None, day=None, label=None, mediums=None,
artist_sort=None, releasegroup_id=None, catalognum=None,
script=None, language=None, country=None, style=None,
genre=None, albumstatus=None, media=None, albumdisambig=None,
releasegroupdisambig=None, artist_credit=None,
original_year=None, original_month=None,
original_day=None, data_source=None, data_url=None,
discogs_albumid=None, discogs_labelid=None,
discogs_artistid=None, **kwargs):
# TYPING: are all of these correct? I've assumed optional strings
def __init__(
self,
tracks: List['TrackInfo'],
album: Optional[str] = None,
album_id: Optional[str] = None,
artist: Optional[str] = None,
artist_id: Optional[str] = None,
asin: Optional[str] = None,
albumtype: Optional[str] = None,
va: bool = False,
year: Optional[str] = None,
month: Optional[str] = None,
day: Optional[str] = None,
label: Optional[str] = None,
mediums: Optional[str] = None,
artist_sort: Optional[str] = None,
releasegroup_id: Optional[str] = None,
catalognum: Optional[str] = None,
script: Optional[str] = None,
language: Optional[str] = None,
country: Optional[str] = None,
style: Optional[str] = None,
genre: Optional[str] = None,
albumstatus: Optional[str] = None,
media: Optional[str] = None,
albumdisambig: Optional[str] = None,
releasegroupdisambig: Optional[str] = None,
artist_credit: Optional[str] = None,
original_year: Optional[str] = None,
original_month: Optional[str] = None,
original_day: Optional[str] = None,
data_source: Optional[str] = None,
data_url: Optional[str] = None,
discogs_albumid: Optional[str] = None,
discogs_labelid: Optional[str] = None,
discogs_artistid: Optional[str] = None,
**kwargs,
):
self.album = album
self.album_id = album_id
self.artist = artist
@ -118,7 +151,7 @@ class AlbumInfo(AttrDict):
# Work around a bug in python-musicbrainz-ngs that causes some
# strings to be bytes rather than Unicode.
# https://github.com/alastair/python-musicbrainz-ngs/issues/85
def decode(self, codec='utf-8'):
def decode(self, codec: str = 'utf-8'):
"""Ensure that all string attributes on this object, and the
constituent `TrackInfo` objects, are decoded to Unicode.
"""
@ -135,7 +168,7 @@ class AlbumInfo(AttrDict):
for track in self.tracks:
track.decode(codec)
def copy(self):
def copy(self) -> 'AlbumInfo':
dupe = AlbumInfo([])
dupe.update(self)
dupe.tracks = [track.copy() for track in self.tracks]
@ -154,15 +187,38 @@ class TrackInfo(AttrDict):
are all 1-based.
"""
def __init__(self, title=None, track_id=None, release_track_id=None,
artist=None, artist_id=None, length=None, index=None,
medium=None, medium_index=None, medium_total=None,
artist_sort=None, disctitle=None, artist_credit=None,
data_source=None, data_url=None, media=None, lyricist=None,
composer=None, composer_sort=None, arranger=None,
track_alt=None, work=None, mb_workid=None,
work_disambig=None, bpm=None, initial_key=None, genre=None,
**kwargs):
# TYPING: are all of these correct? I've assumed optional strings
def __init__(
self,
title: Optional[str] = None,
track_id: Optional[str] = None,
release_track_id: Optional[str] = None,
artist: Optional[str] = None,
artist_id: Optional[str] = None,
length: Optional[str] = None,
index: Optional[str] = None,
medium: Optional[str] = None,
medium_index: Optional[str] = None,
medium_total: Optional[str] = None,
artist_sort: Optional[str] = None,
disctitle: Optional[str] = None,
artist_credit: Optional[str] = None,
data_source: Optional[str] = None,
data_url: Optional[str] = None,
media: Optional[str] = None,
lyricist: Optional[str] = None,
composer: Optional[str] = None,
composer_sort: Optional[str] = None,
arranger: Optional[str] = None,
track_alt: Optional[str] = None,
work: Optional[str] = None,
mb_workid: Optional[str] = None,
work_disambig: Optional[str] = None,
bpm: Optional[str] = None,
initial_key: Optional[str] = None,
genre: Optional[str] = None,
**kwargs,
):
self.title = title
self.track_id = track_id
self.release_track_id = release_track_id
@ -203,7 +259,7 @@ class TrackInfo(AttrDict):
if isinstance(value, bytes):
setattr(self, fld, value.decode(codec, 'ignore'))
def copy(self):
def copy(self) -> 'TrackInfo':
dupe = TrackInfo()
dupe.update(self)
return dupe
@ -229,7 +285,7 @@ SD_REPLACE = [
]
def _string_dist_basic(str1, str2):
def _string_dist_basic(str1: str, str2: str) -> float:
"""Basic edit distance between two strings, ignoring
non-alphanumeric characters and case. Comparisons are based on a
transliteration/lowering to ASCII characters. Normalized by string
@ -246,7 +302,7 @@ def _string_dist_basic(str1, str2):
return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2)))
def string_dist(str1, str2):
def string_dist(str1: str, str2: str) -> float:
"""Gives an "intuitive" edit distance between two strings. This is
an edit distance, normalized by the string length, with a number of
tweaks that reflect intuition about text.
@ -332,7 +388,7 @@ class Distance:
self._penalties = {}
@LazyClassProperty
def _weights(cls): # noqa: N805
def _weights(cls) -> Dict: # noqa: N805
"""A dictionary from keys to floating-point weights.
"""
weights_view = config['match']['distance_weights']
@ -344,7 +400,7 @@ class Distance:
# Access the components and their aggregates.
@property
def distance(self):
def distance(self) -> float:
"""Return a weighted and normalized distance across all
penalties.
"""
@ -354,7 +410,7 @@ class Distance:
return 0.0
@property
def max_distance(self):
def max_distance(self) -> float:
"""Return the maximum distance penalty (normalization factor).
"""
dist_max = 0.0
@ -363,7 +419,7 @@ class Distance:
return dist_max
@property
def raw_distance(self):
def raw_distance(self) -> float:
"""Return the raw (denormalized) distance.
"""
dist_raw = 0.0
@ -371,7 +427,7 @@ class Distance:
dist_raw += sum(penalty) * self._weights[key]
return dist_raw
def items(self):
def items(self) -> List[Tuple[str, float]]:
"""Return a list of (key, dist) pairs, with `dist` being the
weighted distance, sorted from highest to lowest. Does not
include penalties with a zero value.
@ -389,32 +445,32 @@ class Distance:
key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0])
)
def __hash__(self):
def __hash__(self) -> int:
return id(self)
def __eq__(self, other):
def __eq__(self, other) -> bool:
return self.distance == other
# Behave like a float.
def __lt__(self, other):
def __lt__(self, other) -> bool:
return self.distance < other
def __float__(self):
def __float__(self) -> float:
return self.distance
def __sub__(self, other):
def __sub__(self, other) -> float:
return self.distance - other
def __rsub__(self, other):
def __rsub__(self, other) -> float:
return other - self.distance
def __str__(self):
def __str__(self) -> str:
return f"{self.distance:.2f}"
# Behave like a dict.
def __getitem__(self, key):
def __getitem__(self, key) -> float:
"""Returns the weighted distance for a named penalty.
"""
dist = sum(self._penalties[key]) * self._weights[key]
@ -423,16 +479,16 @@ class Distance:
return dist / dist_max
return 0.0
def __iter__(self):
def __iter__(self) -> Iterator[Tuple[str, float]]:
return iter(self.items())
def __len__(self):
def __len__(self) -> int:
return len(self.items())
def keys(self):
def keys(self) -> List[str]:
return [key for key, _ in self.items()]
def update(self, dist):
def update(self, dist: 'Distance'):
"""Adds all the distance penalties from `dist`.
"""
if not isinstance(dist, Distance):
@ -444,7 +500,7 @@ class Distance:
# Adding components.
def _eq(self, value1, value2):
def _eq(self, value1: Union['Distance', Patterntype], value2) -> bool:
"""Returns True if `value1` is equal to `value2`. `value1` may
be a compiled regular expression, in which case it will be
matched against `value2`.
@ -453,7 +509,7 @@ class Distance:
return bool(value1.match(value2))
return value1 == value2
def add(self, key, dist):
def add(self, key: str, dist: float):
"""Adds a distance penalty. `key` must correspond with a
configured weight setting. `dist` must be a float between 0.0
and 1.0, and will be added to any existing distance penalties
@ -465,7 +521,12 @@ class Distance:
)
self._penalties.setdefault(key, []).append(dist)
def add_equality(self, key, value, options):
def add_equality(
self,
key: str,
value: Any,
options: Union[List, Tuple, Patterntype],
):
"""Adds a distance penalty of 1.0 if `value` doesn't match any
of the values in `options`. If an option is a compiled regular
expression, it will be considered equal if it matches against
@ -481,7 +542,7 @@ class Distance:
dist = 1.0
self.add(key, dist)
def add_expr(self, key, expr):
def add_expr(self, key: str, expr: bool):
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
or 0.0.
"""
@ -490,7 +551,7 @@ class Distance:
else:
self.add(key, 0.0)
def add_number(self, key, number1, number2):
def add_number(self, key: str, number1: float, number2: float):
"""Adds a distance penalty of 1.0 for each number of difference
between `number1` and `number2`, or 0.0 when there is no
difference. Use this when there is no upper limit on the
@ -503,7 +564,12 @@ class Distance:
else:
self.add(key, 0.0)
def add_priority(self, key, value, options):
def add_priority(
self,
key: str,
value: Any,
options: Union[List, Tuple, Patterntype],
):
"""Adds a distance penalty that corresponds to the position at
which `value` appears in `options`. A distance penalty of 0.0
for the first option, or 1.0 if there is no matching option. If
@ -521,7 +587,7 @@ class Distance:
dist = 1.0
self.add(key, dist)
def add_ratio(self, key, number1, number2):
def add_ratio(self, key: str, number1: float, number2: float):
"""Adds a distance penalty for `number1` as a ratio of `number2`.
`number1` is bound at 0 and `number2`.
"""
@ -532,7 +598,7 @@ class Distance:
dist = 0.0
self.add(key, dist)
def add_string(self, key, str1, str2):
def add_string(self, key: str, str1: str, str2: str):
"""Adds a distance penalty based on the edit distance between
`str1` and `str2`.
"""
@ -550,7 +616,7 @@ TrackMatch = namedtuple('TrackMatch', ['distance', 'info'])
# Aggregation of sources.
def album_for_mbid(release_id):
def album_for_mbid(release_id: str) -> Optional[AlbumInfo]:
"""Get an AlbumInfo object for a MusicBrainz release ID. Return None
if the ID is not found.
"""
@ -563,7 +629,7 @@ def album_for_mbid(release_id):
exc.log(log)
def track_for_mbid(recording_id):
def track_for_mbid(recording_id: str) -> Optional[TrackInfo]:
"""Get a TrackInfo object for a MusicBrainz recording ID. Return None
if the ID is not found.
"""
@ -576,7 +642,7 @@ def track_for_mbid(recording_id):
exc.log(log)
def albums_for_id(album_id):
def albums_for_id(album_id: str) -> Iterable[Union[None, AlbumInfo]]:
"""Get a list of albums for an ID."""
a = album_for_mbid(album_id)
if a:
@ -587,7 +653,7 @@ def albums_for_id(album_id):
yield a
def tracks_for_id(track_id):
def tracks_for_id(track_id: str) -> Iterable[Union[None, TrackInfo]]:
"""Get a list of tracks for an ID."""
t = track_for_mbid(track_id)
if t:
@ -598,7 +664,7 @@ def tracks_for_id(track_id):
yield t
def invoke_mb(call_func, *args):
def invoke_mb(call_func: Callable, *args):
try:
return call_func(*args)
except mb.MusicBrainzAPIError as exc:
@ -607,7 +673,13 @@ def invoke_mb(call_func, *args):
@plugins.notify_info_yielded('albuminfo_received')
def album_candidates(items, artist, album, va_likely, extra_tags):
def album_candidates(
items: List[Item],
artist: str,
album: str,
va_likely: bool,
extra_tags: Dict,
) -> Iterable[Tuple]:
"""Search for album matches. ``items`` is a list of Item objects
that make up the album. ``artist`` and ``album`` are the respective
names (strings), which may be derived from the item list or may be
@ -633,7 +705,7 @@ def album_candidates(items, artist, album, va_likely, extra_tags):
@plugins.notify_info_yielded('trackinfo_received')
def item_candidates(item, artist, title):
def item_candidates(item: Item, artist: str, title: str) -> Iterable[Tuple]:
"""Search for item matches. ``item`` is the Item to be matched.
``artist`` and ``title`` are strings and either reflect the item or
are specified by the user.