Enable querying related flexible attributes

Unify query creation logic from
- queryparse.py:construct_query_part,
- Model.field_query,
- DefaultTemplateFunctions._tmpl_unique

to a single implementation under `LibModel.field_query` class method.
This method should be used for query resolution for model (flex)fields.

Allow filtering item attributes in album queries and vice versa by
merging `flex_attrs` from Album and Item together as `all_flex_attrs`.
This field is only used for filtering and is discarded after.
This commit is contained in:
Šarūnas Nejus 2024-05-08 11:36:57 +01:00
parent fb4834e0ab
commit 10ce22e289
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
6 changed files with 120 additions and 83 deletions

View file

@ -39,8 +39,6 @@ import beets
from ..util import cached_classproperty, functemplate
from . import types
from .query import (
AndQuery,
FieldQuery,
FieldQueryType,
FieldSort,
MatchQuery,
@ -775,33 +773,6 @@ class Model(ABC, Generic[D]):
"""Set the object's key to a value represented by a string."""
self[key] = self._parse(key, string)
# Convenient queries.
@classmethod
def field_query(
cls,
field,
pattern,
query_cls: FieldQueryType = MatchQuery,
) -> FieldQuery:
"""Get a `FieldQuery` for this model."""
return query_cls(field, pattern, field in cls._fields)
@classmethod
def all_fields_query(
cls: type[Model],
pats: Mapping[str, str],
query_cls: FieldQueryType = MatchQuery,
):
"""Get a query that matches many fields with different patterns.
`pats` should be a mapping from field names to patterns. The
resulting query is a conjunction ("and") of per-field queries
for all of these field/pattern pairs.
"""
subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()]
return AndQuery(subqueries)
# Database controller and supporting interfaces.
@ -1274,17 +1245,30 @@ class Database:
where, subvals = query.clause()
order_by = sort.order_clause()
this_table = model_cls._table
select_fields = [f"{this_table}.*"]
_from = model_cls.table_with_flex_attrs
if query.field_names & model_cls.other_db_fields:
required_fields = query.field_names
if required_fields - model_cls._fields.keys():
_from += f" {model_cls.relation_join}"
table = model_cls._table
# group by id to avoid duplicates when joining with the relation
if required_fields - model_cls.all_db_fields:
# merge all flexible attribute into a single JSON field
select_fields.append(
f"""
json_patch(
COALESCE({this_table}."flex_attrs [json_str]", '{{}}'),
COALESCE({model_cls._relation._table}."flex_attrs [json_str]", '{{}}')
) AS all_flex_attrs
""" # noqa: E501
)
sql = (
f"SELECT {table}.* "
f"SELECT {', '.join(select_fields)} "
f"FROM ({_from}) "
f"WHERE {where or 1} "
f"GROUP BY {table}.id"
f"GROUP BY {this_table}.id"
)
if order_by:

View file

