Resolve some 'mypy' errors

This commit is contained in:
Arav K. 2024-06-02 19:19:44 +02:00
parent d7bf28deed
commit d3bdf137ea
13 changed files with 54 additions and 34 deletions

View file

@ -28,6 +28,7 @@ from typing import (
List,
Optional,
Tuple,
TypeVar,
Union,
cast,
)
@ -42,20 +43,23 @@ from beets.util import as_string
log = logging.getLogger("beets")
V = TypeVar("V")
# Classes used to represent candidate options.
class AttrDict(dict):
class AttrDict(Dict[str, V]):
"""A dictionary that supports attribute ("dot") access, so `d.field`
is equivalent to `d['field']`.
"""
def __getattr__(self, attr):
if attr in self:
return self.get(attr)
def __getattr__(self, attr: str) -> V:
result = self.get(attr)
if result is not None:
return result
else:
raise AttributeError
def __setattr__(self, key, value):
def __setattr__(self, key: str, value: V):
self.__setitem__(key, value)
def __hash__(self):
@ -79,7 +83,7 @@ class AlbumInfo(AttrDict):
# TYPING: are all of these correct? I've assumed optional strings
def __init__(
self,
tracks: List["TrackInfo"],
tracks: List[TrackInfo],
album: Optional[str] = None,
album_id: Optional[str] = None,
artist: Optional[str] = None,
@ -201,7 +205,7 @@ class AlbumInfo(AttrDict):
for track in self.tracks:
track.decode(codec)
def copy(self) -> "AlbumInfo":
def copy(self) -> AlbumInfo:
dupe = AlbumInfo([])
dupe.update(self)
dupe.tracks = [track.copy() for track in self.tracks]
@ -309,7 +313,7 @@ class TrackInfo(AttrDict):
if isinstance(value, bytes):
setattr(self, fld, value.decode(codec, "ignore"))
def copy(self) -> "TrackInfo":
def copy(self) -> TrackInfo:
dupe = TrackInfo()
dupe.update(self)
return dupe
@ -545,7 +549,7 @@ class Distance:
# Adding components.
def _eq(self, value1: Union[re.Pattern, Any], value2: Any) -> bool:
def _eq(self, value1: Union[re.Pattern[str], Any], value2: Any) -> bool:
"""Returns True if `value1` is equal to `value2`. `value1` may
be a compiled regular expression, in which case it will be
matched against `value2`.

View file

@ -238,7 +238,7 @@ def distance(
if album_info.media:
# Preferred media options.
patterns = config["match"]["preferred"]["media"].as_str_seq()
patterns = cast(Sequence, patterns)
patterns = cast(Sequence[str], patterns)
options = [re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in patterns]
if options:
dist.add_priority("media", album_info.media, options)
@ -276,7 +276,7 @@ def distance(
# Preferred countries.
patterns = config["match"]["preferred"]["countries"].as_str_seq()
patterns = cast(Sequence, patterns)
patterns = cast(Sequence[str], patterns)
options = [re.compile(pat, re.I) for pat in patterns]
if album_info.country and options:
dist.add_priority("country", album_info.country, options)
@ -440,7 +440,9 @@ def _add_candidate(
return
# Discard matches without required tags.
for req_tag in cast(Sequence, config["match"]["required"].as_str_seq()):
for req_tag in cast(
Sequence[str], config["match"]["required"].as_str_seq()
):
if getattr(info, req_tag) is None:
log.debug("Ignored. Missing required tag: {0}", req_tag)
return
@ -469,7 +471,7 @@ def tag_album(
items,
search_artist: Optional[str] = None,
search_album: Optional[str] = None,
search_ids: List = [],
search_ids: List[str] = [],
) -> Tuple[str, str, Proposal]:
"""Return a tuple of the current artist name, the current album
name, and a `Proposal` containing `AlbumMatch` candidates.
@ -561,7 +563,7 @@ def tag_item(
item,
search_artist: Optional[str] = None,
search_title: Optional[str] = None,
search_ids: List = [],
search_ids: Optional[List[str]] = None,
) -> Proposal:
"""Find metadata for a single track. Return a `Proposal` consisting
of `TrackMatch` objects.

View file

@ -58,7 +58,6 @@ from . import types
from .query import (
AndQuery,
FieldQuery,
FieldSort,
MatchQuery,
NullSort,
Query,
@ -303,7 +302,7 @@ class Model(ABC):
"""Optional Types for non-fixed (i.e., flexible and computed) fields.
"""
_sorts: Dict[str, Type[FieldSort]] = {}
_sorts: Dict[str, Type[Sort]] = {}
"""Optional named sort criteria. The keys are strings and the values
are subclasses of `Sort`.
"""

View file

@ -256,7 +256,7 @@ class SubstringQuery(StringFieldQuery[str]):
return pattern.lower() in value.lower()
class RegexpQuery(StringFieldQuery[Pattern]):
class RegexpQuery(StringFieldQuery[Pattern[str]]):
"""A query that matches a regular expression in a specific Model field.
Raises InvalidQueryError when the pattern is not a valid regular
@ -342,7 +342,7 @@ class BytesQuery(FieldQuery[bytes]):
return pattern == value
class NumericQuery(FieldQuery):
class NumericQuery(FieldQuery[str]):
"""Matches numeric fields. A syntax using Ruby-style range ellipses
(``..``) lets users specify one- or two-sided ranges. For example,
``year:2001..`` finds music released since the turn of the century.
@ -787,7 +787,7 @@ class DateInterval:
return f"[{self.start}, {self.end})"
class DateQuery(FieldQuery):
class DateQuery(FieldQuery[str]):
"""Matches date fields stored as seconds since Unix epoch time.
Dates can be specified as ``year-month-day`` strings where only year
@ -797,7 +797,7 @@ class DateQuery(FieldQuery):
using an ellipsis interval syntax similar to that of NumericQuery.
"""
def __init__(self, field, pattern, fast: bool = True):
def __init__(self, field: str, pattern: str, fast: bool = True):
super().__init__(field, pattern, fast)
start, end = _parse_periods(pattern)
self.interval = DateInterval.from_periods(start, end)

View file

@ -232,7 +232,7 @@ class BaseFloat(Type[float, N]):
"""
sql = "REAL"
query = NumericQuery
query: type[FieldQuery[Any]] = NumericQuery
model_type = float
def __init__(self, digits: int = 1):

View file

@ -22,6 +22,7 @@ import string
import sys
import time
import unicodedata
from typing import Never
from functools import cached_property
from mediafile import MediaFile, UnreadableFileError
@ -49,7 +50,7 @@ log = logging.getLogger("beets")
# Library-specific query types.
class SingletonQuery(dbcore.FieldQuery):
class SingletonQuery(dbcore.FieldQuery[Never]):
"""This query is responsible for the 'singleton' lookup.
It is based on the FieldQuery and constructs a SQL clause
@ -67,7 +68,7 @@ class SingletonQuery(dbcore.FieldQuery):
return dbcore.query.NotQuery(query)
class PathQuery(dbcore.FieldQuery):
class PathQuery(dbcore.FieldQuery[bytes]):
"""A query that matches all items under a given path.
Matching can either be case-insensitive or case-sensitive. By
@ -185,7 +186,7 @@ class DateType(types.Float):
return self.null
class PathType(types.Type):
class PathType(types.Type[bytes, bytes]):
"""A dbcore type for filesystem paths.
These are represented as `bytes` objects, in keeping with
@ -384,7 +385,7 @@ class LibModel(dbcore.Model):
"""Shared concrete functionality for Items and Albums."""
# Config key that specifies how an instance should be formatted.
_format_config_key = None
_format_config_key: str
def _template_funcs(self):
funcs = DefaultTemplateFunctions(self, self._db).functions()

View file

@ -35,6 +35,7 @@ in place of any single coroutine.
import queue
import sys
from threading import Lock, Thread
from typing import TypeVar
BUBBLE = "__PIPELINE_BUBBLE__"
POISON = "__PIPELINE_POISON__"
@ -84,7 +85,10 @@ def _invalidate_queue(q, val=None, sync=True):
q.mutex.release()
class CountedQueue(queue.Queue):
T = TypeVar("T")
class CountedQueue(queue.Queue[T]):
"""A queue that keeps track of the number of threads that are
still feeding into it. The queue is poisoned when all threads are
finished with the queue.

View file

@ -27,7 +27,7 @@ from beets.plugins import BeetsPlugin
from beets.ui import decargs, print_
class BareascQuery(StringFieldQuery):
class BareascQuery(StringFieldQuery[str]):
"""Compare items using bare ASCII, without accents etc."""
@classmethod

View file

@ -27,6 +27,7 @@ import sys
import time
import traceback
from string import Template
from typing import List
from mediafile import MediaFile
@ -1059,7 +1060,7 @@ class Command:
raise BPDError(ERROR_SYSTEM, "server error", self.name)
class CommandList(list):
class CommandList(List[Command]):
"""A list of commands issued by the client for processing by the
server. May be verbose, in which case the response is delimited, or
not. Should be a list of `Command` objects.

View file

@ -23,9 +23,9 @@ from beets.dbcore.query import StringFieldQuery
from beets.plugins import BeetsPlugin
class FuzzyQuery(StringFieldQuery):
class FuzzyQuery(StringFieldQuery[str]):
@classmethod
def string_match(cls, pattern, val):
def string_match(cls, pattern: str, val: str):
# smartcase
if pattern.islower():
val = val.lower()

View file

@ -1181,7 +1181,9 @@ class ExceptionWatcher(Thread):
Once an exception occurs, raise it and execute a callback.
"""
def __init__(self, queue: queue.Queue, callback: Callable[[], None]):
def __init__(
self, queue: queue.Queue[Exception], callback: Callable[[], None]
):
self._queue = queue
self._callback = callback
self._stopevent = Event()

View file

@ -16,6 +16,7 @@
import re
from typing import List
from beets.plugins import BeetsPlugin
@ -28,7 +29,7 @@ FORMAT = "{0}, {1}"
class ThePlugin(BeetsPlugin):
patterns = []
patterns: List[str] = []
def __init__(self):
super().__init__()

View file

@ -16,6 +16,7 @@
import unittest
from typing import Sequence, Tuple
from beets.autotag.mb import VARIOUS_ARTISTS_ID
from beets.test.helper import TestHelper
@ -98,12 +99,17 @@ class AlbumTypesPluginTest(unittest.TestCase, TestHelper):
result = subject._atypes(album)
self.assertEqual("[EP][Single][OST][Live][Remix]", result)
def _set_config(self, types: [(str, str)], ignore_va: [str], bracket: str):
def _set_config(
self,
types: Sequence[Tuple[str, str]],
ignore_va: Sequence[str],
bracket: str,
):
self.config["albumtypes"]["types"] = types
self.config["albumtypes"]["ignore_va"] = ignore_va
self.config["albumtypes"]["bracket"] = bracket
def _create_album(self, album_types: [str], artist_id: str = 0):
def _create_album(self, album_types: Sequence[str], artist_id: str = "0"):
return self.add_album(
albumtypes=album_types, mb_albumartistid=artist_id
)