diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 1f7f2cf27..b38ba1daa 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -39,8 +39,6 @@ import beets from ..util import cached_classproperty, functemplate from . import types from .query import ( - AndQuery, - FieldQuery, FieldQueryType, FieldSort, MatchQuery, @@ -775,33 +773,6 @@ class Model(ABC, Generic[D]): """Set the object's key to a value represented by a string.""" self[key] = self._parse(key, string) - # Convenient queries. - - @classmethod - def field_query( - cls, - field, - pattern, - query_cls: FieldQueryType = MatchQuery, - ) -> FieldQuery: - """Get a `FieldQuery` for this model.""" - return query_cls(field, pattern, field in cls._fields) - - @classmethod - def all_fields_query( - cls: type[Model], - pats: Mapping[str, str], - query_cls: FieldQueryType = MatchQuery, - ): - """Get a query that matches many fields with different patterns. - - `pats` should be a mapping from field names to patterns. The - resulting query is a conjunction ("and") of per-field queries - for all of these field/pattern pairs. - """ - subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()] - return AndQuery(subqueries) - # Database controller and supporting interfaces. @@ -1274,17 +1245,30 @@ class Database: where, subvals = query.clause() order_by = sort.order_clause() + this_table = model_cls._table + select_fields = [f"{this_table}.*"] _from = model_cls.table_with_flex_attrs - if query.field_names & model_cls.other_db_fields: + + required_fields = query.field_names + if required_fields - model_cls._fields.keys(): _from += f" {model_cls.relation_join}" - table = model_cls._table - # group by id to avoid duplicates when joining with the relation + if required_fields - model_cls.all_db_fields: + # merge all flexible attribute into a single JSON field + select_fields.append( + f""" + json_patch( + COALESCE({this_table}."flex_attrs [json_str]", '{{}}'), + COALESCE({model_cls._relation._table}."flex_attrs [json_str]", '{{}}') + ) AS all_flex_attrs + """ # noqa: E501 + ) + sql = ( - f"SELECT {table}.* " + f"SELECT {', '.join(select_fields)} " f"FROM ({_from}) " f"WHERE {where or 1} " - f"GROUP BY {table}.id" + f"GROUP BY {this_table}.id" ) if order_by: diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index fcaaa0a93..1f5cdb3f6 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -125,9 +125,7 @@ class FieldQuery(Query, Generic[P]): @property def field(self) -> str: if not self.fast: - return ( - f'json_extract("flex_attrs [json_str]", "$.{self.field_name}")' - ) + return f'json_extract(all_flex_attrs, "$.{self.field_name}")' return ( f"{self.table}.{self.field_name}" if self.table else self.field_name diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index 289632668..c4ad9c4c9 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -20,15 +20,17 @@ import itertools import re from typing import TYPE_CHECKING -from . import Model, query +from . import query if TYPE_CHECKING: from collections.abc import Collection, Sequence + from ..library import LibModel from .query import FieldQueryType, Sort Prefixes = dict[str, FieldQueryType] + PARSE_QUERY_PART_REGEX = re.compile( # Non-capturing optional segment for the keyword. r"(-|\^)?" # Negation prefixes. @@ -112,7 +114,7 @@ def parse_query_part( def construct_query_part( - model_cls: type[Model], + model_cls: type[LibModel], prefixes: Prefixes, query_part: str, ) -> query.Query: @@ -160,15 +162,7 @@ def construct_query_part( # Field queries get constructed according to the name of the field # they are querying. else: - field = table = key.lower() - if field in model_cls.shared_db_fields: - # This field exists in both tables, so SQLite will encounter - # an OperationalError if we try to query it in a join. - # Using an explicit table name resolves this. - table = f"{model_cls._table}.{field}" - - field_in_db = field in model_cls.all_db_fields - out_query = query_class(table, pattern, field_in_db) + out_query = model_cls.field_query(key.lower(), pattern, query_class) # Apply negation. if negate: @@ -180,7 +174,7 @@ def construct_query_part( # TYPING ERROR def query_from_strings( query_cls: type[query.CollectionQuery], - model_cls: type[Model], + model_cls: type[LibModel], prefixes: Prefixes, query_parts: Collection[str], ) -> query.Query: @@ -197,7 +191,7 @@ def query_from_strings( def construct_sort_part( - model_cls: type[Model], + model_cls: type[LibModel], part: str, case_insensitive: bool = True, ) -> Sort: @@ -228,7 +222,7 @@ def construct_sort_part( def sort_from_strings( - model_cls: type[Model], + model_cls: type[LibModel], sort_parts: Sequence[str], case_insensitive: bool = True, ) -> Sort: @@ -247,7 +241,7 @@ def sort_from_strings( def parse_sorted_query( - model_cls: type[Model], + model_cls: type[LibModel], parts: list[str], prefixes: Prefixes = {}, case_insensitive: bool = True, diff --git a/beets/library.py b/beets/library.py index 2430f7125..de56669c2 100644 --- a/beets/library.py +++ b/beets/library.py @@ -25,6 +25,7 @@ import time import unicodedata from functools import cached_property from pathlib import Path +from typing import TYPE_CHECKING, Mapping import platformdirs from mediafile import MediaFile, UnreadableFileError @@ -42,6 +43,9 @@ from beets.util import ( ) from beets.util.functemplate import Template, template +if TYPE_CHECKING: + from beets.dbcore.query import FieldQueryType + # To use the SQLite "blob" type, it doesn't suffice to provide a byte # string; SQLite treats that as encoded text. Wrapping it in a # `memoryview` tells it that we actually mean non-text data. @@ -375,6 +379,39 @@ class LibModel(dbcore.Model["Library"]): def __bytes__(self): return self.__str__().encode("utf-8") + # Convenient queries. + + @classmethod + def field_query( + cls, field: str, pattern: str, query_cls: FieldQueryType + ) -> dbcore.FieldQuery: + """Get a `FieldQuery` for this model.""" + fast = field in cls.all_db_fields + if field in cls.shared_db_fields: + # This field exists in both tables, so SQLite will encounter + # an OperationalError if we try to use it in a query. + # Using an explicit table name resolves this. + field = f"{cls._table}.{field}" + + return query_cls(field, pattern, fast) + + @classmethod + def all_fields_query( + cls, pattern_by_field: Mapping[str, str] + ) -> dbcore.AndQuery: + """Get a query that matches many fields with different patterns. + + `pattern_by_field` should be a mapping from field names to patterns. + The resulting query is a conjunction ("and") of per-field queries + for all of these field/pattern pairs. + """ + return dbcore.AndQuery( + [ + cls.field_query(f, p, dbcore.MatchQuery) + for f, p in pattern_by_field.items() + ] + ) + class FormattedItemMapping(dbcore.db.FormattedMapping): """Add lookup for album-level fields. @@ -612,7 +649,7 @@ class Item(LibModel): an album (e.g. singletons) would be left out. """ return ( - f"LEFT JOIN {cls._relation._table} " + f"LEFT JOIN {cls._relation.table_with_flex_attrs} " f"ON {cls._table}.album_id = {cls._relation._table}.id" ) @@ -1233,7 +1270,7 @@ class Album(LibModel): any items. """ return ( - f"LEFT JOIN {cls._relation._table} " + f"LEFT JOIN {cls._relation.table_with_flex_attrs} " f"ON {cls._table}.id = {cls._relation._table}.album_id" ) @@ -1937,9 +1974,10 @@ class DefaultTemplateFunctions: subqueries.extend(initial_subqueries) for key in keys: value = db_item.get(key, "") - # Use slow queries for flexible attributes. - fast = key in item_keys - subqueries.append(dbcore.MatchQuery(key, value, fast)) + subqueries.append( + db_item.field_query(key, value, dbcore.MatchQuery) + ) + query = dbcore.AndQuery(subqueries) ambigous_items = ( self.lib.items(query) diff --git a/test/plugins/test_web.py b/test/plugins/test_web.py index 2ad07bbe5..9f3af0fb9 100644 --- a/test/plugins/test_web.py +++ b/test/plugins/test_web.py @@ -5,6 +5,7 @@ import os.path import platform import shutil from collections import Counter +from pathlib import Path from beets import logging from beets.library import Album, Item @@ -30,36 +31,38 @@ class WebPluginTest(ItemInDBTestCase): # Add library elements. Note that self.lib.add overrides any "id=" # and assigns the next free id number. # The following adds will create items #1, #2 and #3 - path1 = ( - self.path_prefix + os.sep + os.path.join(b"path_1").decode("utf-8") + base_path = Path(self.path_prefix + os.sep) + album2_item1 = Item( + title="title", + path=str(base_path / "path_1"), + album_id=2, + artist="AAA Singers", ) - self.lib.add( - Item(title="title", path=path1, album_id=2, artist="AAA Singers") + album1_item = Item( + title="another title", + path=str(base_path / "somewhere" / "a"), + artist="AAA Singers", ) - path2 = ( - self.path_prefix - + os.sep - + os.path.join(b"somewhere", b"a").decode("utf-8") - ) - self.lib.add( - Item(title="another title", path=path2, artist="AAA Singers") - ) - path3 = ( - self.path_prefix - + os.sep - + os.path.join(b"somewhere", b"abc").decode("utf-8") - ) - self.lib.add( - Item(title="and a third", testattr="ABC", path=path3, album_id=2) + album2_item2 = Item( + title="and a third", + testattr="ABC", + path=str(base_path / "somewhere" / "abc"), + album_id=2, ) + self.lib.add(album2_item1) + self.lib.add(album1_item) + self.lib.add(album2_item2) + # The following adds will create albums #1 and #2 - self.lib.add(Album(album="album", albumtest="xyz")) - path4 = ( - self.path_prefix - + os.sep - + os.path.join(b"somewhere2", b"art_path_2").decode("utf-8") - ) - self.lib.add(Album(album="other album", artpath=path4)) + album1 = self.lib.add_album([album1_item]) + album1.album = "album" + album1.albumtest = "xyz" + album1.store() + + album2 = self.lib.add_album([album2_item1, album2_item2]) + album2.album = "other album" + album2.artpath = str(base_path / "somewhere2" / "art_path_2") + album2.store() web.app.config["TESTING"] = True web.app.config["lib"] = self.lib diff --git a/test/test_query.py b/test/test_query.py index 48df82b2f..6d57215fe 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -1134,3 +1134,23 @@ class RelatedQueriesTest(BeetsTestCase, AssertsMixin): q = "album_flex:Album1" results = self.lib.albums(q) self.assert_albums_matched(results, ["Album1"]) + + def test_get_albums_filter_by_track_flex(self): + q = "item_flex1:Album1" + results = self.lib.albums(q) + self.assert_albums_matched(results, ["Album1"]) + + def test_get_items_filter_by_album_flex(self): + q = "album_flex:Album1" + results = self.lib.items(q) + self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"]) + + def test_filter_by_flex(self): + q = "item_flex1:'Item1 Flex1'" + results = self.lib.items(q) + self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"]) + + def test_filter_by_many_flex(self): + q = "item_flex1:'Item1 Flex1' item_flex2:Album1" + results = self.lib.items(q) + self.assert_items_matched(results, ["Album1 Item1"])