mirror of
https://github.com/beetbox/beets.git
synced 2025-12-14 12:35:19 +01:00
Add typing to queries modules
This commit is contained in:
parent
0ee3342257
commit
e29337d4e6
2 changed files with 155 additions and 104 deletions
|
|
@ -17,11 +17,16 @@
|
|||
|
||||
import re
|
||||
from operator import mul
|
||||
from typing import Union, Tuple, List, Optional, Pattern, Any, Type, Iterator, \
|
||||
Collection, Mapping
|
||||
|
||||
from beets import util
|
||||
from datetime import datetime, timedelta
|
||||
import unicodedata
|
||||
from functools import reduce
|
||||
|
||||
from beets.library import Item
|
||||
|
||||
|
||||
class ParsingError(ValueError):
|
||||
"""Abstract class for any unparseable user-requested album/query
|
||||
|
|
@ -60,7 +65,7 @@ class Query:
|
|||
"""An abstract class representing a query into the item database.
|
||||
"""
|
||||
|
||||
def clause(self):
|
||||
def clause(self) -> Union[None, Tuple]:
|
||||
"""Generate an SQLite expression implementing the query.
|
||||
|
||||
Return (clause, subvals) where clause is a valid sqlite
|
||||
|
|
@ -69,19 +74,19 @@ class Query:
|
|||
"""
|
||||
return None, ()
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item: Item):
|
||||
"""Check whether this query matches a given Item. Can be used to
|
||||
perform queries on arbitrary sets of Items.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return type(self) == type(other)
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
|
|
@ -93,12 +98,12 @@ class FieldQuery(Query):
|
|||
same matching functionality in SQLite.
|
||||
"""
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
def __init__(self, field: str, pattern: Optional[str], fast: bool = True):
|
||||
self.field = field
|
||||
self.pattern = pattern
|
||||
self.fast = fast
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Union[None, Tuple]:
|
||||
return None, ()
|
||||
|
||||
def clause(self):
|
||||
|
|
@ -109,51 +114,51 @@ class FieldQuery(Query):
|
|||
return None, ()
|
||||
|
||||
@classmethod
|
||||
def value_match(cls, pattern, value):
|
||||
def value_match(cls, pattern: str, value: str):
|
||||
"""Determine whether the value matches the pattern. Both
|
||||
arguments are strings.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item: Item):
|
||||
return self.value_match(self.pattern, item.get(self.field))
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return ("{0.__class__.__name__}({0.field!r}, {0.pattern!r}, "
|
||||
"{0.fast})".format(self))
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return super().__eq__(other) and \
|
||||
self.field == other.field and self.pattern == other.pattern
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.field, hash(self.pattern)))
|
||||
|
||||
|
||||
class MatchQuery(FieldQuery):
|
||||
"""A query that looks for exact matches in an item field."""
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Tuple[str, List[str]]:
|
||||
return self.field + " = ?", [self.pattern]
|
||||
|
||||
@classmethod
|
||||
def value_match(cls, pattern, value):
|
||||
def value_match(cls, pattern: str, value: str) -> bool:
|
||||
return pattern == value
|
||||
|
||||
|
||||
class NoneQuery(FieldQuery):
|
||||
"""A query that checks whether a field is null."""
|
||||
|
||||
def __init__(self, field, fast=True):
|
||||
def __init__(self, field, fast: bool = True):
|
||||
super().__init__(field, None, fast)
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Tuple[str, Tuple]:
|
||||
return self.field + " IS NULL", ()
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item: Item) -> bool:
|
||||
return item.get(self.field) is None
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self)
|
||||
|
||||
|
||||
|
|
@ -163,14 +168,18 @@ class StringFieldQuery(FieldQuery):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def value_match(cls, pattern, value):
|
||||
def value_match(cls: Type['StringFieldQuery'], pattern: str, value: Any):
|
||||
"""Determine whether the value matches the pattern. The value
|
||||
may have any type.
|
||||
"""
|
||||
return cls.string_match(pattern, util.as_string(value))
|
||||
|
||||
@classmethod
|
||||
def string_match(cls, pattern, value):
|
||||
def string_match(
|
||||
cls: Type['StringFieldQuery'],
|
||||
pattern: str,
|
||||
value: str,
|
||||
) -> bool:
|
||||
"""Determine whether the value matches the pattern. Both
|
||||
arguments are strings. Subclasses implement this method.
|
||||
"""
|
||||
|
|
@ -180,7 +189,7 @@ class StringFieldQuery(FieldQuery):
|
|||
class StringQuery(StringFieldQuery):
|
||||
"""A query that matches a whole string in a specific item field."""
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Tuple[str, List[str]]:
|
||||
search = (self.pattern
|
||||
.replace('\\', '\\\\')
|
||||
.replace('%', '\\%')
|
||||
|
|
@ -190,14 +199,14 @@ class StringQuery(StringFieldQuery):
|
|||
return clause, subvals
|
||||
|
||||
@classmethod
|
||||
def string_match(cls, pattern, value):
|
||||
def string_match(cls, pattern: str, value: str) -> bool:
|
||||
return pattern.lower() == value.lower()
|
||||
|
||||
|
||||
class SubstringQuery(StringFieldQuery):
|
||||
"""A query that matches a substring in a specific item field."""
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Tuple[str, List[str]]:
|
||||
pattern = (self.pattern
|
||||
.replace('\\', '\\\\')
|
||||
.replace('%', '\\%')
|
||||
|
|
@ -208,7 +217,7 @@ class SubstringQuery(StringFieldQuery):
|
|||
return clause, subvals
|
||||
|
||||
@classmethod
|
||||
def string_match(cls, pattern, value):
|
||||
def string_match(cls, pattern: str, value: str) -> bool:
|
||||
return pattern.lower() in value.lower()
|
||||
|
||||
|
||||
|
|
@ -220,7 +229,7 @@ class RegexpQuery(StringFieldQuery):
|
|||
expression.
|
||||
"""
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
def __init__(self, field: str, pattern: str, fast: bool = True):
|
||||
super().__init__(field, pattern, fast)
|
||||
pattern = self._normalize(pattern)
|
||||
try:
|
||||
|
|
@ -235,14 +244,14 @@ class RegexpQuery(StringFieldQuery):
|
|||
return f" regexp({self.field}, ?)", [self.pattern.pattern]
|
||||
|
||||
@staticmethod
|
||||
def _normalize(s):
|
||||
def _normalize(s: str) -> str:
|
||||
"""Normalize a Unicode string's representation (used on both
|
||||
patterns and matched values).
|
||||
"""
|
||||
return unicodedata.normalize('NFC', s)
|
||||
|
||||
@classmethod
|
||||
def string_match(cls, pattern, value):
|
||||
def string_match(cls, pattern: Pattern, value: str) -> bool:
|
||||
return pattern.search(cls._normalize(value)) is not None
|
||||
|
||||
|
||||
|
|
@ -251,7 +260,12 @@ class BooleanQuery(MatchQuery):
|
|||
string reflecting a boolean.
|
||||
"""
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
def __init__(
|
||||
self,
|
||||
field: str,
|
||||
pattern: Union[bool, str],
|
||||
fast: bool = True,
|
||||
):
|
||||
super().__init__(field, pattern, fast)
|
||||
if isinstance(pattern, str):
|
||||
self.pattern = util.str2bool(pattern)
|
||||
|
|
@ -265,7 +279,7 @@ class BytesQuery(MatchQuery):
|
|||
`MatchQuery` when matching on BLOB values.
|
||||
"""
|
||||
|
||||
def __init__(self, field, pattern):
|
||||
def __init__(self, field: str, pattern: Union[bytes, str, memoryview]):
|
||||
super().__init__(field, pattern)
|
||||
|
||||
# Use a buffer/memoryview representation of the pattern for SQLite
|
||||
|
|
@ -279,7 +293,7 @@ class BytesQuery(MatchQuery):
|
|||
self.buf_pattern = self.pattern
|
||||
self.pattern = bytes(self.pattern)
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Tuple[str, List[memoryview]]:
|
||||
return self.field + " = ?", [self.buf_pattern]
|
||||
|
||||
|
||||
|
|
@ -292,7 +306,7 @@ class NumericQuery(FieldQuery):
|
|||
a float.
|
||||
"""
|
||||
|
||||
def _convert(self, s):
|
||||
def _convert(self, s: str) -> Union[float, int, None]:
|
||||
"""Convert a string to a numeric type (float or int).
|
||||
|
||||
Return None if `s` is empty.
|
||||
|
|
@ -309,7 +323,7 @@ class NumericQuery(FieldQuery):
|
|||
except ValueError:
|
||||
raise InvalidQueryArgumentValueError(s, "an int or a float")
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
def __init__(self, field: str, pattern: str, fast: bool = True):
|
||||
super().__init__(field, pattern, fast)
|
||||
|
||||
parts = pattern.split('..', 1)
|
||||
|
|
@ -324,7 +338,7 @@ class NumericQuery(FieldQuery):
|
|||
self.rangemin = self._convert(parts[0])
|
||||
self.rangemax = self._convert(parts[1])
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item: Item) -> bool:
|
||||
if self.field not in item:
|
||||
return False
|
||||
value = item[self.field]
|
||||
|
|
@ -340,7 +354,7 @@ class NumericQuery(FieldQuery):
|
|||
return False
|
||||
return True
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Tuple[str, Tuple]:
|
||||
if self.point is not None:
|
||||
return self.field + '=?', (self.point,)
|
||||
else:
|
||||
|
|
@ -360,24 +374,27 @@ class CollectionQuery(Query):
|
|||
indexed like a list to access the sub-queries.
|
||||
"""
|
||||
|
||||
def __init__(self, subqueries=()):
|
||||
def __init__(self, subqueries: Mapping = ()):
|
||||
self.subqueries = subqueries
|
||||
|
||||
# Act like a sequence.
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.subqueries)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.subqueries[key]
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator:
|
||||
return iter(self.subqueries)
|
||||
|
||||
def __contains__(self, item):
|
||||
def __contains__(self, item) -> bool:
|
||||
return item in self.subqueries
|
||||
|
||||
def clause_with_joiner(self, joiner):
|
||||
def clause_with_joiner(
|
||||
self,
|
||||
joiner: str,
|
||||
) -> Tuple[Optional[str], Collection]:
|
||||
"""Return a clause created by joining together the clauses of
|
||||
all subqueries with the string joiner (padded by spaces).
|
||||
"""
|
||||
|
|
@ -393,14 +410,14 @@ class CollectionQuery(Query):
|
|||
clause = (' ' + joiner + ' ').join(clause_parts)
|
||||
return clause, subvals
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "{0.__class__.__name__}({0.subqueries!r})".format(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return super().__eq__(other) and \
|
||||
self.subqueries == other.subqueries
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
"""Since subqueries are mutable, this object should not be hashable.
|
||||
However and for conveniences purposes, it can be hashed.
|
||||
"""
|
||||
|
|
@ -413,7 +430,7 @@ class AnyFieldQuery(CollectionQuery):
|
|||
constructor.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern, fields, cls):
|
||||
def __init__(self, pattern, fields, cls: Type[FieldQuery]):
|
||||
self.pattern = pattern
|
||||
self.fields = fields
|
||||
self.query_class = cls
|
||||
|
|
@ -423,24 +440,24 @@ class AnyFieldQuery(CollectionQuery):
|
|||
subqueries.append(cls(field, pattern, True))
|
||||
super().__init__(subqueries)
|
||||
|
||||
def clause(self):
|
||||
def clause(self) -> Tuple[str | None, Collection]:
|
||||
return self.clause_with_joiner('or')
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item: Item) -> bool:
|
||||
for subq in self.subqueries:
|
||||
if subq.match(item):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return ("{0.__class__.__name__}({0.pattern!r}, {0.fields!r}, "
|
||||
"{0.query_class.__name__})".format(self))
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return super().__eq__(other) and \
|
||||
self.query_class == other.query_class
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.pattern, tuple(self.fields), self.query_class))
|
||||
|
||||
|
||||
|
|
@ -459,20 +476,20 @@ class MutableCollectionQuery(CollectionQuery):
|
|||
class AndQuery(MutableCollectionQuery):
|
||||
"""A conjunction of a list of other queries."""
|
||||
|
||||
def clause(self):
|
||||
def clause(self) -> Tuple[str | None, Collection]:
|
||||
return self.clause_with_joiner('and')
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item) -> bool:
|
||||
return all(q.match(item) for q in self.subqueries)
|
||||
|
||||
|
||||
class OrQuery(MutableCollectionQuery):
|
||||
"""A conjunction of a list of other queries."""
|
||||
|
||||
def clause(self):
|
||||
def clause(self) -> Tuple[str | None, Collection]:
|
||||
return self.clause_with_joiner('or')
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item) -> bool:
|
||||
return any(q.match(item) for q in self.subqueries)
|
||||
|
||||
|
||||
|
|
@ -493,43 +510,43 @@ class NotQuery(Query):
|
|||
# is handled by match() for slow queries.
|
||||
return clause, subvals
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item) -> bool:
|
||||
return not self.subquery.match(item)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "{0.__class__.__name__}({0.subquery!r})".format(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return super().__eq__(other) and \
|
||||
self.subquery == other.subquery
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(('not', hash(self.subquery)))
|
||||
|
||||
|
||||
class TrueQuery(Query):
|
||||
"""A query that always matches."""
|
||||
|
||||
def clause(self):
|
||||
def clause(self) -> Tuple[str | None, Collection]:
|
||||
return '1', ()
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class FalseQuery(Query):
|
||||
"""A query that never matches."""
|
||||
|
||||
def clause(self):
|
||||
def clause(self) -> Tuple[str | None, Collection]:
|
||||
return '0', ()
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# Time/date queries.
|
||||
|
||||
def _parse_periods(pattern):
|
||||
def _parse_periods(pattern: str) -> Tuple['Period', 'Period']:
|
||||
"""Parse a string containing two dates separated by two dots (..).
|
||||
Return a pair of `Period` objects.
|
||||
"""
|
||||
|
|
@ -563,7 +580,7 @@ class Period:
|
|||
relative_re = '(?P<sign>[+|-]?)(?P<quantity>[0-9]+)' + \
|
||||
'(?P<timespan>[y|m|w|d])'
|
||||
|
||||
def __init__(self, date, precision):
|
||||
def __init__(self, date: datetime, precision: str):
|
||||
"""Create a period with the given date (a `datetime` object) and
|
||||
precision (a string, one of "year", "month", "day", "hour", "minute",
|
||||
or "second").
|
||||
|
|
@ -574,7 +591,7 @@ class Period:
|
|||
self.precision = precision
|
||||
|
||||
@classmethod
|
||||
def parse(cls, string):
|
||||
def parse(cls: Type['Period'], string: str) -> Optional['Period']:
|
||||
"""Parse a date and return a `Period` object or `None` if the
|
||||
string is empty, or raise an InvalidQueryArgumentValueError if
|
||||
the string cannot be parsed to a date.
|
||||
|
|
@ -591,7 +608,8 @@ class Period:
|
|||
and a "year" is exactly 365 days.
|
||||
"""
|
||||
|
||||
def find_date_and_format(string):
|
||||
def find_date_and_format(string: str) -> \
|
||||
Union[Tuple[None, None], Tuple[datetime, int]]:
|
||||
for ord, format in enumerate(cls.date_formats):
|
||||
for format_option in format:
|
||||
try:
|
||||
|
|
@ -628,7 +646,7 @@ class Period:
|
|||
precision = cls.precisions[ordinal]
|
||||
return cls(date, precision)
|
||||
|
||||
def open_right_endpoint(self):
|
||||
def open_right_endpoint(self) -> datetime:
|
||||
"""Based on the precision, convert the period to a precise
|
||||
`datetime` for use as a right endpoint in a right-open interval.
|
||||
"""
|
||||
|
|
@ -660,7 +678,7 @@ class DateInterval:
|
|||
A right endpoint of None means towards infinity.
|
||||
"""
|
||||
|
||||
def __init__(self, start, end):
|
||||
def __init__(self, start: Optional[datetime], end: Optional[datetime]):
|
||||
if start is not None and end is not None and not start < end:
|
||||
raise ValueError("start date {} is not before end date {}"
|
||||
.format(start, end))
|
||||
|
|
@ -668,21 +686,21 @@ class DateInterval:
|
|||
self.end = end
|
||||
|
||||
@classmethod
|
||||
def from_periods(cls, start, end):
|
||||
def from_periods(cls, start: Period, end: Period) -> 'DateInterval':
|
||||
"""Create an interval with two Periods as the endpoints.
|
||||
"""
|
||||
end_date = end.open_right_endpoint() if end is not None else None
|
||||
start_date = start.date if start is not None else None
|
||||
return cls(start_date, end_date)
|
||||
|
||||
def contains(self, date):
|
||||
def contains(self, date: datetime) -> bool:
|
||||
if self.start is not None and date < self.start:
|
||||
return False
|
||||
if self.end is not None and date >= self.end:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f'[{self.start}, {self.end})'
|
||||
|
||||
|
||||
|
|
@ -696,12 +714,12 @@ class DateQuery(FieldQuery):
|
|||
using an ellipsis interval syntax similar to that of NumericQuery.
|
||||
"""
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
def __init__(self, field, pattern, fast: bool = True):
|
||||
super().__init__(field, pattern, fast)
|
||||
start, end = _parse_periods(pattern)
|
||||
self.interval = DateInterval.from_periods(start, end)
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item: Item) -> bool:
|
||||
if self.field not in item:
|
||||
return False
|
||||
timestamp = float(item[self.field])
|
||||
|
|
@ -710,7 +728,7 @@ class DateQuery(FieldQuery):
|
|||
|
||||
_clause_tmpl = "{0} {1} ?"
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Tuple[str | None, Collection]:
|
||||
clause_parts = []
|
||||
subvals = []
|
||||
|
||||
|
|
@ -742,7 +760,7 @@ class DurationQuery(NumericQuery):
|
|||
or M:SS time interval.
|
||||
"""
|
||||
|
||||
def _convert(self, s):
|
||||
def _convert(self, s: str) -> Optional[float]:
|
||||
"""Convert a M:SS or numeric string to a float.
|
||||
|
||||
Return None if `s` is empty.
|
||||
|
|
@ -774,21 +792,21 @@ class Sort:
|
|||
"""
|
||||
return None
|
||||
|
||||
def sort(self, items):
|
||||
def sort(self, items: List) -> List:
|
||||
"""Sort the list of objects and return a list.
|
||||
"""
|
||||
return sorted(items)
|
||||
|
||||
def is_slow(self):
|
||||
def is_slow(self) -> bool:
|
||||
"""Indicate whether this query is *slow*, meaning that it cannot
|
||||
be executed in SQL and must be executed in Python.
|
||||
"""
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return 0
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return type(self) == type(other)
|
||||
|
||||
|
||||
|
|
@ -796,13 +814,13 @@ class MultipleSort(Sort):
|
|||
"""Sort that encapsulates multiple sub-sorts.
|
||||
"""
|
||||
|
||||
def __init__(self, sorts=None):
|
||||
def __init__(self, sorts: Optional[List[Sort]] = None):
|
||||
self.sorts = sorts or []
|
||||
|
||||
def add_sort(self, sort):
|
||||
def add_sort(self, sort: Sort):
|
||||
self.sorts.append(sort)
|
||||
|
||||
def _sql_sorts(self):
|
||||
def _sql_sorts(self) -> List[Sort]:
|
||||
"""Return the list of sub-sorts for which we can be (at least
|
||||
partially) fast.
|
||||
|
||||
|
|
@ -819,7 +837,7 @@ class MultipleSort(Sort):
|
|||
sql_sorts.reverse()
|
||||
return sql_sorts
|
||||
|
||||
def order_clause(self):
|
||||
def order_clause(self) -> str:
|
||||
order_strings = []
|
||||
for sort in self._sql_sorts():
|
||||
order = sort.order_clause()
|
||||
|
|
@ -827,7 +845,7 @@ class MultipleSort(Sort):
|
|||
|
||||
return ", ".join(order_strings)
|
||||
|
||||
def is_slow(self):
|
||||
def is_slow(self) -> bool:
|
||||
for sort in self.sorts:
|
||||
if sort.is_slow():
|
||||
return True
|
||||
|
|
@ -865,17 +883,22 @@ class FieldSort(Sort):
|
|||
any kind).
|
||||
"""
|
||||
|
||||
def __init__(self, field, ascending=True, case_insensitive=True):
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
ascending: bool = True,
|
||||
case_insensitive: bool = True,
|
||||
):
|
||||
self.field = field
|
||||
self.ascending = ascending
|
||||
self.case_insensitive = case_insensitive
|
||||
|
||||
def sort(self, objs):
|
||||
def sort(self, objs: Collection):
|
||||
# TODO: Conversion and null-detection here. In Python 3,
|
||||
# comparisons with None fail. We should also support flexible
|
||||
# attributes with different types without falling over.
|
||||
|
||||
def key(item):
|
||||
def key(item: Item):
|
||||
field_val = item.get(self.field, '')
|
||||
if self.case_insensitive and isinstance(field_val, str):
|
||||
field_val = field_val.lower()
|
||||
|
|
@ -883,17 +906,17 @@ class FieldSort(Sort):
|
|||
|
||||
return sorted(objs, key=key, reverse=not self.ascending)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return '<{}: {}{}>'.format(
|
||||
type(self).__name__,
|
||||
self.field,
|
||||
'+' if self.ascending else '-',
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.field, self.ascending))
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return super().__eq__(other) and \
|
||||
self.field == other.field and \
|
||||
self.ascending == other.ascending
|
||||
|
|
@ -903,7 +926,7 @@ class FixedFieldSort(FieldSort):
|
|||
"""Sort object to sort on a fixed field.
|
||||
"""
|
||||
|
||||
def order_clause(self):
|
||||
def order_clause(self) -> str:
|
||||
order = "ASC" if self.ascending else "DESC"
|
||||
if self.case_insensitive:
|
||||
field = '(CASE ' \
|
||||
|
|
@ -920,24 +943,24 @@ class SlowFieldSort(FieldSort):
|
|||
i.e., a computed or flexible field.
|
||||
"""
|
||||
|
||||
def is_slow(self):
|
||||
def is_slow(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class NullSort(Sort):
|
||||
"""No sorting. Leave results unsorted."""
|
||||
|
||||
def sort(self, items):
|
||||
def sort(self, items: List) -> List:
|
||||
return items
|
||||
|
||||
def __nonzero__(self):
|
||||
def __nonzero__(self) -> bool:
|
||||
return self.__bool__()
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
return False
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return type(self) == type(other) or other is None
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return 0
|
||||
|
|
|
|||
|
|
@ -17,7 +17,10 @@
|
|||
|
||||
import re
|
||||
import itertools
|
||||
from . import query
|
||||
from typing import Dict, Type, Tuple, Optional, Mapping, Collection, List
|
||||
|
||||
from . import query, Model
|
||||
from .query import Sort
|
||||
|
||||
PARSE_QUERY_PART_REGEX = re.compile(
|
||||
# Non-capturing optional segment for the keyword.
|
||||
|
|
@ -34,8 +37,12 @@ PARSE_QUERY_PART_REGEX = re.compile(
|
|||
)
|
||||
|
||||
|
||||
def parse_query_part(part, query_classes={}, prefixes={},
|
||||
default_class=query.SubstringQuery):
|
||||
def parse_query_part(
|
||||
part: str,
|
||||
query_classes: Dict = {},
|
||||
prefixes: Dict = {},
|
||||
default_class: Type[query.SubstringQuery] = query.SubstringQuery,
|
||||
) -> Tuple[Optional[str], str, Type[query.Query], bool]:
|
||||
"""Parse a single *query part*, which is a chunk of a complete query
|
||||
string representing a single criterion.
|
||||
|
||||
|
|
@ -100,7 +107,11 @@ def parse_query_part(part, query_classes={}, prefixes={},
|
|||
return key, term, query_class, negate
|
||||
|
||||
|
||||
def construct_query_part(model_cls, prefixes, query_part):
|
||||
def construct_query_part(
|
||||
model_cls: Type[Model],
|
||||
prefixes: Mapping[str, Type[query.Query]],
|
||||
query_part: str,
|
||||
) -> query.Query:
|
||||
"""Parse a *query part* string and return a :class:`Query` object.
|
||||
|
||||
:param model_cls: The :class:`Model` class that this is a query for.
|
||||
|
|
@ -158,7 +169,12 @@ def construct_query_part(model_cls, prefixes, query_part):
|
|||
return out_query
|
||||
|
||||
|
||||
def query_from_strings(query_cls, model_cls, prefixes, query_parts):
|
||||
def query_from_strings(
|
||||
query_cls: Type[query.Query],
|
||||
model_cls: Type[Model],
|
||||
prefixes: Mapping[str, Type[query.Query]],
|
||||
query_parts: Collection[str],
|
||||
) -> query.Query:
|
||||
"""Creates a collection query of type `query_cls` from a list of
|
||||
strings in the format used by parse_query_part. `model_cls`
|
||||
determines how queries are constructed from strings.
|
||||
|
|
@ -171,7 +187,11 @@ def query_from_strings(query_cls, model_cls, prefixes, query_parts):
|
|||
return query_cls(subqueries)
|
||||
|
||||
|
||||
def construct_sort_part(model_cls, part, case_insensitive=True):
|
||||
def construct_sort_part(
|
||||
model_cls: Type[Model],
|
||||
part: str,
|
||||
case_insensitive: bool = True,
|
||||
) -> Sort:
|
||||
"""Create a `Sort` from a single string criterion.
|
||||
|
||||
`model_cls` is the `Model` being queried. `part` is a single string
|
||||
|
|
@ -197,7 +217,11 @@ def construct_sort_part(model_cls, part, case_insensitive=True):
|
|||
return sort
|
||||
|
||||
|
||||
def sort_from_strings(model_cls, sort_parts, case_insensitive=True):
|
||||
def sort_from_strings(
|
||||
model_cls: Type[Model],
|
||||
sort_parts: Collection[str],
|
||||
case_insensitive: bool = True,
|
||||
) -> Sort:
|
||||
"""Create a `Sort` from a list of sort criteria (strings).
|
||||
"""
|
||||
if not sort_parts:
|
||||
|
|
@ -212,8 +236,12 @@ def sort_from_strings(model_cls, sort_parts, case_insensitive=True):
|
|||
return sort
|
||||
|
||||
|
||||
def parse_sorted_query(model_cls, parts, prefixes={},
|
||||
case_insensitive=True):
|
||||
def parse_sorted_query(
|
||||
model_cls: Type[Model],
|
||||
parts: List[str],
|
||||
prefixes: Mapping[str, Type[query.Query]] = {},
|
||||
case_insensitive: bool = True,
|
||||
) -> Tuple[query.Query, Sort]:
|
||||
"""Given a list of strings, create the `Query` and `Sort` that they
|
||||
represent.
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue