dbcore/query: improve/fix typing

This commit is contained in:
wisp3rwind 2023-06-25 10:59:20 +02:00
parent 7fbf562d24
commit bffeb9816c
2 changed files with 82 additions and 67 deletions

View file

@ -385,7 +385,7 @@ class Model:
"""
return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT
def _get(self, key, default: bool = None, raise_: bool = False):
def _get(self, key, default: Any = None, raise_: bool = False):
"""Get the value for a field, or `default`. Alternatively,
raise a KeyError if the field is not available.
"""

View file

@ -23,14 +23,16 @@ from operator import mul
from typing import (
Any,
Collection,
Generic,
Iterator,
List,
MutableMapping,
MutableSequence,
Optional,
Pattern,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
@ -126,7 +128,12 @@ class NamedQuery(Query):
...
class FieldQuery(Query):
P = TypeVar("P")
SQLiteType = Union[str, float, int, memoryview]
AnySQLiteType = TypeVar("AnySQLiteType", bound=SQLiteType)
class FieldQuery(Query, Generic[P]):
"""An abstract query that searches in a specific field for a
pattern. Subclasses must provide a `value_match` class method, which
determines whether a certain pattern string matches a certain value
@ -134,15 +141,15 @@ class FieldQuery(Query):
same matching functionality in SQLite.
"""
def __init__(self, field: str, pattern: str, fast: bool = True):
def __init__(self, field: str, pattern: P, fast: bool = True):
self.field = field
self.pattern = pattern
self.fast = fast
def col_clause(self) -> Union[Optional[str], Sequence[Any]]:
def col_clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
return None, ()
def clause(self):
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
if self.fast:
return self.col_clause()
else:
@ -150,7 +157,7 @@ class FieldQuery(Query):
return None, ()
@classmethod
def value_match(cls, pattern: str, value: str):
def value_match(cls, pattern: P, value: Any):
"""Determine whether the value matches the pattern. Both
arguments are strings.
"""
@ -171,24 +178,24 @@ class FieldQuery(Query):
return hash((self.field, hash(self.pattern)))
class MatchQuery(FieldQuery):
class MatchQuery(FieldQuery[AnySQLiteType]):
"""A query that looks for exact matches in an Model field."""
def col_clause(self) -> Tuple[str, List[str]]:
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.field + " = ?", [self.pattern]
@classmethod
def value_match(cls, pattern: str, value: str) -> bool:
def value_match(cls, pattern: AnySQLiteType, value: Any) -> bool:
return pattern == value
class NoneQuery(FieldQuery):
class NoneQuery(FieldQuery[None]):
"""A query that checks whether a field is null."""
def __init__(self, field, fast: bool = True):
super().__init__(field, "", fast)
super().__init__(field, None, fast)
def col_clause(self) -> Tuple[str, Tuple]:
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.field + " IS NULL", ()
def match(self, obj: Model) -> bool:
@ -198,13 +205,13 @@ class NoneQuery(FieldQuery):
return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self)
class StringFieldQuery(FieldQuery):
class StringFieldQuery(FieldQuery[P]):
"""A FieldQuery that converts values to strings before matching
them.
"""
@classmethod
def value_match(cls, pattern: str, value: Any):
def value_match(cls, pattern: P, value: Any):
"""Determine whether the value matches the pattern. The value
may have any type.
"""
@ -213,7 +220,7 @@ class StringFieldQuery(FieldQuery):
@classmethod
def string_match(
cls,
pattern: str,
pattern: P,
value: str,
) -> bool:
"""Determine whether the value matches the pattern. Both
@ -222,10 +229,10 @@ class StringFieldQuery(FieldQuery):
raise NotImplementedError()
class StringQuery(StringFieldQuery):
class StringQuery(StringFieldQuery[str]):
"""A query that matches a whole string in a specific Model field."""
def col_clause(self) -> Tuple[str, List[str]]:
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
search = (self.pattern
.replace('\\', '\\\\')
.replace('%', '\\%')
@ -239,10 +246,10 @@ class StringQuery(StringFieldQuery):
return pattern.lower() == value.lower()
class SubstringQuery(StringFieldQuery):
class SubstringQuery(StringFieldQuery[str]):
"""A query that matches a substring in a specific Model field."""
def col_clause(self) -> Tuple[str, List[str]]:
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
pattern = (self.pattern
.replace('\\', '\\\\')
.replace('%', '\\%')
@ -257,7 +264,7 @@ class SubstringQuery(StringFieldQuery):
return pattern.lower() in value.lower()
class RegexpQuery(StringFieldQuery):
class RegexpQuery(StringFieldQuery[Pattern]):
"""A query that matches a regular expression in a specific Model field.
Raises InvalidQueryError when the pattern is not a valid regular
@ -276,7 +283,7 @@ class RegexpQuery(StringFieldQuery):
super().__init__(field, pattern_re, fast)
def col_clause(self):
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return f" regexp({self.field}, ?)", [self.pattern.pattern]
@staticmethod
@ -291,7 +298,7 @@ class RegexpQuery(StringFieldQuery):
return pattern.search(cls._normalize(value)) is not None
class BooleanQuery(MatchQuery):
class BooleanQuery(MatchQuery[int]):
"""Matches a boolean field. Pattern should either be a boolean or a
string reflecting a boolean.
"""
@ -299,16 +306,18 @@ class BooleanQuery(MatchQuery):
def __init__(
self,
field: str,
pattern: Union[bool, str],
pattern: bool,
fast: bool = True,
):
super().__init__(field, pattern, fast)
if isinstance(pattern, str):
self.pattern = util.str2bool(pattern)
self.pattern = int(self.pattern)
pattern = util.str2bool(pattern)
pattern_int = int(pattern)
super().__init__(field, pattern_int, fast)
class BytesQuery(MatchQuery):
class BytesQuery(FieldQuery[bytes]):
"""Match a raw bytes field (i.e., a path). This is a necessary hack
to work around the `sqlite3` module's desire to treat `bytes` and
`unicode` equivalently in Python 2. Always use this query instead of
@ -316,22 +325,30 @@ class BytesQuery(MatchQuery):
"""
def __init__(self, field: str, pattern: Union[bytes, str, memoryview]):
super().__init__(field, pattern)
# Use a buffer/memoryview representation of the pattern for SQLite
# matching. This instructs SQLite to treat the blob as binary
# rather than encoded Unicode.
if isinstance(self.pattern, (str, bytes)):
if isinstance(self.pattern, str):
self.pattern = self.pattern.encode('utf-8')
self.buf_pattern = memoryview(self.pattern)
if isinstance(pattern, (str, bytes)):
if isinstance(pattern, str):
bytes_pattern = pattern.encode('utf-8')
else:
bytes_pattern = pattern
self.buf_pattern = memoryview(bytes_pattern)
elif isinstance(self.pattern, memoryview):
self.buf_pattern = self.pattern
self.pattern = bytes(self.pattern)
bytes_pattern = bytes(self.pattern)
else:
raise ValueError("pattern must be bytes, str, or memoryview")
def col_clause(self) -> Tuple[str, List[memoryview]]:
super().__init__(field, bytes_pattern)
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.field + " = ?", [self.buf_pattern]
@classmethod
def value_match(cls, pattern: bytes, value: Any) -> bool:
return pattern == value
class NumericQuery(FieldQuery):
"""Matches numeric fields. A syntax using Ruby-style range ellipses
@ -390,7 +407,7 @@ class NumericQuery(FieldQuery):
return False
return True
def col_clause(self) -> Tuple[str, Tuple]:
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
if self.point is not None:
return self.field + '=?', (self.point,)
else:
@ -430,7 +447,7 @@ class CollectionQuery(Query):
def clause_with_joiner(
self,
joiner: str,
) -> Tuple[Optional[str], Collection]:
) -> Tuple[Optional[str], Sequence[SQLiteType]]:
"""Return a clause created by joining together the clauses of
all subqueries with the string joiner (padded by spaces).
"""
@ -477,7 +494,7 @@ class AnyFieldQuery(CollectionQuery):
# TYPING ERROR
super().__init__(subqueries)
def clause(self) -> Tuple[Union[str, None], Collection]:
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
return self.clause_with_joiner('or')
def match(self, obj: Model) -> bool:
@ -502,7 +519,7 @@ class MutableCollectionQuery(CollectionQuery):
"""A collection query whose subqueries may be modified after the
query is initialized.
"""
subqueries: MutableMapping
subqueries: MutableSequence
def __setitem__(self, key, value):
self.subqueries[key] = value
@ -514,7 +531,7 @@ class MutableCollectionQuery(CollectionQuery):
class AndQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
def clause(self) -> Tuple[Union[str, None], Collection]:
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
return self.clause_with_joiner('and')
def match(self, obj: Model) -> bool:
@ -524,7 +541,7 @@ class AndQuery(MutableCollectionQuery):
class OrQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
def clause(self) -> Tuple[Union[str, None], Collection]:
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
return self.clause_with_joiner('or')
def match(self, obj: Model) -> bool:
@ -539,7 +556,7 @@ class NotQuery(Query):
def __init__(self, subquery):
self.subquery = subquery
def clause(self):
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
clause, subvals = self.subquery.clause()
if clause:
return f'not ({clause})', subvals
@ -565,7 +582,7 @@ class NotQuery(Query):
class TrueQuery(Query):
"""A query that always matches."""
def clause(self) -> Tuple[Union[str, None], Collection]:
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return '1', ()
def match(self, obj: Model) -> bool:
@ -575,7 +592,7 @@ class TrueQuery(Query):
class FalseQuery(Query):
"""A query that never matches."""
def clause(self) -> Tuple[Union[str, None], Collection]:
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return '0', ()
def match(self, obj: Model) -> bool:
@ -584,7 +601,7 @@ class FalseQuery(Query):
# Time/date queries.
def _parse_periods(pattern: str) -> Tuple['Period', 'Period']:
def _parse_periods(pattern: str) -> Tuple[Optional[Period], Optional[Period]]:
"""Parse a string containing two dates separated by two dots (..).
Return a pair of `Period` objects.
"""
@ -661,6 +678,8 @@ class Period:
if not string:
return None
date: Optional[datetime]
# Check for a relative date.
match_dq = re.match(cls.relative_re, string)
if match_dq:
@ -678,7 +697,7 @@ class Period:
# Check for an absolute date.
date, ordinal = find_date_and_format(string)
if date is None:
if date is None or ordinal is None:
raise InvalidQueryArgumentValueError(string,
'a valid date/time string')
precision = cls.precisions[ordinal]
@ -724,7 +743,11 @@ class DateInterval:
self.end = end
@classmethod
def from_periods(cls, start: Period, end: Period) -> 'DateInterval':
def from_periods(
cls,
start: Optional[Period],
end: Optional[Period],
) -> DateInterval:
"""Create an interval with two Periods as the endpoints.
"""
end_date = end.open_right_endpoint() if end is not None else None
@ -766,7 +789,7 @@ class DateQuery(FieldQuery):
_clause_tmpl = "{0} {1} ?"
def col_clause(self) -> Tuple[Union[str, None], Collection]:
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
clause_parts = []
subvals = []
@ -824,7 +847,7 @@ class Sort:
the database.
"""
def order_clause(self) -> None:
def order_clause(self) -> Optional[str]:
"""Generates a SQL fragment to be used in a ORDER BY clause, or
None if no fragment is used (i.e., this is a slow sort).
"""
@ -858,30 +881,22 @@ class MultipleSort(Sort):
def add_sort(self, sort: Sort):
self.sorts.append(sort)
def _sql_sorts(self) -> List[Sort]:
"""Return the list of sub-sorts for which we can be (at least
partially) fast.
def order_clause(self) -> str:
"""Return the list SQL clauses for those sub-sorts for which we can be
(at least partially) fast.
A contiguous suffix of fast (SQL-capable) sub-sorts are
executable in SQL. The remaining, even if they are fast
independently, must be executed slowly.
"""
sql_sorts = []
for sort in reversed(self.sorts):
if not sort.order_clause() is None:
sql_sorts.append(sort)
else:
break
sql_sorts.reverse()
return sql_sorts
def order_clause(self) -> str:
order_strings = []
for sort in self._sql_sorts():
order = sort.order_clause()
order_strings.append(order)
for sort in reversed(self.sorts):
clause = sort.order_clause()
if clause is None:
break
order_strings.append(clause)
order_strings.reverse()
# TYPING ERROR
return ", ".join(order_strings)
def is_slow(self) -> bool: