Add ability to filter flexible attributes through the Query

For a flexible attribute query, replace the `col_name` property with
a function call that extracts that attribute from the `field_attrs`
field introduced in the earlier commit.

Additionally, for boolean, numeric and date queries CAST the value to
NUMERIC SQLite affinity to ensure that our queries like 'flex:1..5' and
'flex:true' continue working fine.

This removes the concept of 'slow query', since every query for any
field now has an SQL clause.
This commit is contained in:
Šarūnas Nejus 2024-05-07 20:03:58 +01:00
parent 969d847110
commit 484c00e223
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
4 changed files with 55 additions and 88 deletions

View file

@ -815,7 +815,6 @@ class Results(Generic[AnyModel]):
model_class: Type[AnyModel], model_class: Type[AnyModel],
rows: List[Mapping], rows: List[Mapping],
db: "Database", db: "Database",
query: Optional[Query] = None,
sort=None, sort=None,
): ):
"""Create a result set that will construct objects of type """Create a result set that will construct objects of type
@ -825,9 +824,7 @@ class Results(Generic[AnyModel]):
constructed. `rows` is a query result: a list of mappings. The constructed. `rows` is a query result: a list of mappings. The
new objects will be associated with the database `db`. new objects will be associated with the database `db`.
If `query` is provided, it is used as a predicate to filter the If `sort` is provided, it is used to sort the
results for a "slow query" that cannot be evaluated by the
database directly. If `sort` is provided, it is used to sort the
full list of results before returning. This means it is a "slow full list of results before returning. This means it is a "slow
sort" and all objects must be built before returning the first sort" and all objects must be built before returning the first
one. one.
@ -835,7 +832,6 @@ class Results(Generic[AnyModel]):
self.model_class = model_class self.model_class = model_class
self.rows = rows self.rows = rows
self.db = db self.db = db
self.query = query
self.sort = sort self.sort = sort
# We keep a queue of rows we haven't yet consumed for # We keep a queue of rows we haven't yet consumed for
@ -871,13 +867,10 @@ class Results(Generic[AnyModel]):
while self._rows: while self._rows:
row = self._rows.pop(0) row = self._rows.pop(0)
obj = self._make_model(row) obj = self._make_model(row)
# If there is a slow-query predicate, ensurer that the self._objects.append(obj)
# object passes it. index += 1
if not self.query or self.query.match(obj): yield obj
self._objects.append(obj) break
index += 1
yield obj
break
def __iter__(self) -> Iterator[AnyModel]: def __iter__(self) -> Iterator[AnyModel]:
"""Construct and generate Model objects for all matching """Construct and generate Model objects for all matching
@ -906,16 +899,8 @@ class Results(Generic[AnyModel]):
if not self._rows: if not self._rows:
# Fully materialized. Just count the objects. # Fully materialized. Just count the objects.
return len(self._objects) return len(self._objects)
elif self.query:
# A slow query. Fall back to testing every object.
count = 0
for obj in self:
count += 1
return count
else: else:
# A fast query. Just count the rows. # Just count the rows.
return self._row_count return self._row_count
def __nonzero__(self) -> bool: def __nonzero__(self) -> bool:
@ -1144,7 +1129,9 @@ class Database:
def regexp(value, pattern): def regexp(value, pattern):
if isinstance(value, bytes): if isinstance(value, bytes):
value = value.decode() value = value.decode()
return re.search(pattern, str(value)) is not None return (
value is not None and re.search(pattern, str(value)) is not None
)
def bytelower(bytestring: Optional[AnyStr]) -> Optional[AnyStr]: def bytelower(bytestring: Optional[AnyStr]) -> Optional[AnyStr]:
"""A custom ``bytelower`` sqlite function so we can compare """A custom ``bytelower`` sqlite function so we can compare
@ -1306,7 +1293,6 @@ class Database:
model_cls, model_cls,
rows, rows,
self, self,
None if where else query, # Slow query component.
sort if sort.is_slow() else None, # Slow sort component. sort if sort.is_slow() else None, # Slow sort component.
) )

View file

