Add typing to queries modules

This commit is contained in:
Serene-Arc 2022-09-15 16:02:46 +10:00
parent 0ee3342257
commit e29337d4e6
2 changed files with 155 additions and 104 deletions

View file

@ -17,11 +17,16 @@
import re
from operator import mul
from typing import Union, Tuple, List, Optional, Pattern, Any, Type, Iterator, \
Collection, Mapping
from beets import util
from datetime import datetime, timedelta
import unicodedata
from functools import reduce
from beets.library import Item
class ParsingError(ValueError):
"""Abstract class for any unparseable user-requested album/query
@ -60,7 +65,7 @@ class Query:
"""An abstract class representing a query into the item database.
"""
def clause(self):
def clause(self) -> Union[None, Tuple]:
"""Generate an SQLite expression implementing the query.
Return (clause, subvals) where clause is a valid sqlite
@ -69,19 +74,19 @@ class Query:
"""
return None, ()
def match(self, item):
def match(self, item: Item):
"""Check whether this query matches a given Item. Can be used to
perform queries on arbitrary sets of Items.
"""
raise NotImplementedError
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
def __eq__(self, other):
def __eq__(self, other) -> bool:
return type(self) == type(other)
def __hash__(self):
def __hash__(self) -> int:
return 0
@ -93,12 +98,12 @@ class FieldQuery(Query):
same matching functionality in SQLite.
"""
def __init__(self, field, pattern, fast=True):
def __init__(self, field: str, pattern: Optional[str], fast: bool = True):
self.field = field
self.pattern = pattern
self.fast = fast
def col_clause(self):
def col_clause(self) -> Union[None, Tuple]:
return None, ()
def clause(self):
@ -109,51 +114,51 @@ class FieldQuery(Query):
return None, ()
@classmethod
def value_match(cls, pattern, value):
def value_match(cls, pattern: str, value: str):
"""Determine whether the value matches the pattern. Both
arguments are strings.
"""
raise NotImplementedError()
def match(self, item):
def match(self, item: Item):
return self.value_match(self.pattern, item.get(self.field))
def __repr__(self):
def __repr__(self) -> str:
return ("{0.__class__.__name__}({0.field!r}, {0.pattern!r}, "
"{0.fast})".format(self))
def __eq__(self, other):
def __eq__(self, other) -> bool:
return super().__eq__(other) and \
self.field == other.field and self.pattern == other.pattern
def __hash__(self):
def __hash__(self) -> int:
return hash((self.field, hash(self.pattern)))
class MatchQuery(FieldQuery):
"""A query that looks for exact matches in an item field."""
def col_clause(self):
def col_clause(self) -> Tuple[str, List[str]]:
return self.field + " = ?", [self.pattern]
@classmethod
def value_match(cls, pattern, value):
def value_match(cls, pattern: str, value: str) -> bool:
return pattern == value
class NoneQuery(FieldQuery):
"""A query that checks whether a field is null."""
def __init__(self, field, fast=True):
def __init__(self, field, fast: bool = True):
super().__init__(field, None, fast)
def col_clause(self):
def col_clause(self) -> Tuple[str, Tuple]:
return self.field + " IS NULL", ()
def match(self, item):
def match(self, item: Item) -> bool:
return item.get(self.field) is None
def __repr__(self):
def __repr__(self) -> str:
return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self)
@ -163,14 +168,18 @@ class StringFieldQuery(FieldQuery):
"""
@classmethod
def value_match(cls, pattern, value):
def value_match(cls: Type['StringFieldQuery'], pattern: str, value: Any):
"""Determine whether the value matches the pattern. The value
may have any type.
"""
return cls.string_match(pattern, util.as_string(value))
@classmethod
def string_match(cls, pattern, value):
def string_match(
cls: Type['StringFieldQuery'],
pattern: str,
value: str,
) -> bool:
"""Determine whether the value matches the pattern. Both
arguments are strings. Subclasses implement this method.
"""
@ -180,7 +189,7 @@ class StringFieldQuery(FieldQuery):
class StringQuery(StringFieldQuery):
"""A query that matches a whole string in a specific item field."""
def col_clause(self):
def col_clause(self) -> Tuple[str, List[str]]:
search = (self.pattern
.replace('\\', '\\\\')
.replace('%', '\\%')
@ -190,14 +199,14 @@ class StringQuery(StringFieldQuery):
return clause, subvals
@classmethod
def string_match(cls, pattern, value):
def string_match(cls, pattern: str, value: str) -> bool:
return pattern.lower() == value.lower()
class SubstringQuery(StringFieldQuery):
"""A query that matches a substring in a specific item field."""
def col_clause(self):
def col_clause(self) -> Tuple[str, List[str]]:
pattern = (self.pattern
.replace('\\', '\\\\')
.replace('%', '\\%')
@ -208,7 +217,7 @@ class SubstringQuery(StringFieldQuery):
return clause, subvals
@classmethod
def string_match(cls, pattern, value):
def string_match(cls, pattern: str, value: str) -> bool:
return pattern.lower() in value.lower()
@ -220,7 +229,7 @@ class RegexpQuery(StringFieldQuery):
expression.
"""
def __init__(self, field, pattern, fast=True):
def __init__(self, field: str, pattern: str, fast: bool = True):
super().__init__(field, pattern, fast)
pattern = self._normalize(pattern)
try:
@ -235,14 +244,14 @@ class RegexpQuery(StringFieldQuery):
return f" regexp({self.field}, ?)", [self.pattern.pattern]
@staticmethod
def _normalize(s):
def _normalize(s: str) -> str:
"""Normalize a Unicode string's representation (used on both
patterns and matched values).
"""
return unicodedata.normalize('NFC', s)
@classmethod
def string_match(cls, pattern, value):
def string_match(cls, pattern: Pattern, value: str) -> bool:
return pattern.search(cls._normalize(value)) is not None
@ -251,7 +260,12 @@ class BooleanQuery(MatchQuery):
string reflecting a boolean.
"""
def __init__(self, field, pattern, fast=True):
def __init__(
self,
field: str,
pattern: Union[bool, str],
fast: bool = True,
):
super().__init__(field, pattern, fast)
if isinstance(pattern, str):
self.pattern = util.str2bool(pattern)
@ -265,7 +279,7 @@ class BytesQuery(MatchQuery):
`MatchQuery` when matching on BLOB values.
"""
def __init__(self, field, pattern):
def __init__(self, field: str, pattern: Union[bytes, str, memoryview]):
super().__init__(field, pattern)
# Use a buffer/memoryview representation of the pattern for SQLite
@ -279,7 +293,7 @@ class BytesQuery(MatchQuery):
self.buf_pattern = self.pattern
self.pattern = bytes(self.pattern)
def col_clause(self):
def col_clause(self) -> Tuple[str, List[memoryview]]:
return self.field + " = ?", [self.buf_pattern]
@ -292,7 +306,7 @@ class NumericQuery(FieldQuery):
a float.
"""
def _convert(self, s):
def _convert(self, s: str) -> Union[float, int, None]:
"""Convert a string to a numeric type (float or int).
Return None if `s` is empty.
@ -309,7 +323,7 @@ class NumericQuery(FieldQuery):
except ValueError:
raise InvalidQueryArgumentValueError(s, "an int or a float")
def __init__(self, field, pattern, fast=True):
def __init__(self, field: str, pattern: str, fast: bool = True):
super().__init__(field, pattern, fast)
parts = pattern.split('..', 1)
@ -324,7 +338,7 @@ class NumericQuery(FieldQuery):
self.rangemin = self._convert(parts[0])
self.rangemax = self._convert(parts[1])
def match(self, item):
def match(self, item: Item) -> bool:
if self.field not in item:
return False
value = item[self.field]
@ -340,7 +354,7 @@ class NumericQuery(FieldQuery):
return False
return True
def col_clause(self):
def col_clause(self) -> Tuple[str, Tuple]:
if self.point is not None:
return self.field + '=?', (self.point,)
else:
@ -360,24 +374,27 @@ class CollectionQuery(Query):
indexed like a list to access the sub-queries.
"""
def __init__(self, subqueries=()):
def __init__(self, subqueries: Mapping = ()):
self.subqueries = subqueries
# Act like a sequence.
def __len__(self):
def __len__(self) -> int:
return len(self.subqueries)
def __getitem__(self, key):
return self.subqueries[key]
def __iter__(self):
def __iter__(self) -> Iterator:
return iter(self.subqueries)
def __contains__(self, item):
def __contains__(self, item) -> bool:
return item in self.subqueries
def clause_with_joiner(self, joiner):
def clause_with_joiner(
self,
joiner: str,
) -> Tuple[Optional[str], Collection]:
"""Return a clause created by joining together the clauses of
all subqueries with the string joiner (padded by spaces).
"""
@ -393,14 +410,14 @@ class CollectionQuery(Query):
clause = (' ' + joiner + ' ').join(clause_parts)
return clause, subvals
def __repr__(self):
def __repr__(self) -> str:
return "{0.__class__.__name__}({0.subqueries!r})".format(self)
def __eq__(self, other):
def __eq__(self, other) -> bool:
return super().__eq__(other) and \
self.subqueries == other.subqueries
def __hash__(self):
def __hash__(self) -> int:
"""Since subqueries are mutable, this object should not be hashable.
However and for conveniences purposes, it can be hashed.
"""
@ -413,7 +430,7 @@ class AnyFieldQuery(CollectionQuery):
constructor.
"""
def __init__(self, pattern, fields, cls):
def __init__(self, pattern, fields, cls: Type[FieldQuery]):
self.pattern = pattern
self.fields = fields
self.query_class = cls
@ -423,24 +440,24 @@ class AnyFieldQuery(CollectionQuery):
subqueries.append(cls(field, pattern, True))
super().__init__(subqueries)
def clause(self):
def clause(self) -> Tuple[str | None, Collection]:
return self.clause_with_joiner('or')
def match(self, item):
def match(self, item: Item) -> bool:
for subq in self.subqueries:
if subq.match(item):
return True
return False
def __repr__(self):
def __repr__(self) -> str:
return ("{0.__class__.__name__}({0.pattern!r}, {0.fields!r}, "
"{0.query_class.__name__})".format(self))
def __eq__(self, other):
def __eq__(self, other) -> bool:
return super().__eq__(other) and \
self.query_class == other.query_class
def __hash__(self):
def __hash__(self) -> int:
return hash((self.pattern, tuple(self.fields), self.query_class))
@ -459,20 +476,20 @@ class MutableCollectionQuery(CollectionQuery):
class AndQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
def clause(self):
def clause(self) -> Tuple[str | None, Collection]:
return self.clause_with_joiner('and')
def match(self, item):
def match(self, item) -> bool:
return all(q.match(item) for q in self.subqueries)
class OrQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
def clause(self):
def clause(self) -> Tuple[str | None, Collection]:
return self.clause_with_joiner('or')
def match(self, item):
def match(self, item) -> bool:
return any(q.match(item) for q in self.subqueries)
@ -493,43 +510,43 @@ class NotQuery(Query):
# is handled by match() for slow queries.
return clause, subvals
def match(self, item):
def match(self, item) -> bool:
return not self.subquery.match(item)
def __repr__(self):
def __repr__(self) -> str:
return "{0.__class__.__name__}({0.subquery!r})".format(self)
def __eq__(self, other):
def __eq__(self, other) -> bool:
return super().__eq__(other) and \
self.subquery == other.subquery
def __hash__(self):
def __hash__(self) -> int:
return hash(('not', hash(self.subquery)))
class TrueQuery(Query):
"""A query that always matches."""
def clause(self):
def clause(self) -> Tuple[str | None, Collection]:
return '1', ()
def match(self, item):
def match(self, item) -> bool:
return True
class FalseQuery(Query):
"""A query that never matches."""
def clause(self):
def clause(self) -> Tuple[str | None, Collection]:
return '0', ()
def match(self, item):
def match(self, item) -> bool:
return False
# Time/date queries.
def _parse_periods(pattern):
def _parse_periods(pattern: str) -> Tuple['Period', 'Period']:
"""Parse a string containing two dates separated by two dots (..).
Return a pair of `Period` objects.
"""
@ -563,7 +580,7 @@ class Period:
relative_re = '(?P<sign>[+|-]?)(?P<quantity>[0-9]+)' + \
'(?P<timespan>[y|m|w|d])'
def __init__(self, date, precision):
def __init__(self, date: datetime, precision: str):
"""Create a period with the given date (a `datetime` object) and
precision (a string, one of "year", "month", "day", "hour", "minute",
or "second").
@ -574,7 +591,7 @@ class Period:
self.precision = precision
@classmethod
def parse(cls, string):
def parse(cls: Type['Period'], string: str) -> Optional['Period']:
"""Parse a date and return a `Period` object or `None` if the
string is empty, or raise an InvalidQueryArgumentValueError if
the string cannot be parsed to a date.
@ -591,7 +608,8 @@ class Period:
and a "year" is exactly 365 days.
"""
def find_date_and_format(string):
def find_date_and_format(string: str) -> \
Union[Tuple[None, None], Tuple[datetime, int]]:
for ord, format in enumerate(cls.date_formats):
for format_option in format:
try:
@ -628,7 +646,7 @@ class Period:
precision = cls.precisions[ordinal]
return cls(date, precision)
def open_right_endpoint(self):
def open_right_endpoint(self) -> datetime:
"""Based on the precision, convert the period to a precise
`datetime` for use as a right endpoint in a right-open interval.
"""
@ -660,7 +678,7 @@ class DateInterval:
A right endpoint of None means towards infinity.
"""
def __init__(self, start, end):
def __init__(self, start: Optional[datetime], end: Optional[datetime]):
if start is not None and end is not None and not start < end:
raise ValueError("start date {} is not before end date {}"
.format(start, end))
@ -668,21 +686,21 @@ class DateInterval:
self.end = end
@classmethod
def from_periods(cls, start, end):
def from_periods(cls, start: Period, end: Period) -> 'DateInterval':
"""Create an interval with two Periods as the endpoints.
"""
end_date = end.open_right_endpoint() if end is not None else None
start_date = start.date if start is not None else None
return cls(start_date, end_date)
def contains(self, date):
def contains(self, date: datetime) -> bool:
if self.start is not None and date < self.start:
return False
if self.end is not None and date >= self.end:
return False
return True
def __str__(self):
def __str__(self) -> str:
return f'[{self.start}, {self.end})'
@ -696,12 +714,12 @@ class DateQuery(FieldQuery):
using an ellipsis interval syntax similar to that of NumericQuery.
"""
def __init__(self, field, pattern, fast=True):
def __init__(self, field, pattern, fast: bool = True):
super().__init__(field, pattern, fast)
start, end = _parse_periods(pattern)
self.interval = DateInterval.from_periods(start, end)
def match(self, item):
def match(self, item: Item) -> bool:
if self.field not in item:
return False
timestamp = float(item[self.field])
@ -710,7 +728,7 @@ class DateQuery(FieldQuery):
_clause_tmpl = "{0} {1} ?"
def col_clause(self):
def col_clause(self) -> Tuple[str | None, Collection]:
clause_parts = []
subvals = []
@ -742,7 +760,7 @@ class DurationQuery(NumericQuery):
or M:SS time interval.
"""
def _convert(self, s):
def _convert(self, s: str) -> Optional[float]:
"""Convert a M:SS or numeric string to a float.
Return None if `s` is empty.
@ -774,21 +792,21 @@ class Sort:
"""
return None
def sort(self, items):
def sort(self, items: List) -> List:
"""Sort the list of objects and return a list.
"""
return sorted(items)
def is_slow(self):
def is_slow(self) -> bool:
"""Indicate whether this query is *slow*, meaning that it cannot
be executed in SQL and must be executed in Python.
"""
return False
def __hash__(self):
def __hash__(self) -> int:
return 0
def __eq__(self, other):
def __eq__(self, other) -> bool:
return type(self) == type(other)
@ -796,13 +814,13 @@ class MultipleSort(Sort):
"""Sort that encapsulates multiple sub-sorts.
"""
def __init__(self, sorts=None):
def __init__(self, sorts: Optional[List[Sort]] = None):
self.sorts = sorts or []
def add_sort(self, sort):
def add_sort(self, sort: Sort):
self.sorts.append(sort)
def _sql_sorts(self):
def _sql_sorts(self) -> List[Sort]:
"""Return the list of sub-sorts for which we can be (at least
partially) fast.
@ -819,7 +837,7 @@ class MultipleSort(Sort):
sql_sorts.reverse()
return sql_sorts
def order_clause(self):
def order_clause(self) -> str:
order_strings = []
for sort in self._sql_sorts():
order = sort.order_clause()
@ -827,7 +845,7 @@ class MultipleSort(Sort):
return ", ".join(order_strings)
def is_slow(self):
def is_slow(self) -> bool:
for sort in self.sorts:
if sort.is_slow():
return True
@ -865,17 +883,22 @@ class FieldSort(Sort):
any kind).
"""
def __init__(self, field, ascending=True, case_insensitive=True):
def __init__(
self,
field,
ascending: bool = True,
case_insensitive: bool = True,
):
self.field = field
self.ascending = ascending
self.case_insensitive = case_insensitive
def sort(self, objs):
def sort(self, objs: Collection):
# TODO: Conversion and null-detection here. In Python 3,
# comparisons with None fail. We should also support flexible
# attributes with different types without falling over.
def key(item):
def key(item: Item):
field_val = item.get(self.field, '')
if self.case_insensitive and isinstance(field_val, str):
field_val = field_val.lower()
@ -883,17 +906,17 @@ class FieldSort(Sort):
return sorted(objs, key=key, reverse=not self.ascending)
def __repr__(self):
def __repr__(self) -> str:
return '<{}: {}{}>'.format(
type(self).__name__,
self.field,
'+' if self.ascending else '-',
)
def __hash__(self):
def __hash__(self) -> int:
return hash((self.field, self.ascending))
def __eq__(self, other):
def __eq__(self, other) -> bool:
return super().__eq__(other) and \
self.field == other.field and \
self.ascending == other.ascending
@ -903,7 +926,7 @@ class FixedFieldSort(FieldSort):
"""Sort object to sort on a fixed field.
"""
def order_clause(self):
def order_clause(self) -> str:
order = "ASC" if self.ascending else "DESC"
if self.case_insensitive:
field = '(CASE ' \
@ -920,24 +943,24 @@ class SlowFieldSort(FieldSort):
i.e., a computed or flexible field.
"""
def is_slow(self):
def is_slow(self) -> bool:
return True
class NullSort(Sort):
"""No sorting. Leave results unsorted."""
def sort(self, items):
def sort(self, items: List) -> List:
return items
def __nonzero__(self):
def __nonzero__(self) -> bool:
return self.__bool__()
def __bool__(self):
def __bool__(self) -> bool:
return False
def __eq__(self, other):
def __eq__(self, other) -> bool:
return type(self) == type(other) or other is None
def __hash__(self):
def __hash__(self) -> int:
return 0

View file

@ -17,7 +17,10 @@
import re
import itertools
from . import query
from typing import Dict, Type, Tuple, Optional, Mapping, Collection, List
from . import query, Model
from .query import Sort
PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
@ -34,8 +37,12 @@ PARSE_QUERY_PART_REGEX = re.compile(
)
def parse_query_part(part, query_classes={}, prefixes={},
default_class=query.SubstringQuery):
def parse_query_part(
part: str,
query_classes: Dict = {},
prefixes: Dict = {},
default_class: Type[query.SubstringQuery] = query.SubstringQuery,
) -> Tuple[Optional[str], str, Type[query.Query], bool]:
"""Parse a single *query part*, which is a chunk of a complete query
string representing a single criterion.
@ -100,7 +107,11 @@ def parse_query_part(part, query_classes={}, prefixes={},
return key, term, query_class, negate
def construct_query_part(model_cls, prefixes, query_part):
def construct_query_part(
model_cls: Type[Model],
prefixes: Mapping[str, Type[query.Query]],
query_part: str,
) -> query.Query:
"""Parse a *query part* string and return a :class:`Query` object.
:param model_cls: The :class:`Model` class that this is a query for.
@ -158,7 +169,12 @@ def construct_query_part(model_cls, prefixes, query_part):
return out_query
def query_from_strings(query_cls, model_cls, prefixes, query_parts):
def query_from_strings(
query_cls: Type[query.Query],
model_cls: Type[Model],
prefixes: Mapping[str, Type[query.Query]],
query_parts: Collection[str],
) -> query.Query:
"""Creates a collection query of type `query_cls` from a list of
strings in the format used by parse_query_part. `model_cls`
determines how queries are constructed from strings.
@ -171,7 +187,11 @@ def query_from_strings(query_cls, model_cls, prefixes, query_parts):
return query_cls(subqueries)
def construct_sort_part(model_cls, part, case_insensitive=True):
def construct_sort_part(
model_cls: Type[Model],
part: str,
case_insensitive: bool = True,
) -> Sort:
"""Create a `Sort` from a single string criterion.
`model_cls` is the `Model` being queried. `part` is a single string
@ -197,7 +217,11 @@ def construct_sort_part(model_cls, part, case_insensitive=True):
return sort
def sort_from_strings(model_cls, sort_parts, case_insensitive=True):
def sort_from_strings(
model_cls: Type[Model],
sort_parts: Collection[str],
case_insensitive: bool = True,
) -> Sort:
"""Create a `Sort` from a list of sort criteria (strings).
"""
if not sort_parts:
@ -212,8 +236,12 @@ def sort_from_strings(model_cls, sort_parts, case_insensitive=True):
return sort
def parse_sorted_query(model_cls, parts, prefixes={},
case_insensitive=True):
def parse_sorted_query(
model_cls: Type[Model],
parts: List[str],
prefixes: Mapping[str, Type[query.Query]] = {},
case_insensitive: bool = True,
) -> Tuple[query.Query, Sort]:
"""Given a list of strings, create the `Query` and `Sort` that they
represent.
"""