mirror of
https://github.com/beetbox/beets.git
synced 2026-01-04 15:03:22 +01:00
Enable querying related flexible attributes
Unify query creation logic from - queryparse.py:construct_query_part, - Model.field_query, - DefaultTemplateFunctions._tmpl_unique to a single implementation under `LibModel.field_query` class method. This method should be used for query resolution for model (flex)fields. Allow filtering item attributes in album queries and vice versa by merging `flex_attrs` from Album and Item together as `all_flex_attrs`. This field is only used for filtering and is discarded after.
This commit is contained in:
parent
484c00e223
commit
9207b17d13
7 changed files with 142 additions and 85 deletions
|
|
@ -58,15 +58,7 @@ import beets
|
|||
|
||||
from ..util import cached_classproperty, functemplate
|
||||
from . import types
|
||||
from .query import (
|
||||
AndQuery,
|
||||
FieldQuery,
|
||||
MatchQuery,
|
||||
NullSort,
|
||||
Query,
|
||||
Sort,
|
||||
TrueQuery,
|
||||
)
|
||||
from .query import FieldQuery, MatchQuery, NullSort, Query, Sort, TrueQuery
|
||||
|
||||
# convert data under 'json_str' type name to Python dictionary automatically
|
||||
sqlite3.register_converter("json_str", json.loads)
|
||||
|
|
@ -395,6 +387,10 @@ class Model(ABC):
|
|||
) {cls._table}
|
||||
"""
|
||||
|
||||
@cached_classproperty
|
||||
def all_model_db_fields(cls) -> Set[str]:
|
||||
return set()
|
||||
|
||||
@classmethod
|
||||
def _getters(cls: Type["Model"]):
|
||||
"""Return a mapping from field names to getter functions."""
|
||||
|
|
@ -771,33 +767,6 @@ class Model(ABC):
|
|||
"""Set the object's key to a value represented by a string."""
|
||||
self[key] = self._parse(key, string)
|
||||
|
||||
# Convenient queries.
|
||||
|
||||
@classmethod
|
||||
def field_query(
|
||||
cls,
|
||||
field,
|
||||
pattern,
|
||||
query_cls: Type[FieldQuery] = MatchQuery,
|
||||
) -> FieldQuery:
|
||||
"""Get a `FieldQuery` for this model."""
|
||||
return query_cls(field, pattern, field in cls._fields)
|
||||
|
||||
@classmethod
|
||||
def all_fields_query(
|
||||
cls: Type["Model"],
|
||||
pats: Mapping,
|
||||
query_cls: Type[FieldQuery] = MatchQuery,
|
||||
):
|
||||
"""Get a query that matches many fields with different patterns.
|
||||
|
||||
`pats` should be a mapping from field names to patterns. The
|
||||
resulting query is a conjunction ("and") of per-field queries
|
||||
for all of these field/pattern pairs.
|
||||
"""
|
||||
subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()]
|
||||
return AndQuery(subqueries)
|
||||
|
||||
|
||||
# Database controller and supporting interfaces.
|
||||
|
||||
|
|
@ -1270,13 +1239,26 @@ class Database:
|
|||
where, subvals = query.clause()
|
||||
order_by = sort.order_clause()
|
||||
|
||||
this_table = model_cls._table
|
||||
select_fields = [f"{this_table}.*"]
|
||||
_from = model_cls.table_with_flex_attrs
|
||||
|
||||
required_fields = query.field_names
|
||||
if required_fields - model_cls._fields.keys():
|
||||
_from += f" {model_cls.relation_join}"
|
||||
|
||||
table = model_cls._table
|
||||
sql = f"SELECT {table}.* FROM {_from} WHERE {where or 1} GROUP BY {table}.id"
|
||||
if required_fields - model_cls.all_model_db_fields:
|
||||
# merge all flexible attribute into a single JSON field
|
||||
select_fields.append(
|
||||
f"""
|
||||
json_patch(
|
||||
COALESCE({this_table}."flex_attrs [json_str]", '{{}}'),
|
||||
COALESCE({model_cls._relation._table}."flex_attrs [json_str]", '{{}}')
|
||||
) AS all_flex_attrs
|
||||
""" # noqa: E501
|
||||
)
|
||||
|
||||
sql = f"SELECT {', '.join(select_fields)} FROM {_from} WHERE {where or 1} GROUP BY {this_table}.id" # noqa: E501
|
||||
|
||||
if order_by:
|
||||
# the sort field may exist in both 'items' and 'albums' tables
|
||||
|
|
|
|||
|
|
@ -144,7 +144,7 @@ class FieldQuery(Query, Generic[P]):
|
|||
@property
|
||||
def col_name(self) -> str:
|
||||
if not self.fast:
|
||||
return f'json_extract("flex_attrs [json_str]", "$.{self.field}")'
|
||||
return f'json_extract(all_flex_attrs, "$.{self.field}")'
|
||||
|
||||
return f"{self.table}.{self.field}" if self.table else self.field
|
||||
|
||||
|
|
|
|||
|
|
@ -16,12 +16,23 @@
|
|||
|
||||
import itertools
|
||||
import re
|
||||
from typing import Collection, Dict, List, Optional, Sequence, Tuple, Type
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Collection,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
from .. import library
|
||||
from . import Model, query
|
||||
from .query import Sort
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..library import LibModel
|
||||
|
||||
PARSE_QUERY_PART_REGEX = re.compile(
|
||||
# Non-capturing optional segment for the keyword.
|
||||
r"(-|\^)?" # Negation prefixes.
|
||||
|
|
@ -105,7 +116,7 @@ def parse_query_part(
|
|||
|
||||
|
||||
def construct_query_part(
|
||||
model_cls: Type[Model],
|
||||
model_cls: Type["LibModel"],
|
||||
prefixes: Dict,
|
||||
query_part: str,
|
||||
) -> query.Query:
|
||||
|
|
@ -153,17 +164,7 @@ def construct_query_part(
|
|||
# Field queries get constructed according to the name of the field
|
||||
# they are querying.
|
||||
else:
|
||||
key = key.lower()
|
||||
album_fields = library.Album._fields.keys()
|
||||
item_fields = library.Item._fields.keys()
|
||||
fast = key in album_fields | item_fields
|
||||
|
||||
if key in album_fields & item_fields:
|
||||
# This field exists in both tables, so SQLite will encounter
|
||||
# an OperationalError. Using an explicit table name resolves this.
|
||||
key = f"{model_cls._table}.{key}"
|
||||
|
||||
out_query = query_class(key, pattern, fast)
|
||||
out_query = model_cls.field_query(key.lower(), pattern, query_class)
|
||||
|
||||
# Apply negation.
|
||||
if negate:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
"""The core data store and collection logic for beets.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
|
|
@ -23,6 +24,7 @@ import sys
|
|||
import time
|
||||
import unicodedata
|
||||
from functools import cached_property
|
||||
from typing import Mapping, Set, Type
|
||||
|
||||
from mediafile import MediaFile, UnreadableFileError
|
||||
|
||||
|
|
@ -387,6 +389,14 @@ class LibModel(dbcore.Model):
|
|||
# Config key that specifies how an instance should be formatted.
|
||||
_format_config_key: str
|
||||
|
||||
@cached_classproperty
|
||||
def all_model_db_fields(cls) -> Set[str]:
|
||||
return cls._fields.keys() | cls._relation._fields.keys()
|
||||
|
||||
@cached_classproperty
|
||||
def shared_model_db_fields(cls) -> Set[str]:
|
||||
return cls._fields.keys() & cls._relation._fields.keys()
|
||||
|
||||
def _template_funcs(self):
|
||||
funcs = DefaultTemplateFunctions(self, self._db).functions()
|
||||
funcs.update(plugins.template_funcs())
|
||||
|
|
@ -416,6 +426,39 @@ class LibModel(dbcore.Model):
|
|||
def __bytes__(self):
|
||||
return self.__str__().encode("utf-8")
|
||||
|
||||
# Convenient queries.
|
||||
|
||||
@classmethod
|
||||
def field_query(
|
||||
cls, field: str, pattern: str, query_cls: Type[dbcore.FieldQuery]
|
||||
) -> dbcore.Query:
|
||||
"""Get a `FieldQuery` for this model."""
|
||||
fast = field in cls.all_model_db_fields
|
||||
if field in cls.shared_model_db_fields:
|
||||
# This field exists in both tables, so SQLite will encounter
|
||||
# an OperationalError if we try to use it in a query.
|
||||
# Using an explicit table name resolves this.
|
||||
field = f"{cls._table}.{field}"
|
||||
|
||||
return query_cls(field, pattern, fast)
|
||||
|
||||
@classmethod
|
||||
def all_fields_query(
|
||||
cls, pattern_by_field: Mapping[str, str]
|
||||
) -> dbcore.AndQuery:
|
||||
"""Get a query that matches many fields with different patterns.
|
||||
|
||||
`pattern_by_field` should be a mapping from field names to patterns.
|
||||
The resulting query is a conjunction ("and") of per-field queries
|
||||
for all of these field/pattern pairs.
|
||||
"""
|
||||
return dbcore.AndQuery(
|
||||
[
|
||||
cls.field_query(f, p, dbcore.MatchQuery)
|
||||
for f, p in pattern_by_field.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FormattedItemMapping(dbcore.db.FormattedMapping):
|
||||
"""Add lookup for album-level fields.
|
||||
|
|
@ -652,7 +695,10 @@ class Item(LibModel):
|
|||
We need to use a LEFT JOIN here, otherwise items that are not part of
|
||||
an album (e.g. singletons) would be left out.
|
||||
"""
|
||||
return f"LEFT JOIN {cls._relation._table} ON {cls._table}.album_id = {cls._relation._table}.id"
|
||||
return (
|
||||
f"LEFT JOIN {cls._relation.table_with_flex_attrs}"
|
||||
f" ON {cls._table}.album_id = {cls._relation._table}.id"
|
||||
)
|
||||
|
||||
@property
|
||||
def _cached_album(self):
|
||||
|
|
@ -1265,7 +1311,10 @@ class Album(LibModel):
|
|||
Here we can use INNER JOIN (which is more performant than LEFT JOIN),
|
||||
since we only want to see albums that have at least one Item in them.
|
||||
"""
|
||||
return f"INNER JOIN {cls._relation._table} ON {cls._table}.id = {cls._relation._table}.album_id"
|
||||
return (
|
||||
f"INNER JOIN {cls._relation.table_with_flex_attrs}"
|
||||
f" ON {cls._table}.id = {cls._relation._table}.album_id"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _getters(cls):
|
||||
|
|
@ -1955,9 +2004,10 @@ class DefaultTemplateFunctions:
|
|||
subqueries.extend(initial_subqueries)
|
||||
for key in keys:
|
||||
value = db_item.get(key, "")
|
||||
# Use slow queries for flexible attributes.
|
||||
fast = key in item_keys
|
||||
subqueries.append(dbcore.MatchQuery(key, value, fast))
|
||||
subqueries.append(
|
||||
db_item.field_query(key, value, dbcore.MatchQuery)
|
||||
)
|
||||
|
||||
query = dbcore.AndQuery(subqueries)
|
||||
ambigous_items = (
|
||||
self.lib.items(query)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import os.path
|
|||
import platform
|
||||
import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from beets import logging
|
||||
from beets.library import Album, Item
|
||||
|
|
@ -29,36 +30,38 @@ class WebPluginTest(_common.LibTestCase):
|
|||
# Add library elements. Note that self.lib.add overrides any "id=<n>"
|
||||
# and assigns the next free id number.
|
||||
# The following adds will create items #1, #2 and #3
|
||||
path1 = (
|
||||
self.path_prefix + os.sep + os.path.join(b"path_1").decode("utf-8")
|
||||
base_path = Path(self.path_prefix + os.sep)
|
||||
album2_item1 = Item(
|
||||
title="title",
|
||||
path=str(base_path / "path_1"),
|
||||
album_id=2,
|
||||
artist="AAA Singers",
|
||||
)
|
||||
self.lib.add(
|
||||
Item(title="title", path=path1, album_id=2, artist="AAA Singers")
|
||||
album1_item = Item(
|
||||
title="another title",
|
||||
path=str(base_path / "somewhere" / "a"),
|
||||
artist="AAA Singers",
|
||||
)
|
||||
path2 = (
|
||||
self.path_prefix
|
||||
+ os.sep
|
||||
+ os.path.join(b"somewhere", b"a").decode("utf-8")
|
||||
)
|
||||
self.lib.add(
|
||||
Item(title="another title", path=path2, artist="AAA Singers")
|
||||
)
|
||||
path3 = (
|
||||
self.path_prefix
|
||||
+ os.sep
|
||||
+ os.path.join(b"somewhere", b"abc").decode("utf-8")
|
||||
)
|
||||
self.lib.add(
|
||||
Item(title="and a third", testattr="ABC", path=path3, album_id=2)
|
||||
album2_item2 = Item(
|
||||
title="and a third",
|
||||
testattr="ABC",
|
||||
path=str(base_path / "somewhere" / "abc"),
|
||||
album_id=2,
|
||||
)
|
||||
self.lib.add(album2_item1)
|
||||
self.lib.add(album1_item)
|
||||
self.lib.add(album2_item2)
|
||||
|
||||
# The following adds will create albums #1 and #2
|
||||
self.lib.add(Album(album="album", albumtest="xyz"))
|
||||
path4 = (
|
||||
self.path_prefix
|
||||
+ os.sep
|
||||
+ os.path.join(b"somewhere2", b"art_path_2").decode("utf-8")
|
||||
)
|
||||
self.lib.add(Album(album="other album", artpath=path4))
|
||||
album1 = self.lib.add_album([album1_item])
|
||||
album1.album = "album"
|
||||
album1.albumtest = "xyz"
|
||||
album1.store()
|
||||
|
||||
album2 = self.lib.add_album([album2_item1, album2_item2])
|
||||
album2.album = "other album"
|
||||
album2.artpath = str(base_path / "somewhere2" / "art_path_2")
|
||||
album2.store()
|
||||
|
||||
web.app.config["TESTING"] = True
|
||||
web.app.config["lib"] = self.lib
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import unittest
|
|||
from tempfile import mkstemp
|
||||
|
||||
from beets import dbcore
|
||||
from beets.library import LibModel
|
||||
from beets.test import _common
|
||||
|
||||
# Fixture: concrete database and model classes. For migration tests, we
|
||||
|
|
@ -42,7 +43,7 @@ class QueryFixture(dbcore.query.FieldQuery):
|
|||
return True
|
||||
|
||||
|
||||
class ModelFixture1(dbcore.Model):
|
||||
class ModelFixture1(LibModel):
|
||||
_table = "test"
|
||||
_flex_table = "testflex"
|
||||
_fields = {
|
||||
|
|
|
|||
|
|
@ -1152,6 +1152,26 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
|
|||
results = self.lib.albums(q)
|
||||
self.assert_albums_matched(results, ["Album1"])
|
||||
|
||||
def test_get_albums_filter_by_track_flex(self):
|
||||
q = "item_flex1:Album1"
|
||||
results = self.lib.albums(q)
|
||||
self.assert_albums_matched(results, ["Album1"])
|
||||
|
||||
def test_get_items_filter_by_album_flex(self):
|
||||
q = "album_flex:Album1"
|
||||
results = self.lib.items(q)
|
||||
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
|
||||
|
||||
def test_filter_by_flex(self):
|
||||
q = "item_flex1:'Item1 Flex1'"
|
||||
results = self.lib.items(q)
|
||||
self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"])
|
||||
|
||||
def test_filter_by_many_flex(self):
|
||||
q = "item_flex1:'Item1 Flex1' item_flex2:Album1"
|
||||
results = self.lib.items(q)
|
||||
self.assert_items_matched(results, ["Album1 Item1"])
|
||||
|
||||
|
||||
def suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
|
|
|||
Loading…
Reference in a new issue