mirror of
https://github.com/beetbox/beets.git
synced 2025-12-06 08:39:17 +01:00
Add ability to filter flexible attributes through the Query
For a flexible attribute query, replace the `col_name` property with a function call that extracts that attribute from the `field_attrs` field introduced in the earlier commit. Additionally, for boolean, numeric and date queries CAST the value to NUMERIC SQLite affinity to ensure that our queries like 'flex:1..5' and 'flex:true' continue working fine. This removes the concept of 'slow query', since every query for any field now has an SQL clause.
This commit is contained in:
parent
969d847110
commit
484c00e223
4 changed files with 55 additions and 88 deletions
|
|
@ -815,7 +815,6 @@ class Results(Generic[AnyModel]):
|
||||||
model_class: Type[AnyModel],
|
model_class: Type[AnyModel],
|
||||||
rows: List[Mapping],
|
rows: List[Mapping],
|
||||||
db: "Database",
|
db: "Database",
|
||||||
query: Optional[Query] = None,
|
|
||||||
sort=None,
|
sort=None,
|
||||||
):
|
):
|
||||||
"""Create a result set that will construct objects of type
|
"""Create a result set that will construct objects of type
|
||||||
|
|
@ -825,9 +824,7 @@ class Results(Generic[AnyModel]):
|
||||||
constructed. `rows` is a query result: a list of mappings. The
|
constructed. `rows` is a query result: a list of mappings. The
|
||||||
new objects will be associated with the database `db`.
|
new objects will be associated with the database `db`.
|
||||||
|
|
||||||
If `query` is provided, it is used as a predicate to filter the
|
If `sort` is provided, it is used to sort the
|
||||||
results for a "slow query" that cannot be evaluated by the
|
|
||||||
database directly. If `sort` is provided, it is used to sort the
|
|
||||||
full list of results before returning. This means it is a "slow
|
full list of results before returning. This means it is a "slow
|
||||||
sort" and all objects must be built before returning the first
|
sort" and all objects must be built before returning the first
|
||||||
one.
|
one.
|
||||||
|
|
@ -835,7 +832,6 @@ class Results(Generic[AnyModel]):
|
||||||
self.model_class = model_class
|
self.model_class = model_class
|
||||||
self.rows = rows
|
self.rows = rows
|
||||||
self.db = db
|
self.db = db
|
||||||
self.query = query
|
|
||||||
self.sort = sort
|
self.sort = sort
|
||||||
|
|
||||||
# We keep a queue of rows we haven't yet consumed for
|
# We keep a queue of rows we haven't yet consumed for
|
||||||
|
|
@ -871,13 +867,10 @@ class Results(Generic[AnyModel]):
|
||||||
while self._rows:
|
while self._rows:
|
||||||
row = self._rows.pop(0)
|
row = self._rows.pop(0)
|
||||||
obj = self._make_model(row)
|
obj = self._make_model(row)
|
||||||
# If there is a slow-query predicate, ensurer that the
|
self._objects.append(obj)
|
||||||
# object passes it.
|
index += 1
|
||||||
if not self.query or self.query.match(obj):
|
yield obj
|
||||||
self._objects.append(obj)
|
break
|
||||||
index += 1
|
|
||||||
yield obj
|
|
||||||
break
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[AnyModel]:
|
def __iter__(self) -> Iterator[AnyModel]:
|
||||||
"""Construct and generate Model objects for all matching
|
"""Construct and generate Model objects for all matching
|
||||||
|
|
@ -906,16 +899,8 @@ class Results(Generic[AnyModel]):
|
||||||
if not self._rows:
|
if not self._rows:
|
||||||
# Fully materialized. Just count the objects.
|
# Fully materialized. Just count the objects.
|
||||||
return len(self._objects)
|
return len(self._objects)
|
||||||
|
|
||||||
elif self.query:
|
|
||||||
# A slow query. Fall back to testing every object.
|
|
||||||
count = 0
|
|
||||||
for obj in self:
|
|
||||||
count += 1
|
|
||||||
return count
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# A fast query. Just count the rows.
|
# Just count the rows.
|
||||||
return self._row_count
|
return self._row_count
|
||||||
|
|
||||||
def __nonzero__(self) -> bool:
|
def __nonzero__(self) -> bool:
|
||||||
|
|
@ -1144,7 +1129,9 @@ class Database:
|
||||||
def regexp(value, pattern):
|
def regexp(value, pattern):
|
||||||
if isinstance(value, bytes):
|
if isinstance(value, bytes):
|
||||||
value = value.decode()
|
value = value.decode()
|
||||||
return re.search(pattern, str(value)) is not None
|
return (
|
||||||
|
value is not None and re.search(pattern, str(value)) is not None
|
||||||
|
)
|
||||||
|
|
||||||
def bytelower(bytestring: Optional[AnyStr]) -> Optional[AnyStr]:
|
def bytelower(bytestring: Optional[AnyStr]) -> Optional[AnyStr]:
|
||||||
"""A custom ``bytelower`` sqlite function so we can compare
|
"""A custom ``bytelower`` sqlite function so we can compare
|
||||||
|
|
@ -1306,7 +1293,6 @@ class Database:
|
||||||
model_cls,
|
model_cls,
|
||||||
rows,
|
rows,
|
||||||
self,
|
self,
|
||||||
None if where else query, # Slow query component.
|
|
||||||
sort if sort.is_slow() else None, # Slow sort component.
|
sort if sort.is_slow() else None, # Slow sort component.
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -93,11 +93,8 @@ class Query(ABC):
|
||||||
Return (clause, subvals) where clause is a valid sqlite
|
Return (clause, subvals) where clause is a valid sqlite
|
||||||
WHERE clause implementing the query and subvals is a list of
|
WHERE clause implementing the query and subvals is a list of
|
||||||
items to be substituted for ?s in the clause.
|
items to be substituted for ?s in the clause.
|
||||||
|
|
||||||
The default implementation returns None, falling back to a slow query
|
|
||||||
using `match()`.
|
|
||||||
"""
|
"""
|
||||||
return None, ()
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def match(self, obj: Model):
|
def match(self, obj: Model):
|
||||||
|
|
@ -146,17 +143,16 @@ class FieldQuery(Query, Generic[P]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def col_name(self) -> str:
|
def col_name(self) -> str:
|
||||||
|
if not self.fast:
|
||||||
|
return f'json_extract("flex_attrs [json_str]", "$.{self.field}")'
|
||||||
|
|
||||||
return f"{self.table}.{self.field}" if self.table else self.field
|
return f"{self.table}.{self.field}" if self.table else self.field
|
||||||
|
|
||||||
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||||
return self.col_name, ()
|
raise NotImplementedError
|
||||||
|
|
||||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||||
if self.fast:
|
return self.col_clause()
|
||||||
return self.col_clause()
|
|
||||||
else:
|
|
||||||
# Matching a flexattr. This is a slow query.
|
|
||||||
return None, ()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_match(cls, pattern: P, value: Any):
|
def value_match(cls, pattern: P, value: Any):
|
||||||
|
|
@ -305,7 +301,17 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
|
||||||
return pattern.search(cls._normalize(value)) is not None
|
return pattern.search(cls._normalize(value)) is not None
|
||||||
|
|
||||||
|
|
||||||
class BooleanQuery(MatchQuery[int]):
|
class NumericColumnQuery(MatchQuery[AnySQLiteType]):
|
||||||
|
"""A base class for queries that work with NUMERIC SQLite affinity."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def col_name(self) -> str:
|
||||||
|
"""Cast a flexible attribute column (string) to NUMERIC affinity."""
|
||||||
|
col_name = super().col_name
|
||||||
|
return col_name if self.fast else f"CAST({col_name} AS NUMERIC)"
|
||||||
|
|
||||||
|
|
||||||
|
class BooleanQuery(NumericColumnQuery[bool]):
|
||||||
"""Matches a boolean field. Pattern should either be a boolean or a
|
"""Matches a boolean field. Pattern should either be a boolean or a
|
||||||
string reflecting a boolean.
|
string reflecting a boolean.
|
||||||
"""
|
"""
|
||||||
|
|
@ -357,7 +363,7 @@ class BytesQuery(FieldQuery[bytes]):
|
||||||
return pattern == value
|
return pattern == value
|
||||||
|
|
||||||
|
|
||||||
class NumericQuery(FieldQuery[str]):
|
class NumericQuery(NumericColumnQuery[Union[int, float]]):
|
||||||
"""Matches numeric fields. A syntax using Ruby-style range ellipses
|
"""Matches numeric fields. A syntax using Ruby-style range ellipses
|
||||||
(``..``) lets users specify one- or two-sided ranges. For example,
|
(``..``) lets users specify one- or two-sided ranges. For example,
|
||||||
``year:2001..`` finds music released since the turn of the century.
|
``year:2001..`` finds music released since the turn of the century.
|
||||||
|
|
@ -483,7 +489,7 @@ class CollectionQuery(Query):
|
||||||
def clause_with_joiner(
|
def clause_with_joiner(
|
||||||
self,
|
self,
|
||||||
joiner: str,
|
joiner: str,
|
||||||
) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
) -> Tuple[str, Sequence[SQLiteType]]:
|
||||||
"""Return a clause created by joining together the clauses of
|
"""Return a clause created by joining together the clauses of
|
||||||
all subqueries with the string joiner (padded by spaces).
|
all subqueries with the string joiner (padded by spaces).
|
||||||
"""
|
"""
|
||||||
|
|
@ -491,9 +497,6 @@ class CollectionQuery(Query):
|
||||||
subvals = []
|
subvals = []
|
||||||
for subq in self.subqueries:
|
for subq in self.subqueries:
|
||||||
subq_clause, subq_subvals = subq.clause()
|
subq_clause, subq_subvals = subq.clause()
|
||||||
if not subq_clause:
|
|
||||||
# Fall back to slow query.
|
|
||||||
return None, ()
|
|
||||||
clause_parts.append("(" + subq_clause + ")")
|
clause_parts.append("(" + subq_clause + ")")
|
||||||
subvals += subq_subvals
|
subvals += subq_subvals
|
||||||
clause = (" " + joiner + " ").join(clause_parts)
|
clause = (" " + joiner + " ").join(clause_parts)
|
||||||
|
|
@ -533,7 +536,7 @@ class AnyFieldQuery(CollectionQuery):
|
||||||
def field_names(self) -> Set[str]:
|
def field_names(self) -> Set[str]:
|
||||||
return set(self.fields)
|
return set(self.fields)
|
||||||
|
|
||||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||||
return self.clause_with_joiner("or")
|
return self.clause_with_joiner("or")
|
||||||
|
|
||||||
def match(self, obj: Model) -> bool:
|
def match(self, obj: Model) -> bool:
|
||||||
|
|
@ -572,7 +575,7 @@ class MutableCollectionQuery(CollectionQuery):
|
||||||
class AndQuery(MutableCollectionQuery):
|
class AndQuery(MutableCollectionQuery):
|
||||||
"""A conjunction of a list of other queries."""
|
"""A conjunction of a list of other queries."""
|
||||||
|
|
||||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||||
return self.clause_with_joiner("and")
|
return self.clause_with_joiner("and")
|
||||||
|
|
||||||
def match(self, obj: Model) -> bool:
|
def match(self, obj: Model) -> bool:
|
||||||
|
|
@ -582,7 +585,7 @@ class AndQuery(MutableCollectionQuery):
|
||||||
class OrQuery(MutableCollectionQuery):
|
class OrQuery(MutableCollectionQuery):
|
||||||
"""A conjunction of a list of other queries."""
|
"""A conjunction of a list of other queries."""
|
||||||
|
|
||||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||||
return self.clause_with_joiner("or")
|
return self.clause_with_joiner("or")
|
||||||
|
|
||||||
def match(self, obj: Model) -> bool:
|
def match(self, obj: Model) -> bool:
|
||||||
|
|
@ -602,14 +605,9 @@ class NotQuery(Query):
|
||||||
"""Return a set with field names that this query operates on."""
|
"""Return a set with field names that this query operates on."""
|
||||||
return self.subquery.field_names
|
return self.subquery.field_names
|
||||||
|
|
||||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||||
clause, subvals = self.subquery.clause()
|
clause, subvals = self.subquery.clause()
|
||||||
if clause:
|
return f"not ({clause})", subvals
|
||||||
return f"not ({clause})", subvals
|
|
||||||
else:
|
|
||||||
# If there is no clause, there is nothing to negate. All the logic
|
|
||||||
# is handled by match() for slow queries.
|
|
||||||
return clause, subvals
|
|
||||||
|
|
||||||
def match(self, obj: Model) -> bool:
|
def match(self, obj: Model) -> bool:
|
||||||
return not self.subquery.match(obj)
|
return not self.subquery.match(obj)
|
||||||
|
|
@ -816,7 +814,7 @@ class DateInterval:
|
||||||
return f"[{self.start}, {self.end})"
|
return f"[{self.start}, {self.end})"
|
||||||
|
|
||||||
|
|
||||||
class DateQuery(FieldQuery[str]):
|
class DateQuery(NumericColumnQuery[int]):
|
||||||
"""Matches date fields stored as seconds since Unix epoch time.
|
"""Matches date fields stored as seconds since Unix epoch time.
|
||||||
|
|
||||||
Dates can be specified as ``year-month-day`` strings where only year
|
Dates can be specified as ``year-month-day`` strings where only year
|
||||||
|
|
@ -910,7 +908,7 @@ class Sort:
|
||||||
return sorted(items)
|
return sorted(items)
|
||||||
|
|
||||||
def is_slow(self) -> bool:
|
def is_slow(self) -> bool:
|
||||||
"""Indicate whether this query is *slow*, meaning that it cannot
|
"""Indicate whether this sort is *slow*, meaning that it cannot
|
||||||
be executed in SQL and must be executed in Python.
|
be executed in SQL and must be executed in Python.
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
|
|
@ -1019,7 +1019,7 @@ class SingletonImportTask(ImportTask):
|
||||||
# temporary `Item` object to generate any computed fields.
|
# temporary `Item` object to generate any computed fields.
|
||||||
tmp_item = library.Item(lib, **info)
|
tmp_item = library.Item(lib, **info)
|
||||||
keys = config["import"]["duplicate_keys"]["item"].as_str_seq()
|
keys = config["import"]["duplicate_keys"]["item"].as_str_seq()
|
||||||
dup_query = library.Album.all_fields_query(
|
dup_query = library.Item.all_fields_query(
|
||||||
{key: tmp_item.get(key) for key in keys}
|
{key: tmp_item.get(key) for key in keys}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -857,17 +857,17 @@ class NoneQueryTest(unittest.TestCase, TestHelper):
|
||||||
|
|
||||||
def test_match_slow(self):
|
def test_match_slow(self):
|
||||||
item = self.add_item()
|
item = self.add_item()
|
||||||
matched = self.lib.items(NoneQuery("rg_track_peak", fast=False))
|
matched = self.lib.items(NoneQuery("rg_track_peak"))
|
||||||
self.assertInResult(item, matched)
|
self.assertInResult(item, matched)
|
||||||
|
|
||||||
def test_match_slow_after_set_none(self):
|
def test_match_slow_after_set_none(self):
|
||||||
item = self.add_item(rg_track_gain=0)
|
item = self.add_item(rg_track_gain=0)
|
||||||
matched = self.lib.items(NoneQuery("rg_track_gain", fast=False))
|
matched = self.lib.items(NoneQuery("rg_track_gain"))
|
||||||
self.assertNotInResult(item, matched)
|
self.assertNotInResult(item, matched)
|
||||||
|
|
||||||
item["rg_track_gain"] = None
|
item["rg_track_gain"] = None
|
||||||
item.store()
|
item.store()
|
||||||
matched = self.lib.items(NoneQuery("rg_track_gain", fast=False))
|
matched = self.lib.items(NoneQuery("rg_track_gain"))
|
||||||
self.assertInResult(item, matched)
|
self.assertInResult(item, matched)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1097,37 +1097,6 @@ class NotQueryTest(DummyDataTestCase):
|
||||||
results = self.lib.items(q)
|
results = self.lib.items(q)
|
||||||
self.assert_items_matched(results, ["baz qux"])
|
self.assert_items_matched(results, ["baz qux"])
|
||||||
|
|
||||||
def test_fast_vs_slow(self):
|
|
||||||
"""Test that the results are the same regardless of the `fast` flag
|
|
||||||
for negated `FieldQuery`s.
|
|
||||||
|
|
||||||
TODO: investigate NoneQuery(fast=False), as it is raising
|
|
||||||
AttributeError: type object 'NoneQuery' has no attribute 'field'
|
|
||||||
at NoneQuery.match() (due to being @classmethod, and no self?)
|
|
||||||
"""
|
|
||||||
classes = [
|
|
||||||
(dbcore.query.DateQuery, ["added", "2001-01-01"]),
|
|
||||||
(dbcore.query.MatchQuery, ["artist", "one"]),
|
|
||||||
# (dbcore.query.NoneQuery, ['rg_track_gain']),
|
|
||||||
(dbcore.query.NumericQuery, ["year", "2002"]),
|
|
||||||
(dbcore.query.StringFieldQuery, ["year", "2001"]),
|
|
||||||
(dbcore.query.RegexpQuery, ["album", "^.a"]),
|
|
||||||
(dbcore.query.SubstringQuery, ["title", "x"]),
|
|
||||||
]
|
|
||||||
|
|
||||||
for klass, args in classes:
|
|
||||||
q_fast = dbcore.query.NotQuery(klass(*(args + [True])))
|
|
||||||
q_slow = dbcore.query.NotQuery(klass(*(args + [False])))
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.assertEqual(
|
|
||||||
[i.title for i in self.lib.items(q_fast)],
|
|
||||||
[i.title for i in self.lib.items(q_slow)],
|
|
||||||
)
|
|
||||||
except NotImplementedError:
|
|
||||||
# ignore classes that do not provide `fast` implementation
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class RelatedQueriesTest(_common.TestCase, AssertsMixin):
|
class RelatedQueriesTest(_common.TestCase, AssertsMixin):
|
||||||
"""Test album-level queries with track-level filters and vice-versa."""
|
"""Test album-level queries with track-level filters and vice-versa."""
|
||||||
|
|
@ -1143,12 +1112,16 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
|
||||||
for item_idx in range(1, 3):
|
for item_idx in range(1, 3):
|
||||||
item = _common.item()
|
item = _common.item()
|
||||||
item.album = album_name
|
item.album = album_name
|
||||||
item.title = f"{album_name} Item{item_idx}"
|
title = f"{album_name} Item{item_idx}"
|
||||||
|
item.title = title
|
||||||
|
item.item_flex1 = f"{title} Flex1"
|
||||||
|
item.item_flex2 = f"{title} Flex2"
|
||||||
self.lib.add(item)
|
self.lib.add(item)
|
||||||
album_items.append(item)
|
album_items.append(item)
|
||||||
album = self.lib.add_album(album_items)
|
album = self.lib.add_album(album_items)
|
||||||
album.artpath = f"{album_name} Artpath"
|
album.artpath = f"{album_name} Artpath"
|
||||||
album.catalognum = "ABC"
|
album.catalognum = "ABC"
|
||||||
|
album.album_flex = f"{album_name} Flex"
|
||||||
album.store()
|
album.store()
|
||||||
albums.append(album)
|
albums.append(album)
|
||||||
|
|
||||||
|
|
@ -1169,6 +1142,16 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
|
||||||
results = self.lib.albums(q)
|
results = self.lib.albums(q)
|
||||||
self.assert_albums_matched(results, ["Album1"])
|
self.assert_albums_matched(results, ["Album1"])
|
||||||
|
|
||||||
|
def test_get_items_filter_by_track_flex(self):
|
||||||
|
q = "item_flex1:Item1"
|
||||||
|
results = self.lib.items(q)
|
||||||
|
self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"])
|
||||||
|
|
||||||
|
def test_get_albums_filter_by_album_flex(self):
|
||||||
|
q = "album_flex:Album1"
|
||||||
|
results = self.lib.albums(q)
|
||||||
|
self.assert_albums_matched(results, ["Album1"])
|
||||||
|
|
||||||
|
|
||||||
def suite():
|
def suite():
|
||||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue