Merge pull request #4827 from wisp3rwind/dbcore_typing_1

dbcore/query: remove spurious dependency on library
This commit is contained in:
Adrian Sampson 2023-06-23 17:17:22 -07:00 committed by GitHub
commit 511824028c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -30,7 +30,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from beets.library import Item
from beets.dbcore import Model
@ -68,7 +67,7 @@ class InvalidQueryArgumentValueError(ParsingError):
class Query:
"""An abstract class representing a query into the item database.
"""An abstract class representing a query into the database.
"""
def clause(self) -> Tuple[None, Tuple]:
@ -80,9 +79,9 @@ class Query:
"""
return None, ()
def match(self, item: Item):
"""Check whether this query matches a given Item. Can be used to
perform queries on arbitrary sets of Items.
def match(self, obj: Model) -> bool:
"""Check whether this query matches a given Model. Can be used to
perform queries on arbitrary sets of Model.
"""
raise NotImplementedError
@ -126,8 +125,8 @@ class FieldQuery(Query):
"""
raise NotImplementedError()
def match(self, item: Model):
return self.value_match(self.pattern, item.get(self.field))
def match(self, obj: Model) -> bool:
return self.value_match(self.pattern, obj.get(self.field))
def __repr__(self) -> str:
return ("{0.__class__.__name__}({0.field!r}, {0.pattern!r}, "
@ -142,7 +141,7 @@ class FieldQuery(Query):
class MatchQuery(FieldQuery):
"""A query that looks for exact matches in an item field."""
"""A query that looks for exact matches in an Model field."""
def col_clause(self) -> Tuple[str, List[str]]:
return self.field + " = ?", [self.pattern]
@ -161,8 +160,8 @@ class NoneQuery(FieldQuery):
def col_clause(self) -> Tuple[str, Tuple]:
return self.field + " IS NULL", ()
def match(self, item: 'Item') -> bool:
return item.get(self.field) is None
def match(self, obj: Model) -> bool:
return obj.get(self.field) is None
def __repr__(self) -> str:
return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self)
@ -193,7 +192,7 @@ class StringFieldQuery(FieldQuery):
class StringQuery(StringFieldQuery):
"""A query that matches a whole string in a specific item field."""
"""A query that matches a whole string in a specific Model field."""
def col_clause(self) -> Tuple[str, List[str]]:
search = (self.pattern
@ -210,7 +209,7 @@ class StringQuery(StringFieldQuery):
class SubstringQuery(StringFieldQuery):
"""A query that matches a substring in a specific item field."""
"""A query that matches a substring in a specific Model field."""
def col_clause(self) -> Tuple[str, List[str]]:
pattern = (self.pattern
@ -228,8 +227,7 @@ class SubstringQuery(StringFieldQuery):
class RegexpQuery(StringFieldQuery):
"""A query that matches a regular expression in a specific item
field.
"""A query that matches a regular expression in a specific Model field.
Raises InvalidQueryError when the pattern is not a valid regular
expression.
@ -344,10 +342,10 @@ class NumericQuery(FieldQuery):
self.rangemin = self._convert(parts[0])
self.rangemax = self._convert(parts[1])
def match(self, item: 'Item') -> bool:
if self.field not in item:
def match(self, obj: Model) -> bool:
if self.field not in obj:
return False
value = item[self.field]
value = obj[self.field]
if isinstance(value, str):
value = self._convert(value)
@ -394,8 +392,8 @@ class CollectionQuery(Query):
def __iter__(self) -> Iterator:
return iter(self.subqueries)
def __contains__(self, item) -> bool:
return item in self.subqueries
def __contains__(self, subq) -> bool:
return subq in self.subqueries
def clause_with_joiner(
self,
@ -450,9 +448,9 @@ class AnyFieldQuery(CollectionQuery):
def clause(self) -> Tuple[Union[str, None], Collection]:
return self.clause_with_joiner('or')
def match(self, item: 'Item') -> bool:
def match(self, obj: Model) -> bool:
for subq in self.subqueries:
if subq.match(item):
if subq.match(obj):
return True
return False
@ -487,8 +485,8 @@ class AndQuery(MutableCollectionQuery):
def clause(self) -> Tuple[Union[str, None], Collection]:
return self.clause_with_joiner('and')
def match(self, item) -> bool:
return all(q.match(item) for q in self.subqueries)
def match(self, obj: Model) -> bool:
return all(q.match(obj) for q in self.subqueries)
class OrQuery(MutableCollectionQuery):
@ -497,8 +495,8 @@ class OrQuery(MutableCollectionQuery):
def clause(self) -> Tuple[Union[str, None], Collection]:
return self.clause_with_joiner('or')
def match(self, item) -> bool:
return any(q.match(item) for q in self.subqueries)
def match(self, obj: Model) -> bool:
return any(q.match(obj) for q in self.subqueries)
class NotQuery(Query):
@ -518,8 +516,8 @@ class NotQuery(Query):
# is handled by match() for slow queries.
return clause, subvals
def match(self, item) -> bool:
return not self.subquery.match(item)
def match(self, obj: Model) -> bool:
return not self.subquery.match(obj)
def __repr__(self) -> str:
return "{0.__class__.__name__}({0.subquery!r})".format(self)
@ -538,7 +536,7 @@ class TrueQuery(Query):
def clause(self) -> Tuple[Union[str, None], Collection]:
return '1', ()
def match(self, item) -> bool:
def match(self, obj: Model) -> bool:
return True
@ -548,7 +546,7 @@ class FalseQuery(Query):
def clause(self) -> Tuple[Union[str, None], Collection]:
return '0', ()
def match(self, item) -> bool:
def match(self, obj: Model) -> bool:
return False
@ -727,10 +725,10 @@ class DateQuery(FieldQuery):
start, end = _parse_periods(pattern)
self.interval = DateInterval.from_periods(start, end)
def match(self, item: 'Item') -> bool:
if self.field not in item:
def match(self, obj: Model) -> bool:
if self.field not in obj:
return False
timestamp = float(item[self.field])
timestamp = float(obj[self.field])
date = datetime.fromtimestamp(timestamp)
return self.interval.contains(date)
@ -791,7 +789,7 @@ class DurationQuery(NumericQuery):
class Sort:
"""An abstract class representing a sort operation for a query into
the item database.
the database.
"""
def order_clause(self) -> None:
@ -907,8 +905,8 @@ class FieldSort(Sort):
# comparisons with None fail. We should also support flexible
# attributes with different types without falling over.
def key(item: 'Item'):
field_val = item.get(self.field, '')
def key(obj: Model) -> Any:
field_val = obj.get(self.field, '')
if self.case_insensitive and isinstance(field_val, str):
field_val = field_val.lower()
return field_val