Replace typing.cast with explicit type definitions and ignore TC006

This commit is contained in:
Šarūnas Nejus 2025-05-07 12:49:58 +01:00
parent 99dc0861c2
commit fdc1aba603
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
6 changed files with 45 additions and 36 deletions

View file

@ -18,7 +18,7 @@ from __future__ import annotations
import re
from functools import total_ordering
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar
from jellyfish import levenshtein_distance
from unidecode import unidecode
@ -474,7 +474,6 @@ class Distance:
matched against `value2`.
"""
if isinstance(value1, re.Pattern):
value2 = cast(str, value2)
return bool(value1.match(value2))
return value1 == value2

View file

@ -20,10 +20,9 @@ from __future__ import annotations
import datetime
import re
from collections.abc import Iterable, Sequence
from enum import IntEnum
from functools import cache
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar, cast
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
import lap
import numpy as np
@ -40,6 +39,8 @@ from beets.autotag import (
from beets.util import plurality
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from beets.library import Item
# Artist signals that indicate "various artists". These are used at the
@ -241,12 +242,14 @@ def distance(
# Album.
dist.add_string("album", likelies["album"], album_info.album)
preferred_config = config["match"]["preferred"]
# Current or preferred media.
if album_info.media:
# Preferred media options.
patterns = config["match"]["preferred"]["media"].as_str_seq()
patterns = cast(Sequence[str], patterns)
options = [re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in patterns]
media_patterns: Sequence[str] = preferred_config["media"].as_str_seq()
options = [
re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in media_patterns
]
if options:
dist.add_priority("media", album_info.media, options)
# Current media.
@ -258,7 +261,7 @@ def distance(
dist.add_number("mediums", likelies["disctotal"], album_info.mediums)
# Prefer earliest release.
if album_info.year and config["match"]["preferred"]["original_year"]:
if album_info.year and preferred_config["original_year"]:
# Assume 1889 (earliest first gramophone discs) if we don't know the
# original year.
original = album_info.original_year or 1889
@ -282,9 +285,8 @@ def distance(
dist.add("year", 1.0)
# Preferred countries.
patterns = config["match"]["preferred"]["countries"].as_str_seq()
patterns = cast(Sequence[str], patterns)
options = [re.compile(pat, re.I) for pat in patterns]
country_patterns: Sequence[str] = preferred_config["countries"].as_str_seq()
options = [re.compile(pat, re.I) for pat in country_patterns]
if album_info.country and options:
dist.add_priority("country", album_info.country, options)
# Country.
@ -447,9 +449,8 @@ def _add_candidate(
return
# Discard matches without required tags.
for req_tag in cast(
Sequence[str], config["match"]["required"].as_str_seq()
):
required_tags: Sequence[str] = config["match"]["required"].as_str_seq()
for req_tag in required_tags:
if getattr(info, req_tag) is None:
log.debug("Ignored. Missing required tag: {0}", req_tag)
return
@ -462,8 +463,8 @@ def _add_candidate(
# Skip matches with ignored penalties.
penalties = [key for key, _ in dist]
ignored = cast(Sequence[str], config["match"]["ignored"].as_str_seq())
for penalty in ignored:
ignored_tags: Sequence[str] = config["match"]["ignored"].as_str_seq()
for penalty in ignored_tags:
if penalty in penalties:
log.debug("Ignored. Penalty: {0}", penalty)
return
@ -499,8 +500,8 @@ def tag_album(
"""
# Get current metadata.
likelies, consensus = current_metadata(items)
cur_artist = cast(str, likelies["artist"])
cur_album = cast(str, likelies["album"])
cur_artist: str = likelies["artist"]
cur_album: str = likelies["album"]
log.debug("Tagging {0} - {1}", cur_artist, cur_album)
# The output result, keys are the MB album ID.

View file

