Merge branch 'pr_autotag_typing_fixes'

This commit is contained in:
wisp3rwind 2023-02-22 20:04:35 +01:00
commit 7aa7df2a93
3 changed files with 103 additions and 63 deletions

View file

@ -19,7 +19,7 @@ from collections import namedtuple
from functools import total_ordering
import re
from typing import Dict, List, Tuple, Iterator, Union, Any, Optional,\
Iterable, Callable, TypeVar
Iterable, Callable, cast
from beets import logging
from beets import plugins
@ -33,9 +33,6 @@ from unidecode import unidecode
log = logging.getLogger('beets')
T = TypeVar('T')
# Classes used to represent candidate options.
class AttrDict(dict):
"""A dictionary that supports attribute ("dot") access, so `d.field`
@ -51,7 +48,7 @@ class AttrDict(dict):
def __setattr__(self, key, value):
self.__setitem__(key, value)
def __hash__(self) -> int:
def __hash__(self):
return id(self)
@ -80,9 +77,9 @@ class AlbumInfo(AttrDict):
asin: Optional[str] = None,
albumtype: Optional[str] = None,
va: bool = False,
year: Optional[str] = None,
month: Optional[str] = None,
day: Optional[str] = None,
year: Optional[int] = None,
month: Optional[int] = None,
day: Optional[int] = None,
label: Optional[str] = None,
mediums: Optional[int] = None,
artist_sort: Optional[str] = None,
@ -98,9 +95,9 @@ class AlbumInfo(AttrDict):
albumdisambig: Optional[str] = None,
releasegroupdisambig: Optional[str] = None,
artist_credit: Optional[str] = None,
original_year: Optional[str] = None,
original_month: Optional[str] = None,
original_day: Optional[str] = None,
original_year: Optional[int] = None,
original_month: Optional[int] = None,
original_day: Optional[int] = None,
data_source: Optional[str] = None,
data_url: Optional[str] = None,
discogs_albumid: Optional[str] = None,
@ -191,7 +188,7 @@ class TrackInfo(AttrDict):
release_track_id: Optional[str] = None,
artist: Optional[str] = None,
artist_id: Optional[str] = None,
length: Optional[str] = None,
length: Optional[float] = None,
index: Optional[int] = None,
medium: Optional[int] = None,
medium_index: Optional[int] = None,
@ -298,7 +295,7 @@ def _string_dist_basic(str1: str, str2: str) -> float:
return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2)))
def string_dist(str1: str, str2: str) -> float:
def string_dist(str1: Optional[str], str2: Optional[str]) -> float:
"""Gives an "intuitive" edit distance between two strings. This is
an edit distance, normalized by the string length, with a number of
tweaks that reflect intuition about text.
@ -382,6 +379,7 @@ class Distance:
def __init__(self):
self._penalties = {}
self.tracks: Dict[TrackInfo, Distance] = {}
@LazyClassProperty
def _weights(cls) -> Dict[str, float]: # noqa: N805
@ -496,12 +494,13 @@ class Distance:
# Adding components.
def _eq(self, value1: T, value2: T) -> bool:
def _eq(self, value1: Union[re.Pattern, Any], value2: Any) -> bool:
"""Returns True if `value1` is equal to `value2`. `value1` may
be a compiled regular expression, in which case it will be
matched against `value2`.
"""
if isinstance(value1, re.Pattern):
value2 = cast(str, value2)
return bool(value1.match(value2))
return value1 == value2
@ -521,7 +520,7 @@ class Distance:
self,
key: str,
value: Any,
options: Union[List[T, ...], Tuple[T, ...], T],
options: Union[List[Any], Tuple[Any, ...], Any],
):
"""Adds a distance penalty of 1.0 if `value` doesn't match any
of the values in `options`. If an option is a compiled regular
@ -564,7 +563,7 @@ class Distance:
self,
key: str,
value: Any,
options: Union[List[T, ...], Tuple[T, ...], T],
options: Union[List[Any], Tuple[Any, ...], Any],
):
"""Adds a distance penalty that corresponds to the position at
which `value` appears in `options`. A distance penalty of 0.0
@ -599,7 +598,7 @@ class Distance:
dist = 0.0
self.add(key, dist)
def add_string(self, key: str, str1: str, str2: str):
def add_string(self, key: str, str1: Optional[str], str2: Optional[str]):
"""Adds a distance penalty based on the edit distance between
`str1` and `str2`.
"""
@ -628,6 +627,7 @@ def album_for_mbid(release_id: str) -> Optional[AlbumInfo]:
return album
except mb.MusicBrainzAPIError as exc:
exc.log(log)
return None
def track_for_mbid(recording_id: str) -> Optional[TrackInfo]:
@ -641,6 +641,7 @@ def track_for_mbid(recording_id: str) -> Optional[TrackInfo]:
return track
except mb.MusicBrainzAPIError as exc:
exc.log(log)
return None
def albums_for_id(album_id: str) -> Iterable[AlbumInfo]:

