Enable querying related flexible attributes

Unify query creation logic from
- queryparse.py:construct_query_part,
- Model.field_query,
- DefaultTemplateFunctions._tmpl_unique

to a single implementation under `LibModel.field_query` class method.
This method should be used for query resolution for model (flex)fields.

Allow filtering item attributes in album queries and vice versa by
merging `flex_attrs` from Album and Item together as `all_flex_attrs`.
This field is only used for filtering and is discarded after.
This commit is contained in:
Šarūnas Nejus 2024-05-08 11:36:57 +01:00
parent 484c00e223
commit 9207b17d13
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
7 changed files with 142 additions and 85 deletions

View file

@ -58,15 +58,7 @@ import beets
from ..util import cached_classproperty, functemplate
from . import types
from .query import (
AndQuery,
FieldQuery,
MatchQuery,
NullSort,
Query,
Sort,
TrueQuery,
)
from .query import FieldQuery, MatchQuery, NullSort, Query, Sort, TrueQuery
# convert data under 'json_str' type name to Python dictionary automatically
sqlite3.register_converter("json_str", json.loads)
@ -395,6 +387,10 @@ class Model(ABC):
) {cls._table}
"""
@cached_classproperty
def all_model_db_fields(cls) -> Set[str]:
return set()
@classmethod
def _getters(cls: Type["Model"]):
"""Return a mapping from field names to getter functions."""
@ -771,33 +767,6 @@ class Model(ABC):
"""Set the object's key to a value represented by a string."""
self[key] = self._parse(key, string)
# Convenient queries.
@classmethod
def field_query(
cls,
field,
pattern,
query_cls: Type[FieldQuery] = MatchQuery,
) -> FieldQuery:
"""Get a `FieldQuery` for this model."""
return query_cls(field, pattern, field in cls._fields)
@classmethod
def all_fields_query(
cls: Type["Model"],
pats: Mapping,
query_cls: Type[FieldQuery] = MatchQuery,
):
"""Get a query that matches many fields with different patterns.
`pats` should be a mapping from field names to patterns. The
resulting query is a conjunction ("and") of per-field queries
for all of these field/pattern pairs.
"""
subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()]
return AndQuery(subqueries)
# Database controller and supporting interfaces.
@ -1270,13 +1239,26 @@ class Database:
where, subvals = query.clause()
order_by = sort.order_clause()
this_table = model_cls._table
select_fields = [f"{this_table}.*"]
_from = model_cls.table_with_flex_attrs
required_fields = query.field_names
if required_fields - model_cls._fields.keys():
_from += f" {model_cls.relation_join}"
table = model_cls._table
sql = f"SELECT {table}.* FROM {_from} WHERE {where or 1} GROUP BY {table}.id"
if required_fields - model_cls.all_model_db_fields:
# merge all flexible attribute into a single JSON field
select_fields.append(
f"""
json_patch(
COALESCE({this_table}."flex_attrs [json_str]", '{{}}'),
COALESCE({model_cls._relation._table}."flex_attrs [json_str]", '{{}}')
) AS all_flex_attrs
""" # noqa: E501
)
sql = f"SELECT {', '.join(select_fields)} FROM {_from} WHERE {where or 1} GROUP BY {this_table}.id" # noqa: E501
if order_by:
# the sort field may exist in both 'items' and 'albums' tables

View file

@ -144,7 +144,7 @@ class FieldQuery(Query, Generic[P]):
@property
def col_name(self) -> str:
if not self.fast:
return f'json_extract("flex_attrs [json_str]", "$.{self.field}")'
return f'json_extract(all_flex_attrs, "$.{self.field}")'
return f"{self.table}.{self.field}" if self.table else self.field

View file

