Add support for filtering relations

This commit is contained in:
Šarūnas Nejus 2024-06-19 22:33:33 +01:00
parent 8e237d62c8
commit 981a61bd56
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
8 changed files with 180 additions and 24 deletions

View file

@ -40,7 +40,6 @@ from typing import (
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
@ -51,9 +50,8 @@ from typing import (
from unidecode import unidecode
import beets
from beets.util import functemplate
from ..util.functemplate import Template
from ..util import cached_classproperty, functemplate
from . import types
from .query import (
AndQuery,
@ -323,6 +321,32 @@ class Model(ABC):
to the database.
"""
@cached_classproperty
def _relation(cls) -> type[Model]:
"""The model that this model is closely related to."""
return cls
@cached_classproperty
def relation_join(cls) -> str:
"""Return the join required to include the related table in the query.
This is intended to be used as a FROM clause in the SQL query.
"""
return ""
@cached_classproperty
def all_db_fields(cls) -> set[str]:
return cls._fields.keys() | cls._relation._fields.keys()
@cached_classproperty
def shared_db_fields(cls) -> set[str]:
return cls._fields.keys() & cls._relation._fields.keys()
@cached_classproperty
def other_db_fields(cls) -> set[str]:
"""Fields in the related table."""
return cls._relation._fields.keys() - cls.shared_db_fields
@classmethod
def _getters(cls: Type["Model"]):
"""Return a mapping from field names to getter functions."""
@ -344,7 +368,7 @@ class Model(ABC):
initial field values.
"""
self._db = db
self._dirty: Set[str] = set()
self._dirty: set[str] = set()
self._values_fixed = LazyConvertDict(self)
self._values_flex = LazyConvertDict(self)
@ -668,7 +692,7 @@ class Model(ABC):
def evaluate_template(
self,
template: Union[str, Template],
template: Union[str, functemplate.Template],
for_path: bool = False,
) -> str:
"""Evaluate a template (a string or a `Template` object) using
@ -1223,24 +1247,35 @@ class Database:
where, subvals = query.clause()
order_by = sort.order_clause()
sql = ("SELECT * FROM {} WHERE {} {}").format(
model_cls._table,
where or "1",
f"ORDER BY {order_by}" if order_by else "",
)
table = model_cls._table
_from = table
if query.field_names & model_cls.other_db_fields:
_from += f" {model_cls.relation_join}"
# group by id to avoid duplicates when joining with the relation
sql = (
f"SELECT {table}.* "
f"FROM ({_from}) "
f"WHERE {where or 1} "
f"GROUP BY {table}.id"
)
# Fetch flexible attributes for items matching the main query.
# Doing the per-item filtering in python is faster than issuing
# one query per item to sqlite.
flex_sql = """
SELECT * FROM {} WHERE entity_id IN
(SELECT id FROM {} WHERE {});
""".format(
model_cls._flex_table,
model_cls._table,
where or "1",
flex_sql = (
"SELECT * "
f"FROM {model_cls._flex_table} "
f"WHERE entity_id IN (SELECT id FROM ({sql}))"
)
if order_by:
# the sort field may exist in both 'items' and 'albums' tables
# (when they are joined), causing ambiguous column OperationalError
# if we try to order directly.
# Since the join is required only for filtering, we can filter in
# a subquery and order the result, which returns unique fields.
sql = f"SELECT * FROM ({sql}) ORDER BY {order_by}"
with self.transaction() as tx:
rows = tx.query(sql, subvals)
flex_rows = tx.query(flex_sql, subvals)

View file

@ -21,7 +21,7 @@ import unicodedata
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from functools import reduce
from operator import mul
from operator import mul, or_
from typing import (
TYPE_CHECKING,
Any,
@ -33,6 +33,7 @@ from typing import (
Optional,
Pattern,
Sequence,
Set,
Tuple,
Type,
TypeVar,
@ -81,6 +82,11 @@ class InvalidQueryArgumentValueError(ParsingError):
class Query(ABC):
"""An abstract class representing a query into the database."""
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return set()
def clause(self) -> Tuple[Optional[str], Sequence[Any]]:
"""Generate an SQLite expression implementing the query.
@ -128,6 +134,11 @@ class FieldQuery(Query, Generic[P]):
same matching functionality in SQLite.
"""
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return {self.field}
def __init__(self, field: str, pattern: P, fast: bool = True):
self.field = field
self.pattern = pattern
@ -443,6 +454,11 @@ class CollectionQuery(Query):
indexed like a list to access the sub-queries.
"""
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return reduce(or_, (sq.field_names for sq in self.subqueries))
def __init__(self, subqueries: Sequence = ()):
self.subqueries = subqueries
@ -498,6 +514,11 @@ class AnyFieldQuery(CollectionQuery):
constructor.
"""
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return set(self.fields)
def __init__(self, pattern, fields, cls: Type[FieldQuery]):
self.pattern = pattern
self.fields = fields
@ -570,6 +591,11 @@ class NotQuery(Query):
performing `not(subquery)` without using regular expressions.
"""
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return self.subquery.field_names
def __init__(self, subquery):
self.subquery = subquery

View file

@ -152,7 +152,8 @@ def construct_query_part(
# Field queries get constructed according to the name of the field
# they are querying.
else:
out_query = query_class(key.lower(), pattern, key in model_cls._fields)
key = key.lower()
out_query = query_class(key, pattern, key in model_cls.all_db_fields)
# Apply negation.
if negate:

View file

@ -14,6 +14,7 @@
"""The core data store and collection logic for beets.
"""
from __future__ import annotations
import os
import re
@ -32,6 +33,7 @@ from beets.dbcore import Results, types
from beets.util import (
MoveOperation,
bytestring_path,
cached_classproperty,
normpath,
samefile,
syspath,
@ -640,6 +642,22 @@ class Item(LibModel):
# Cached album object. Read-only.
__album = None
@cached_classproperty
def _relation(cls) -> type[Album]:
return Album
@cached_classproperty
def relation_join(cls) -> str:
"""Return the FROM clause which includes related albums.
We need to use a LEFT JOIN here, otherwise items that are not part of
an album (e.g. singletons) would be left out.
"""
return (
f"LEFT JOIN {cls._relation._table} "
f"ON {cls._table}.album_id = {cls._relation._table}.id"
)
@property
def _cached_album(self):
"""The Album object that this item belongs to, if any, or
@ -1240,6 +1258,22 @@ class Album(LibModel):
_format_config_key = "format_album"
@cached_classproperty
def _relation(cls) -> type[Item]:
return Item
@cached_classproperty
def relation_join(cls) -> str:
"""Return FROM clause which joins on related album items.
Use LEFT join to select all albums, including those that do not have
any items.
"""
return (
f"LEFT JOIN {cls._relation._table} "
f"ON {cls._table}.id = {cls._relation._table}.album_id"
)
@classmethod
def _getters(cls):
# In addition to plugin-provided computed fields, also expose

View file

@ -79,6 +79,11 @@ class LimitPlugin(BeetsPlugin):
n = 0
N = None
def __init__(self, *args, **kwargs) -> None:
"""Force the query to be slow so that 'value_match' is called."""
super().__init__(*args, **kwargs)
self.fast = False
@classmethod
def value_match(cls, pattern, value):
if cls.N is None:

View file

@ -6,6 +6,13 @@ Unreleased
Changelog goes here! Please add your entry to the bottom of one of the lists below!
New features:
* Ability to query albums with track db fields and vice-versa, for example
`beet list -a title:something` or `beet list artpath:cover`. Consequently
album queries involving `path` field have been sped up, like `beet list -a
path:/path/`.
Bug fixes:
* Improved naming of temporary files by separating the random part with the file extension.

View file

@ -17,7 +17,9 @@ This command::
$ beet list love
will show all tracks matching the query string ``love``. By default any unadorned word like this matches in a track's title, artist, album name, album artist, genre and comments. See below on how to search other fields.
will show all tracks matching the query string ``love``. By default any
unadorned word like this matches in a track's title, artist, album name, album
artist, genre and comments. See below on how to search other fields.
For example, this is what I might see when I run the command above::
@ -83,6 +85,14 @@ For multi-valued tags (such as ``artists`` or ``albumartists``), a regular
expression search must be used to search for a single value within the
multi-valued tag.
Note that you can filter albums by querying tracks fields and vice versa::
$ beet list -a title:love
and vice versa::
$ beet list art_path::love
Phrases
-------
@ -115,9 +125,9 @@ the field name's colon and before the expression::
$ beet list artist:=AIR
The first query is a simple substring one that returns tracks by Air, AIR, and
Air Supply. The second query returns tracks by Air and AIR, since both are a
Air Supply. The second query returns tracks by Air and AIR, since both are a
case-insensitive match for the entire expression, but does not return anything
by Air Supply. The third query, which requires a case-sensitive exact match,
by Air Supply. The third query, which requires a case-sensitive exact match,
returns tracks by AIR only.
Exact matches may be performed on phrases as well::
@ -358,7 +368,7 @@ result in lower-case values being placed after upper-case values, e.g.,
``Bar Qux foo``.
Note that when sorting by fields that are not present on all items (such as
flexible fields, or those defined by plugins) in *ascending* order, the items
flexible fields, or those defined by plugins) in *ascending* order, the items
that lack that particular field will be listed at the *beginning* of the list.
You can set the default sorting behavior with the :ref:`sort_item` and

View file

@ -521,7 +521,7 @@ class PathQueryTest(_common.LibTestCase, TestHelper, AssertsMixin):
self.assert_items_matched(results, ["path item"])
results = self.lib.albums(q)
self.assert_albums_matched(results, [])
self.assert_albums_matched(results, ["path album"])
# FIXME: fails on windows
@unittest.skipIf(sys.platform == "win32", "win32")
@ -604,6 +604,9 @@ class PathQueryTest(_common.LibTestCase, TestHelper, AssertsMixin):
results = self.lib.items(q)
self.assert_items_matched(results, ["path item"])
results = self.lib.albums(q)
self.assert_albums_matched(results, ["path album"])
def test_path_album_regex(self):
q = "path::b"
results = self.lib.albums(q)
@ -1126,6 +1129,41 @@ class NotQueryTest(DummyDataTestCase):
pass
class RelatedQueriesTest(_common.TestCase, AssertsMixin):
"""Test album-level queries with track-level filters and vice-versa."""
def setUp(self):
super().setUp()
self.lib = beets.library.Library(":memory:")
albums = []
for album_idx in range(1, 3):
album_name = f"Album{album_idx}"
album_items = []
for item_idx in range(1, 3):
item = _common.item()
item.album = album_name
item.title = f"{album_name} Item{item_idx}"
self.lib.add(item)
album_items.append(item)
album = self.lib.add_album(album_items)
album.artpath = f"{album_name} Artpath"
album.store()
albums.append(album)
self.album, self.another_album = albums
def test_get_albums_filter_by_track_field(self):
q = "title:Album1"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def test_get_items_filter_by_album_field(self):
q = "artpath::Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
def suite():
return unittest.TestLoader().loadTestsFromName(__name__)