Add ability to filter flexible attributes through the Query

For a flexible attribute query, replace the `col_name` property with
a function call that extracts that attribute from the `field_attrs`
field introduced in the earlier commit.

Additionally, for boolean, numeric and date queries CAST the value to
NUMERIC SQLite affinity to ensure that our queries like 'flex:1..5' and
'flex:true' continue working fine.

This removes the concept of 'slow query', since every query for any
field now has an SQL clause.
This commit is contained in:
Šarūnas Nejus 2024-05-07 20:03:58 +01:00
parent 969d847110
commit 484c00e223
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
4 changed files with 55 additions and 88 deletions

View file

@ -815,7 +815,6 @@ class Results(Generic[AnyModel]):
model_class: Type[AnyModel],
rows: List[Mapping],
db: "Database",
query: Optional[Query] = None,
sort=None,
):
"""Create a result set that will construct objects of type
@ -825,9 +824,7 @@ class Results(Generic[AnyModel]):
constructed. `rows` is a query result: a list of mappings. The
new objects will be associated with the database `db`.
If `query` is provided, it is used as a predicate to filter the
results for a "slow query" that cannot be evaluated by the
database directly. If `sort` is provided, it is used to sort the
If `sort` is provided, it is used to sort the
full list of results before returning. This means it is a "slow
sort" and all objects must be built before returning the first
one.
@ -835,7 +832,6 @@ class Results(Generic[AnyModel]):
self.model_class = model_class
self.rows = rows
self.db = db
self.query = query
self.sort = sort
# We keep a queue of rows we haven't yet consumed for
@ -871,13 +867,10 @@ class Results(Generic[AnyModel]):
while self._rows:
row = self._rows.pop(0)
obj = self._make_model(row)
# If there is a slow-query predicate, ensurer that the
# object passes it.
if not self.query or self.query.match(obj):
self._objects.append(obj)
index += 1
yield obj
break
self._objects.append(obj)
index += 1
yield obj
break
def __iter__(self) -> Iterator[AnyModel]:
"""Construct and generate Model objects for all matching
@ -906,16 +899,8 @@ class Results(Generic[AnyModel]):
if not self._rows:
# Fully materialized. Just count the objects.
return len(self._objects)
elif self.query:
# A slow query. Fall back to testing every object.
count = 0
for obj in self:
count += 1
return count
else:
# A fast query. Just count the rows.
# Just count the rows.
return self._row_count
def __nonzero__(self) -> bool:
@ -1144,7 +1129,9 @@ class Database:
def regexp(value, pattern):
if isinstance(value, bytes):
value = value.decode()
return re.search(pattern, str(value)) is not None
return (
value is not None and re.search(pattern, str(value)) is not None
)
def bytelower(bytestring: Optional[AnyStr]) -> Optional[AnyStr]:
"""A custom ``bytelower`` sqlite function so we can compare
@ -1306,7 +1293,6 @@ class Database:
model_cls,
rows,
self,
None if where else query, # Slow query component.
sort if sort.is_slow() else None, # Slow sort component.
)

View file

@ -93,11 +93,8 @@ class Query(ABC):
Return (clause, subvals) where clause is a valid sqlite
WHERE clause implementing the query and subvals is a list of
items to be substituted for ?s in the clause.
The default implementation returns None, falling back to a slow query
using `match()`.
"""
return None, ()
raise NotImplementedError
@abstractmethod
def match(self, obj: Model):
@ -146,17 +143,16 @@ class FieldQuery(Query, Generic[P]):
@property
def col_name(self) -> str:
if not self.fast:
return f'json_extract("flex_attrs [json_str]", "$.{self.field}")'
return f"{self.table}.{self.field}" if self.table else self.field
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.col_name, ()
raise NotImplementedError
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
if self.fast:
return self.col_clause()
else:
# Matching a flexattr. This is a slow query.
return None, ()
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.col_clause()
@classmethod
def value_match(cls, pattern: P, value: Any):
@ -305,7 +301,17 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
return pattern.search(cls._normalize(value)) is not None
class BooleanQuery(MatchQuery[int]):
class NumericColumnQuery(MatchQuery[AnySQLiteType]):
"""A base class for queries that work with NUMERIC SQLite affinity."""
@property
def col_name(self) -> str:
"""Cast a flexible attribute column (string) to NUMERIC affinity."""
col_name = super().col_name
return col_name if self.fast else f"CAST({col_name} AS NUMERIC)"
class BooleanQuery(NumericColumnQuery[bool]):
"""Matches a boolean field. Pattern should either be a boolean or a
string reflecting a boolean.
"""
@ -357,7 +363,7 @@ class BytesQuery(FieldQuery[bytes]):
return pattern == value
class NumericQuery(FieldQuery[str]):
class NumericQuery(NumericColumnQuery[Union[int, float]]):
"""Matches numeric fields. A syntax using Ruby-style range ellipses
(``..``) lets users specify one- or two-sided ranges. For example,
``year:2001..`` finds music released since the turn of the century.
@ -483,7 +489,7 @@ class CollectionQuery(Query):
def clause_with_joiner(
self,
joiner: str,
) -> Tuple[Optional[str], Sequence[SQLiteType]]:
) -> Tuple[str, Sequence[SQLiteType]]:
"""Return a clause created by joining together the clauses of
all subqueries with the string joiner (padded by spaces).
"""
@ -491,9 +497,6 @@ class CollectionQuery(Query):
subvals = []
for subq in self.subqueries:
subq_clause, subq_subvals = subq.clause()
if not subq_clause:
# Fall back to slow query.
return None, ()
clause_parts.append("(" + subq_clause + ")")
subvals += subq_subvals
clause = (" " + joiner + " ").join(clause_parts)
@ -533,7 +536,7 @@ class AnyFieldQuery(CollectionQuery):
def field_names(self) -> Set[str]:
return set(self.fields)
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.clause_with_joiner("or")
def match(self, obj: Model) -> bool:
@ -572,7 +575,7 @@ class MutableCollectionQuery(CollectionQuery):
class AndQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.clause_with_joiner("and")
def match(self, obj: Model) -> bool:
@ -582,7 +585,7 @@ class AndQuery(MutableCollectionQuery):
class OrQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.clause_with_joiner("or")
def match(self, obj: Model) -> bool:
@ -602,14 +605,9 @@ class NotQuery(Query):
"""Return a set with field names that this query operates on."""
return self.subquery.field_names
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
clause, subvals = self.subquery.clause()
if clause:
return f"not ({clause})", subvals
else:
# If there is no clause, there is nothing to negate. All the logic
# is handled by match() for slow queries.
return clause, subvals
return f"not ({clause})", subvals
def match(self, obj: Model) -> bool:
return not self.subquery.match(obj)
@ -816,7 +814,7 @@ class DateInterval:
return f"[{self.start}, {self.end})"
class DateQuery(FieldQuery[str]):
class DateQuery(NumericColumnQuery[int]):
"""Matches date fields stored as seconds since Unix epoch time.
Dates can be specified as ``year-month-day`` strings where only year
@ -910,7 +908,7 @@ class Sort:
return sorted(items)
def is_slow(self) -> bool:
"""Indicate whether this query is *slow*, meaning that it cannot
"""Indicate whether this sort is *slow*, meaning that it cannot
be executed in SQL and must be executed in Python.
"""
return False

View file

@ -1019,7 +1019,7 @@ class SingletonImportTask(ImportTask):
# temporary `Item` object to generate any computed fields.
tmp_item = library.Item(lib, **info)
keys = config["import"]["duplicate_keys"]["item"].as_str_seq()
dup_query = library.Album.all_fields_query(
dup_query = library.Item.all_fields_query(
{key: tmp_item.get(key) for key in keys}
)

View file

@ -857,17 +857,17 @@ class NoneQueryTest(unittest.TestCase, TestHelper):
def test_match_slow(self):
item = self.add_item()
matched = self.lib.items(NoneQuery("rg_track_peak", fast=False))
matched = self.lib.items(NoneQuery("rg_track_peak"))
self.assertInResult(item, matched)
def test_match_slow_after_set_none(self):
item = self.add_item(rg_track_gain=0)
matched = self.lib.items(NoneQuery("rg_track_gain", fast=False))
matched = self.lib.items(NoneQuery("rg_track_gain"))
self.assertNotInResult(item, matched)
item["rg_track_gain"] = None
item.store()
matched = self.lib.items(NoneQuery("rg_track_gain", fast=False))
matched = self.lib.items(NoneQuery("rg_track_gain"))
self.assertInResult(item, matched)
@ -1097,37 +1097,6 @@ class NotQueryTest(DummyDataTestCase):
results = self.lib.items(q)
self.assert_items_matched(results, ["baz qux"])
def test_fast_vs_slow(self):
"""Test that the results are the same regardless of the `fast` flag
for negated `FieldQuery`s.
TODO: investigate NoneQuery(fast=False), as it is raising
AttributeError: type object 'NoneQuery' has no attribute 'field'
at NoneQuery.match() (due to being @classmethod, and no self?)
"""
classes = [
(dbcore.query.DateQuery, ["added", "2001-01-01"]),
(dbcore.query.MatchQuery, ["artist", "one"]),
# (dbcore.query.NoneQuery, ['rg_track_gain']),
(dbcore.query.NumericQuery, ["year", "2002"]),
(dbcore.query.StringFieldQuery, ["year", "2001"]),
(dbcore.query.RegexpQuery, ["album", "^.a"]),
(dbcore.query.SubstringQuery, ["title", "x"]),
]
for klass, args in classes:
q_fast = dbcore.query.NotQuery(klass(*(args + [True])))
q_slow = dbcore.query.NotQuery(klass(*(args + [False])))
try:
self.assertEqual(
[i.title for i in self.lib.items(q_fast)],
[i.title for i in self.lib.items(q_slow)],
)
except NotImplementedError:
# ignore classes that do not provide `fast` implementation
pass
class RelatedQueriesTest(_common.TestCase, AssertsMixin):
"""Test album-level queries with track-level filters and vice-versa."""
@ -1143,12 +1112,16 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
for item_idx in range(1, 3):
item = _common.item()
item.album = album_name
item.title = f"{album_name} Item{item_idx}"
title = f"{album_name} Item{item_idx}"
item.title = title
item.item_flex1 = f"{title} Flex1"
item.item_flex2 = f"{title} Flex2"
self.lib.add(item)
album_items.append(item)
album = self.lib.add_album(album_items)
album.artpath = f"{album_name} Artpath"
album.catalognum = "ABC"
album.album_flex = f"{album_name} Flex"
album.store()
albums.append(album)
@ -1169,6 +1142,16 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def test_get_items_filter_by_track_flex(self):
q = "item_flex1:Item1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"])
def test_get_albums_filter_by_album_flex(self):
q = "album_flex:Album1"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def suite():
return unittest.TestLoader().loadTestsFromName(__name__)