mirror of
https://github.com/beetbox/beets.git
synced 2026-01-30 20:13:37 +01:00
dbcore/query: improve/fix typing
This commit is contained in:
parent
7fbf562d24
commit
bffeb9816c
2 changed files with 82 additions and 67 deletions
|
|
@ -385,7 +385,7 @@ class Model:
|
|||
"""
|
||||
return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT
|
||||
|
||||
def _get(self, key, default: bool = None, raise_: bool = False):
|
||||
def _get(self, key, default: Any = None, raise_: bool = False):
|
||||
"""Get the value for a field, or `default`. Alternatively,
|
||||
raise a KeyError if the field is not available.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -23,14 +23,16 @@ from operator import mul
|
|||
from typing import (
|
||||
Any,
|
||||
Collection,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
MutableMapping,
|
||||
MutableSequence,
|
||||
Optional,
|
||||
Pattern,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
|
@ -126,7 +128,12 @@ class NamedQuery(Query):
|
|||
...
|
||||
|
||||
|
||||
class FieldQuery(Query):
|
||||
P = TypeVar("P")
|
||||
SQLiteType = Union[str, float, int, memoryview]
|
||||
AnySQLiteType = TypeVar("AnySQLiteType", bound=SQLiteType)
|
||||
|
||||
|
||||
class FieldQuery(Query, Generic[P]):
|
||||
"""An abstract query that searches in a specific field for a
|
||||
pattern. Subclasses must provide a `value_match` class method, which
|
||||
determines whether a certain pattern string matches a certain value
|
||||
|
|
@ -134,15 +141,15 @@ class FieldQuery(Query):
|
|||
same matching functionality in SQLite.
|
||||
"""
|
||||
|
||||
def __init__(self, field: str, pattern: str, fast: bool = True):
|
||||
def __init__(self, field: str, pattern: P, fast: bool = True):
|
||||
self.field = field
|
||||
self.pattern = pattern
|
||||
self.fast = fast
|
||||
|
||||
def col_clause(self) -> Union[Optional[str], Sequence[Any]]:
|
||||
def col_clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
||||
return None, ()
|
||||
|
||||
def clause(self):
|
||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
||||
if self.fast:
|
||||
return self.col_clause()
|
||||
else:
|
||||
|
|
@ -150,7 +157,7 @@ class FieldQuery(Query):
|
|||
return None, ()
|
||||
|
||||
@classmethod
|
||||
def value_match(cls, pattern: str, value: str):
|
||||
def value_match(cls, pattern: P, value: Any):
|
||||
"""Determine whether the value matches the pattern. Both
|
||||
arguments are strings.
|
||||
"""
|
||||
|
|
@ -171,24 +178,24 @@ class FieldQuery(Query):
|
|||
return hash((self.field, hash(self.pattern)))
|
||||
|
||||
|
||||
class MatchQuery(FieldQuery):
|
||||
class MatchQuery(FieldQuery[AnySQLiteType]):
|
||||
"""A query that looks for exact matches in an Model field."""
|
||||
|
||||
def col_clause(self) -> Tuple[str, List[str]]:
|
||||
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
return self.field + " = ?", [self.pattern]
|
||||
|
||||
@classmethod
|
||||
def value_match(cls, pattern: str, value: str) -> bool:
|
||||
def value_match(cls, pattern: AnySQLiteType, value: Any) -> bool:
|
||||
return pattern == value
|
||||
|
||||
|
||||
class NoneQuery(FieldQuery):
|
||||
class NoneQuery(FieldQuery[None]):
|
||||
"""A query that checks whether a field is null."""
|
||||
|
||||
def __init__(self, field, fast: bool = True):
|
||||
super().__init__(field, "", fast)
|
||||
super().__init__(field, None, fast)
|
||||
|
||||
def col_clause(self) -> Tuple[str, Tuple]:
|
||||
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
return self.field + " IS NULL", ()
|
||||
|
||||
def match(self, obj: Model) -> bool:
|
||||
|
|
@ -198,13 +205,13 @@ class NoneQuery(FieldQuery):
|
|||
return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self)
|
||||
|
||||
|
||||
class StringFieldQuery(FieldQuery):
|
||||
class StringFieldQuery(FieldQuery[P]):
|
||||
"""A FieldQuery that converts values to strings before matching
|
||||
them.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def value_match(cls, pattern: str, value: Any):
|
||||
def value_match(cls, pattern: P, value: Any):
|
||||
"""Determine whether the value matches the pattern. The value
|
||||
may have any type.
|
||||
"""
|
||||
|
|
@ -213,7 +220,7 @@ class StringFieldQuery(FieldQuery):
|
|||
@classmethod
|
||||
def string_match(
|
||||
cls,
|
||||
pattern: str,
|
||||
pattern: P,
|
||||
value: str,
|
||||
) -> bool:
|
||||
"""Determine whether the value matches the pattern. Both
|
||||
|
|
@ -222,10 +229,10 @@ class StringFieldQuery(FieldQuery):
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
class StringQuery(StringFieldQuery):
|
||||
class StringQuery(StringFieldQuery[str]):
|
||||
"""A query that matches a whole string in a specific Model field."""
|
||||
|
||||
def col_clause(self) -> Tuple[str, List[str]]:
|
||||
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
search = (self.pattern
|
||||
.replace('\\', '\\\\')
|
||||
.replace('%', '\\%')
|
||||
|
|
@ -239,10 +246,10 @@ class StringQuery(StringFieldQuery):
|
|||
return pattern.lower() == value.lower()
|
||||
|
||||
|
||||
class SubstringQuery(StringFieldQuery):
|
||||
class SubstringQuery(StringFieldQuery[str]):
|
||||
"""A query that matches a substring in a specific Model field."""
|
||||
|
||||
def col_clause(self) -> Tuple[str, List[str]]:
|
||||
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
pattern = (self.pattern
|
||||
.replace('\\', '\\\\')
|
||||
.replace('%', '\\%')
|
||||
|
|
@ -257,7 +264,7 @@ class SubstringQuery(StringFieldQuery):
|
|||
return pattern.lower() in value.lower()
|
||||
|
||||
|
||||
class RegexpQuery(StringFieldQuery):
|
||||
class RegexpQuery(StringFieldQuery[Pattern]):
|
||||
"""A query that matches a regular expression in a specific Model field.
|
||||
|
||||
Raises InvalidQueryError when the pattern is not a valid regular
|
||||
|
|
@ -276,7 +283,7 @@ class RegexpQuery(StringFieldQuery):
|
|||
|
||||
super().__init__(field, pattern_re, fast)
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
return f" regexp({self.field}, ?)", [self.pattern.pattern]
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -291,7 +298,7 @@ class RegexpQuery(StringFieldQuery):
|
|||
return pattern.search(cls._normalize(value)) is not None
|
||||
|
||||
|
||||
class BooleanQuery(MatchQuery):
|
||||
class BooleanQuery(MatchQuery[int]):
|
||||
"""Matches a boolean field. Pattern should either be a boolean or a
|
||||
string reflecting a boolean.
|
||||
"""
|
||||
|
|
@ -299,16 +306,18 @@ class BooleanQuery(MatchQuery):
|
|||
def __init__(
|
||||
self,
|
||||
field: str,
|
||||
pattern: Union[bool, str],
|
||||
pattern: bool,
|
||||
fast: bool = True,
|
||||
):
|
||||
super().__init__(field, pattern, fast)
|
||||
if isinstance(pattern, str):
|
||||
self.pattern = util.str2bool(pattern)
|
||||
self.pattern = int(self.pattern)
|
||||
pattern = util.str2bool(pattern)
|
||||
|
||||
pattern_int = int(pattern)
|
||||
|
||||
super().__init__(field, pattern_int, fast)
|
||||
|
||||
|
||||
class BytesQuery(MatchQuery):
|
||||
class BytesQuery(FieldQuery[bytes]):
|
||||
"""Match a raw bytes field (i.e., a path). This is a necessary hack
|
||||
to work around the `sqlite3` module's desire to treat `bytes` and
|
||||
`unicode` equivalently in Python 2. Always use this query instead of
|
||||
|
|
@ -316,22 +325,30 @@ class BytesQuery(MatchQuery):
|
|||
"""
|
||||
|
||||
def __init__(self, field: str, pattern: Union[bytes, str, memoryview]):
|
||||
super().__init__(field, pattern)
|
||||
|
||||
# Use a buffer/memoryview representation of the pattern for SQLite
|
||||
# matching. This instructs SQLite to treat the blob as binary
|
||||
# rather than encoded Unicode.
|
||||
if isinstance(self.pattern, (str, bytes)):
|
||||
if isinstance(self.pattern, str):
|
||||
self.pattern = self.pattern.encode('utf-8')
|
||||
self.buf_pattern = memoryview(self.pattern)
|
||||
if isinstance(pattern, (str, bytes)):
|
||||
if isinstance(pattern, str):
|
||||
bytes_pattern = pattern.encode('utf-8')
|
||||
else:
|
||||
bytes_pattern = pattern
|
||||
self.buf_pattern = memoryview(bytes_pattern)
|
||||
elif isinstance(self.pattern, memoryview):
|
||||
self.buf_pattern = self.pattern
|
||||
self.pattern = bytes(self.pattern)
|
||||
bytes_pattern = bytes(self.pattern)
|
||||
else:
|
||||
raise ValueError("pattern must be bytes, str, or memoryview")
|
||||
|
||||
def col_clause(self) -> Tuple[str, List[memoryview]]:
|
||||
super().__init__(field, bytes_pattern)
|
||||
|
||||
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
return self.field + " = ?", [self.buf_pattern]
|
||||
|
||||
@classmethod
|
||||
def value_match(cls, pattern: bytes, value: Any) -> bool:
|
||||
return pattern == value
|
||||
|
||||
|
||||
class NumericQuery(FieldQuery):
|
||||
"""Matches numeric fields. A syntax using Ruby-style range ellipses
|
||||
|
|
@ -390,7 +407,7 @@ class NumericQuery(FieldQuery):
|
|||
return False
|
||||
return True
|
||||
|
||||
def col_clause(self) -> Tuple[str, Tuple]:
|
||||
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
if self.point is not None:
|
||||
return self.field + '=?', (self.point,)
|
||||
else:
|
||||
|
|
@ -430,7 +447,7 @@ class CollectionQuery(Query):
|
|||
def clause_with_joiner(
|
||||
self,
|
||||
joiner: str,
|
||||
) -> Tuple[Optional[str], Collection]:
|
||||
) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
||||
"""Return a clause created by joining together the clauses of
|
||||
all subqueries with the string joiner (padded by spaces).
|
||||
"""
|
||||
|
|
@ -477,7 +494,7 @@ class AnyFieldQuery(CollectionQuery):
|
|||
# TYPING ERROR
|
||||
super().__init__(subqueries)
|
||||
|
||||
def clause(self) -> Tuple[Union[str, None], Collection]:
|
||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
||||
return self.clause_with_joiner('or')
|
||||
|
||||
def match(self, obj: Model) -> bool:
|
||||
|
|
@ -502,7 +519,7 @@ class MutableCollectionQuery(CollectionQuery):
|
|||
"""A collection query whose subqueries may be modified after the
|
||||
query is initialized.
|
||||
"""
|
||||
subqueries: MutableMapping
|
||||
subqueries: MutableSequence
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.subqueries[key] = value
|
||||
|
|
@ -514,7 +531,7 @@ class MutableCollectionQuery(CollectionQuery):
|
|||
class AndQuery(MutableCollectionQuery):
|
||||
"""A conjunction of a list of other queries."""
|
||||
|
||||
def clause(self) -> Tuple[Union[str, None], Collection]:
|
||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
||||
return self.clause_with_joiner('and')
|
||||
|
||||
def match(self, obj: Model) -> bool:
|
||||
|
|
@ -524,7 +541,7 @@ class AndQuery(MutableCollectionQuery):
|
|||
class OrQuery(MutableCollectionQuery):
|
||||
"""A conjunction of a list of other queries."""
|
||||
|
||||
def clause(self) -> Tuple[Union[str, None], Collection]:
|
||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
||||
return self.clause_with_joiner('or')
|
||||
|
||||
def match(self, obj: Model) -> bool:
|
||||
|
|
@ -539,7 +556,7 @@ class NotQuery(Query):
|
|||
def __init__(self, subquery):
|
||||
self.subquery = subquery
|
||||
|
||||
def clause(self):
|
||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
||||
clause, subvals = self.subquery.clause()
|
||||
if clause:
|
||||
return f'not ({clause})', subvals
|
||||
|
|
@ -565,7 +582,7 @@ class NotQuery(Query):
|
|||
class TrueQuery(Query):
|
||||
"""A query that always matches."""
|
||||
|
||||
def clause(self) -> Tuple[Union[str, None], Collection]:
|
||||
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
return '1', ()
|
||||
|
||||
def match(self, obj: Model) -> bool:
|
||||
|
|
@ -575,7 +592,7 @@ class TrueQuery(Query):
|
|||
class FalseQuery(Query):
|
||||
"""A query that never matches."""
|
||||
|
||||
def clause(self) -> Tuple[Union[str, None], Collection]:
|
||||
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
return '0', ()
|
||||
|
||||
def match(self, obj: Model) -> bool:
|
||||
|
|
@ -584,7 +601,7 @@ class FalseQuery(Query):
|
|||
|
||||
# Time/date queries.
|
||||
|
||||
def _parse_periods(pattern: str) -> Tuple['Period', 'Period']:
|
||||
def _parse_periods(pattern: str) -> Tuple[Optional[Period], Optional[Period]]:
|
||||
"""Parse a string containing two dates separated by two dots (..).
|
||||
Return a pair of `Period` objects.
|
||||
"""
|
||||
|
|
@ -661,6 +678,8 @@ class Period:
|
|||
if not string:
|
||||
return None
|
||||
|
||||
date: Optional[datetime]
|
||||
|
||||
# Check for a relative date.
|
||||
match_dq = re.match(cls.relative_re, string)
|
||||
if match_dq:
|
||||
|
|
@ -678,7 +697,7 @@ class Period:
|
|||
|
||||
# Check for an absolute date.
|
||||
date, ordinal = find_date_and_format(string)
|
||||
if date is None:
|
||||
if date is None or ordinal is None:
|
||||
raise InvalidQueryArgumentValueError(string,
|
||||
'a valid date/time string')
|
||||
precision = cls.precisions[ordinal]
|
||||
|
|
@ -724,7 +743,11 @@ class DateInterval:
|
|||
self.end = end
|
||||
|
||||
@classmethod
|
||||
def from_periods(cls, start: Period, end: Period) -> 'DateInterval':
|
||||
def from_periods(
|
||||
cls,
|
||||
start: Optional[Period],
|
||||
end: Optional[Period],
|
||||
) -> DateInterval:
|
||||
"""Create an interval with two Periods as the endpoints.
|
||||
"""
|
||||
end_date = end.open_right_endpoint() if end is not None else None
|
||||
|
|
@ -766,7 +789,7 @@ class DateQuery(FieldQuery):
|
|||
|
||||
_clause_tmpl = "{0} {1} ?"
|
||||
|
||||
def col_clause(self) -> Tuple[Union[str, None], Collection]:
|
||||
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
clause_parts = []
|
||||
subvals = []
|
||||
|
||||
|
|
@ -824,7 +847,7 @@ class Sort:
|
|||
the database.
|
||||
"""
|
||||
|
||||
def order_clause(self) -> None:
|
||||
def order_clause(self) -> Optional[str]:
|
||||
"""Generates a SQL fragment to be used in a ORDER BY clause, or
|
||||
None if no fragment is used (i.e., this is a slow sort).
|
||||
"""
|
||||
|
|
@ -858,30 +881,22 @@ class MultipleSort(Sort):
|
|||
def add_sort(self, sort: Sort):
|
||||
self.sorts.append(sort)
|
||||
|
||||
def _sql_sorts(self) -> List[Sort]:
|
||||
"""Return the list of sub-sorts for which we can be (at least
|
||||
partially) fast.
|
||||
def order_clause(self) -> str:
|
||||
"""Return the list SQL clauses for those sub-sorts for which we can be
|
||||
(at least partially) fast.
|
||||
|
||||
A contiguous suffix of fast (SQL-capable) sub-sorts are
|
||||
executable in SQL. The remaining, even if they are fast
|
||||
independently, must be executed slowly.
|
||||
"""
|
||||
sql_sorts = []
|
||||
for sort in reversed(self.sorts):
|
||||
if not sort.order_clause() is None:
|
||||
sql_sorts.append(sort)
|
||||
else:
|
||||
break
|
||||
sql_sorts.reverse()
|
||||
return sql_sorts
|
||||
|
||||
def order_clause(self) -> str:
|
||||
order_strings = []
|
||||
for sort in self._sql_sorts():
|
||||
order = sort.order_clause()
|
||||
order_strings.append(order)
|
||||
for sort in reversed(self.sorts):
|
||||
clause = sort.order_clause()
|
||||
if clause is None:
|
||||
break
|
||||
order_strings.append(clause)
|
||||
order_strings.reverse()
|
||||
|
||||
# TYPING ERROR
|
||||
return ", ".join(order_strings)
|
||||
|
||||
def is_slow(self) -> bool:
|
||||
|
|
|
|||
Loading…
Reference in a new issue