Replace typing.cast with explicit type definitions and ignore TC006

This commit is contained in:
Šarūnas Nejus 2025-05-07 12:49:58 +01:00
parent 99dc0861c2
commit fdc1aba603
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
6 changed files with 45 additions and 36 deletions

View file

@ -18,7 +18,7 @@ from __future__ import annotations
import re import re
from functools import total_ordering from functools import total_ordering
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar, cast from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar
from jellyfish import levenshtein_distance from jellyfish import levenshtein_distance
from unidecode import unidecode from unidecode import unidecode
@ -474,7 +474,6 @@ class Distance:
matched against `value2`. matched against `value2`.
""" """
if isinstance(value1, re.Pattern): if isinstance(value1, re.Pattern):
value2 = cast(str, value2)
return bool(value1.match(value2)) return bool(value1.match(value2))
return value1 == value2 return value1 == value2

View file

@ -20,10 +20,9 @@ from __future__ import annotations
import datetime import datetime
import re import re
from collections.abc import Iterable, Sequence
from enum import IntEnum from enum import IntEnum
from functools import cache from functools import cache
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar, cast from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
import lap import lap
import numpy as np import numpy as np
@ -40,6 +39,8 @@ from beets.autotag import (
from beets.util import plurality from beets.util import plurality
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from beets.library import Item from beets.library import Item
# Artist signals that indicate "various artists". These are used at the # Artist signals that indicate "various artists". These are used at the
@ -241,12 +242,14 @@ def distance(
# Album. # Album.
dist.add_string("album", likelies["album"], album_info.album) dist.add_string("album", likelies["album"], album_info.album)
preferred_config = config["match"]["preferred"]
# Current or preferred media. # Current or preferred media.
if album_info.media: if album_info.media:
# Preferred media options. # Preferred media options.
patterns = config["match"]["preferred"]["media"].as_str_seq() media_patterns: Sequence[str] = preferred_config["media"].as_str_seq()
patterns = cast(Sequence[str], patterns) options = [
options = [re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in patterns] re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in media_patterns
]
if options: if options:
dist.add_priority("media", album_info.media, options) dist.add_priority("media", album_info.media, options)
# Current media. # Current media.
@ -258,7 +261,7 @@ def distance(
dist.add_number("mediums", likelies["disctotal"], album_info.mediums) dist.add_number("mediums", likelies["disctotal"], album_info.mediums)
# Prefer earliest release. # Prefer earliest release.
if album_info.year and config["match"]["preferred"]["original_year"]: if album_info.year and preferred_config["original_year"]:
# Assume 1889 (earliest first gramophone discs) if we don't know the # Assume 1889 (earliest first gramophone discs) if we don't know the
# original year. # original year.
original = album_info.original_year or 1889 original = album_info.original_year or 1889
@ -282,9 +285,8 @@ def distance(
dist.add("year", 1.0) dist.add("year", 1.0)
# Preferred countries. # Preferred countries.
patterns = config["match"]["preferred"]["countries"].as_str_seq() country_patterns: Sequence[str] = preferred_config["countries"].as_str_seq()
patterns = cast(Sequence[str], patterns) options = [re.compile(pat, re.I) for pat in country_patterns]
options = [re.compile(pat, re.I) for pat in patterns]
if album_info.country and options: if album_info.country and options:
dist.add_priority("country", album_info.country, options) dist.add_priority("country", album_info.country, options)
# Country. # Country.
@ -447,9 +449,8 @@ def _add_candidate(
return return
# Discard matches without required tags. # Discard matches without required tags.
for req_tag in cast( required_tags: Sequence[str] = config["match"]["required"].as_str_seq()
Sequence[str], config["match"]["required"].as_str_seq() for req_tag in required_tags:
):
if getattr(info, req_tag) is None: if getattr(info, req_tag) is None:
log.debug("Ignored. Missing required tag: {0}", req_tag) log.debug("Ignored. Missing required tag: {0}", req_tag)
return return
@ -462,8 +463,8 @@ def _add_candidate(
# Skip matches with ignored penalties. # Skip matches with ignored penalties.
penalties = [key for key, _ in dist] penalties = [key for key, _ in dist]
ignored = cast(Sequence[str], config["match"]["ignored"].as_str_seq()) ignored_tags: Sequence[str] = config["match"]["ignored"].as_str_seq()
for penalty in ignored: for penalty in ignored_tags:
if penalty in penalties: if penalty in penalties:
log.debug("Ignored. Penalty: {0}", penalty) log.debug("Ignored. Penalty: {0}", penalty)
return return
@ -499,8 +500,8 @@ def tag_album(
""" """
# Get current metadata. # Get current metadata.
likelies, consensus = current_metadata(items) likelies, consensus = current_metadata(items)
cur_artist = cast(str, likelies["artist"]) cur_artist: str = likelies["artist"]
cur_album = cast(str, likelies["album"]) cur_album: str = likelies["album"]
log.debug("Tagging {0} - {1}", cur_artist, cur_album) log.debug("Tagging {0} - {1}", cur_artist, cur_album)
# The output result, keys are the MB album ID. # The output result, keys are the MB album ID.

View file

@ -19,9 +19,8 @@ from __future__ import annotations
import re import re
import traceback import traceback
from collections import Counter from collections import Counter
from collections.abc import Iterator, Sequence
from itertools import product from itertools import product
from typing import Any, cast from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin from urllib.parse import urljoin
import musicbrainzngs import musicbrainzngs
@ -37,6 +36,9 @@ from beets.util.id_extractors import (
spotify_id_regex, spotify_id_regex,
) )
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
VARIOUS_ARTISTS_ID = "89ad4ac3-39f7-470e-963a-56509c546377" VARIOUS_ARTISTS_ID = "89ad4ac3-39f7-470e-963a-56509c546377"
BASE_URL = "https://musicbrainz.org/" BASE_URL = "https://musicbrainz.org/"
@ -178,15 +180,18 @@ def _preferred_alias(aliases: list):
return matches[0] return matches[0]
def _preferred_release_event(release: dict[str, Any]) -> tuple[str, str]: def _preferred_release_event(
release: dict[str, Any],
) -> tuple[str | None, str | None]:
"""Given a release, select and return the user's preferred release """Given a release, select and return the user's preferred release
event as a tuple of (country, release_date). Fall back to the event as a tuple of (country, release_date). Fall back to the
default release event if a preferred event is not found. default release event if a preferred event is not found.
""" """
countries = config["match"]["preferred"]["countries"].as_str_seq() preferred_countries: Sequence[str] = config["match"]["preferred"][
countries = cast(Sequence, countries) "countries"
].as_str_seq()
for country in countries: for country in preferred_countries:
for event in release.get("release-event-list", {}): for event in release.get("release-event-list", {}):
try: try:
if country in event["area"]["iso-3166-1-code-list"]: if country in event["area"]["iso-3166-1-code-list"]:
@ -194,7 +199,7 @@ def _preferred_release_event(release: dict[str, Any]) -> tuple[str, str]:
except KeyError: except KeyError:
pass pass
return (cast(str, release.get("country")), cast(str, release.get("date"))) return release.get("country"), release.get("date")
def _multi_artist_credit( def _multi_artist_credit(
@ -589,6 +594,8 @@ def album_info(release: dict) -> beets.autotag.hooks.AlbumInfo:
if not release_date: if not release_date:
# Fall back if release-specific date is not available. # Fall back if release-specific date is not available.
release_date = release_group_date release_date = release_group_date
if release_date:
_set_date_str(info, release_date, False) _set_date_str(info, release_date, False)
_set_date_str(info, release_group_date, True) _set_date_str(info, release_group_date, True)

View file

@ -26,7 +26,7 @@ from abc import ABC
from collections import defaultdict from collections import defaultdict
from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence
from sqlite3 import Connection from sqlite3 import Connection
from typing import TYPE_CHECKING, Any, AnyStr, Callable, Generic, TypeVar, cast from typing import TYPE_CHECKING, Any, AnyStr, Callable, Generic, TypeVar
from unidecode import unidecode from unidecode import unidecode
@ -126,8 +126,8 @@ class FormattedMapping(Mapping[str, str]):
value = value.decode("utf-8", "ignore") value = value.decode("utf-8", "ignore")
if self.for_path: if self.for_path:
sep_repl = cast(str, beets.config["path_sep_replace"].as_str()) sep_repl: str = beets.config["path_sep_replace"].as_str()
sep_drive = cast(str, beets.config["drive_sep_replace"].as_str()) sep_drive: str = beets.config["drive_sep_replace"].as_str()
if re.match(r"^\w:", value): if re.match(r"^\w:", value):
value = re.sub(r"(?<=^\w):", sep_drive, value) value = re.sub(r"(?<=^\w):", sep_drive, value)

View file

@ -28,7 +28,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from threading import Event, Thread from threading import Event, Thread
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast from typing import TYPE_CHECKING, Any, Callable, TypeVar
from beets import ui from beets import ui
from beets.plugins import BeetsPlugin from beets.plugins import BeetsPlugin
@ -576,7 +576,7 @@ class CommandBackend(Backend):
} }
) )
self.command = cast(str, config["command"].as_str()) self.command: str = config["command"].as_str()
if self.command: if self.command:
# Explicit executable path. # Explicit executable path.
@ -1225,7 +1225,7 @@ class ReplayGainPlugin(BeetsPlugin):
# FIXME: Consider renaming the configuration option and deprecating the # FIXME: Consider renaming the configuration option and deprecating the
# old name 'overwrite'. # old name 'overwrite'.
self.force_on_import = cast(bool, self.config["overwrite"].get(bool)) self.force_on_import: bool = self.config["overwrite"].get(bool)
# Remember which backend is used for CLI feedback # Remember which backend is used for CLI feedback
self.backend_name = self.config["backend"].as_str() self.backend_name = self.config["backend"].as_str()
@ -1491,7 +1491,7 @@ class ReplayGainPlugin(BeetsPlugin):
def import_begin(self, session: ImportSession): def import_begin(self, session: ImportSession):
"""Handle `import_begin` event -> open pool""" """Handle `import_begin` event -> open pool"""
threads = cast(int, self.config["threads"].get(int)) threads: int = self.config["threads"].get(int)
if ( if (
self.config["parallel_on_import"] self.config["parallel_on_import"]
@ -1526,9 +1526,7 @@ class ReplayGainPlugin(BeetsPlugin):
# Bypass self.open_pool() if called with `--threads 0` # Bypass self.open_pool() if called with `--threads 0`
if opts.threads != 0: if opts.threads != 0:
threads = opts.threads or cast( threads: int = opts.threads or self.config["threads"].get(int)
int, self.config["threads"].get(int)
)
self.open_pool(threads) self.open_pool(threads)
if opts.album: if opts.album:

View file

@ -263,6 +263,10 @@ select = [
"TCH", # flake8-type-checking "TCH", # flake8-type-checking
"W", # pycodestyle "W", # pycodestyle
] ]
ignore = [
"TC006" # no need to quote 'cast's since we use 'from __future__ import annotations'
]
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"beets/**" = ["PT"] "beets/**" = ["PT"]
"test/test_util.py" = ["E501"] "test/test_util.py" = ["E501"]