mirror of
https://github.com/beetbox/beets.git
synced 2026-01-08 00:45:55 +01:00
Enable querying related flexible attributes
Unify query creation logic from - queryparse.py:construct_query_part, - Model.field_query, - DefaultTemplateFunctions._tmpl_unique to a single implementation under `LibModel.field_query` class method. This method should be used for query resolution for model (flex)fields. Allow filtering item attributes in album queries and vice versa by merging `flex_attrs` from Album and Item together as `all_flex_attrs`. This field is only used for filtering and is discarded after.
This commit is contained in:
parent
fb4834e0ab
commit
10ce22e289
6 changed files with 120 additions and 83 deletions
|
|
@ -39,8 +39,6 @@ import beets
|
|||
from ..util import cached_classproperty, functemplate
|
||||
from . import types
|
||||
from .query import (
|
||||
AndQuery,
|
||||
FieldQuery,
|
||||
FieldQueryType,
|
||||
FieldSort,
|
||||
MatchQuery,
|
||||
|
|
@ -775,33 +773,6 @@ class Model(ABC, Generic[D]):
|
|||
"""Set the object's key to a value represented by a string."""
|
||||
self[key] = self._parse(key, string)
|
||||
|
||||
# Convenient queries.
|
||||
|
||||
@classmethod
|
||||
def field_query(
|
||||
cls,
|
||||
field,
|
||||
pattern,
|
||||
query_cls: FieldQueryType = MatchQuery,
|
||||
) -> FieldQuery:
|
||||
"""Get a `FieldQuery` for this model."""
|
||||
return query_cls(field, pattern, field in cls._fields)
|
||||
|
||||
@classmethod
|
||||
def all_fields_query(
|
||||
cls: type[Model],
|
||||
pats: Mapping[str, str],
|
||||
query_cls: FieldQueryType = MatchQuery,
|
||||
):
|
||||
"""Get a query that matches many fields with different patterns.
|
||||
|
||||
`pats` should be a mapping from field names to patterns. The
|
||||
resulting query is a conjunction ("and") of per-field queries
|
||||
for all of these field/pattern pairs.
|
||||
"""
|
||||
subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()]
|
||||
return AndQuery(subqueries)
|
||||
|
||||
|
||||
# Database controller and supporting interfaces.
|
||||
|
||||
|
|
@ -1274,17 +1245,30 @@ class Database:
|
|||
where, subvals = query.clause()
|
||||
order_by = sort.order_clause()
|
||||
|
||||
this_table = model_cls._table
|
||||
select_fields = [f"{this_table}.*"]
|
||||
_from = model_cls.table_with_flex_attrs
|
||||
if query.field_names & model_cls.other_db_fields:
|
||||
|
||||
required_fields = query.field_names
|
||||
if required_fields - model_cls._fields.keys():
|
||||
_from += f" {model_cls.relation_join}"
|
||||
|
||||
table = model_cls._table
|
||||
# group by id to avoid duplicates when joining with the relation
|
||||
if required_fields - model_cls.all_db_fields:
|
||||
# merge all flexible attribute into a single JSON field
|
||||
select_fields.append(
|
||||
f"""
|
||||
json_patch(
|
||||
COALESCE({this_table}."flex_attrs [json_str]", '{{}}'),
|
||||
COALESCE({model_cls._relation._table}."flex_attrs [json_str]", '{{}}')
|
||||
) AS all_flex_attrs
|
||||
""" # noqa: E501
|
||||
)
|
||||
|
||||
sql = (
|
||||
f"SELECT {table}.* "
|
||||
f"SELECT {', '.join(select_fields)} "
|
||||
f"FROM ({_from}) "
|
||||
f"WHERE {where or 1} "
|
||||
f"GROUP BY {table}.id"
|
||||
f"GROUP BY {this_table}.id"
|
||||
)
|
||||
|
||||
if order_by:
|
||||
|
|
|
|||
|
|
@ -125,9 +125,7 @@ class FieldQuery(Query, Generic[P]):
|
|||
@property
|
||||
def field(self) -> str:
|
||||
if not self.fast:
|
||||
return (
|
||||
f'json_extract("flex_attrs [json_str]", "$.{self.field_name}")'
|
||||
)
|
||||
return f'json_extract(all_flex_attrs, "$.{self.field_name}")'
|
||||
|
||||
return (
|
||||
f"{self.table}.{self.field_name}" if self.table else self.field_name
|
||||
|
|
|
|||
|
|
@ -20,15 +20,17 @@ import itertools
|
|||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from . import Model, query
|
||||
from . import query
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Collection, Sequence
|
||||
|
||||
from ..library import LibModel
|
||||
from .query import FieldQueryType, Sort
|
||||
|
||||
Prefixes = dict[str, FieldQueryType]
|
||||
|
||||
|
||||
PARSE_QUERY_PART_REGEX = re.compile(
|
||||
# Non-capturing optional segment for the keyword.
|
||||
r"(-|\^)?" # Negation prefixes.
|
||||
|
|
@ -112,7 +114,7 @@ def parse_query_part(
|
|||
|
||||
|
||||
def construct_query_part(
|
||||
model_cls: type[Model],
|
||||
model_cls: type[LibModel],
|
||||
prefixes: Prefixes,
|
||||
query_part: str,
|
||||
) -> query.Query:
|
||||
|
|
@ -160,15 +162,7 @@ def construct_query_part(
|
|||
# Field queries get constructed according to the name of the field
|
||||
# they are querying.
|
||||
else:
|
||||
field = table = key.lower()
|
||||
if field in model_cls.shared_db_fields:
|
||||
# This field exists in both tables, so SQLite will encounter
|
||||
# an OperationalError if we try to query it in a join.
|
||||
# Using an explicit table name resolves this.
|
||||
table = f"{model_cls._table}.{field}"
|
||||
|
||||
field_in_db = field in model_cls.all_db_fields
|
||||
out_query = query_class(table, pattern, field_in_db)
|
||||
out_query = model_cls.field_query(key.lower(), pattern, query_class)
|
||||
|
||||
# Apply negation.
|
||||
if negate:
|
||||
|
|
@ -180,7 +174,7 @@ def construct_query_part(
|
|||
# TYPING ERROR
|
||||
def query_from_strings(
|
||||
query_cls: type[query.CollectionQuery],
|
||||
model_cls: type[Model],
|
||||
model_cls: type[LibModel],
|
||||
prefixes: Prefixes,
|
||||
query_parts: Collection[str],
|
||||
) -> query.Query:
|
||||
|
|
@ -197,7 +191,7 @@ def query_from_strings(
|
|||
|
||||
|
||||
def construct_sort_part(
|
||||
model_cls: type[Model],
|
||||
model_cls: type[LibModel],
|
||||
part: str,
|
||||
case_insensitive: bool = True,
|
||||
) -> Sort:
|
||||
|
|
@ -228,7 +222,7 @@ def construct_sort_part(
|
|||
|
||||
|
||||
def sort_from_strings(
|
||||
model_cls: type[Model],
|
||||
model_cls: type[LibModel],
|
||||
sort_parts: Sequence[str],
|
||||
case_insensitive: bool = True,
|
||||
) -> Sort:
|
||||
|
|
@ -247,7 +241,7 @@ def sort_from_strings(
|
|||
|
||||
|
||||
def parse_sorted_query(
|
||||
model_cls: type[Model],
|
||||
model_cls: type[LibModel],
|
||||
parts: list[str],
|
||||
prefixes: Prefixes = {},
|
||||
case_insensitive: bool = True,
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import time
|
|||
import unicodedata
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Mapping
|
||||
|
||||
import platformdirs
|
||||
from mediafile import MediaFile, UnreadableFileError
|
||||
|
|
@ -42,6 +43,9 @@ from beets.util import (
|
|||
)
|
||||
from beets.util.functemplate import Template, template
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from beets.dbcore.query import FieldQueryType
|
||||
|
||||
# To use the SQLite "blob" type, it doesn't suffice to provide a byte
|
||||
# string; SQLite treats that as encoded text. Wrapping it in a
|
||||
# `memoryview` tells it that we actually mean non-text data.
|
||||
|
|
@ -375,6 +379,39 @@ class LibModel(dbcore.Model["Library"]):
|
|||
def __bytes__(self):
|
||||
return self.__str__().encode("utf-8")
|
||||
|
||||
# Convenient queries.
|
||||
|
||||
@classmethod
|
||||
def field_query(
|
||||
cls, field: str, pattern: str, query_cls: FieldQueryType
|
||||
) -> dbcore.FieldQuery:
|
||||
"""Get a `FieldQuery` for this model."""
|
||||
fast = field in cls.all_db_fields
|
||||
if field in cls.shared_db_fields:
|
||||
# This field exists in both tables, so SQLite will encounter
|
||||
# an OperationalError if we try to use it in a query.
|
||||
# Using an explicit table name resolves this.
|
||||
field = f"{cls._table}.{field}"
|
||||
|
||||
return query_cls(field, pattern, fast)
|
||||
|
||||
@classmethod
|
||||
def all_fields_query(
|
||||
cls, pattern_by_field: Mapping[str, str]
|
||||
) -> dbcore.AndQuery:
|
||||
"""Get a query that matches many fields with different patterns.
|
||||
|
||||
`pattern_by_field` should be a mapping from field names to patterns.
|
||||
The resulting query is a conjunction ("and") of per-field queries
|
||||
for all of these field/pattern pairs.
|
||||
"""
|
||||
return dbcore.AndQuery(
|
||||
[
|
||||
cls.field_query(f, p, dbcore.MatchQuery)
|
||||
for f, p in pattern_by_field.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FormattedItemMapping(dbcore.db.FormattedMapping):
|
||||
"""Add lookup for album-level fields.
|
||||
|
|
@ -612,7 +649,7 @@ class Item(LibModel):
|
|||
an album (e.g. singletons) would be left out.
|
||||
"""
|
||||
return (
|
||||
f"LEFT JOIN {cls._relation._table} "
|
||||
f"LEFT JOIN {cls._relation.table_with_flex_attrs} "
|
||||
f"ON {cls._table}.album_id = {cls._relation._table}.id"
|
||||
)
|
||||
|
||||
|
|
@ -1233,7 +1270,7 @@ class Album(LibModel):
|
|||
any items.
|
||||
"""
|
||||
return (
|
||||
f"LEFT JOIN {cls._relation._table} "
|
||||
f"LEFT JOIN {cls._relation.table_with_flex_attrs} "
|
||||
f"ON {cls._table}.id = {cls._relation._table}.album_id"
|
||||
)
|
||||
|
||||
|
|
@ -1937,9 +1974,10 @@ class DefaultTemplateFunctions:
|
|||
subqueries.extend(initial_subqueries)
|
||||
for key in keys:
|
||||
value = db_item.get(key, "")
|
||||
# Use slow queries for flexible attributes.
|
||||
fast = key in item_keys
|
||||
subqueries.append(dbcore.MatchQuery(key, value, fast))
|
||||
subqueries.append(
|
||||
db_item.field_query(key, value, dbcore.MatchQuery)
|
||||
)
|
||||
|
||||
query = dbcore.AndQuery(subqueries)
|
||||
ambigous_items = (
|
||||
self.lib.items(query)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import os.path
|
|||
import platform
|
||||
import shutil
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
from beets import logging
|
||||
from beets.library import Album, Item
|
||||
|
|
@ -30,36 +31,38 @@ class WebPluginTest(ItemInDBTestCase):
|
|||
# Add library elements. Note that self.lib.add overrides any "id=<n>"
|
||||
# and assigns the next free id number.
|
||||
# The following adds will create items #1, #2 and #3
|
||||
path1 = (
|
||||
self.path_prefix + os.sep + os.path.join(b"path_1").decode("utf-8")
|
||||
base_path = Path(self.path_prefix + os.sep)
|
||||
album2_item1 = Item(
|
||||
title="title",
|
||||
path=str(base_path / "path_1"),
|
||||
album_id=2,
|
||||
artist="AAA Singers",
|
||||
)
|
||||
self.lib.add(
|
||||
Item(title="title", path=path1, album_id=2, artist="AAA Singers")
|
||||
album1_item = Item(
|
||||
title="another title",
|
||||
path=str(base_path / "somewhere" / "a"),
|
||||
artist="AAA Singers",
|
||||
)
|
||||
path2 = (
|
||||
self.path_prefix
|
||||
+ os.sep
|
||||
+ os.path.join(b"somewhere", b"a").decode("utf-8")
|
||||
)
|
||||
self.lib.add(
|
||||
Item(title="another title", path=path2, artist="AAA Singers")
|
||||
)
|
||||
path3 = (
|
||||
self.path_prefix
|
||||
+ os.sep
|
||||
+ os.path.join(b"somewhere", b"abc").decode("utf-8")
|
||||
)
|
||||
self.lib.add(
|
||||
Item(title="and a third", testattr="ABC", path=path3, album_id=2)
|
||||
album2_item2 = Item(
|
||||
title="and a third",
|
||||
testattr="ABC",
|
||||
path=str(base_path / "somewhere" / "abc"),
|
||||
album_id=2,
|
||||
)
|
||||
self.lib.add(album2_item1)
|
||||
self.lib.add(album1_item)
|
||||
self.lib.add(album2_item2)
|
||||
|
||||
# The following adds will create albums #1 and #2
|
||||
self.lib.add(Album(album="album", albumtest="xyz"))
|
||||
path4 = (
|
||||
self.path_prefix
|
||||
+ os.sep
|
||||
+ os.path.join(b"somewhere2", b"art_path_2").decode("utf-8")
|
||||
)
|
||||
self.lib.add(Album(album="other album", artpath=path4))
|
||||
album1 = self.lib.add_album([album1_item])
|
||||
album1.album = "album"
|
||||
album1.albumtest = "xyz"
|
||||
album1.store()
|
||||
|
||||
album2 = self.lib.add_album([album2_item1, album2_item2])
|
||||
album2.album = "other album"
|
||||
album2.artpath = str(base_path / "somewhere2" / "art_path_2")
|
||||
album2.store()
|
||||
|
||||
web.app.config["TESTING"] = True
|
||||
web.app.config["lib"] = self.lib
|
||||
|
|
|
|||
|
|
@ -1134,3 +1134,23 @@ class RelatedQueriesTest(BeetsTestCase, AssertsMixin):
|
|||
q = "album_flex:Album1"
|
||||
results = self.lib.albums(q)
|
||||
self.assert_albums_matched(results, ["Album1"])
|
||||
|
||||
def test_get_albums_filter_by_track_flex(self):
|
||||
q = "item_flex1:Album1"
|
||||
results = self.lib.albums(q)
|
||||
self.assert_albums_matched(results, ["Album1"])
|
||||
|
||||
def test_get_items_filter_by_album_flex(self):
|
||||
q = "album_flex:Album1"
|
||||
results = self.lib.items(q)
|
||||
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
|
||||
|
||||
def test_filter_by_flex(self):
|
||||
q = "item_flex1:'Item1 Flex1'"
|
||||
results = self.lib.items(q)
|
||||
self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"])
|
||||
|
||||
def test_filter_by_many_flex(self):
|
||||
q = "item_flex1:'Item1 Flex1' item_flex2:Album1"
|
||||
results = self.lib.items(q)
|
||||
self.assert_items_matched(results, ["Album1 Item1"])
|
||||
|
|
|
|||
Loading…
Reference in a new issue