diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 016fe2c1a..e190083c5 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -17,11 +17,16 @@ import re from operator import mul +from typing import Union, Tuple, List, Optional, Pattern, Any, Type, Iterator, \ + Collection, Mapping + from beets import util from datetime import datetime, timedelta import unicodedata from functools import reduce +from beets.library import Item + class ParsingError(ValueError): """Abstract class for any unparseable user-requested album/query @@ -60,7 +65,7 @@ class Query: """An abstract class representing a query into the item database. """ - def clause(self): + def clause(self) -> Union[None, Tuple]: """Generate an SQLite expression implementing the query. Return (clause, subvals) where clause is a valid sqlite @@ -69,19 +74,19 @@ class Query: """ return None, () - def match(self, item): + def match(self, item: Item): """Check whether this query matches a given Item. Can be used to perform queries on arbitrary sets of Items. """ raise NotImplementedError - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}()" - def __eq__(self, other): + def __eq__(self, other) -> bool: return type(self) == type(other) - def __hash__(self): + def __hash__(self) -> int: return 0 @@ -93,12 +98,12 @@ class FieldQuery(Query): same matching functionality in SQLite. """ - def __init__(self, field, pattern, fast=True): + def __init__(self, field: str, pattern: Optional[str], fast: bool = True): self.field = field self.pattern = pattern self.fast = fast - def col_clause(self): + def col_clause(self) -> Union[None, Tuple]: return None, () def clause(self): @@ -109,51 +114,51 @@ class FieldQuery(Query): return None, () @classmethod - def value_match(cls, pattern, value): + def value_match(cls, pattern: str, value: str): """Determine whether the value matches the pattern. Both arguments are strings. """ raise NotImplementedError() - def match(self, item): + def match(self, item: Item): return self.value_match(self.pattern, item.get(self.field)) - def __repr__(self): + def __repr__(self) -> str: return ("{0.__class__.__name__}({0.field!r}, {0.pattern!r}, " "{0.fast})".format(self)) - def __eq__(self, other): + def __eq__(self, other) -> bool: return super().__eq__(other) and \ self.field == other.field and self.pattern == other.pattern - def __hash__(self): + def __hash__(self) -> int: return hash((self.field, hash(self.pattern))) class MatchQuery(FieldQuery): """A query that looks for exact matches in an item field.""" - def col_clause(self): + def col_clause(self) -> Tuple[str, List[str]]: return self.field + " = ?", [self.pattern] @classmethod - def value_match(cls, pattern, value): + def value_match(cls, pattern: str, value: str) -> bool: return pattern == value class NoneQuery(FieldQuery): """A query that checks whether a field is null.""" - def __init__(self, field, fast=True): + def __init__(self, field, fast: bool = True): super().__init__(field, None, fast) - def col_clause(self): + def col_clause(self) -> Tuple[str, Tuple]: return self.field + " IS NULL", () - def match(self, item): + def match(self, item: Item) -> bool: return item.get(self.field) is None - def __repr__(self): + def __repr__(self) -> str: return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self) @@ -163,14 +168,18 @@ class StringFieldQuery(FieldQuery): """ @classmethod - def value_match(cls, pattern, value): + def value_match(cls: Type['StringFieldQuery'], pattern: str, value: Any): """Determine whether the value matches the pattern. The value may have any type. """ return cls.string_match(pattern, util.as_string(value)) @classmethod - def string_match(cls, pattern, value): + def string_match( + cls: Type['StringFieldQuery'], + pattern: str, + value: str, + ) -> bool: """Determine whether the value matches the pattern. Both arguments are strings. Subclasses implement this method. """ @@ -180,7 +189,7 @@ class StringFieldQuery(FieldQuery): class StringQuery(StringFieldQuery): """A query that matches a whole string in a specific item field.""" - def col_clause(self): + def col_clause(self) -> Tuple[str, List[str]]: search = (self.pattern .replace('\\', '\\\\') .replace('%', '\\%') @@ -190,14 +199,14 @@ class StringQuery(StringFieldQuery): return clause, subvals @classmethod - def string_match(cls, pattern, value): + def string_match(cls, pattern: str, value: str) -> bool: return pattern.lower() == value.lower() class SubstringQuery(StringFieldQuery): """A query that matches a substring in a specific item field.""" - def col_clause(self): + def col_clause(self) -> Tuple[str, List[str]]: pattern = (self.pattern .replace('\\', '\\\\') .replace('%', '\\%') @@ -208,7 +217,7 @@ class SubstringQuery(StringFieldQuery): return clause, subvals @classmethod - def string_match(cls, pattern, value): + def string_match(cls, pattern: str, value: str) -> bool: return pattern.lower() in value.lower() @@ -220,7 +229,7 @@ class RegexpQuery(StringFieldQuery): expression. """ - def __init__(self, field, pattern, fast=True): + def __init__(self, field: str, pattern: str, fast: bool = True): super().__init__(field, pattern, fast) pattern = self._normalize(pattern) try: @@ -235,14 +244,14 @@ class RegexpQuery(StringFieldQuery): return f" regexp({self.field}, ?)", [self.pattern.pattern] @staticmethod - def _normalize(s): + def _normalize(s: str) -> str: """Normalize a Unicode string's representation (used on both patterns and matched values). """ return unicodedata.normalize('NFC', s) @classmethod - def string_match(cls, pattern, value): + def string_match(cls, pattern: Pattern, value: str) -> bool: return pattern.search(cls._normalize(value)) is not None @@ -251,7 +260,12 @@ class BooleanQuery(MatchQuery): string reflecting a boolean. """ - def __init__(self, field, pattern, fast=True): + def __init__( + self, + field: str, + pattern: Union[bool, str], + fast: bool = True, + ): super().__init__(field, pattern, fast) if isinstance(pattern, str): self.pattern = util.str2bool(pattern) @@ -265,7 +279,7 @@ class BytesQuery(MatchQuery): `MatchQuery` when matching on BLOB values. """ - def __init__(self, field, pattern): + def __init__(self, field: str, pattern: Union[bytes, str, memoryview]): super().__init__(field, pattern) # Use a buffer/memoryview representation of the pattern for SQLite @@ -279,7 +293,7 @@ class BytesQuery(MatchQuery): self.buf_pattern = self.pattern self.pattern = bytes(self.pattern) - def col_clause(self): + def col_clause(self) -> Tuple[str, List[memoryview]]: return self.field + " = ?", [self.buf_pattern] @@ -292,7 +306,7 @@ class NumericQuery(FieldQuery): a float. """ - def _convert(self, s): + def _convert(self, s: str) -> Union[float, int, None]: """Convert a string to a numeric type (float or int). Return None if `s` is empty. @@ -309,7 +323,7 @@ class NumericQuery(FieldQuery): except ValueError: raise InvalidQueryArgumentValueError(s, "an int or a float") - def __init__(self, field, pattern, fast=True): + def __init__(self, field: str, pattern: str, fast: bool = True): super().__init__(field, pattern, fast) parts = pattern.split('..', 1) @@ -324,7 +338,7 @@ class NumericQuery(FieldQuery): self.rangemin = self._convert(parts[0]) self.rangemax = self._convert(parts[1]) - def match(self, item): + def match(self, item: Item) -> bool: if self.field not in item: return False value = item[self.field] @@ -340,7 +354,7 @@ class NumericQuery(FieldQuery): return False return True - def col_clause(self): + def col_clause(self) -> Tuple[str, Tuple]: if self.point is not None: return self.field + '=?', (self.point,) else: @@ -360,24 +374,27 @@ class CollectionQuery(Query): indexed like a list to access the sub-queries. """ - def __init__(self, subqueries=()): + def __init__(self, subqueries: Mapping = ()): self.subqueries = subqueries # Act like a sequence. - def __len__(self): + def __len__(self) -> int: return len(self.subqueries) def __getitem__(self, key): return self.subqueries[key] - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.subqueries) - def __contains__(self, item): + def __contains__(self, item) -> bool: return item in self.subqueries - def clause_with_joiner(self, joiner): + def clause_with_joiner( + self, + joiner: str, + ) -> Tuple[Optional[str], Collection]: """Return a clause created by joining together the clauses of all subqueries with the string joiner (padded by spaces). """ @@ -393,14 +410,14 @@ class CollectionQuery(Query): clause = (' ' + joiner + ' ').join(clause_parts) return clause, subvals - def __repr__(self): + def __repr__(self) -> str: return "{0.__class__.__name__}({0.subqueries!r})".format(self) - def __eq__(self, other): + def __eq__(self, other) -> bool: return super().__eq__(other) and \ self.subqueries == other.subqueries - def __hash__(self): + def __hash__(self) -> int: """Since subqueries are mutable, this object should not be hashable. However and for conveniences purposes, it can be hashed. """ @@ -413,7 +430,7 @@ class AnyFieldQuery(CollectionQuery): constructor. """ - def __init__(self, pattern, fields, cls): + def __init__(self, pattern, fields, cls: Type[FieldQuery]): self.pattern = pattern self.fields = fields self.query_class = cls @@ -423,24 +440,24 @@ class AnyFieldQuery(CollectionQuery): subqueries.append(cls(field, pattern, True)) super().__init__(subqueries) - def clause(self): + def clause(self) -> Tuple[str | None, Collection]: return self.clause_with_joiner('or') - def match(self, item): + def match(self, item: Item) -> bool: for subq in self.subqueries: if subq.match(item): return True return False - def __repr__(self): + def __repr__(self) -> str: return ("{0.__class__.__name__}({0.pattern!r}, {0.fields!r}, " "{0.query_class.__name__})".format(self)) - def __eq__(self, other): + def __eq__(self, other) -> bool: return super().__eq__(other) and \ self.query_class == other.query_class - def __hash__(self): + def __hash__(self) -> int: return hash((self.pattern, tuple(self.fields), self.query_class)) @@ -459,20 +476,20 @@ class MutableCollectionQuery(CollectionQuery): class AndQuery(MutableCollectionQuery): """A conjunction of a list of other queries.""" - def clause(self): + def clause(self) -> Tuple[str | None, Collection]: return self.clause_with_joiner('and') - def match(self, item): + def match(self, item) -> bool: return all(q.match(item) for q in self.subqueries) class OrQuery(MutableCollectionQuery): """A conjunction of a list of other queries.""" - def clause(self): + def clause(self) -> Tuple[str | None, Collection]: return self.clause_with_joiner('or') - def match(self, item): + def match(self, item) -> bool: return any(q.match(item) for q in self.subqueries) @@ -493,43 +510,43 @@ class NotQuery(Query): # is handled by match() for slow queries. return clause, subvals - def match(self, item): + def match(self, item) -> bool: return not self.subquery.match(item) - def __repr__(self): + def __repr__(self) -> str: return "{0.__class__.__name__}({0.subquery!r})".format(self) - def __eq__(self, other): + def __eq__(self, other) -> bool: return super().__eq__(other) and \ self.subquery == other.subquery - def __hash__(self): + def __hash__(self) -> int: return hash(('not', hash(self.subquery))) class TrueQuery(Query): """A query that always matches.""" - def clause(self): + def clause(self) -> Tuple[str | None, Collection]: return '1', () - def match(self, item): + def match(self, item) -> bool: return True class FalseQuery(Query): """A query that never matches.""" - def clause(self): + def clause(self) -> Tuple[str | None, Collection]: return '0', () - def match(self, item): + def match(self, item) -> bool: return False # Time/date queries. -def _parse_periods(pattern): +def _parse_periods(pattern: str) -> Tuple['Period', 'Period']: """Parse a string containing two dates separated by two dots (..). Return a pair of `Period` objects. """ @@ -563,7 +580,7 @@ class Period: relative_re = '(?P[+|-]?)(?P[0-9]+)' + \ '(?P[y|m|w|d])' - def __init__(self, date, precision): + def __init__(self, date: datetime, precision: str): """Create a period with the given date (a `datetime` object) and precision (a string, one of "year", "month", "day", "hour", "minute", or "second"). @@ -574,7 +591,7 @@ class Period: self.precision = precision @classmethod - def parse(cls, string): + def parse(cls: Type['Period'], string: str) -> Optional['Period']: """Parse a date and return a `Period` object or `None` if the string is empty, or raise an InvalidQueryArgumentValueError if the string cannot be parsed to a date. @@ -591,7 +608,8 @@ class Period: and a "year" is exactly 365 days. """ - def find_date_and_format(string): + def find_date_and_format(string: str) -> \ + Union[Tuple[None, None], Tuple[datetime, int]]: for ord, format in enumerate(cls.date_formats): for format_option in format: try: @@ -628,7 +646,7 @@ class Period: precision = cls.precisions[ordinal] return cls(date, precision) - def open_right_endpoint(self): + def open_right_endpoint(self) -> datetime: """Based on the precision, convert the period to a precise `datetime` for use as a right endpoint in a right-open interval. """ @@ -660,7 +678,7 @@ class DateInterval: A right endpoint of None means towards infinity. """ - def __init__(self, start, end): + def __init__(self, start: Optional[datetime], end: Optional[datetime]): if start is not None and end is not None and not start < end: raise ValueError("start date {} is not before end date {}" .format(start, end)) @@ -668,21 +686,21 @@ class DateInterval: self.end = end @classmethod - def from_periods(cls, start, end): + def from_periods(cls, start: Period, end: Period) -> 'DateInterval': """Create an interval with two Periods as the endpoints. """ end_date = end.open_right_endpoint() if end is not None else None start_date = start.date if start is not None else None return cls(start_date, end_date) - def contains(self, date): + def contains(self, date: datetime) -> bool: if self.start is not None and date < self.start: return False if self.end is not None and date >= self.end: return False return True - def __str__(self): + def __str__(self) -> str: return f'[{self.start}, {self.end})' @@ -696,12 +714,12 @@ class DateQuery(FieldQuery): using an ellipsis interval syntax similar to that of NumericQuery. """ - def __init__(self, field, pattern, fast=True): + def __init__(self, field, pattern, fast: bool = True): super().__init__(field, pattern, fast) start, end = _parse_periods(pattern) self.interval = DateInterval.from_periods(start, end) - def match(self, item): + def match(self, item: Item) -> bool: if self.field not in item: return False timestamp = float(item[self.field]) @@ -710,7 +728,7 @@ class DateQuery(FieldQuery): _clause_tmpl = "{0} {1} ?" - def col_clause(self): + def col_clause(self) -> Tuple[str | None, Collection]: clause_parts = [] subvals = [] @@ -742,7 +760,7 @@ class DurationQuery(NumericQuery): or M:SS time interval. """ - def _convert(self, s): + def _convert(self, s: str) -> Optional[float]: """Convert a M:SS or numeric string to a float. Return None if `s` is empty. @@ -774,21 +792,21 @@ class Sort: """ return None - def sort(self, items): + def sort(self, items: List) -> List: """Sort the list of objects and return a list. """ return sorted(items) - def is_slow(self): + def is_slow(self) -> bool: """Indicate whether this query is *slow*, meaning that it cannot be executed in SQL and must be executed in Python. """ return False - def __hash__(self): + def __hash__(self) -> int: return 0 - def __eq__(self, other): + def __eq__(self, other) -> bool: return type(self) == type(other) @@ -796,13 +814,13 @@ class MultipleSort(Sort): """Sort that encapsulates multiple sub-sorts. """ - def __init__(self, sorts=None): + def __init__(self, sorts: Optional[List[Sort]] = None): self.sorts = sorts or [] - def add_sort(self, sort): + def add_sort(self, sort: Sort): self.sorts.append(sort) - def _sql_sorts(self): + def _sql_sorts(self) -> List[Sort]: """Return the list of sub-sorts for which we can be (at least partially) fast. @@ -819,7 +837,7 @@ class MultipleSort(Sort): sql_sorts.reverse() return sql_sorts - def order_clause(self): + def order_clause(self) -> str: order_strings = [] for sort in self._sql_sorts(): order = sort.order_clause() @@ -827,7 +845,7 @@ class MultipleSort(Sort): return ", ".join(order_strings) - def is_slow(self): + def is_slow(self) -> bool: for sort in self.sorts: if sort.is_slow(): return True @@ -865,17 +883,22 @@ class FieldSort(Sort): any kind). """ - def __init__(self, field, ascending=True, case_insensitive=True): + def __init__( + self, + field, + ascending: bool = True, + case_insensitive: bool = True, + ): self.field = field self.ascending = ascending self.case_insensitive = case_insensitive - def sort(self, objs): + def sort(self, objs: Collection): # TODO: Conversion and null-detection here. In Python 3, # comparisons with None fail. We should also support flexible # attributes with different types without falling over. - def key(item): + def key(item: Item): field_val = item.get(self.field, '') if self.case_insensitive and isinstance(field_val, str): field_val = field_val.lower() @@ -883,17 +906,17 @@ class FieldSort(Sort): return sorted(objs, key=key, reverse=not self.ascending) - def __repr__(self): + def __repr__(self) -> str: return '<{}: {}{}>'.format( type(self).__name__, self.field, '+' if self.ascending else '-', ) - def __hash__(self): + def __hash__(self) -> int: return hash((self.field, self.ascending)) - def __eq__(self, other): + def __eq__(self, other) -> bool: return super().__eq__(other) and \ self.field == other.field and \ self.ascending == other.ascending @@ -903,7 +926,7 @@ class FixedFieldSort(FieldSort): """Sort object to sort on a fixed field. """ - def order_clause(self): + def order_clause(self) -> str: order = "ASC" if self.ascending else "DESC" if self.case_insensitive: field = '(CASE ' \ @@ -920,24 +943,24 @@ class SlowFieldSort(FieldSort): i.e., a computed or flexible field. """ - def is_slow(self): + def is_slow(self) -> bool: return True class NullSort(Sort): """No sorting. Leave results unsorted.""" - def sort(self, items): + def sort(self, items: List) -> List: return items - def __nonzero__(self): + def __nonzero__(self) -> bool: return self.__bool__() - def __bool__(self): + def __bool__(self) -> bool: return False - def __eq__(self, other): + def __eq__(self, other) -> bool: return type(self) == type(other) or other is None - def __hash__(self): + def __hash__(self) -> int: return 0 diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index 3bf02e4d2..cc2ff14fe 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -17,7 +17,10 @@ import re import itertools -from . import query +from typing import Dict, Type, Tuple, Optional, Mapping, Collection, List + +from . import query, Model +from .query import Sort PARSE_QUERY_PART_REGEX = re.compile( # Non-capturing optional segment for the keyword. @@ -34,8 +37,12 @@ PARSE_QUERY_PART_REGEX = re.compile( ) -def parse_query_part(part, query_classes={}, prefixes={}, - default_class=query.SubstringQuery): +def parse_query_part( + part: str, + query_classes: Dict = {}, + prefixes: Dict = {}, + default_class: Type[query.SubstringQuery] = query.SubstringQuery, +) -> Tuple[Optional[str], str, Type[query.Query], bool]: """Parse a single *query part*, which is a chunk of a complete query string representing a single criterion. @@ -100,7 +107,11 @@ def parse_query_part(part, query_classes={}, prefixes={}, return key, term, query_class, negate -def construct_query_part(model_cls, prefixes, query_part): +def construct_query_part( + model_cls: Type[Model], + prefixes: Mapping[str, Type[query.Query]], + query_part: str, +) -> query.Query: """Parse a *query part* string and return a :class:`Query` object. :param model_cls: The :class:`Model` class that this is a query for. @@ -158,7 +169,12 @@ def construct_query_part(model_cls, prefixes, query_part): return out_query -def query_from_strings(query_cls, model_cls, prefixes, query_parts): +def query_from_strings( + query_cls: Type[query.Query], + model_cls: Type[Model], + prefixes: Mapping[str, Type[query.Query]], + query_parts: Collection[str], +) -> query.Query: """Creates a collection query of type `query_cls` from a list of strings in the format used by parse_query_part. `model_cls` determines how queries are constructed from strings. @@ -171,7 +187,11 @@ def query_from_strings(query_cls, model_cls, prefixes, query_parts): return query_cls(subqueries) -def construct_sort_part(model_cls, part, case_insensitive=True): +def construct_sort_part( + model_cls: Type[Model], + part: str, + case_insensitive: bool = True, +) -> Sort: """Create a `Sort` from a single string criterion. `model_cls` is the `Model` being queried. `part` is a single string @@ -197,7 +217,11 @@ def construct_sort_part(model_cls, part, case_insensitive=True): return sort -def sort_from_strings(model_cls, sort_parts, case_insensitive=True): +def sort_from_strings( + model_cls: Type[Model], + sort_parts: Collection[str], + case_insensitive: bool = True, +) -> Sort: """Create a `Sort` from a list of sort criteria (strings). """ if not sort_parts: @@ -212,8 +236,12 @@ def sort_from_strings(model_cls, sort_parts, case_insensitive=True): return sort -def parse_sorted_query(model_cls, parts, prefixes={}, - case_insensitive=True): +def parse_sorted_query( + model_cls: Type[Model], + parts: List[str], + prefixes: Mapping[str, Type[query.Query]] = {}, + case_insensitive: bool = True, +) -> Tuple[query.Query, Sort]: """Given a list of strings, create the `Query` and `Sort` that they represent. """