@ -125,9 +125,7 @@ class FieldQuery(Query, Generic[P]):
@property
def field(self) -> str:
if not self.fast:
return (
f'json_extract("flex_attrs [json_str]", "$.{self.field_name}")'
)
return f'json_extract(all_flex_attrs, "$.{self.field_name}")'
return (
f"{self.table}.{self.field_name}" if self.table else self.field_name

View file

@ -20,15 +20,17 @@ import itertools
import re
from typing import TYPE_CHECKING
from . import Model, query
from . import query
if TYPE_CHECKING:
from collections.abc import Collection, Sequence
from ..library import LibModel
from .query import FieldQueryType, Sort
Prefixes = dict[str, FieldQueryType]
PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
r"(-|\^)?" # Negation prefixes.
@ -112,7 +114,7 @@ def parse_query_part(
def construct_query_part(
model_cls: type[Model],
model_cls: type[LibModel],
prefixes: Prefixes,
query_part: str,
) -> query.Query:
@ -160,15 +162,7 @@ def construct_query_part(
# Field queries get constructed according to the name of the field
# they are querying.
else:
field = table = key.lower()
if field in model_cls.shared_db_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError if we try to query it in a join.
# Using an explicit table name resolves this.
table = f"{model_cls._table}.{field}"
field_in_db = field in model_cls.all_db_fields
out_query = query_class(table, pattern, field_in_db)
out_query = model_cls.field_query(key.lower(), pattern, query_class)
# Apply negation.
if negate:
@ -180,7 +174,7 @@ def construct_query_part(
# TYPING ERROR
def query_from_strings(
query_cls: type[query.CollectionQuery],
model_cls: type[Model],
model_cls: type[LibModel],
prefixes: Prefixes,
query_parts: Collection[str],
) -> query.Query:
@ -197,7 +191,7 @@ def query_from_strings(
def construct_sort_part(
model_cls: type[Model],
model_cls: type[LibModel],
part: str,
case_insensitive: bool = True,
) -> Sort:
@ -228,7 +222,7 @@ def construct_sort_part(
def sort_from_strings(
model_cls: type[Model],
model_cls: type[LibModel],
sort_parts: Sequence[str],
case_insensitive: bool = True,
) -> Sort:
@ -247,7 +241,7 @@ def sort_from_strings(
def parse_sorted_query(
model_cls: type[Model],
model_cls: type[LibModel],
parts: list[str],
prefixes: Prefixes = {},
case_insensitive: bool = True,

View file

@ -25,6 +25,7 @@ import time
import unicodedata
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Mapping
import platformdirs
from mediafile import MediaFile, UnreadableFileError
@ -42,6 +43,9 @@ from beets.util import (
)
from beets.util.functemplate import Template, template
if TYPE_CHECKING:
from beets.dbcore.query import FieldQueryType
# To use the SQLite "blob" type, it doesn't suffice to provide a byte
# string; SQLite treats that as encoded text. Wrapping it in a
# `memoryview` tells it that we actually mean non-text data.
@ -375,6 +379,39 @@ class LibModel(dbcore.Model["Library"]):
def __bytes__(self):
return self.__str__().encode("utf-8")
# Convenient queries.
@classmethod
def field_query(
cls, field: str, pattern: str, query_cls: FieldQueryType
) -> dbcore.FieldQuery:
"""Get a `FieldQuery` for this model."""
fast = field in cls.all_db_fields
if field in cls.shared_db_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError if we try to use it in a query.
# Using an explicit table name resolves this.
field = f"{cls._table}.{field}"
return query_cls(field, pattern, fast)
@classmethod
def all_fields_query(
cls, pattern_by_field: Mapping[str, str]
) -> dbcore.AndQuery:
"""Get a query that matches many fields with different patterns.
`pattern_by_field` should be a mapping from field names to patterns.
The resulting query is a conjunction ("and") of per-field queries
for all of these field/pattern pairs.
"""
return dbcore.AndQuery(
[
cls.field_query(f, p, dbcore.MatchQuery)
for f, p in pattern_by_field.items()
]
)
class FormattedItemMapping(dbcore.db.FormattedMapping):
"""Add lookup for album-level fields.
@ -612,7 +649,7 @@ class Item(LibModel):
an album (e.g. singletons) would be left out.
"""
return (
f"LEFT JOIN {cls._relation._table} "
f"LEFT JOIN {cls._relation.table_with_flex_attrs} "
f"ON {cls._table}.album_id = {cls._relation._table}.id"
)
@ -1233,7 +1270,7 @@ class Album(LibModel):
any items.
"""
return (
f"LEFT JOIN {cls._relation._table} "
f"LEFT JOIN {cls._relation.table_with_flex_attrs} "
f"ON {cls._table}.id = {cls._relation._table}.album_id"
)
@ -1937,9 +1974,10 @@ class DefaultTemplateFunctions:
subqueries.extend(initial_subqueries)
for key in keys:
value = db_item.get(key, "")
# Use slow queries for flexible attributes.
fast = key in item_keys
subqueries.append(dbcore.MatchQuery(key, value, fast))
subqueries.append(
db_item.field_query(key, value, dbcore.MatchQuery)
)
query = dbcore.AndQuery(subqueries)
ambigous_items = (
self.lib.items(query)

View file

@ -5,6 +5,7 @@ import os.path
import platform
import shutil
from collections import Counter
from pathlib import Path
from beets import logging
from beets.library import Album, Item
@ -30,36 +31,38 @@ class WebPluginTest(ItemInDBTestCase):
# Add library elements. Note that self.lib.add overrides any "id=<n>"
# and assigns the next free id number.
# The following adds will create items #1, #2 and #3
path1 = (
self.path_prefix + os.sep + os.path.join(b"path_1").decode("utf-8")
base_path = Path(self.path_prefix + os.sep)
album2_item1 = Item(
title="title",
path=str(base_path / "path_1"),
album_id=2,
artist="AAA Singers",
)
self.lib.add(
Item(title="title", path=path1, album_id=2, artist="AAA Singers")
album1_item = Item(
title="another title",
path=str(base_path / "somewhere" / "a"),
artist="AAA Singers",
)
path2 = (
self.path_prefix
+ os.sep
+ os.path.join(b"somewhere", b"a").decode("utf-8")
)
self.lib.add(
Item(title="another title", path=path2, artist="AAA Singers")
)
path3 = (
self.path_prefix
+ os.sep
+ os.path.join(b"somewhere", b"abc").decode("utf-8")
)
self.lib.add(
Item(title="and a third", testattr="ABC", path=path3, album_id=2)
album2_item2 = Item(
title="and a third",
testattr="ABC",
path=str(base_path / "somewhere" / "abc"),
album_id=2,
)
self.lib.add(album2_item1)
self.lib.add(album1_item)
self.lib.add(album2_item2)
# The following adds will create albums #1 and #2
self.lib.add(Album(album="album", albumtest="xyz"))
path4 = (
self.path_prefix
+ os.sep
+ os.path.join(b"somewhere2", b"art_path_2").decode("utf-8")
)
self.lib.add(Album(album="other album", artpath=path4))
album1 = self.lib.add_album([album1_item])
album1.album = "album"
album1.albumtest = "xyz"
album1.store()
album2 = self.lib.add_album([album2_item1, album2_item2])
album2.album = "other album"
album2.artpath = str(base_path / "somewhere2" / "art_path_2")
album2.store()
web.app.config["TESTING"] = True
web.app.config["lib"] = self.lib

View file

@ -1134,3 +1134,23 @@ class RelatedQueriesTest(BeetsTestCase, AssertsMixin):
q = "album_flex:Album1"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def test_get_albums_filter_by_track_flex(self):
q = "item_flex1:Album1"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def test_get_items_filter_by_album_flex(self):
q = "album_flex:Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
def test_filter_by_flex(self):
q = "item_flex1:'Item1 Flex1'"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"])
def test_filter_by_many_flex(self):
q = "item_flex1:'Item1 Flex1' item_flex2:Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1"])