@ -93,11 +93,8 @@ class Query(ABC):
Return (clause, subvals) where clause is a valid sqlite Return (clause, subvals) where clause is a valid sqlite
WHERE clause implementing the query and subvals is a list of WHERE clause implementing the query and subvals is a list of
items to be substituted for ?s in the clause. items to be substituted for ?s in the clause.
The default implementation returns None, falling back to a slow query
using `match()`.
""" """
return None, () raise NotImplementedError
@abstractmethod @abstractmethod
def match(self, obj: Model): def match(self, obj: Model):
@ -146,17 +143,16 @@ class FieldQuery(Query, Generic[P]):
@property @property
def col_name(self) -> str: def col_name(self) -> str:
if not self.fast:
return f'json_extract("flex_attrs [json_str]", "$.{self.field}")'
return f"{self.table}.{self.field}" if self.table else self.field return f"{self.table}.{self.field}" if self.table else self.field
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.col_name, () raise NotImplementedError
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
if self.fast: return self.col_clause()
return self.col_clause()
else:
# Matching a flexattr. This is a slow query.
return None, ()
@classmethod @classmethod
def value_match(cls, pattern: P, value: Any): def value_match(cls, pattern: P, value: Any):
@ -305,7 +301,17 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
return pattern.search(cls._normalize(value)) is not None return pattern.search(cls._normalize(value)) is not None
class BooleanQuery(MatchQuery[int]): class NumericColumnQuery(MatchQuery[AnySQLiteType]):
"""A base class for queries that work with NUMERIC SQLite affinity."""
@property
def col_name(self) -> str:
"""Cast a flexible attribute column (string) to NUMERIC affinity."""
col_name = super().col_name
return col_name if self.fast else f"CAST({col_name} AS NUMERIC)"
class BooleanQuery(NumericColumnQuery[bool]):
"""Matches a boolean field. Pattern should either be a boolean or a """Matches a boolean field. Pattern should either be a boolean or a
string reflecting a boolean. string reflecting a boolean.
""" """
@ -357,7 +363,7 @@ class BytesQuery(FieldQuery[bytes]):
return pattern == value return pattern == value
class NumericQuery(FieldQuery[str]): class NumericQuery(NumericColumnQuery[Union[int, float]]):
"""Matches numeric fields. A syntax using Ruby-style range ellipses """Matches numeric fields. A syntax using Ruby-style range ellipses
(``..``) lets users specify one- or two-sided ranges. For example, (``..``) lets users specify one- or two-sided ranges. For example,
``year:2001..`` finds music released since the turn of the century. ``year:2001..`` finds music released since the turn of the century.
@ -483,7 +489,7 @@ class CollectionQuery(Query):
def clause_with_joiner( def clause_with_joiner(
self, self,
joiner: str, joiner: str,
) -> Tuple[Optional[str], Sequence[SQLiteType]]: ) -> Tuple[str, Sequence[SQLiteType]]:
"""Return a clause created by joining together the clauses of """Return a clause created by joining together the clauses of
all subqueries with the string joiner (padded by spaces). all subqueries with the string joiner (padded by spaces).
""" """
@ -491,9 +497,6 @@ class CollectionQuery(Query):
subvals = [] subvals = []
for subq in self.subqueries: for subq in self.subqueries:
subq_clause, subq_subvals = subq.clause() subq_clause, subq_subvals = subq.clause()
if not subq_clause:
# Fall back to slow query.
return None, ()
clause_parts.append("(" + subq_clause + ")") clause_parts.append("(" + subq_clause + ")")
subvals += subq_subvals subvals += subq_subvals
clause = (" " + joiner + " ").join(clause_parts) clause = (" " + joiner + " ").join(clause_parts)
@ -533,7 +536,7 @@ class AnyFieldQuery(CollectionQuery):
def field_names(self) -> Set[str]: def field_names(self) -> Set[str]:
return set(self.fields) return set(self.fields)
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.clause_with_joiner("or") return self.clause_with_joiner("or")
def match(self, obj: Model) -> bool: def match(self, obj: Model) -> bool:
@ -572,7 +575,7 @@ class MutableCollectionQuery(CollectionQuery):
class AndQuery(MutableCollectionQuery): class AndQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries.""" """A conjunction of a list of other queries."""
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.clause_with_joiner("and") return self.clause_with_joiner("and")
def match(self, obj: Model) -> bool: def match(self, obj: Model) -> bool:
@ -582,7 +585,7 @@ class AndQuery(MutableCollectionQuery):
class OrQuery(MutableCollectionQuery): class OrQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries.""" """A conjunction of a list of other queries."""
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.clause_with_joiner("or") return self.clause_with_joiner("or")
def match(self, obj: Model) -> bool: def match(self, obj: Model) -> bool:
@ -602,14 +605,9 @@ class NotQuery(Query):
"""Return a set with field names that this query operates on.""" """Return a set with field names that this query operates on."""
return self.subquery.field_names return self.subquery.field_names
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
clause, subvals = self.subquery.clause() clause, subvals = self.subquery.clause()
if clause: return f"not ({clause})", subvals
return f"not ({clause})", subvals
else:
# If there is no clause, there is nothing to negate. All the logic
# is handled by match() for slow queries.
return clause, subvals
def match(self, obj: Model) -> bool: def match(self, obj: Model) -> bool:
return not self.subquery.match(obj) return not self.subquery.match(obj)
@ -816,7 +814,7 @@ class DateInterval:
return f"[{self.start}, {self.end})" return f"[{self.start}, {self.end})"
class DateQuery(FieldQuery[str]): class DateQuery(NumericColumnQuery[int]):
"""Matches date fields stored as seconds since Unix epoch time. """Matches date fields stored as seconds since Unix epoch time.
Dates can be specified as ``year-month-day`` strings where only year Dates can be specified as ``year-month-day`` strings where only year
@ -910,7 +908,7 @@ class Sort:
return sorted(items) return sorted(items)
def is_slow(self) -> bool: def is_slow(self) -> bool:
"""Indicate whether this query is *slow*, meaning that it cannot """Indicate whether this sort is *slow*, meaning that it cannot
be executed in SQL and must be executed in Python. be executed in SQL and must be executed in Python.
""" """
return False return False