@ -16,12 +16,23 @@
import itertools
import re
from typing import Collection, Dict, List, Optional, Sequence, Tuple, Type
from typing import (
TYPE_CHECKING,
Collection,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
)
from .. import library
from . import Model, query
from .query import Sort
if TYPE_CHECKING:
from ..library import LibModel
PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
r"(-|\^)?" # Negation prefixes.
@ -105,7 +116,7 @@ def parse_query_part(
def construct_query_part(
model_cls: Type[Model],
model_cls: Type["LibModel"],
prefixes: Dict,
query_part: str,
) -> query.Query:
@ -153,17 +164,7 @@ def construct_query_part(
# Field queries get constructed according to the name of the field
# they are querying.
else:
key = key.lower()
album_fields = library.Album._fields.keys()
item_fields = library.Item._fields.keys()
fast = key in album_fields | item_fields
if key in album_fields & item_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError. Using an explicit table name resolves this.
key = f"{model_cls._table}.{key}"
out_query = query_class(key, pattern, fast)
out_query = model_cls.field_query(key.lower(), pattern, query_class)
# Apply negation.
if negate:

View file

@ -14,6 +14,7 @@
"""The core data store and collection logic for beets.
"""
from __future__ import annotations
import os
import re
@ -23,6 +24,7 @@ import sys
import time
import unicodedata
from functools import cached_property
from typing import Mapping, Set, Type
from mediafile import MediaFile, UnreadableFileError
@ -387,6 +389,14 @@ class LibModel(dbcore.Model):
# Config key that specifies how an instance should be formatted.
_format_config_key: str
@cached_classproperty
def all_model_db_fields(cls) -> Set[str]:
return cls._fields.keys() | cls._relation._fields.keys()
@cached_classproperty
def shared_model_db_fields(cls) -> Set[str]:
return cls._fields.keys() & cls._relation._fields.keys()
def _template_funcs(self):
funcs = DefaultTemplateFunctions(self, self._db).functions()
funcs.update(plugins.template_funcs())
@ -416,6 +426,39 @@ class LibModel(dbcore.Model):
def __bytes__(self):
return self.__str__().encode("utf-8")
# Convenient queries.
@classmethod
def field_query(
cls, field: str, pattern: str, query_cls: Type[dbcore.FieldQuery]
) -> dbcore.Query:
"""Get a `FieldQuery` for this model."""
fast = field in cls.all_model_db_fields
if field in cls.shared_model_db_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError if we try to use it in a query.
# Using an explicit table name resolves this.
field = f"{cls._table}.{field}"
return query_cls(field, pattern, fast)
@classmethod
def all_fields_query(
cls, pattern_by_field: Mapping[str, str]
) -> dbcore.AndQuery:
"""Get a query that matches many fields with different patterns.
`pattern_by_field` should be a mapping from field names to patterns.
The resulting query is a conjunction ("and") of per-field queries
for all of these field/pattern pairs.
"""
return dbcore.AndQuery(
[
cls.field_query(f, p, dbcore.MatchQuery)
for f, p in pattern_by_field.items()
]
)
class FormattedItemMapping(dbcore.db.FormattedMapping):
"""Add lookup for album-level fields.
@ -652,7 +695,10 @@ class Item(LibModel):
We need to use a LEFT JOIN here, otherwise items that are not part of
an album (e.g. singletons) would be left out.
"""
return f"LEFT JOIN {cls._relation._table} ON {cls._table}.album_id = {cls._relation._table}.id"
return (
f"LEFT JOIN {cls._relation.table_with_flex_attrs}"
f" ON {cls._table}.album_id = {cls._relation._table}.id"
)
@property
def _cached_album(self):
@ -1265,7 +1311,10 @@ class Album(LibModel):
Here we can use INNER JOIN (which is more performant than LEFT JOIN),
since we only want to see albums that have at least one Item in them.
"""
return f"INNER JOIN {cls._relation._table} ON {cls._table}.id = {cls._relation._table}.album_id"
return (
f"INNER JOIN {cls._relation.table_with_flex_attrs}"
f" ON {cls._table}.id = {cls._relation._table}.album_id"
)
@classmethod
def _getters(cls):
@ -1955,9 +2004,10 @@ class DefaultTemplateFunctions:
subqueries.extend(initial_subqueries)
for key in keys:
value = db_item.get(key, "")
# Use slow queries for flexible attributes.
fast = key in item_keys
subqueries.append(dbcore.MatchQuery(key, value, fast))
subqueries.append(
db_item.field_query(key, value, dbcore.MatchQuery)
)
query = dbcore.AndQuery(subqueries)
ambigous_items = (
self.lib.items(query)

View file

@ -5,6 +5,7 @@ import os.path
import platform
import shutil
import unittest
from pathlib import Path
from beets import logging
from beets.library import Album, Item
@ -29,36 +30,38 @@ class WebPluginTest(_common.LibTestCase):
# Add library elements. Note that self.lib.add overrides any "id=<n>"
# and assigns the next free id number.
# The following adds will create items #1, #2 and #3
path1 = (
self.path_prefix + os.sep + os.path.join(b"path_1").decode("utf-8")
base_path = Path(self.path_prefix + os.sep)
album2_item1 = Item(
title="title",
path=str(base_path / "path_1"),
album_id=2,
artist="AAA Singers",
)
self.lib.add(
Item(title="title", path=path1, album_id=2, artist="AAA Singers")
album1_item = Item(
title="another title",
path=str(base_path / "somewhere" / "a"),
artist="AAA Singers",
)
path2 = (
self.path_prefix
+ os.sep
+ os.path.join(b"somewhere", b"a").decode("utf-8")
)
self.lib.add(
Item(title="another title", path=path2, artist="AAA Singers")
)
path3 = (
self.path_prefix
+ os.sep
+ os.path.join(b"somewhere", b"abc").decode("utf-8")
)
self.lib.add(
Item(title="and a third", testattr="ABC", path=path3, album_id=2)
album2_item2 = Item(
title="and a third",
testattr="ABC",
path=str(base_path / "somewhere" / "abc"),
album_id=2,
)
self.lib.add(album2_item1)
self.lib.add(album1_item)
self.lib.add(album2_item2)
# The following adds will create albums #1 and #2
self.lib.add(Album(album="album", albumtest="xyz"))
path4 = (
self.path_prefix
+ os.sep
+ os.path.join(b"somewhere2", b"art_path_2").decode("utf-8")
)
self.lib.add(Album(album="other album", artpath=path4))
album1 = self.lib.add_album([album1_item])
album1.album = "album"
album1.albumtest = "xyz"
album1.store()
album2 = self.lib.add_album([album2_item1, album2_item2])
album2.album = "other album"
album2.artpath = str(base_path / "somewhere2" / "art_path_2")
album2.store()
web.app.config["TESTING"] = True
web.app.config["lib"] = self.lib

View file

@ -21,6 +21,7 @@ import unittest
from tempfile import mkstemp
from beets import dbcore
from beets.library import LibModel
from beets.test import _common
# Fixture: concrete database and model classes. For migration tests, we
@ -42,7 +43,7 @@ class QueryFixture(dbcore.query.FieldQuery):
return True
class ModelFixture1(dbcore.Model):
class ModelFixture1(LibModel):
_table = "test"
_flex_table = "testflex"
_fields = {

View file

@ -1152,6 +1152,26 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def test_get_albums_filter_by_track_flex(self):
q = "item_flex1:Album1"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def test_get_items_filter_by_album_flex(self):
q = "album_flex:Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
def test_filter_by_flex(self):
q = "item_flex1:'Item1 Flex1'"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"])
def test_filter_by_many_flex(self):
q = "item_flex1:'Item1 Flex1' item_flex2:Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1"])
def suite():
return unittest.TestLoader().loadTestsFromName(__name__)