From 981a61bd56559c28b500b3ceafd7b0257f0ceb60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Wed, 19 Jun 2024 22:33:33 +0100 Subject: [PATCH] Add support for filtering relations --- beets/dbcore/db.py | 69 ++++++++++++++++++++++++++++---------- beets/dbcore/query.py | 28 +++++++++++++++- beets/dbcore/queryparse.py | 3 +- beets/library.py | 34 +++++++++++++++++++ beetsplug/limit.py | 5 +++ docs/changelog.rst | 7 ++++ docs/reference/query.rst | 18 +++++++--- test/test_query.py | 40 +++++++++++++++++++++- 8 files changed, 180 insertions(+), 24 deletions(-) diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 4f7665e37..566c11631 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -40,7 +40,6 @@ from typing import ( Mapping, Optional, Sequence, - Set, Tuple, Type, TypeVar, @@ -51,9 +50,8 @@ from typing import ( from unidecode import unidecode import beets -from beets.util import functemplate -from ..util.functemplate import Template +from ..util import cached_classproperty, functemplate from . import types from .query import ( AndQuery, @@ -323,6 +321,32 @@ class Model(ABC): to the database. """ + @cached_classproperty + def _relation(cls) -> type[Model]: + """The model that this model is closely related to.""" + return cls + + @cached_classproperty + def relation_join(cls) -> str: + """Return the join required to include the related table in the query. + + This is intended to be used as a FROM clause in the SQL query. + """ + return "" + + @cached_classproperty + def all_db_fields(cls) -> set[str]: + return cls._fields.keys() | cls._relation._fields.keys() + + @cached_classproperty + def shared_db_fields(cls) -> set[str]: + return cls._fields.keys() & cls._relation._fields.keys() + + @cached_classproperty + def other_db_fields(cls) -> set[str]: + """Fields in the related table.""" + return cls._relation._fields.keys() - cls.shared_db_fields + @classmethod def _getters(cls: Type["Model"]): """Return a mapping from field names to getter functions.""" @@ -344,7 +368,7 @@ class Model(ABC): initial field values. """ self._db = db - self._dirty: Set[str] = set() + self._dirty: set[str] = set() self._values_fixed = LazyConvertDict(self) self._values_flex = LazyConvertDict(self) @@ -668,7 +692,7 @@ class Model(ABC): def evaluate_template( self, - template: Union[str, Template], + template: Union[str, functemplate.Template], for_path: bool = False, ) -> str: """Evaluate a template (a string or a `Template` object) using @@ -1223,24 +1247,35 @@ class Database: where, subvals = query.clause() order_by = sort.order_clause() - sql = ("SELECT * FROM {} WHERE {} {}").format( - model_cls._table, - where or "1", - f"ORDER BY {order_by}" if order_by else "", - ) + table = model_cls._table + _from = table + if query.field_names & model_cls.other_db_fields: + _from += f" {model_cls.relation_join}" + # group by id to avoid duplicates when joining with the relation + sql = ( + f"SELECT {table}.* " + f"FROM ({_from}) " + f"WHERE {where or 1} " + f"GROUP BY {table}.id" + ) # Fetch flexible attributes for items matching the main query. # Doing the per-item filtering in python is faster than issuing # one query per item to sqlite. - flex_sql = """ - SELECT * FROM {} WHERE entity_id IN - (SELECT id FROM {} WHERE {}); - """.format( - model_cls._flex_table, - model_cls._table, - where or "1", + flex_sql = ( + "SELECT * " + f"FROM {model_cls._flex_table} " + f"WHERE entity_id IN (SELECT id FROM ({sql}))" ) + if order_by: + # the sort field may exist in both 'items' and 'albums' tables + # (when they are joined), causing ambiguous column OperationalError + # if we try to order directly. + # Since the join is required only for filtering, we can filter in + # a subquery and order the result, which returns unique fields. + sql = f"SELECT * FROM ({sql}) ORDER BY {order_by}" + with self.transaction() as tx: rows = tx.query(sql, subvals) flex_rows = tx.query(flex_sql, subvals) diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 2e1385ca2..5309ebaf3 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -21,7 +21,7 @@ import unicodedata from abc import ABC, abstractmethod from datetime import datetime, timedelta from functools import reduce -from operator import mul +from operator import mul, or_ from typing import ( TYPE_CHECKING, Any, @@ -33,6 +33,7 @@ from typing import ( Optional, Pattern, Sequence, + Set, Tuple, Type, TypeVar, @@ -81,6 +82,11 @@ class InvalidQueryArgumentValueError(ParsingError): class Query(ABC): """An abstract class representing a query into the database.""" + @property + def field_names(self) -> Set[str]: + """Return a set with field names that this query operates on.""" + return set() + def clause(self) -> Tuple[Optional[str], Sequence[Any]]: """Generate an SQLite expression implementing the query. @@ -128,6 +134,11 @@ class FieldQuery(Query, Generic[P]): same matching functionality in SQLite. """ + @property + def field_names(self) -> Set[str]: + """Return a set with field names that this query operates on.""" + return {self.field} + def __init__(self, field: str, pattern: P, fast: bool = True): self.field = field self.pattern = pattern @@ -443,6 +454,11 @@ class CollectionQuery(Query): indexed like a list to access the sub-queries. """ + @property + def field_names(self) -> Set[str]: + """Return a set with field names that this query operates on.""" + return reduce(or_, (sq.field_names for sq in self.subqueries)) + def __init__(self, subqueries: Sequence = ()): self.subqueries = subqueries @@ -498,6 +514,11 @@ class AnyFieldQuery(CollectionQuery): constructor. """ + @property + def field_names(self) -> Set[str]: + """Return a set with field names that this query operates on.""" + return set(self.fields) + def __init__(self, pattern, fields, cls: Type[FieldQuery]): self.pattern = pattern self.fields = fields @@ -570,6 +591,11 @@ class NotQuery(Query): performing `not(subquery)` without using regular expressions. """ + @property + def field_names(self) -> Set[str]: + """Return a set with field names that this query operates on.""" + return self.subquery.field_names + def __init__(self, subquery): self.subquery = subquery diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index ea6f16922..fd29aedff 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -152,7 +152,8 @@ def construct_query_part( # Field queries get constructed according to the name of the field # they are querying. else: - out_query = query_class(key.lower(), pattern, key in model_cls._fields) + key = key.lower() + out_query = query_class(key, pattern, key in model_cls.all_db_fields) # Apply negation. if negate: diff --git a/beets/library.py b/beets/library.py index 68789bf84..6d0ee613b 100644 --- a/beets/library.py +++ b/beets/library.py @@ -14,6 +14,7 @@ """The core data store and collection logic for beets. """ +from __future__ import annotations import os import re @@ -32,6 +33,7 @@ from beets.dbcore import Results, types from beets.util import ( MoveOperation, bytestring_path, + cached_classproperty, normpath, samefile, syspath, @@ -640,6 +642,22 @@ class Item(LibModel): # Cached album object. Read-only. __album = None + @cached_classproperty + def _relation(cls) -> type[Album]: + return Album + + @cached_classproperty + def relation_join(cls) -> str: + """Return the FROM clause which includes related albums. + + We need to use a LEFT JOIN here, otherwise items that are not part of + an album (e.g. singletons) would be left out. + """ + return ( + f"LEFT JOIN {cls._relation._table} " + f"ON {cls._table}.album_id = {cls._relation._table}.id" + ) + @property def _cached_album(self): """The Album object that this item belongs to, if any, or @@ -1240,6 +1258,22 @@ class Album(LibModel): _format_config_key = "format_album" + @cached_classproperty + def _relation(cls) -> type[Item]: + return Item + + @cached_classproperty + def relation_join(cls) -> str: + """Return FROM clause which joins on related album items. + + Use LEFT join to select all albums, including those that do not have + any items. + """ + return ( + f"LEFT JOIN {cls._relation._table} " + f"ON {cls._table}.id = {cls._relation._table}.album_id" + ) + @classmethod def _getters(cls): # In addition to plugin-provided computed fields, also expose diff --git a/beetsplug/limit.py b/beetsplug/limit.py index 5c351a1a4..0a13a78aa 100644 --- a/beetsplug/limit.py +++ b/beetsplug/limit.py @@ -79,6 +79,11 @@ class LimitPlugin(BeetsPlugin): n = 0 N = None + def __init__(self, *args, **kwargs) -> None: + """Force the query to be slow so that 'value_match' is called.""" + super().__init__(*args, **kwargs) + self.fast = False + @classmethod def value_match(cls, pattern, value): if cls.N is None: diff --git a/docs/changelog.rst b/docs/changelog.rst index 3725e4993..cd08fe43c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,13 @@ Unreleased Changelog goes here! Please add your entry to the bottom of one of the lists below! +New features: + +* Ability to query albums with track db fields and vice-versa, for example + `beet list -a title:something` or `beet list artpath:cover`. Consequently + album queries involving `path` field have been sped up, like `beet list -a + path:/path/`. + Bug fixes: * Improved naming of temporary files by separating the random part with the file extension. diff --git a/docs/reference/query.rst b/docs/reference/query.rst index 2bed2ed68..eaa2d6701 100644 --- a/docs/reference/query.rst +++ b/docs/reference/query.rst @@ -17,7 +17,9 @@ This command:: $ beet list love -will show all tracks matching the query string ``love``. By default any unadorned word like this matches in a track's title, artist, album name, album artist, genre and comments. See below on how to search other fields. +will show all tracks matching the query string ``love``. By default any +unadorned word like this matches in a track's title, artist, album name, album +artist, genre and comments. See below on how to search other fields. For example, this is what I might see when I run the command above:: @@ -83,6 +85,14 @@ For multi-valued tags (such as ``artists`` or ``albumartists``), a regular expression search must be used to search for a single value within the multi-valued tag. +Note that you can filter albums by querying tracks fields and vice versa:: + + $ beet list -a title:love + +and vice versa:: + + $ beet list art_path::love + Phrases ------- @@ -115,9 +125,9 @@ the field name's colon and before the expression:: $ beet list artist:=AIR The first query is a simple substring one that returns tracks by Air, AIR, and -Air Supply. The second query returns tracks by Air and AIR, since both are a +Air Supply. The second query returns tracks by Air and AIR, since both are a case-insensitive match for the entire expression, but does not return anything -by Air Supply. The third query, which requires a case-sensitive exact match, +by Air Supply. The third query, which requires a case-sensitive exact match, returns tracks by AIR only. Exact matches may be performed on phrases as well:: @@ -358,7 +368,7 @@ result in lower-case values being placed after upper-case values, e.g., ``Bar Qux foo``. Note that when sorting by fields that are not present on all items (such as -flexible fields, or those defined by plugins) in *ascending* order, the items +flexible fields, or those defined by plugins) in *ascending* order, the items that lack that particular field will be listed at the *beginning* of the list. You can set the default sorting behavior with the :ref:`sort_item` and diff --git a/test/test_query.py b/test/test_query.py index b710da13b..41b3a2de3 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -521,7 +521,7 @@ class PathQueryTest(_common.LibTestCase, TestHelper, AssertsMixin): self.assert_items_matched(results, ["path item"]) results = self.lib.albums(q) - self.assert_albums_matched(results, []) + self.assert_albums_matched(results, ["path album"]) # FIXME: fails on windows @unittest.skipIf(sys.platform == "win32", "win32") @@ -604,6 +604,9 @@ class PathQueryTest(_common.LibTestCase, TestHelper, AssertsMixin): results = self.lib.items(q) self.assert_items_matched(results, ["path item"]) + results = self.lib.albums(q) + self.assert_albums_matched(results, ["path album"]) + def test_path_album_regex(self): q = "path::b" results = self.lib.albums(q) @@ -1126,6 +1129,41 @@ class NotQueryTest(DummyDataTestCase): pass +class RelatedQueriesTest(_common.TestCase, AssertsMixin): + """Test album-level queries with track-level filters and vice-versa.""" + + def setUp(self): + super().setUp() + self.lib = beets.library.Library(":memory:") + + albums = [] + for album_idx in range(1, 3): + album_name = f"Album{album_idx}" + album_items = [] + for item_idx in range(1, 3): + item = _common.item() + item.album = album_name + item.title = f"{album_name} Item{item_idx}" + self.lib.add(item) + album_items.append(item) + album = self.lib.add_album(album_items) + album.artpath = f"{album_name} Artpath" + album.store() + albums.append(album) + + self.album, self.another_album = albums + + def test_get_albums_filter_by_track_field(self): + q = "title:Album1" + results = self.lib.albums(q) + self.assert_albums_matched(results, ["Album1"]) + + def test_get_items_filter_by_album_field(self): + q = "artpath::Album1" + results = self.lib.items(q) + self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"]) + + def suite(): return unittest.TestLoader().loadTestsFromName(__name__)