diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index a73b4515f..c9fd5def8 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -58,15 +58,7 @@ import beets from ..util import cached_classproperty, functemplate from . import types -from .query import ( - AndQuery, - FieldQuery, - MatchQuery, - NullSort, - Query, - Sort, - TrueQuery, -) +from .query import FieldQuery, MatchQuery, NullSort, Query, Sort, TrueQuery # convert data under 'json_str' type name to Python dictionary automatically sqlite3.register_converter("json_str", json.loads) @@ -395,6 +387,10 @@ class Model(ABC): ) {cls._table} """ + @cached_classproperty + def all_model_db_fields(cls) -> Set[str]: + return set() + @classmethod def _getters(cls: Type["Model"]): """Return a mapping from field names to getter functions.""" @@ -771,33 +767,6 @@ class Model(ABC): """Set the object's key to a value represented by a string.""" self[key] = self._parse(key, string) - # Convenient queries. - - @classmethod - def field_query( - cls, - field, - pattern, - query_cls: Type[FieldQuery] = MatchQuery, - ) -> FieldQuery: - """Get a `FieldQuery` for this model.""" - return query_cls(field, pattern, field in cls._fields) - - @classmethod - def all_fields_query( - cls: Type["Model"], - pats: Mapping, - query_cls: Type[FieldQuery] = MatchQuery, - ): - """Get a query that matches many fields with different patterns. - - `pats` should be a mapping from field names to patterns. The - resulting query is a conjunction ("and") of per-field queries - for all of these field/pattern pairs. - """ - subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()] - return AndQuery(subqueries) - # Database controller and supporting interfaces. @@ -1270,13 +1239,26 @@ class Database: where, subvals = query.clause() order_by = sort.order_clause() + this_table = model_cls._table + select_fields = [f"{this_table}.*"] _from = model_cls.table_with_flex_attrs + required_fields = query.field_names if required_fields - model_cls._fields.keys(): _from += f" {model_cls.relation_join}" - table = model_cls._table - sql = f"SELECT {table}.* FROM {_from} WHERE {where or 1} GROUP BY {table}.id" + if required_fields - model_cls.all_model_db_fields: + # merge all flexible attribute into a single JSON field + select_fields.append( + f""" + json_patch( + COALESCE({this_table}."flex_attrs [json_str]", '{{}}'), + COALESCE({model_cls._relation._table}."flex_attrs [json_str]", '{{}}') + ) AS all_flex_attrs + """ # noqa: E501 + ) + + sql = f"SELECT {', '.join(select_fields)} FROM {_from} WHERE {where or 1} GROUP BY {this_table}.id" # noqa: E501 if order_by: # the sort field may exist in both 'items' and 'albums' tables diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index d4a4fd4f7..1797ddc71 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -144,7 +144,7 @@ class FieldQuery(Query, Generic[P]): @property def col_name(self) -> str: if not self.fast: - return f'json_extract("flex_attrs [json_str]", "$.{self.field}")' + return f'json_extract(all_flex_attrs, "$.{self.field}")' return f"{self.table}.{self.field}" if self.table else self.field diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index caea88e5d..db989c5d0 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -16,12 +16,23 @@ import itertools import re -from typing import Collection, Dict, List, Optional, Sequence, Tuple, Type +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, +) -from .. import library from . import Model, query from .query import Sort +if TYPE_CHECKING: + from ..library import LibModel + PARSE_QUERY_PART_REGEX = re.compile( # Non-capturing optional segment for the keyword. r"(-|\^)?" # Negation prefixes. @@ -105,7 +116,7 @@ def parse_query_part( def construct_query_part( - model_cls: Type[Model], + model_cls: Type["LibModel"], prefixes: Dict, query_part: str, ) -> query.Query: @@ -153,17 +164,7 @@ def construct_query_part( # Field queries get constructed according to the name of the field # they are querying. else: - key = key.lower() - album_fields = library.Album._fields.keys() - item_fields = library.Item._fields.keys() - fast = key in album_fields | item_fields - - if key in album_fields & item_fields: - # This field exists in both tables, so SQLite will encounter - # an OperationalError. Using an explicit table name resolves this. - key = f"{model_cls._table}.{key}" - - out_query = query_class(key, pattern, fast) + out_query = model_cls.field_query(key.lower(), pattern, query_class) # Apply negation. if negate: diff --git a/beets/library.py b/beets/library.py index 433393ccb..2871f7807 100644 --- a/beets/library.py +++ b/beets/library.py @@ -14,6 +14,7 @@ """The core data store and collection logic for beets. """ +from __future__ import annotations import os import re @@ -23,6 +24,7 @@ import sys import time import unicodedata from functools import cached_property +from typing import Mapping, Set, Type from mediafile import MediaFile, UnreadableFileError @@ -387,6 +389,14 @@ class LibModel(dbcore.Model): # Config key that specifies how an instance should be formatted. _format_config_key: str + @cached_classproperty + def all_model_db_fields(cls) -> Set[str]: + return cls._fields.keys() | cls._relation._fields.keys() + + @cached_classproperty + def shared_model_db_fields(cls) -> Set[str]: + return cls._fields.keys() & cls._relation._fields.keys() + def _template_funcs(self): funcs = DefaultTemplateFunctions(self, self._db).functions() funcs.update(plugins.template_funcs()) @@ -416,6 +426,39 @@ class LibModel(dbcore.Model): def __bytes__(self): return self.__str__().encode("utf-8") + # Convenient queries. + + @classmethod + def field_query( + cls, field: str, pattern: str, query_cls: Type[dbcore.FieldQuery] + ) -> dbcore.Query: + """Get a `FieldQuery` for this model.""" + fast = field in cls.all_model_db_fields + if field in cls.shared_model_db_fields: + # This field exists in both tables, so SQLite will encounter + # an OperationalError if we try to use it in a query. + # Using an explicit table name resolves this. + field = f"{cls._table}.{field}" + + return query_cls(field, pattern, fast) + + @classmethod + def all_fields_query( + cls, pattern_by_field: Mapping[str, str] + ) -> dbcore.AndQuery: + """Get a query that matches many fields with different patterns. + + `pattern_by_field` should be a mapping from field names to patterns. + The resulting query is a conjunction ("and") of per-field queries + for all of these field/pattern pairs. + """ + return dbcore.AndQuery( + [ + cls.field_query(f, p, dbcore.MatchQuery) + for f, p in pattern_by_field.items() + ] + ) + class FormattedItemMapping(dbcore.db.FormattedMapping): """Add lookup for album-level fields. @@ -652,7 +695,10 @@ class Item(LibModel): We need to use a LEFT JOIN here, otherwise items that are not part of an album (e.g. singletons) would be left out. """ - return f"LEFT JOIN {cls._relation._table} ON {cls._table}.album_id = {cls._relation._table}.id" + return ( + f"LEFT JOIN {cls._relation.table_with_flex_attrs}" + f" ON {cls._table}.album_id = {cls._relation._table}.id" + ) @property def _cached_album(self): @@ -1265,7 +1311,10 @@ class Album(LibModel): Here we can use INNER JOIN (which is more performant than LEFT JOIN), since we only want to see albums that have at least one Item in them. """ - return f"INNER JOIN {cls._relation._table} ON {cls._table}.id = {cls._relation._table}.album_id" + return ( + f"INNER JOIN {cls._relation.table_with_flex_attrs}" + f" ON {cls._table}.id = {cls._relation._table}.album_id" + ) @classmethod def _getters(cls): @@ -1955,9 +2004,10 @@ class DefaultTemplateFunctions: subqueries.extend(initial_subqueries) for key in keys: value = db_item.get(key, "") - # Use slow queries for flexible attributes. - fast = key in item_keys - subqueries.append(dbcore.MatchQuery(key, value, fast)) + subqueries.append( + db_item.field_query(key, value, dbcore.MatchQuery) + ) + query = dbcore.AndQuery(subqueries) ambigous_items = ( self.lib.items(query) diff --git a/test/plugins/test_web.py b/test/plugins/test_web.py index afd1ed706..7dfee8321 100644 --- a/test/plugins/test_web.py +++ b/test/plugins/test_web.py @@ -5,6 +5,7 @@ import os.path import platform import shutil import unittest +from pathlib import Path from beets import logging from beets.library import Album, Item @@ -29,36 +30,38 @@ class WebPluginTest(_common.LibTestCase): # Add library elements. Note that self.lib.add overrides any "id=" # and assigns the next free id number. # The following adds will create items #1, #2 and #3 - path1 = ( - self.path_prefix + os.sep + os.path.join(b"path_1").decode("utf-8") + base_path = Path(self.path_prefix + os.sep) + album2_item1 = Item( + title="title", + path=str(base_path / "path_1"), + album_id=2, + artist="AAA Singers", ) - self.lib.add( - Item(title="title", path=path1, album_id=2, artist="AAA Singers") + album1_item = Item( + title="another title", + path=str(base_path / "somewhere" / "a"), + artist="AAA Singers", ) - path2 = ( - self.path_prefix - + os.sep - + os.path.join(b"somewhere", b"a").decode("utf-8") - ) - self.lib.add( - Item(title="another title", path=path2, artist="AAA Singers") - ) - path3 = ( - self.path_prefix - + os.sep - + os.path.join(b"somewhere", b"abc").decode("utf-8") - ) - self.lib.add( - Item(title="and a third", testattr="ABC", path=path3, album_id=2) + album2_item2 = Item( + title="and a third", + testattr="ABC", + path=str(base_path / "somewhere" / "abc"), + album_id=2, ) + self.lib.add(album2_item1) + self.lib.add(album1_item) + self.lib.add(album2_item2) + # The following adds will create albums #1 and #2 - self.lib.add(Album(album="album", albumtest="xyz")) - path4 = ( - self.path_prefix - + os.sep - + os.path.join(b"somewhere2", b"art_path_2").decode("utf-8") - ) - self.lib.add(Album(album="other album", artpath=path4)) + album1 = self.lib.add_album([album1_item]) + album1.album = "album" + album1.albumtest = "xyz" + album1.store() + + album2 = self.lib.add_album([album2_item1, album2_item2]) + album2.album = "other album" + album2.artpath = str(base_path / "somewhere2" / "art_path_2") + album2.store() web.app.config["TESTING"] = True web.app.config["lib"] = self.lib diff --git a/test/test_dbcore.py b/test/test_dbcore.py index 763601b7f..67165f7f4 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -21,6 +21,7 @@ import unittest from tempfile import mkstemp from beets import dbcore +from beets.library import LibModel from beets.test import _common # Fixture: concrete database and model classes. For migration tests, we @@ -42,7 +43,7 @@ class QueryFixture(dbcore.query.FieldQuery): return True -class ModelFixture1(dbcore.Model): +class ModelFixture1(LibModel): _table = "test" _flex_table = "testflex" _fields = { diff --git a/test/test_query.py b/test/test_query.py index 109645374..d2a763b04 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -1152,6 +1152,26 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin): results = self.lib.albums(q) self.assert_albums_matched(results, ["Album1"]) + def test_get_albums_filter_by_track_flex(self): + q = "item_flex1:Album1" + results = self.lib.albums(q) + self.assert_albums_matched(results, ["Album1"]) + + def test_get_items_filter_by_album_flex(self): + q = "album_flex:Album1" + results = self.lib.items(q) + self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"]) + + def test_filter_by_flex(self): + q = "item_flex1:'Item1 Flex1'" + results = self.lib.items(q) + self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"]) + + def test_filter_by_many_flex(self): + q = "item_flex1:'Item1 Flex1' item_flex2:Album1" + results = self.lib.items(q) + self.assert_items_matched(results, ["Album1 Item1"]) + def suite(): return unittest.TestLoader().loadTestsFromName(__name__)