diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 084ceef99..accb62327 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -385,7 +385,7 @@ class Model: """ return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT - def _get(self, key, default: bool = None, raise_: bool = False): + def _get(self, key, default: Any = None, raise_: bool = False): """Get the value for a field, or `default`. Alternatively, raise a KeyError if the field is not available. """ diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 48bc8e4e2..fbc080426 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -23,14 +23,16 @@ from operator import mul from typing import ( Any, Collection, + Generic, Iterator, List, - MutableMapping, + MutableSequence, Optional, Pattern, Sequence, Tuple, Type, + TypeVar, Union, ) @@ -126,7 +128,12 @@ class NamedQuery(Query): ... -class FieldQuery(Query): +P = TypeVar("P") +SQLiteType = Union[str, float, int, memoryview] +AnySQLiteType = TypeVar("AnySQLiteType", bound=SQLiteType) + + +class FieldQuery(Query, Generic[P]): """An abstract query that searches in a specific field for a pattern. Subclasses must provide a `value_match` class method, which determines whether a certain pattern string matches a certain value @@ -134,15 +141,15 @@ class FieldQuery(Query): same matching functionality in SQLite. """ - def __init__(self, field: str, pattern: str, fast: bool = True): + def __init__(self, field: str, pattern: P, fast: bool = True): self.field = field self.pattern = pattern self.fast = fast - def col_clause(self) -> Union[Optional[str], Sequence[Any]]: + def col_clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: return None, () - def clause(self): + def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: if self.fast: return self.col_clause() else: @@ -150,7 +157,7 @@ class FieldQuery(Query): return None, () @classmethod - def value_match(cls, pattern: str, value: str): + def value_match(cls, pattern: P, value: Any): """Determine whether the value matches the pattern. Both arguments are strings. """ @@ -171,24 +178,24 @@ class FieldQuery(Query): return hash((self.field, hash(self.pattern))) -class MatchQuery(FieldQuery): +class MatchQuery(FieldQuery[AnySQLiteType]): """A query that looks for exact matches in an Model field.""" - def col_clause(self) -> Tuple[str, List[str]]: + def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: return self.field + " = ?", [self.pattern] @classmethod - def value_match(cls, pattern: str, value: str) -> bool: + def value_match(cls, pattern: AnySQLiteType, value: Any) -> bool: return pattern == value -class NoneQuery(FieldQuery): +class NoneQuery(FieldQuery[None]): """A query that checks whether a field is null.""" def __init__(self, field, fast: bool = True): - super().__init__(field, "", fast) + super().__init__(field, None, fast) - def col_clause(self) -> Tuple[str, Tuple]: + def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: return self.field + " IS NULL", () def match(self, obj: Model) -> bool: @@ -198,13 +205,13 @@ class NoneQuery(FieldQuery): return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self) -class StringFieldQuery(FieldQuery): +class StringFieldQuery(FieldQuery[P]): """A FieldQuery that converts values to strings before matching them. """ @classmethod - def value_match(cls, pattern: str, value: Any): + def value_match(cls, pattern: P, value: Any): """Determine whether the value matches the pattern. The value may have any type. """ @@ -213,7 +220,7 @@ class StringFieldQuery(FieldQuery): @classmethod def string_match( cls, - pattern: str, + pattern: P, value: str, ) -> bool: """Determine whether the value matches the pattern. Both @@ -222,10 +229,10 @@ class StringFieldQuery(FieldQuery): raise NotImplementedError() -class StringQuery(StringFieldQuery): +class StringQuery(StringFieldQuery[str]): """A query that matches a whole string in a specific Model field.""" - def col_clause(self) -> Tuple[str, List[str]]: + def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: search = (self.pattern .replace('\\', '\\\\') .replace('%', '\\%') @@ -239,10 +246,10 @@ class StringQuery(StringFieldQuery): return pattern.lower() == value.lower() -class SubstringQuery(StringFieldQuery): +class SubstringQuery(StringFieldQuery[str]): """A query that matches a substring in a specific Model field.""" - def col_clause(self) -> Tuple[str, List[str]]: + def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: pattern = (self.pattern .replace('\\', '\\\\') .replace('%', '\\%') @@ -257,7 +264,7 @@ class SubstringQuery(StringFieldQuery): return pattern.lower() in value.lower() -class RegexpQuery(StringFieldQuery): +class RegexpQuery(StringFieldQuery[Pattern]): """A query that matches a regular expression in a specific Model field. Raises InvalidQueryError when the pattern is not a valid regular @@ -276,7 +283,7 @@ class RegexpQuery(StringFieldQuery): super().__init__(field, pattern_re, fast) - def col_clause(self): + def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: return f" regexp({self.field}, ?)", [self.pattern.pattern] @staticmethod @@ -291,7 +298,7 @@ class RegexpQuery(StringFieldQuery): return pattern.search(cls._normalize(value)) is not None -class BooleanQuery(MatchQuery): +class BooleanQuery(MatchQuery[int]): """Matches a boolean field. Pattern should either be a boolean or a string reflecting a boolean. """ @@ -299,16 +306,18 @@ class BooleanQuery(MatchQuery): def __init__( self, field: str, - pattern: Union[bool, str], + pattern: bool, fast: bool = True, ): - super().__init__(field, pattern, fast) if isinstance(pattern, str): - self.pattern = util.str2bool(pattern) - self.pattern = int(self.pattern) + pattern = util.str2bool(pattern) + + pattern_int = int(pattern) + + super().__init__(field, pattern_int, fast) -class BytesQuery(MatchQuery): +class BytesQuery(FieldQuery[bytes]): """Match a raw bytes field (i.e., a path). This is a necessary hack to work around the `sqlite3` module's desire to treat `bytes` and `unicode` equivalently in Python 2. Always use this query instead of @@ -316,22 +325,30 @@ class BytesQuery(MatchQuery): """ def __init__(self, field: str, pattern: Union[bytes, str, memoryview]): - super().__init__(field, pattern) - # Use a buffer/memoryview representation of the pattern for SQLite # matching. This instructs SQLite to treat the blob as binary # rather than encoded Unicode. - if isinstance(self.pattern, (str, bytes)): - if isinstance(self.pattern, str): - self.pattern = self.pattern.encode('utf-8') - self.buf_pattern = memoryview(self.pattern) + if isinstance(pattern, (str, bytes)): + if isinstance(pattern, str): + bytes_pattern = pattern.encode('utf-8') + else: + bytes_pattern = pattern + self.buf_pattern = memoryview(bytes_pattern) elif isinstance(self.pattern, memoryview): self.buf_pattern = self.pattern - self.pattern = bytes(self.pattern) + bytes_pattern = bytes(self.pattern) + else: + raise ValueError("pattern must be bytes, str, or memoryview") - def col_clause(self) -> Tuple[str, List[memoryview]]: + super().__init__(field, bytes_pattern) + + def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: return self.field + " = ?", [self.buf_pattern] + @classmethod + def value_match(cls, pattern: bytes, value: Any) -> bool: + return pattern == value + class NumericQuery(FieldQuery): """Matches numeric fields. A syntax using Ruby-style range ellipses @@ -390,7 +407,7 @@ class NumericQuery(FieldQuery): return False return True - def col_clause(self) -> Tuple[str, Tuple]: + def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: if self.point is not None: return self.field + '=?', (self.point,) else: @@ -430,7 +447,7 @@ class CollectionQuery(Query): def clause_with_joiner( self, joiner: str, - ) -> Tuple[Optional[str], Collection]: + ) -> Tuple[Optional[str], Sequence[SQLiteType]]: """Return a clause created by joining together the clauses of all subqueries with the string joiner (padded by spaces). """ @@ -477,7 +494,7 @@ class AnyFieldQuery(CollectionQuery): # TYPING ERROR super().__init__(subqueries) - def clause(self) -> Tuple[Union[str, None], Collection]: + def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: return self.clause_with_joiner('or') def match(self, obj: Model) -> bool: @@ -502,7 +519,7 @@ class MutableCollectionQuery(CollectionQuery): """A collection query whose subqueries may be modified after the query is initialized. """ - subqueries: MutableMapping + subqueries: MutableSequence def __setitem__(self, key, value): self.subqueries[key] = value @@ -514,7 +531,7 @@ class MutableCollectionQuery(CollectionQuery): class AndQuery(MutableCollectionQuery): """A conjunction of a list of other queries.""" - def clause(self) -> Tuple[Union[str, None], Collection]: + def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: return self.clause_with_joiner('and') def match(self, obj: Model) -> bool: @@ -524,7 +541,7 @@ class AndQuery(MutableCollectionQuery): class OrQuery(MutableCollectionQuery): """A conjunction of a list of other queries.""" - def clause(self) -> Tuple[Union[str, None], Collection]: + def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: return self.clause_with_joiner('or') def match(self, obj: Model) -> bool: @@ -539,7 +556,7 @@ class NotQuery(Query): def __init__(self, subquery): self.subquery = subquery - def clause(self): + def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: clause, subvals = self.subquery.clause() if clause: return f'not ({clause})', subvals @@ -565,7 +582,7 @@ class NotQuery(Query): class TrueQuery(Query): """A query that always matches.""" - def clause(self) -> Tuple[Union[str, None], Collection]: + def clause(self) -> Tuple[str, Sequence[SQLiteType]]: return '1', () def match(self, obj: Model) -> bool: @@ -575,7 +592,7 @@ class TrueQuery(Query): class FalseQuery(Query): """A query that never matches.""" - def clause(self) -> Tuple[Union[str, None], Collection]: + def clause(self) -> Tuple[str, Sequence[SQLiteType]]: return '0', () def match(self, obj: Model) -> bool: @@ -584,7 +601,7 @@ class FalseQuery(Query): # Time/date queries. -def _parse_periods(pattern: str) -> Tuple['Period', 'Period']: +def _parse_periods(pattern: str) -> Tuple[Optional[Period], Optional[Period]]: """Parse a string containing two dates separated by two dots (..). Return a pair of `Period` objects. """ @@ -661,6 +678,8 @@ class Period: if not string: return None + date: Optional[datetime] + # Check for a relative date. match_dq = re.match(cls.relative_re, string) if match_dq: @@ -678,7 +697,7 @@ class Period: # Check for an absolute date. date, ordinal = find_date_and_format(string) - if date is None: + if date is None or ordinal is None: raise InvalidQueryArgumentValueError(string, 'a valid date/time string') precision = cls.precisions[ordinal] @@ -724,7 +743,11 @@ class DateInterval: self.end = end @classmethod - def from_periods(cls, start: Period, end: Period) -> 'DateInterval': + def from_periods( + cls, + start: Optional[Period], + end: Optional[Period], + ) -> DateInterval: """Create an interval with two Periods as the endpoints. """ end_date = end.open_right_endpoint() if end is not None else None @@ -766,7 +789,7 @@ class DateQuery(FieldQuery): _clause_tmpl = "{0} {1} ?" - def col_clause(self) -> Tuple[Union[str, None], Collection]: + def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: clause_parts = [] subvals = [] @@ -824,7 +847,7 @@ class Sort: the database. """ - def order_clause(self) -> None: + def order_clause(self) -> Optional[str]: """Generates a SQL fragment to be used in a ORDER BY clause, or None if no fragment is used (i.e., this is a slow sort). """ @@ -858,30 +881,22 @@ class MultipleSort(Sort): def add_sort(self, sort: Sort): self.sorts.append(sort) - def _sql_sorts(self) -> List[Sort]: - """Return the list of sub-sorts for which we can be (at least - partially) fast. + def order_clause(self) -> str: + """Return the list SQL clauses for those sub-sorts for which we can be + (at least partially) fast. A contiguous suffix of fast (SQL-capable) sub-sorts are executable in SQL. The remaining, even if they are fast independently, must be executed slowly. """ - sql_sorts = [] - for sort in reversed(self.sorts): - if not sort.order_clause() is None: - sql_sorts.append(sort) - else: - break - sql_sorts.reverse() - return sql_sorts - - def order_clause(self) -> str: order_strings = [] - for sort in self._sql_sorts(): - order = sort.order_clause() - order_strings.append(order) + for sort in reversed(self.sorts): + clause = sort.order_clause() + if clause is None: + break + order_strings.append(clause) + order_strings.reverse() - # TYPING ERROR return ", ".join(order_strings) def is_slow(self) -> bool: