Merge branch 'master' into dbcore_typing_0

This commit is contained in:
Adrian Sampson 2023-06-23 17:27:34 -07:00 committed by GitHub
commit 6c77e1a78d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -43,7 +43,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from beets.library import Item
from beets.dbcore import Model
@ -81,7 +80,7 @@ class InvalidQueryArgumentValueError(ParsingError):
class Query(ABC):
"""An abstract class representing a query into the item database.
"""An abstract class representing a query into the database.
"""
def clause(self) -> Tuple[Optional[str], Sequence[Any]]:
@ -98,8 +97,8 @@ class Query(ABC):
@abstractmethod
def match(self, item: Item):
"""Check whether this query matches a given Item. Can be used to
perform queries on arbitrary sets of Items.
"""Check whether this query matches a given Model. Can be used to
perform queries on arbitrary sets of Model.
"""
...
@ -157,8 +156,8 @@ class FieldQuery(Query):
"""
raise NotImplementedError()
def match(self, item: Model):
return self.value_match(self.pattern, item.get(self.field))
def match(self, obj: Model) -> bool:
return self.value_match(self.pattern, obj.get(self.field))
def __repr__(self) -> str:
return ("{0.__class__.__name__}({0.field!r}, {0.pattern!r}, "
@ -173,7 +172,7 @@ class FieldQuery(Query):
class MatchQuery(FieldQuery):
"""A query that looks for exact matches in an item field."""
"""A query that looks for exact matches in an Model field."""
def col_clause(self) -> Tuple[str, List[str]]:
return self.field + " = ?", [self.pattern]
@ -192,8 +191,8 @@ class NoneQuery(FieldQuery):
def col_clause(self) -> Tuple[str, Tuple]:
return self.field + " IS NULL", ()
def match(self, item: 'Item') -> bool:
return item.get(self.field) is None
def match(self, obj: Model) -> bool:
return obj.get(self.field) is None
def __repr__(self) -> str:
return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self)
@ -224,7 +223,7 @@ class StringFieldQuery(FieldQuery):
class StringQuery(StringFieldQuery):
"""A query that matches a whole string in a specific item field."""
"""A query that matches a whole string in a specific Model field."""
def col_clause(self) -> Tuple[str, List[str]]:
search = (self.pattern
@ -241,7 +240,7 @@ class StringQuery(StringFieldQuery):
class SubstringQuery(StringFieldQuery):
"""A query that matches a substring in a specific item field."""
"""A query that matches a substring in a specific Model field."""
def col_clause(self) -> Tuple[str, List[str]]:
pattern = (self.pattern
@ -259,8 +258,7 @@ class SubstringQuery(StringFieldQuery):
class RegexpQuery(StringFieldQuery):
"""A query that matches a regular expression in a specific item
field.
"""A query that matches a regular expression in a specific Model field.
Raises InvalidQueryError when the pattern is not a valid regular
expression.
@ -375,10 +373,10 @@ class NumericQuery(FieldQuery):
self.rangemin = self._convert(parts[0])
self.rangemax = self._convert(parts[1])
def match(self, item: 'Item') -> bool:
if self.field not in item:
def match(self, obj: Model) -> bool:
if self.field not in obj:
return False
value = item[self.field]
value = obj[self.field]
if isinstance(value, str):
value = self._convert(value)
@ -425,8 +423,8 @@ class CollectionQuery(Query):
def __iter__(self) -> Iterator:
return iter(self.subqueries)
def __contains__(self, item) -> bool:
return item in self.subqueries
def __contains__(self, subq) -> bool:
return subq in self.subqueries
def clause_with_joiner(
self,
@ -481,9 +479,9 @@ class AnyFieldQuery(CollectionQuery):
def clause(self) -> Tuple[Union[str, None], Collection]:
return self.clause_with_joiner('or')
def match(self, item: 'Item') -> bool:
def match(self, obj: Model) -> bool:
for subq in self.subqueries:
if subq.match(item):
if subq.match(obj):
return True
return False
@ -518,8 +516,8 @@ class AndQuery(MutableCollectionQuery):
def clause(self) -> Tuple[Union[str, None], Collection]:
return self.clause_with_joiner('and')
def match(self, item) -> bool:
return all(q.match(item) for q in self.subqueries)
def match(self, obj: Model) -> bool:
return all(q.match(obj) for q in self.subqueries)
class OrQuery(MutableCollectionQuery):
@ -528,8 +526,8 @@ class OrQuery(MutableCollectionQuery):
def clause(self) -> Tuple[Union[str, None], Collection]:
return self.clause_with_joiner('or')
def match(self, item) -> bool:
return any(q.match(item) for q in self.subqueries)
def match(self, obj: Model) -> bool:
return any(q.match(obj) for q in self.subqueries)
class NotQuery(Query):
@ -549,8 +547,8 @@ class NotQuery(Query):
# is handled by match() for slow queries.
return clause, subvals
def match(self, item) -> bool:
return not self.subquery.match(item)
def match(self, obj: Model) -> bool:
return not self.subquery.match(obj)
def __repr__(self) -> str:
return "{0.__class__.__name__}({0.subquery!r})".format(self)
@ -569,7 +567,7 @@ class TrueQuery(Query):
def clause(self) -> Tuple[Union[str, None], Collection]:
return '1', ()
def match(self, item) -> bool:
def match(self, obj: Model) -> bool:
return True
@ -579,7 +577,7 @@ class FalseQuery(Query):
def clause(self) -> Tuple[Union[str, None], Collection]:
return '0', ()
def match(self, item) -> bool:
def match(self, obj: Model) -> bool:
return False
@ -758,10 +756,10 @@ class DateQuery(FieldQuery):
start, end = _parse_periods(pattern)
self.interval = DateInterval.from_periods(start, end)
def match(self, item: 'Item') -> bool:
if self.field not in item:
def match(self, obj: Model) -> bool:
if self.field not in obj:
return False
timestamp = float(item[self.field])
timestamp = float(obj[self.field])
date = datetime.fromtimestamp(timestamp)
return self.interval.contains(date)
@ -822,7 +820,7 @@ class DurationQuery(NumericQuery):
class Sort:
"""An abstract class representing a sort operation for a query into
the item database.
the database.
"""
def order_clause(self) -> None:
@ -938,8 +936,8 @@ class FieldSort(Sort):
# comparisons with None fail. We should also support flexible
# attributes with different types without falling over.
def key(item: 'Item'):
field_val = item.get(self.field, '')
def key(obj: Model) -> Any:
field_val = obj.get(self.field, '')
if self.case_insensitive and isinstance(field_val, str):
field_val = field_val.lower()
return field_val