@ -19,9 +19,8 @@ from __future__ import annotations
import re
import traceback
from collections import Counter
from collections.abc import Iterator, Sequence
from itertools import product
from typing import Any, cast
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin
import musicbrainzngs
@ -37,6 +36,9 @@ from beets.util.id_extractors import (
spotify_id_regex,
)
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
VARIOUS_ARTISTS_ID = "89ad4ac3-39f7-470e-963a-56509c546377"
BASE_URL = "https://musicbrainz.org/"
@ -178,15 +180,18 @@ def _preferred_alias(aliases: list):
return matches[0]
def _preferred_release_event(release: dict[str, Any]) -> tuple[str, str]:
def _preferred_release_event(
release: dict[str, Any],
) -> tuple[str | None, str | None]:
"""Given a release, select and return the user's preferred release
event as a tuple of (country, release_date). Fall back to the
default release event if a preferred event is not found.
"""
countries = config["match"]["preferred"]["countries"].as_str_seq()
countries = cast(Sequence, countries)
preferred_countries: Sequence[str] = config["match"]["preferred"][
"countries"
].as_str_seq()
for country in countries:
for country in preferred_countries:
for event in release.get("release-event-list", {}):
try:
if country in event["area"]["iso-3166-1-code-list"]:
@ -194,7 +199,7 @@ def _preferred_release_event(release: dict[str, Any]) -> tuple[str, str]:
except KeyError:
pass
return (cast(str, release.get("country")), cast(str, release.get("date")))
return release.get("country"), release.get("date")
def _multi_artist_credit(
@ -589,7 +594,9 @@ def album_info(release: dict) -> beets.autotag.hooks.AlbumInfo:
if not release_date:
# Fall back if release-specific date is not available.
release_date = release_group_date
_set_date_str(info, release_date, False)
if release_date:
_set_date_str(info, release_date, False)
_set_date_str(info, release_group_date, True)
# Label name.

View file

@ -26,7 +26,7 @@ from abc import ABC
from collections import defaultdict
from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence
from sqlite3 import Connection
from typing import TYPE_CHECKING, Any, AnyStr, Callable, Generic, TypeVar, cast
from typing import TYPE_CHECKING, Any, AnyStr, Callable, Generic, TypeVar
from unidecode import unidecode
@ -126,8 +126,8 @@ class FormattedMapping(Mapping[str, str]):
value = value.decode("utf-8", "ignore")
if self.for_path:
sep_repl = cast(str, beets.config["path_sep_replace"].as_str())
sep_drive = cast(str, beets.config["drive_sep_replace"].as_str())
sep_repl: str = beets.config["path_sep_replace"].as_str()
sep_drive: str = beets.config["drive_sep_replace"].as_str()
if re.match(r"^\w:", value):
value = re.sub(r"(?<=^\w):", sep_drive, value)

View file

@ -28,7 +28,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from multiprocessing.pool import ThreadPool
from threading import Event, Thread
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from beets import ui
from beets.plugins import BeetsPlugin
@ -576,7 +576,7 @@ class CommandBackend(Backend):
}
)
self.command = cast(str, config["command"].as_str())
self.command: str = config["command"].as_str()
if self.command:
# Explicit executable path.
@ -1225,7 +1225,7 @@ class ReplayGainPlugin(BeetsPlugin):
# FIXME: Consider renaming the configuration option and deprecating the
# old name 'overwrite'.
self.force_on_import = cast(bool, self.config["overwrite"].get(bool))
self.force_on_import: bool = self.config["overwrite"].get(bool)
# Remember which backend is used for CLI feedback
self.backend_name = self.config["backend"].as_str()
@ -1491,7 +1491,7 @@ class ReplayGainPlugin(BeetsPlugin):
def import_begin(self, session: ImportSession):
"""Handle `import_begin` event -> open pool"""
threads = cast(int, self.config["threads"].get(int))
threads: int = self.config["threads"].get(int)
if (
self.config["parallel_on_import"]
@ -1526,9 +1526,7 @@ class ReplayGainPlugin(BeetsPlugin):
# Bypass self.open_pool() if called with `--threads 0`
if opts.threads != 0:
threads = opts.threads or cast(
int, self.config["threads"].get(int)
)
threads: int = opts.threads or self.config["threads"].get(int)
self.open_pool(threads)
if opts.album:

View file

@ -263,6 +263,10 @@ select = [
"TCH", # flake8-type-checking
"W", # pycodestyle
]
ignore = [
"TC006" # no need to quote 'cast's since we use 'from __future__ import annotations'
]
[tool.ruff.lint.per-file-ignores]
"beets/**" = ["PT"]
"test/test_util.py" = ["E501"]