From d3bdf137ea7e7a3d252f72235b9fa4635fada60e Mon Sep 17 00:00:00 2001 From: "Arav K." Date: Sun, 2 Jun 2024 19:19:44 +0200 Subject: [PATCH] Resolve some 'mypy' errors --- beets/autotag/hooks.py | 22 +++++++++++++--------- beets/autotag/match.py | 12 +++++++----- beets/dbcore/db.py | 3 +-- beets/dbcore/query.py | 8 ++++---- beets/dbcore/types.py | 2 +- beets/library.py | 9 +++++---- beets/util/pipeline.py | 6 +++++- beetsplug/bareasc.py | 2 +- beetsplug/bpd/__init__.py | 3 ++- beetsplug/fuzzy.py | 4 ++-- beetsplug/replaygain.py | 4 +++- beetsplug/the.py | 3 ++- test/plugins/test_albumtypes.py | 10 ++++++++-- 13 files changed, 54 insertions(+), 34 deletions(-) diff --git a/beets/autotag/hooks.py b/beets/autotag/hooks.py index 67546f47c..999e98b70 100644 --- a/beets/autotag/hooks.py +++ b/beets/autotag/hooks.py @@ -28,6 +28,7 @@ from typing import ( List, Optional, Tuple, + TypeVar, Union, cast, ) @@ -42,20 +43,23 @@ from beets.util import as_string log = logging.getLogger("beets") +V = TypeVar("V") + # Classes used to represent candidate options. -class AttrDict(dict): +class AttrDict(Dict[str, V]): """A dictionary that supports attribute ("dot") access, so `d.field` is equivalent to `d['field']`. """ - def __getattr__(self, attr): - if attr in self: - return self.get(attr) + def __getattr__(self, attr: str) -> V: + result = self.get(attr) + if result is not None: + return result else: raise AttributeError - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: V): self.__setitem__(key, value) def __hash__(self): @@ -79,7 +83,7 @@ class AlbumInfo(AttrDict): # TYPING: are all of these correct? I've assumed optional strings def __init__( self, - tracks: List["TrackInfo"], + tracks: List[TrackInfo], album: Optional[str] = None, album_id: Optional[str] = None, artist: Optional[str] = None, @@ -201,7 +205,7 @@ class AlbumInfo(AttrDict): for track in self.tracks: track.decode(codec) - def copy(self) -> "AlbumInfo": + def copy(self) -> AlbumInfo: dupe = AlbumInfo([]) dupe.update(self) dupe.tracks = [track.copy() for track in self.tracks] @@ -309,7 +313,7 @@ class TrackInfo(AttrDict): if isinstance(value, bytes): setattr(self, fld, value.decode(codec, "ignore")) - def copy(self) -> "TrackInfo": + def copy(self) -> TrackInfo: dupe = TrackInfo() dupe.update(self) return dupe @@ -545,7 +549,7 @@ class Distance: # Adding components. - def _eq(self, value1: Union[re.Pattern, Any], value2: Any) -> bool: + def _eq(self, value1: Union[re.Pattern[str], Any], value2: Any) -> bool: """Returns True if `value1` is equal to `value2`. `value1` may be a compiled regular expression, in which case it will be matched against `value2`. diff --git a/beets/autotag/match.py b/beets/autotag/match.py index a256960f7..63db9e33c 100644 --- a/beets/autotag/match.py +++ b/beets/autotag/match.py @@ -238,7 +238,7 @@ def distance( if album_info.media: # Preferred media options. patterns = config["match"]["preferred"]["media"].as_str_seq() - patterns = cast(Sequence, patterns) + patterns = cast(Sequence[str], patterns) options = [re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in patterns] if options: dist.add_priority("media", album_info.media, options) @@ -276,7 +276,7 @@ def distance( # Preferred countries. patterns = config["match"]["preferred"]["countries"].as_str_seq() - patterns = cast(Sequence, patterns) + patterns = cast(Sequence[str], patterns) options = [re.compile(pat, re.I) for pat in patterns] if album_info.country and options: dist.add_priority("country", album_info.country, options) @@ -440,7 +440,9 @@ def _add_candidate( return # Discard matches without required tags. - for req_tag in cast(Sequence, config["match"]["required"].as_str_seq()): + for req_tag in cast( + Sequence[str], config["match"]["required"].as_str_seq() + ): if getattr(info, req_tag) is None: log.debug("Ignored. Missing required tag: {0}", req_tag) return @@ -469,7 +471,7 @@ def tag_album( items, search_artist: Optional[str] = None, search_album: Optional[str] = None, - search_ids: List = [], + search_ids: List[str] = [], ) -> Tuple[str, str, Proposal]: """Return a tuple of the current artist name, the current album name, and a `Proposal` containing `AlbumMatch` candidates. @@ -561,7 +563,7 @@ def tag_item( item, search_artist: Optional[str] = None, search_title: Optional[str] = None, - search_ids: List = [], + search_ids: Optional[List[str]] = None, ) -> Proposal: """Find metadata for a single track. Return a `Proposal` consisting of `TrackMatch` objects. diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 369b1ffe0..dce111267 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -58,7 +58,6 @@ from . import types from .query import ( AndQuery, FieldQuery, - FieldSort, MatchQuery, NullSort, Query, @@ -303,7 +302,7 @@ class Model(ABC): """Optional Types for non-fixed (i.e., flexible and computed) fields. """ - _sorts: Dict[str, Type[FieldSort]] = {} + _sorts: Dict[str, Type[Sort]] = {} """Optional named sort criteria. The keys are strings and the values are subclasses of `Sort`. """ diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index ffa89168e..2e1385ca2 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -256,7 +256,7 @@ class SubstringQuery(StringFieldQuery[str]): return pattern.lower() in value.lower() -class RegexpQuery(StringFieldQuery[Pattern]): +class RegexpQuery(StringFieldQuery[Pattern[str]]): """A query that matches a regular expression in a specific Model field. Raises InvalidQueryError when the pattern is not a valid regular @@ -342,7 +342,7 @@ class BytesQuery(FieldQuery[bytes]): return pattern == value -class NumericQuery(FieldQuery): +class NumericQuery(FieldQuery[str]): """Matches numeric fields. A syntax using Ruby-style range ellipses (``..``) lets users specify one- or two-sided ranges. For example, ``year:2001..`` finds music released since the turn of the century. @@ -787,7 +787,7 @@ class DateInterval: return f"[{self.start}, {self.end})" -class DateQuery(FieldQuery): +class DateQuery(FieldQuery[str]): """Matches date fields stored as seconds since Unix epoch time. Dates can be specified as ``year-month-day`` strings where only year @@ -797,7 +797,7 @@ class DateQuery(FieldQuery): using an ellipsis interval syntax similar to that of NumericQuery. """ - def __init__(self, field, pattern, fast: bool = True): + def __init__(self, field: str, pattern: str, fast: bool = True): super().__init__(field, pattern, fast) start, end = _parse_periods(pattern) self.interval = DateInterval.from_periods(start, end) diff --git a/beets/dbcore/types.py b/beets/dbcore/types.py index 99007fb60..469991c84 100644 --- a/beets/dbcore/types.py +++ b/beets/dbcore/types.py @@ -232,7 +232,7 @@ class BaseFloat(Type[float, N]): """ sql = "REAL" - query = NumericQuery + query: type[FieldQuery[Any]] = NumericQuery model_type = float def __init__(self, digits: int = 1): diff --git a/beets/library.py b/beets/library.py index 90841c493..d44ccdd25 100644 --- a/beets/library.py +++ b/beets/library.py @@ -22,6 +22,7 @@ import string import sys import time import unicodedata +from typing import Never from functools import cached_property from mediafile import MediaFile, UnreadableFileError @@ -49,7 +50,7 @@ log = logging.getLogger("beets") # Library-specific query types. -class SingletonQuery(dbcore.FieldQuery): +class SingletonQuery(dbcore.FieldQuery[Never]): """This query is responsible for the 'singleton' lookup. It is based on the FieldQuery and constructs a SQL clause @@ -67,7 +68,7 @@ class SingletonQuery(dbcore.FieldQuery): return dbcore.query.NotQuery(query) -class PathQuery(dbcore.FieldQuery): +class PathQuery(dbcore.FieldQuery[bytes]): """A query that matches all items under a given path. Matching can either be case-insensitive or case-sensitive. By @@ -185,7 +186,7 @@ class DateType(types.Float): return self.null -class PathType(types.Type): +class PathType(types.Type[bytes, bytes]): """A dbcore type for filesystem paths. These are represented as `bytes` objects, in keeping with @@ -384,7 +385,7 @@ class LibModel(dbcore.Model): """Shared concrete functionality for Items and Albums.""" # Config key that specifies how an instance should be formatted. - _format_config_key = None + _format_config_key: str def _template_funcs(self): funcs = DefaultTemplateFunctions(self, self._db).functions() diff --git a/beets/util/pipeline.py b/beets/util/pipeline.py index c4933ff00..bf856eb55 100644 --- a/beets/util/pipeline.py +++ b/beets/util/pipeline.py @@ -35,6 +35,7 @@ in place of any single coroutine. import queue import sys from threading import Lock, Thread +from typing import TypeVar BUBBLE = "__PIPELINE_BUBBLE__" POISON = "__PIPELINE_POISON__" @@ -84,7 +85,10 @@ def _invalidate_queue(q, val=None, sync=True): q.mutex.release() -class CountedQueue(queue.Queue): +T = TypeVar("T") + + +class CountedQueue(queue.Queue[T]): """A queue that keeps track of the number of threads that are still feeding into it. The queue is poisoned when all threads are finished with the queue. diff --git a/beetsplug/bareasc.py b/beetsplug/bareasc.py index dff0b2e93..8cdcbb113 100644 --- a/beetsplug/bareasc.py +++ b/beetsplug/bareasc.py @@ -27,7 +27,7 @@ from beets.plugins import BeetsPlugin from beets.ui import decargs, print_ -class BareascQuery(StringFieldQuery): +class BareascQuery(StringFieldQuery[str]): """Compare items using bare ASCII, without accents etc.""" @classmethod diff --git a/beetsplug/bpd/__init__.py b/beetsplug/bpd/__init__.py index 3d7396401..a4cb4d291 100644 --- a/beetsplug/bpd/__init__.py +++ b/beetsplug/bpd/__init__.py @@ -27,6 +27,7 @@ import sys import time import traceback from string import Template +from typing import List from mediafile import MediaFile @@ -1059,7 +1060,7 @@ class Command: raise BPDError(ERROR_SYSTEM, "server error", self.name) -class CommandList(list): +class CommandList(List[Command]): """A list of commands issued by the client for processing by the server. May be verbose, in which case the response is delimited, or not. Should be a list of `Command` objects. diff --git a/beetsplug/fuzzy.py b/beetsplug/fuzzy.py index d3d14d86c..45ada8b0b 100644 --- a/beetsplug/fuzzy.py +++ b/beetsplug/fuzzy.py @@ -23,9 +23,9 @@ from beets.dbcore.query import StringFieldQuery from beets.plugins import BeetsPlugin -class FuzzyQuery(StringFieldQuery): +class FuzzyQuery(StringFieldQuery[str]): @classmethod - def string_match(cls, pattern, val): + def string_match(cls, pattern: str, val: str): # smartcase if pattern.islower(): val = val.lower() diff --git a/beetsplug/replaygain.py b/beetsplug/replaygain.py index 78640b6a8..583555530 100644 --- a/beetsplug/replaygain.py +++ b/beetsplug/replaygain.py @@ -1181,7 +1181,9 @@ class ExceptionWatcher(Thread): Once an exception occurs, raise it and execute a callback. """ - def __init__(self, queue: queue.Queue, callback: Callable[[], None]): + def __init__( + self, queue: queue.Queue[Exception], callback: Callable[[], None] + ): self._queue = queue self._callback = callback self._stopevent = Event() diff --git a/beetsplug/the.py b/beetsplug/the.py index 2deab9cd5..c6fb46ddf 100644 --- a/beetsplug/the.py +++ b/beetsplug/the.py @@ -16,6 +16,7 @@ import re +from typing import List from beets.plugins import BeetsPlugin @@ -28,7 +29,7 @@ FORMAT = "{0}, {1}" class ThePlugin(BeetsPlugin): - patterns = [] + patterns: List[str] = [] def __init__(self): super().__init__() diff --git a/test/plugins/test_albumtypes.py b/test/plugins/test_albumtypes.py index 532fdc69c..6b3b48d10 100644 --- a/test/plugins/test_albumtypes.py +++ b/test/plugins/test_albumtypes.py @@ -16,6 +16,7 @@ import unittest +from typing import Sequence, Tuple from beets.autotag.mb import VARIOUS_ARTISTS_ID from beets.test.helper import TestHelper @@ -98,12 +99,17 @@ class AlbumTypesPluginTest(unittest.TestCase, TestHelper): result = subject._atypes(album) self.assertEqual("[EP][Single][OST][Live][Remix]", result) - def _set_config(self, types: [(str, str)], ignore_va: [str], bracket: str): + def _set_config( + self, + types: Sequence[Tuple[str, str]], + ignore_va: Sequence[str], + bracket: str, + ): self.config["albumtypes"]["types"] = types self.config["albumtypes"]["ignore_va"] = ignore_va self.config["albumtypes"]["bracket"] = bracket - def _create_album(self, album_types: [str], artist_id: str = 0): + def _create_album(self, album_types: Sequence[str], artist_id: str = "0"): return self.add_album( albumtypes=album_types, mb_albumartistid=artist_id )