mirror of
https://github.com/beetbox/beets.git
synced 2025-12-06 08:39:17 +01:00
Add ability to filter flexible attributes through the Query
For a flexible attribute query, replace the `col_name` property with a function call that extracts that attribute from the `field_attrs` field introduced in the earlier commit. Additionally, for boolean, numeric and date queries CAST the value to NUMERIC SQLite affinity to ensure that our queries like 'flex:1..5' and 'flex:true' continue working fine. This removes the concept of 'slow query', since every query for any field now has an SQL clause.
This commit is contained in:
parent
969d847110
commit
484c00e223
4 changed files with 55 additions and 88 deletions
|
|
@ -815,7 +815,6 @@ class Results(Generic[AnyModel]):
|
|||
model_class: Type[AnyModel],
|
||||
rows: List[Mapping],
|
||||
db: "Database",
|
||||
query: Optional[Query] = None,
|
||||
sort=None,
|
||||
):
|
||||
"""Create a result set that will construct objects of type
|
||||
|
|
@ -825,9 +824,7 @@ class Results(Generic[AnyModel]):
|
|||
constructed. `rows` is a query result: a list of mappings. The
|
||||
new objects will be associated with the database `db`.
|
||||
|
||||
If `query` is provided, it is used as a predicate to filter the
|
||||
results for a "slow query" that cannot be evaluated by the
|
||||
database directly. If `sort` is provided, it is used to sort the
|
||||
If `sort` is provided, it is used to sort the
|
||||
full list of results before returning. This means it is a "slow
|
||||
sort" and all objects must be built before returning the first
|
||||
one.
|
||||
|
|
@ -835,7 +832,6 @@ class Results(Generic[AnyModel]):
|
|||
self.model_class = model_class
|
||||
self.rows = rows
|
||||
self.db = db
|
||||
self.query = query
|
||||
self.sort = sort
|
||||
|
||||
# We keep a queue of rows we haven't yet consumed for
|
||||
|
|
@ -871,13 +867,10 @@ class Results(Generic[AnyModel]):
|
|||
while self._rows:
|
||||
row = self._rows.pop(0)
|
||||
obj = self._make_model(row)
|
||||
# If there is a slow-query predicate, ensurer that the
|
||||
# object passes it.
|
||||
if not self.query or self.query.match(obj):
|
||||
self._objects.append(obj)
|
||||
index += 1
|
||||
yield obj
|
||||
break
|
||||
self._objects.append(obj)
|
||||
index += 1
|
||||
yield obj
|
||||
break
|
||||
|
||||
def __iter__(self) -> Iterator[AnyModel]:
|
||||
"""Construct and generate Model objects for all matching
|
||||
|
|
@ -906,16 +899,8 @@ class Results(Generic[AnyModel]):
|
|||
if not self._rows:
|
||||
# Fully materialized. Just count the objects.
|
||||
return len(self._objects)
|
||||
|
||||
elif self.query:
|
||||
# A slow query. Fall back to testing every object.
|
||||
count = 0
|
||||
for obj in self:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
else:
|
||||
# A fast query. Just count the rows.
|
||||
# Just count the rows.
|
||||
return self._row_count
|
||||
|
||||
def __nonzero__(self) -> bool:
|
||||
|
|
@ -1144,7 +1129,9 @@ class Database:
|
|||
def regexp(value, pattern):
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
return re.search(pattern, str(value)) is not None
|
||||
return (
|
||||
value is not None and re.search(pattern, str(value)) is not None
|
||||
)
|
||||
|
||||
def bytelower(bytestring: Optional[AnyStr]) -> Optional[AnyStr]:
|
||||
"""A custom ``bytelower`` sqlite function so we can compare
|
||||
|
|
@ -1306,7 +1293,6 @@ class Database:
|
|||
model_cls,
|
||||
rows,
|
||||
self,
|
||||
None if where else query, # Slow query component.
|
||||
sort if sort.is_slow() else None, # Slow sort component.
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -93,11 +93,8 @@ class Query(ABC):
|
|||
Return (clause, subvals) where clause is a valid sqlite
|
||||
WHERE clause implementing the query and subvals is a list of
|
||||
items to be substituted for ?s in the clause.
|
||||
|
||||
The default implementation returns None, falling back to a slow query
|
||||
using `match()`.
|
||||
"""
|
||||
return None, ()
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def match(self, obj: Model):
|
||||
|
|
@ -146,17 +143,16 @@ class FieldQuery(Query, Generic[P]):
|
|||
|
||||
@property
|
||||
def col_name(self) -> str:
|
||||
if not self.fast:
|
||||
return f'json_extract("flex_attrs [json_str]", "$.{self.field}")'
|
||||
|
||||
return f"{self.table}.{self.field}" if self.table else self.field
|
||||
|
||||
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
return self.col_name, ()
|
||||
raise NotImplementedError
|
||||
|
||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
||||
if self.fast:
|
||||
return self.col_clause()
|
||||
else:
|
||||
# Matching a flexattr. This is a slow query.
|
||||
return None, ()
|
||||
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
return self.col_clause()
|
||||
|
||||
@classmethod
|
||||
def value_match(cls, pattern: P, value: Any):
|
||||
|
|
@ -305,7 +301,17 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
|
|||
return pattern.search(cls._normalize(value)) is not None
|
||||
|
||||
|
||||
class BooleanQuery(MatchQuery[int]):
|
||||
class NumericColumnQuery(MatchQuery[AnySQLiteType]):
|
||||
"""A base class for queries that work with NUMERIC SQLite affinity."""
|
||||
|
||||
@property
|
||||
def col_name(self) -> str:
|
||||
"""Cast a flexible attribute column (string) to NUMERIC affinity."""
|
||||
col_name = super().col_name
|
||||
return col_name if self.fast else f"CAST({col_name} AS NUMERIC)"
|
||||
|
||||
|
||||
class BooleanQuery(NumericColumnQuery[bool]):
|
||||
"""Matches a boolean field. Pattern should either be a boolean or a
|
||||
string reflecting a boolean.
|
||||
"""
|
||||
|
|
@ -357,7 +363,7 @@ class BytesQuery(FieldQuery[bytes]):
|
|||
return pattern == value
|
||||
|
||||
|
||||
class NumericQuery(FieldQuery[str]):
|
||||
class NumericQuery(NumericColumnQuery[Union[int, float]]):
|
||||
"""Matches numeric fields. A syntax using Ruby-style range ellipses
|
||||
(``..``) lets users specify one- or two-sided ranges. For example,
|
||||
``year:2001..`` finds music released since the turn of the century.
|
||||
|
|
@ -483,7 +489,7 @@ class CollectionQuery(Query):
|
|||
def clause_with_joiner(
|
||||
self,
|
||||
joiner: str,
|
||||
) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
||||
) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
"""Return a clause created by joining together the clauses of
|
||||
all subqueries with the string joiner (padded by spaces).
|
||||
"""
|
||||
|
|
@ -491,9 +497,6 @@ class CollectionQuery(Query):
|
|||
subvals = []
|
||||
for subq in self.subqueries:
|
||||
subq_clause, subq_subvals = subq.clause()
|
||||
if not subq_clause:
|
||||
# Fall back to slow query.
|
||||
return None, ()
|
||||
clause_parts.append("(" + subq_clause + ")")
|
||||
subvals += subq_subvals
|
||||
clause = (" " + joiner + " ").join(clause_parts)
|
||||
|
|
@ -533,7 +536,7 @@ class AnyFieldQuery(CollectionQuery):
|
|||
def field_names(self) -> Set[str]:
|
||||
return set(self.fields)
|
||||
|
||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
||||
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
return self.clause_with_joiner("or")
|
||||
|
||||
def match(self, obj: Model) -> bool:
|
||||
|
|
@ -572,7 +575,7 @@ class MutableCollectionQuery(CollectionQuery):
|
|||
class AndQuery(MutableCollectionQuery):
|
||||
"""A conjunction of a list of other queries."""
|
||||
|
||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
||||
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
return self.clause_with_joiner("and")
|
||||
|
||||
def match(self, obj: Model) -> bool:
|
||||
|
|
@ -582,7 +585,7 @@ class AndQuery(MutableCollectionQuery):
|
|||
class OrQuery(MutableCollectionQuery):
|
||||
"""A conjunction of a list of other queries."""
|
||||
|
||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
||||
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
return self.clause_with_joiner("or")
|
||||
|
||||
def match(self, obj: Model) -> bool:
|
||||
|
|
@ -602,14 +605,9 @@ class NotQuery(Query):
|
|||
"""Return a set with field names that this query operates on."""
|
||||
return self.subquery.field_names
|
||||
|
||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
||||
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
clause, subvals = self.subquery.clause()
|
||||
if clause:
|
||||
return f"not ({clause})", subvals
|
||||
else:
|
||||
# If there is no clause, there is nothing to negate. All the logic
|
||||
# is handled by match() for slow queries.
|
||||
return clause, subvals
|
||||
return f"not ({clause})", subvals
|
||||
|
||||
def match(self, obj: Model) -> bool:
|
||||
return not self.subquery.match(obj)
|
||||
|
|
@ -816,7 +814,7 @@ class DateInterval:
|
|||
return f"[{self.start}, {self.end})"
|
||||
|
||||
|
||||
class DateQuery(FieldQuery[str]):
|
||||
class DateQuery(NumericColumnQuery[int]):
|
||||
"""Matches date fields stored as seconds since Unix epoch time.
|
||||
|
||||
Dates can be specified as ``year-month-day`` strings where only year
|
||||
|
|
@ -910,7 +908,7 @@ class Sort:
|
|||
return sorted(items)
|
||||
|
||||
def is_slow(self) -> bool:
|
||||
"""Indicate whether this query is *slow*, meaning that it cannot
|
||||
"""Indicate whether this sort is *slow*, meaning that it cannot
|
||||
be executed in SQL and must be executed in Python.
|
||||
"""
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -1019,7 +1019,7 @@ class SingletonImportTask(ImportTask):
|
|||
# temporary `Item` object to generate any computed fields.
|
||||
tmp_item = library.Item(lib, **info)
|
||||
keys = config["import"]["duplicate_keys"]["item"].as_str_seq()
|
||||
dup_query = library.Album.all_fields_query(
|
||||
dup_query = library.Item.all_fields_query(
|
||||
{key: tmp_item.get(key) for key in keys}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -857,17 +857,17 @@ class NoneQueryTest(unittest.TestCase, TestHelper):
|
|||
|
||||
def test_match_slow(self):
|
||||
item = self.add_item()
|
||||
matched = self.lib.items(NoneQuery("rg_track_peak", fast=False))
|
||||
matched = self.lib.items(NoneQuery("rg_track_peak"))
|
||||
self.assertInResult(item, matched)
|
||||
|
||||
def test_match_slow_after_set_none(self):
|
||||
item = self.add_item(rg_track_gain=0)
|
||||
matched = self.lib.items(NoneQuery("rg_track_gain", fast=False))
|
||||
matched = self.lib.items(NoneQuery("rg_track_gain"))
|
||||
self.assertNotInResult(item, matched)
|
||||
|
||||
item["rg_track_gain"] = None
|
||||
item.store()
|
||||
matched = self.lib.items(NoneQuery("rg_track_gain", fast=False))
|
||||
matched = self.lib.items(NoneQuery("rg_track_gain"))
|
||||
self.assertInResult(item, matched)
|
||||
|
||||
|
||||
|
|
@ -1097,37 +1097,6 @@ class NotQueryTest(DummyDataTestCase):
|
|||
results = self.lib.items(q)
|
||||
self.assert_items_matched(results, ["baz qux"])
|
||||
|
||||
def test_fast_vs_slow(self):
|
||||
"""Test that the results are the same regardless of the `fast` flag
|
||||
for negated `FieldQuery`s.
|
||||
|
||||
TODO: investigate NoneQuery(fast=False), as it is raising
|
||||
AttributeError: type object 'NoneQuery' has no attribute 'field'
|
||||
at NoneQuery.match() (due to being @classmethod, and no self?)
|
||||
"""
|
||||
classes = [
|
||||
(dbcore.query.DateQuery, ["added", "2001-01-01"]),
|
||||
(dbcore.query.MatchQuery, ["artist", "one"]),
|
||||
# (dbcore.query.NoneQuery, ['rg_track_gain']),
|
||||
(dbcore.query.NumericQuery, ["year", "2002"]),
|
||||
(dbcore.query.StringFieldQuery, ["year", "2001"]),
|
||||
(dbcore.query.RegexpQuery, ["album", "^.a"]),
|
||||
(dbcore.query.SubstringQuery, ["title", "x"]),
|
||||
]
|
||||
|
||||
for klass, args in classes:
|
||||
q_fast = dbcore.query.NotQuery(klass(*(args + [True])))
|
||||
q_slow = dbcore.query.NotQuery(klass(*(args + [False])))
|
||||
|
||||
try:
|
||||
self.assertEqual(
|
||||
[i.title for i in self.lib.items(q_fast)],
|
||||
[i.title for i in self.lib.items(q_slow)],
|
||||
)
|
||||
except NotImplementedError:
|
||||
# ignore classes that do not provide `fast` implementation
|
||||
pass
|
||||
|
||||
|
||||
class RelatedQueriesTest(_common.TestCase, AssertsMixin):
|
||||
"""Test album-level queries with track-level filters and vice-versa."""
|
||||
|
|
@ -1143,12 +1112,16 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
|
|||
for item_idx in range(1, 3):
|
||||
item = _common.item()
|
||||
item.album = album_name
|
||||
item.title = f"{album_name} Item{item_idx}"
|
||||
title = f"{album_name} Item{item_idx}"
|
||||
item.title = title
|
||||
item.item_flex1 = f"{title} Flex1"
|
||||
item.item_flex2 = f"{title} Flex2"
|
||||
self.lib.add(item)
|
||||
album_items.append(item)
|
||||
album = self.lib.add_album(album_items)
|
||||
album.artpath = f"{album_name} Artpath"
|
||||
album.catalognum = "ABC"
|
||||
album.album_flex = f"{album_name} Flex"
|
||||
album.store()
|
||||
albums.append(album)
|
||||
|
||||
|
|
@ -1169,6 +1142,16 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
|
|||
results = self.lib.albums(q)
|
||||
self.assert_albums_matched(results, ["Album1"])
|
||||
|
||||
def test_get_items_filter_by_track_flex(self):
|
||||
q = "item_flex1:Item1"
|
||||
results = self.lib.items(q)
|
||||
self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"])
|
||||
|
||||
def test_get_albums_filter_by_album_flex(self):
|
||||
q = "album_flex:Album1"
|
||||
results = self.lib.albums(q)
|
||||
self.assert_albums_matched(results, ["Album1"])
|
||||
|
||||
|
||||
def suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
|
|
|||
Loading…
Reference in a new issue