View file

@ -1019,7 +1019,7 @@ class SingletonImportTask(ImportTask):
# temporary `Item` object to generate any computed fields. # temporary `Item` object to generate any computed fields.
tmp_item = library.Item(lib, **info) tmp_item = library.Item(lib, **info)
keys = config["import"]["duplicate_keys"]["item"].as_str_seq() keys = config["import"]["duplicate_keys"]["item"].as_str_seq()
dup_query = library.Album.all_fields_query( dup_query = library.Item.all_fields_query(
{key: tmp_item.get(key) for key in keys} {key: tmp_item.get(key) for key in keys}
) )

View file

@ -857,17 +857,17 @@ class NoneQueryTest(unittest.TestCase, TestHelper):
def test_match_slow(self): def test_match_slow(self):
item = self.add_item() item = self.add_item()
matched = self.lib.items(NoneQuery("rg_track_peak", fast=False)) matched = self.lib.items(NoneQuery("rg_track_peak"))
self.assertInResult(item, matched) self.assertInResult(item, matched)
def test_match_slow_after_set_none(self): def test_match_slow_after_set_none(self):
item = self.add_item(rg_track_gain=0) item = self.add_item(rg_track_gain=0)
matched = self.lib.items(NoneQuery("rg_track_gain", fast=False)) matched = self.lib.items(NoneQuery("rg_track_gain"))
self.assertNotInResult(item, matched) self.assertNotInResult(item, matched)
item["rg_track_gain"] = None item["rg_track_gain"] = None
item.store() item.store()
matched = self.lib.items(NoneQuery("rg_track_gain", fast=False)) matched = self.lib.items(NoneQuery("rg_track_gain"))
self.assertInResult(item, matched) self.assertInResult(item, matched)
@ -1097,37 +1097,6 @@ class NotQueryTest(DummyDataTestCase):
results = self.lib.items(q) results = self.lib.items(q)
self.assert_items_matched(results, ["baz qux"]) self.assert_items_matched(results, ["baz qux"])
def test_fast_vs_slow(self):
"""Test that the results are the same regardless of the `fast` flag
for negated `FieldQuery`s.
TODO: investigate NoneQuery(fast=False), as it is raising
AttributeError: type object 'NoneQuery' has no attribute 'field'
at NoneQuery.match() (due to being @classmethod, and no self?)
"""
classes = [
(dbcore.query.DateQuery, ["added", "2001-01-01"]),
(dbcore.query.MatchQuery, ["artist", "one"]),
# (dbcore.query.NoneQuery, ['rg_track_gain']),
(dbcore.query.NumericQuery, ["year", "2002"]),
(dbcore.query.StringFieldQuery, ["year", "2001"]),
(dbcore.query.RegexpQuery, ["album", "^.a"]),
(dbcore.query.SubstringQuery, ["title", "x"]),
]
for klass, args in classes:
q_fast = dbcore.query.NotQuery(klass(*(args + [True])))
q_slow = dbcore.query.NotQuery(klass(*(args + [False])))
try:
self.assertEqual(
[i.title for i in self.lib.items(q_fast)],
[i.title for i in self.lib.items(q_slow)],
)
except NotImplementedError:
# ignore classes that do not provide `fast` implementation
pass
class RelatedQueriesTest(_common.TestCase, AssertsMixin): class RelatedQueriesTest(_common.TestCase, AssertsMixin):
"""Test album-level queries with track-level filters and vice-versa.""" """Test album-level queries with track-level filters and vice-versa."""
@ -1143,12 +1112,16 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
for item_idx in range(1, 3): for item_idx in range(1, 3):
item = _common.item() item = _common.item()
item.album = album_name item.album = album_name
item.title = f"{album_name} Item{item_idx}" title = f"{album_name} Item{item_idx}"
item.title = title
item.item_flex1 = f"{title} Flex1"
item.item_flex2 = f"{title} Flex2"
self.lib.add(item) self.lib.add(item)
album_items.append(item) album_items.append(item)
album = self.lib.add_album(album_items) album = self.lib.add_album(album_items)
album.artpath = f"{album_name} Artpath" album.artpath = f"{album_name} Artpath"
album.catalognum = "ABC" album.catalognum = "ABC"
album.album_flex = f"{album_name} Flex"
album.store() album.store()
albums.append(album) albums.append(album)
@ -1169,6 +1142,16 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
results = self.lib.albums(q) results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"]) self.assert_albums_matched(results, ["Album1"])
def test_get_items_filter_by_track_flex(self):
q = "item_flex1:Item1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"])
def test_get_albums_filter_by_album_flex(self):
q = "album_flex:Album1"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def suite(): def suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)