View file

@ -19,7 +19,18 @@ releases and tracks.
import datetime
import re
from typing import List, Dict, Tuple, Iterable, Union, Optional
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
)
from munkres import Munkres
from collections import namedtuple
@ -64,7 +75,9 @@ Proposal = namedtuple('Proposal', ('candidates', 'recommendation'))
# Primary matching functionality.
def current_metadata(items: List[Item]) -> Tuple[Dict, Dict]:
def current_metadata(
items: Iterable[Item],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Extract the likely current metadata for an album given a list of its
items. Return two dictionaries:
- The most common value for each field.
@ -90,9 +103,9 @@ def current_metadata(items: List[Item]) -> Tuple[Dict, Dict]:
def assign_items(
items: List[Item],
tracks: List[TrackInfo],
) -> Tuple[Dict, List[Item], List[TrackInfo]]:
items: Sequence[Item],
tracks: Sequence[TrackInfo],
) -> Tuple[Dict[Item, TrackInfo], List[Item], List[TrackInfo]]:
"""Given a list of Items and a list of TrackInfo objects, find the
best mapping between them. Returns a mapping from Items to TrackInfo
objects, a set of extra Items, and a set of extra TrackInfo
@ -100,10 +113,10 @@ def assign_items(
of objects of the two types.
"""
# Construct the cost matrix.
costs = []
costs: List[List[Distance]] = []
for item in items:
row = []
for i, track in enumerate(tracks):
for track in tracks:
row.append(track_distance(item, track))
costs.append(row)
@ -141,10 +154,18 @@ def track_distance(
# Length.
if track_info.length:
diff = abs(item.length - track_info.length) - \
config['match']['track_length_grace'].as_number()
dist.add_ratio('track_length', diff,
config['match']['track_length_max'].as_number())
item_length = cast(float, item.length)
track_length_grace = cast(
Union[float, int],
config['match']['track_length_grace'].as_number(),
)
track_length_max = cast(
Union[float, int],
config['match']['track_length_max'].as_number(),
)
diff = abs(item_length - track_info.length) - track_length_grace
dist.add_ratio('track_length', diff, track_length_max)
# Title.
dist.add_string('track_title', item.title, track_info.title)
@ -169,7 +190,7 @@ def track_distance(
def distance(
items: Iterable[Item],
items: Sequence[Item],
album_info: AlbumInfo,
mapping: Dict[Item, TrackInfo],
) -> Distance:
@ -196,6 +217,7 @@ def distance(
if album_info.media:
# Preferred media options.
patterns = config['match']['preferred']['media'].as_str_seq()
patterns = cast(Sequence, patterns)
options = [re.compile(r'(\d+x)?(%s)' % pat, re.I) for pat in patterns]
if options:
dist.add_priority('media', album_info.media, options)
@ -232,6 +254,7 @@ def distance(
# Preferred countries.
patterns = config['match']['preferred']['countries'].as_str_seq()
patterns = cast(Sequence, patterns)
options = [re.compile(pat, re.I) for pat in patterns]
if album_info.country and options:
dist.add_priority('country', album_info.country, options)
@ -265,11 +288,11 @@ def distance(
dist.add('tracks', dist.tracks[track].distance)
# Missing tracks.
for i in range(len(album_info.tracks) - len(mapping)):
for _ in range(len(album_info.tracks) - len(mapping)):
dist.add('missing_tracks', 1.0)
# Unmatched tracks.
for i in range(len(items) - len(mapping)):
for _ in range(len(items) - len(mapping)):
dist.add('unmatched_tracks', 1.0)
# Plugins.
@ -303,7 +326,7 @@ def match_by_id(items: Iterable[Item]):
def _recommendation(
results: List[Union[AlbumMatch, TrackMatch]],
results: Sequence[Union[AlbumMatch, TrackMatch]],
) -> Recommendation:
"""Given a sorted list of AlbumMatch or TrackMatch objects, return a
recommendation based on the results' distances.
@ -355,12 +378,19 @@ def _recommendation(
return rec
def _sort_candidates(candidates) -> Iterable:
AnyMatch = TypeVar("AnyMatch", TrackMatch, AlbumMatch)
def _sort_candidates(candidates: Iterable[AnyMatch]) -> Sequence[AnyMatch]:
"""Sort candidates by distance."""
return sorted(candidates, key=lambda match: match.distance)
def _add_candidate(items: Iterable[Item], results: Dict, info: AlbumInfo):
def _add_candidate(
items: Sequence[Item],
results: Dict[Any, AlbumMatch],
info: AlbumInfo,
):
"""Given a candidate AlbumInfo object, attempt to add the candidate
to the output dictionary of AlbumMatch objects. This involves
checking the track count, ordering the items, checking for
@ -380,7 +410,7 @@ def _add_candidate(items: Iterable[Item], results: Dict, info: AlbumInfo):
return
# Discard matches without required tags.
for req_tag in config['match']['required'].as_str_seq():
for req_tag in cast(Sequence, config['match']['required'].as_str_seq()):
if getattr(info, req_tag) is None:
log.debug('Ignored. Missing required tag: {0}', req_tag)
return
@ -393,7 +423,8 @@ def _add_candidate(items: Iterable[Item], results: Dict, info: AlbumInfo):
# Skip matches with ignored penalties.
penalties = [key for key, _ in dist]
for penalty in config['match']['ignored'].as_str_seq():
ignored = cast(Sequence[str], config['match']['ignored'].as_str_seq())
for penalty in ignored:
if penalty in penalties:
log.debug('Ignored. Penalty: {0}', penalty)
return
@ -428,20 +459,19 @@ def tag_album(
"""
# Get current metadata.
likelies, consensus = current_metadata(items)
cur_artist = likelies['artist']
cur_album = likelies['album']
cur_artist = cast(str, likelies['artist'])
cur_album = cast(str, likelies['album'])
log.debug('Tagging {0} - {1}', cur_artist, cur_album)
# The output result (distance, AlbumInfo) tuples (keyed by MB album
# ID).
candidates = {}
# The output result, keys are the MB album ID.
candidates: Dict[Any, AlbumMatch] = {}
# Search by explicit ID.
if search_ids:
for search_id in search_ids:
log.debug('Searching for album ID: {0}', search_id)
for id_candidate in hooks.albums_for_id(search_id):
_add_candidate(items, candidates, id_candidate)
for album_info_for_id in hooks.albums_for_id(search_id):
_add_candidate(items, candidates, album_info_for_id)
# Use existing metadata or text search.
else:
@ -488,9 +518,9 @@ def tag_album(
log.debug('Evaluating {0} candidates.', len(candidates))
# Sort and get the recommendation.
candidates = _sort_candidates(candidates.values())
rec = _recommendation(candidates)
return cur_artist, cur_album, Proposal(candidates, rec)
candidates_sorted = _sort_candidates(candidates.values())
rec = _recommendation(candidates_sorted)
return cur_artist, cur_album, Proposal(candidates_sorted, rec)
def tag_item(
@ -510,6 +540,7 @@ def tag_item(
# Holds candidates found so far: keys are MBIDs; values are
# (distance, TrackInfo) pairs.
candidates = {}
rec: Optional[Recommendation] = None
# First, try matching by MusicBrainz ID.
trackids = search_ids or [t for t in [item.mb_trackid] if t]
@ -530,6 +561,7 @@ def tag_item(
# If we're searching by ID, don't proceed.
if search_ids:
if candidates:
assert rec is not None
return Proposal(_sort_candidates(candidates.values()), rec)
else:
return Proposal([], Recommendation.none)
@ -546,6 +578,6 @@ def tag_item(
# Sort by distance and return with recommendation.
log.debug('Found {0} candidates.', len(candidates))
candidates = _sort_candidates(candidates.values())
rec = _recommendation(candidates)
return Proposal(candidates, rec)
candidates_sorted = _sort_candidates(candidates.values())
rec = _recommendation(candidates_sorted)
return Proposal(candidates_sorted, rec)

View file

@ -15,7 +15,7 @@
"""Searches for albums in the MusicBrainz database.
"""
from __future__ import annotations
from typing import List, Tuple, Dict, Optional, Iterator
from typing import Any, List, Sequence, Tuple, Dict, Optional, Iterator, cast
import musicbrainzngs
import re
@ -140,12 +140,13 @@ def _preferred_alias(aliases: List):
return matches[0]
def _preferred_release_event(release: Dict) -> Tuple[str, str]:
def _preferred_release_event(release: Dict[str, Any]) -> Tuple[str, str]:
"""Given a release, select and return the user's preferred release
event as a tuple of (country, release_date). Fall back to the
default release event if a preferred event is not found.
"""
countries = config['match']['preferred']['countries'].as_str_seq()
countries = cast(Sequence, countries)
for country in countries:
for event in release.get('release-event-list', {}):
@ -155,7 +156,10 @@ def _preferred_release_event(release: Dict) -> Tuple[str, str]:
except KeyError:
pass
return release.get('country'), release.get('date')
return (
cast(str, release.get('country')),
cast(str, release.get('date'))
)
def _flatten_artist_credit(credit: List[Dict]) -> Tuple[str, str, str]:
@ -258,7 +262,7 @@ def track_info(
)
if recording.get('length'):
info.length = int(recording['length']) / (1000.0)
info.length = int(recording['length']) / 1000.0
info.trackdisambig = recording.get('disambiguation')
@ -498,12 +502,14 @@ def album_info(release: Dict) -> beets.autotag.hooks.AlbumInfo:
release['release-group'].get('genre-list', []),
release.get('genre-list', []),
]
genres = Counter()
genres: Counter[str] = Counter()
for source in sources:
for genreitem in source:
genres[genreitem['name']] += int(genreitem['count'])
info.genre = '; '.join(g[0] for g in sorted(genres.items(),
key=lambda g: -g[1]))
info.genre = '; '.join(
genre for genre, _count
in sorted(genres.items(), key=lambda g: -g[1])
)
extra_albumdatas = plugins.send('mb_album_extract', data=release)
for extra_albumdata in extra_albumdatas:
@ -517,7 +523,7 @@ def match_album(
artist: str,
album: str,
tracks: Optional[int] = None,
extra_tags: Dict = None,
extra_tags: Optional[Dict[str, Any]] = None,
) -> Iterator[beets.autotag.hooks.AlbumInfo]:
"""Searches for a single album ("release" in MusicBrainz parlance)
and returns an iterator over AlbumInfo objects. May raise a
@ -538,9 +544,9 @@ def match_album(
# Additional search cues from existing metadata.
if extra_tags:
for tag in extra_tags:
for tag, value in extra_tags.items():
key = FIELDS_TO_MB_KEYS[tag]
value = str(extra_tags.get(tag, '')).lower().strip()
value = str(value).lower().strip()
if key == 'catno':
value = value.replace(' ', '')
if value:
@ -596,8 +602,9 @@ def _parse_id(s: str) -> Optional[str]:
"""
# Find the first thing that looks like a UUID/MBID.
match = re.search('[a-f0-9]{8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}', s)
if match:
return match.group()
if match is not None:
return match.group() if match else None
return None
def album_for_id(releaseid: str) -> Optional[beets.autotag.hooks.AlbumInfo]:
@ -609,7 +616,7 @@ def album_for_id(releaseid: str) -> Optional[beets.autotag.hooks.AlbumInfo]:
albumid = _parse_id(releaseid)
if not albumid:
log.debug('Invalid MBID ({0}).', releaseid)
return
return None
try:
res = musicbrainzngs.get_release_by_id(albumid,
RELEASE_INCLUDES)
@ -629,7 +636,7 @@ def track_for_id(releaseid: str) -> Optional[beets.autotag.hooks.TrackInfo]:
trackid = _parse_id(releaseid)
if not trackid:
log.debug('Invalid MBID ({0}).', releaseid)
return
return None
try:
res = musicbrainzngs.get_recording_by_id(trackid, TRACK_INCLUDES)
except musicbrainzngs.ResponseError: