From 484c00e22352c41d7f3b5cd74b6d6e5eff22fd6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Tue, 7 May 2024 20:03:58 +0100 Subject: [PATCH] Add ability to filter flexible attributes through the Query For a flexible attribute query, replace the `col_name` property with a function call that extracts that attribute from the `field_attrs` field introduced in the earlier commit. Additionally, for boolean, numeric and date queries CAST the value to NUMERIC SQLite affinity to ensure that our queries like 'flex:1..5' and 'flex:true' continue working fine. This removes the concept of 'slow query', since every query for any field now has an SQL clause. --- beets/dbcore/db.py | 32 +++++++------------------ beets/dbcore/query.py | 56 +++++++++++++++++++++---------------------- beets/importer.py | 2 +- test/test_query.py | 53 ++++++++++++++-------------------------- 4 files changed, 55 insertions(+), 88 deletions(-) diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 74d8d7f74..a73b4515f 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -815,7 +815,6 @@ class Results(Generic[AnyModel]): model_class: Type[AnyModel], rows: List[Mapping], db: "Database", - query: Optional[Query] = None, sort=None, ): """Create a result set that will construct objects of type @@ -825,9 +824,7 @@ class Results(Generic[AnyModel]): constructed. `rows` is a query result: a list of mappings. The new objects will be associated with the database `db`. - If `query` is provided, it is used as a predicate to filter the - results for a "slow query" that cannot be evaluated by the - database directly. If `sort` is provided, it is used to sort the + If `sort` is provided, it is used to sort the full list of results before returning. This means it is a "slow sort" and all objects must be built before returning the first one. @@ -835,7 +832,6 @@ class Results(Generic[AnyModel]): self.model_class = model_class self.rows = rows self.db = db - self.query = query self.sort = sort # We keep a queue of rows we haven't yet consumed for @@ -871,13 +867,10 @@ class Results(Generic[AnyModel]): while self._rows: row = self._rows.pop(0) obj = self._make_model(row) - # If there is a slow-query predicate, ensurer that the - # object passes it. - if not self.query or self.query.match(obj): - self._objects.append(obj) - index += 1 - yield obj - break + self._objects.append(obj) + index += 1 + yield obj + break def __iter__(self) -> Iterator[AnyModel]: """Construct and generate Model objects for all matching @@ -906,16 +899,8 @@ class Results(Generic[AnyModel]): if not self._rows: # Fully materialized. Just count the objects. return len(self._objects) - - elif self.query: - # A slow query. Fall back to testing every object. - count = 0 - for obj in self: - count += 1 - return count - else: - # A fast query. Just count the rows. + # Just count the rows. return self._row_count def __nonzero__(self) -> bool: @@ -1144,7 +1129,9 @@ class Database: def regexp(value, pattern): if isinstance(value, bytes): value = value.decode() - return re.search(pattern, str(value)) is not None + return ( + value is not None and re.search(pattern, str(value)) is not None + ) def bytelower(bytestring: Optional[AnyStr]) -> Optional[AnyStr]: """A custom ``bytelower`` sqlite function so we can compare @@ -1306,7 +1293,6 @@ class Database: model_cls, rows, self, - None if where else query, # Slow query component. sort if sort.is_slow() else None, # Slow sort component. ) diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 9eaf84576..d4a4fd4f7 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -93,11 +93,8 @@ class Query(ABC): Return (clause, subvals) where clause is a valid sqlite WHERE clause implementing the query and subvals is a list of items to be substituted for ?s in the clause. - - The default implementation returns None, falling back to a slow query - using `match()`. """ - return None, () + raise NotImplementedError @abstractmethod def match(self, obj: Model): @@ -146,17 +143,16 @@ class FieldQuery(Query, Generic[P]): @property def col_name(self) -> str: + if not self.fast: + return f'json_extract("flex_attrs [json_str]", "$.{self.field}")' + return f"{self.table}.{self.field}" if self.table else self.field def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: - return self.col_name, () + raise NotImplementedError - def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: - if self.fast: - return self.col_clause() - else: - # Matching a flexattr. This is a slow query. - return None, () + def clause(self) -> Tuple[str, Sequence[SQLiteType]]: + return self.col_clause() @classmethod def value_match(cls, pattern: P, value: Any): @@ -305,7 +301,17 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]): return pattern.search(cls._normalize(value)) is not None -class BooleanQuery(MatchQuery[int]): +class NumericColumnQuery(MatchQuery[AnySQLiteType]): + """A base class for queries that work with NUMERIC SQLite affinity.""" + + @property + def col_name(self) -> str: + """Cast a flexible attribute column (string) to NUMERIC affinity.""" + col_name = super().col_name + return col_name if self.fast else f"CAST({col_name} AS NUMERIC)" + + +class BooleanQuery(NumericColumnQuery[bool]): """Matches a boolean field. Pattern should either be a boolean or a string reflecting a boolean. """ @@ -357,7 +363,7 @@ class BytesQuery(FieldQuery[bytes]): return pattern == value -class NumericQuery(FieldQuery[str]): +class NumericQuery(NumericColumnQuery[Union[int, float]]): """Matches numeric fields. A syntax using Ruby-style range ellipses (``..``) lets users specify one- or two-sided ranges. For example, ``year:2001..`` finds music released since the turn of the century. @@ -483,7 +489,7 @@ class CollectionQuery(Query): def clause_with_joiner( self, joiner: str, - ) -> Tuple[Optional[str], Sequence[SQLiteType]]: + ) -> Tuple[str, Sequence[SQLiteType]]: """Return a clause created by joining together the clauses of all subqueries with the string joiner (padded by spaces). """ @@ -491,9 +497,6 @@ class CollectionQuery(Query): subvals = [] for subq in self.subqueries: subq_clause, subq_subvals = subq.clause() - if not subq_clause: - # Fall back to slow query. - return None, () clause_parts.append("(" + subq_clause + ")") subvals += subq_subvals clause = (" " + joiner + " ").join(clause_parts) @@ -533,7 +536,7 @@ class AnyFieldQuery(CollectionQuery): def field_names(self) -> Set[str]: return set(self.fields) - def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: + def clause(self) -> Tuple[str, Sequence[SQLiteType]]: return self.clause_with_joiner("or") def match(self, obj: Model) -> bool: @@ -572,7 +575,7 @@ class MutableCollectionQuery(CollectionQuery): class AndQuery(MutableCollectionQuery): """A conjunction of a list of other queries.""" - def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: + def clause(self) -> Tuple[str, Sequence[SQLiteType]]: return self.clause_with_joiner("and") def match(self, obj: Model) -> bool: @@ -582,7 +585,7 @@ class AndQuery(MutableCollectionQuery): class OrQuery(MutableCollectionQuery): """A conjunction of a list of other queries.""" - def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: + def clause(self) -> Tuple[str, Sequence[SQLiteType]]: return self.clause_with_joiner("or") def match(self, obj: Model) -> bool: @@ -602,14 +605,9 @@ class NotQuery(Query): """Return a set with field names that this query operates on.""" return self.subquery.field_names - def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: + def clause(self) -> Tuple[str, Sequence[SQLiteType]]: clause, subvals = self.subquery.clause() - if clause: - return f"not ({clause})", subvals - else: - # If there is no clause, there is nothing to negate. All the logic - # is handled by match() for slow queries. - return clause, subvals + return f"not ({clause})", subvals def match(self, obj: Model) -> bool: return not self.subquery.match(obj) @@ -816,7 +814,7 @@ class DateInterval: return f"[{self.start}, {self.end})" -class DateQuery(FieldQuery[str]): +class DateQuery(NumericColumnQuery[int]): """Matches date fields stored as seconds since Unix epoch time. Dates can be specified as ``year-month-day`` strings where only year @@ -910,7 +908,7 @@ class Sort: return sorted(items) def is_slow(self) -> bool: - """Indicate whether this query is *slow*, meaning that it cannot + """Indicate whether this sort is *slow*, meaning that it cannot be executed in SQL and must be executed in Python. """ return False diff --git a/beets/importer.py b/beets/importer.py index f6517b515..f7f35935f 100644 --- a/beets/importer.py +++ b/beets/importer.py @@ -1019,7 +1019,7 @@ class SingletonImportTask(ImportTask): # temporary `Item` object to generate any computed fields. tmp_item = library.Item(lib, **info) keys = config["import"]["duplicate_keys"]["item"].as_str_seq() - dup_query = library.Album.all_fields_query( + dup_query = library.Item.all_fields_query( {key: tmp_item.get(key) for key in keys} ) diff --git a/test/test_query.py b/test/test_query.py index 69277cfcd..109645374 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -857,17 +857,17 @@ class NoneQueryTest(unittest.TestCase, TestHelper): def test_match_slow(self): item = self.add_item() - matched = self.lib.items(NoneQuery("rg_track_peak", fast=False)) + matched = self.lib.items(NoneQuery("rg_track_peak")) self.assertInResult(item, matched) def test_match_slow_after_set_none(self): item = self.add_item(rg_track_gain=0) - matched = self.lib.items(NoneQuery("rg_track_gain", fast=False)) + matched = self.lib.items(NoneQuery("rg_track_gain")) self.assertNotInResult(item, matched) item["rg_track_gain"] = None item.store() - matched = self.lib.items(NoneQuery("rg_track_gain", fast=False)) + matched = self.lib.items(NoneQuery("rg_track_gain")) self.assertInResult(item, matched) @@ -1097,37 +1097,6 @@ class NotQueryTest(DummyDataTestCase): results = self.lib.items(q) self.assert_items_matched(results, ["baz qux"]) - def test_fast_vs_slow(self): - """Test that the results are the same regardless of the `fast` flag - for negated `FieldQuery`s. - - TODO: investigate NoneQuery(fast=False), as it is raising - AttributeError: type object 'NoneQuery' has no attribute 'field' - at NoneQuery.match() (due to being @classmethod, and no self?) - """ - classes = [ - (dbcore.query.DateQuery, ["added", "2001-01-01"]), - (dbcore.query.MatchQuery, ["artist", "one"]), - # (dbcore.query.NoneQuery, ['rg_track_gain']), - (dbcore.query.NumericQuery, ["year", "2002"]), - (dbcore.query.StringFieldQuery, ["year", "2001"]), - (dbcore.query.RegexpQuery, ["album", "^.a"]), - (dbcore.query.SubstringQuery, ["title", "x"]), - ] - - for klass, args in classes: - q_fast = dbcore.query.NotQuery(klass(*(args + [True]))) - q_slow = dbcore.query.NotQuery(klass(*(args + [False]))) - - try: - self.assertEqual( - [i.title for i in self.lib.items(q_fast)], - [i.title for i in self.lib.items(q_slow)], - ) - except NotImplementedError: - # ignore classes that do not provide `fast` implementation - pass - class RelatedQueriesTest(_common.TestCase, AssertsMixin): """Test album-level queries with track-level filters and vice-versa.""" @@ -1143,12 +1112,16 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin): for item_idx in range(1, 3): item = _common.item() item.album = album_name - item.title = f"{album_name} Item{item_idx}" + title = f"{album_name} Item{item_idx}" + item.title = title + item.item_flex1 = f"{title} Flex1" + item.item_flex2 = f"{title} Flex2" self.lib.add(item) album_items.append(item) album = self.lib.add_album(album_items) album.artpath = f"{album_name} Artpath" album.catalognum = "ABC" + album.album_flex = f"{album_name} Flex" album.store() albums.append(album) @@ -1169,6 +1142,16 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin): results = self.lib.albums(q) self.assert_albums_matched(results, ["Album1"]) + def test_get_items_filter_by_track_flex(self): + q = "item_flex1:Item1" + results = self.lib.items(q) + self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"]) + + def test_get_albums_filter_by_album_flex(self): + q = "album_flex:Album1" + results = self.lib.albums(q) + self.assert_albums_matched(results, ["Album1"]) + def suite(): return unittest.TestLoader().loadTestsFromName(__name__)