Filter albums by tracks fields and vice versa (#5327)

Fixes #4360

This PR enables querying albums by track fields and tracks by album
fields, and speeds up querying albums by `path` field.

It originally was part of #5240, however we found that the changes
related to the flexible attributes caused degradation in performance. So
this PR contains the first part of #5240 which joined `items` and
`albums` tables in queries.
This commit is contained in:
Šarūnas Nejus 2024-06-25 02:04:45 +01:00 committed by GitHub
commit 4e06b59b60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 247 additions and 71 deletions

View file

@ -39,7 +39,7 @@ from unidecode import unidecode
from beets import config, logging, plugins
from beets.autotag import mb
from beets.library import Item
from beets.util import as_string
from beets.util import as_string, cached_classproperty
log = logging.getLogger("beets")
@ -413,23 +413,6 @@ def string_dist(str1: Optional[str], str2: Optional[str]) -> float:
return base_dist + penalty
class LazyClassProperty:
"""A decorator implementing a read-only property that is *lazy* in
the sense that the getter is only invoked once. Subsequent accesses
through *any* instance use the cached result.
"""
def __init__(self, getter):
self.getter = getter
self.computed = False
def __get__(self, obj, owner):
if not self.computed:
self.value = self.getter(owner)
self.computed = True
return self.value
@total_ordering
class Distance:
"""Keeps track of multiple distance penalties. Provides a single
@ -441,7 +424,7 @@ class Distance:
self._penalties = {}
self.tracks: Dict[TrackInfo, Distance] = {}
@LazyClassProperty
@cached_classproperty
def _weights(cls) -> Dict[str, float]: # noqa: N805
"""A dictionary from keys to floating-point weights."""
weights_view = config["match"]["distance_weights"]

View file

@ -40,7 +40,6 @@ from typing import (
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
@ -51,9 +50,8 @@ from typing import (
from unidecode import unidecode
import beets
from beets.util import functemplate
from ..util.functemplate import Template
from ..util import cached_classproperty, functemplate
from . import types
from .query import (
AndQuery,
@ -323,6 +321,32 @@ class Model(ABC):
to the database.
"""
@cached_classproperty
def _relation(cls) -> type[Model]:
"""The model that this model is closely related to."""
return cls
@cached_classproperty
def relation_join(cls) -> str:
"""Return the join required to include the related table in the query.
This is intended to be used as a FROM clause in the SQL query.
"""
return ""
@cached_classproperty
def all_db_fields(cls) -> set[str]:
return cls._fields.keys() | cls._relation._fields.keys()
@cached_classproperty
def shared_db_fields(cls) -> set[str]:
return cls._fields.keys() & cls._relation._fields.keys()
@cached_classproperty
def other_db_fields(cls) -> set[str]:
"""Fields in the related table."""
return cls._relation._fields.keys() - cls.shared_db_fields
@classmethod
def _getters(cls: Type["Model"]):
"""Return a mapping from field names to getter functions."""
@ -344,7 +368,7 @@ class Model(ABC):
initial field values.
"""
self._db = db
self._dirty: Set[str] = set()
self._dirty: set[str] = set()
self._values_fixed = LazyConvertDict(self)
self._values_flex = LazyConvertDict(self)
@ -668,7 +692,7 @@ class Model(ABC):
def evaluate_template(
self,
template: Union[str, Template],
template: Union[str, functemplate.Template],
for_path: bool = False,
) -> str:
"""Evaluate a template (a string or a `Template` object) using
@ -1223,24 +1247,35 @@ class Database:
where, subvals = query.clause()
order_by = sort.order_clause()
sql = ("SELECT * FROM {} WHERE {} {}").format(
model_cls._table,
where or "1",
f"ORDER BY {order_by}" if order_by else "",
)
table = model_cls._table
_from = table
if query.field_names & model_cls.other_db_fields:
_from += f" {model_cls.relation_join}"
# group by id to avoid duplicates when joining with the relation
sql = (
f"SELECT {table}.* "
f"FROM ({_from}) "
f"WHERE {where or 1} "
f"GROUP BY {table}.id"
)
# Fetch flexible attributes for items matching the main query.
# Doing the per-item filtering in python is faster than issuing
# one query per item to sqlite.
flex_sql = """
SELECT * FROM {} WHERE entity_id IN
(SELECT id FROM {} WHERE {});
""".format(
model_cls._flex_table,
model_cls._table,
where or "1",
flex_sql = (
"SELECT * "
f"FROM {model_cls._flex_table} "
f"WHERE entity_id IN (SELECT id FROM ({sql}))"
)
if order_by:
# the sort field may exist in both 'items' and 'albums' tables
# (when they are joined), causing ambiguous column OperationalError
# if we try to order directly.
# Since the join is required only for filtering, we can filter in
# a subquery and order the result, which returns unique fields.
sql = f"SELECT * FROM ({sql}) ORDER BY {order_by}"
with self.transaction() as tx:
rows = tx.query(sql, subvals)
flex_rows = tx.query(flex_sql, subvals)

View file

@ -21,7 +21,7 @@ import unicodedata
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from functools import reduce
from operator import mul
from operator import mul, or_
from typing import (
TYPE_CHECKING,
Any,
@ -33,6 +33,7 @@ from typing import (
Optional,
Pattern,
Sequence,
Set,
Tuple,
Type,
TypeVar,
@ -81,6 +82,11 @@ class InvalidQueryArgumentValueError(ParsingError):
class Query(ABC):
"""An abstract class representing a query into the database."""
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return set()
def clause(self) -> Tuple[Optional[str], Sequence[Any]]:
"""Generate an SQLite expression implementing the query.
@ -128,13 +134,24 @@ class FieldQuery(Query, Generic[P]):
same matching functionality in SQLite.
"""
def __init__(self, field: str, pattern: P, fast: bool = True):
self.field = field
@property
def field(self) -> str:
return (
f"{self.table}.{self.field_name}" if self.table else self.field_name
)
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return {self.field_name}
def __init__(self, field_name: str, pattern: P, fast: bool = True):
self.table, _, self.field_name = field_name.rpartition(".")
self.pattern = pattern
self.fast = fast
def col_clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
return None, ()
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.field, ()
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
if self.fast:
@ -149,23 +166,23 @@ class FieldQuery(Query, Generic[P]):
raise NotImplementedError()
def match(self, obj: Model) -> bool:
return self.value_match(self.pattern, obj.get(self.field))
return self.value_match(self.pattern, obj.get(self.field_name))
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.field!r}, {self.pattern!r}, "
f"{self.__class__.__name__}({self.field_name!r}, {self.pattern!r}, "
f"fast={self.fast})"
)
def __eq__(self, other) -> bool:
return (
super().__eq__(other)
and self.field == other.field
and self.field_name == other.field_name
and self.pattern == other.pattern
)
def __hash__(self) -> int:
return hash((self.field, hash(self.pattern)))
return hash((self.field_name, hash(self.pattern)))
class MatchQuery(FieldQuery[AnySQLiteType]):
@ -189,10 +206,10 @@ class NoneQuery(FieldQuery[None]):
return self.field + " IS NULL", ()
def match(self, obj: Model) -> bool:
return obj.get(self.field) is None
return obj.get(self.field_name) is None
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.field!r}, {self.fast})"
return f"{self.__class__.__name__}({self.field_name!r}, {self.fast})"
class StringFieldQuery(FieldQuery[P]):
@ -263,7 +280,7 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
expression.
"""
def __init__(self, field: str, pattern: str, fast: bool = True):
def __init__(self, field_name: str, pattern: str, fast: bool = True):
pattern = self._normalize(pattern)
try:
pattern_re = re.compile(pattern)
@ -273,7 +290,7 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
pattern, "a regular expression", format(exc)
)
super().__init__(field, pattern_re, fast)
super().__init__(field_name, pattern_re, fast)
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return f" regexp({self.field}, ?)", [self.pattern.pattern]
@ -297,7 +314,7 @@ class BooleanQuery(MatchQuery[int]):
def __init__(
self,
field: str,
field_name: str,
pattern: bool,
fast: bool = True,
):
@ -306,7 +323,7 @@ class BooleanQuery(MatchQuery[int]):
pattern_int = int(pattern)
super().__init__(field, pattern_int, fast)
super().__init__(field_name, pattern_int, fast)
class BytesQuery(FieldQuery[bytes]):
@ -316,7 +333,7 @@ class BytesQuery(FieldQuery[bytes]):
`MatchQuery` when matching on BLOB values.
"""
def __init__(self, field: str, pattern: Union[bytes, str, memoryview]):
def __init__(self, field_name: str, pattern: Union[bytes, str, memoryview]):
# Use a buffer/memoryview representation of the pattern for SQLite
# matching. This instructs SQLite to treat the blob as binary
# rather than encoded Unicode.
@ -332,7 +349,7 @@ class BytesQuery(FieldQuery[bytes]):
else:
raise ValueError("pattern must be bytes, str, or memoryview")
super().__init__(field, bytes_pattern)
super().__init__(field_name, bytes_pattern)
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.field + " = ?", [self.buf_pattern]
@ -368,8 +385,8 @@ class NumericQuery(FieldQuery[str]):
except ValueError:
raise InvalidQueryArgumentValueError(s, "an int or a float")
def __init__(self, field: str, pattern: str, fast: bool = True):
super().__init__(field, pattern, fast)
def __init__(self, field_name: str, pattern: str, fast: bool = True):
super().__init__(field_name, pattern, fast)
parts = pattern.split("..", 1)
if len(parts) == 1:
@ -384,9 +401,9 @@ class NumericQuery(FieldQuery[str]):
self.rangemax = self._convert(parts[1])
def match(self, obj: Model) -> bool:
if self.field not in obj:
if self.field_name not in obj:
return False
value = obj[self.field]
value = obj[self.field_name]
if isinstance(value, str):
value = self._convert(value)
@ -419,7 +436,7 @@ class NumericQuery(FieldQuery[str]):
class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]):
"""Query which matches values in the given set."""
field: str
field_name: str
pattern: Sequence[AnySQLiteType]
fast: bool = True
@ -429,7 +446,7 @@ class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]):
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
placeholders = ", ".join(["?"] * len(self.subvals))
return f"{self.field} IN ({placeholders})", self.subvals
return f"{self.field_name} IN ({placeholders})", self.subvals
@classmethod
def value_match(
@ -443,6 +460,11 @@ class CollectionQuery(Query):
indexed like a list to access the sub-queries.
"""
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return reduce(or_, (sq.field_names for sq in self.subqueries))
def __init__(self, subqueries: Sequence = ()):
self.subqueries = subqueries
@ -498,6 +520,11 @@ class AnyFieldQuery(CollectionQuery):
constructor.
"""
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return set(self.fields)
def __init__(self, pattern, fields, cls: Type[FieldQuery]):
self.pattern = pattern
self.fields = fields
@ -570,6 +597,11 @@ class NotQuery(Query):
performing `not(subquery)` without using regular expressions.
"""
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return self.subquery.field_names
def __init__(self, subquery):
self.subquery = subquery
@ -797,15 +829,15 @@ class DateQuery(FieldQuery[str]):
using an ellipsis interval syntax similar to that of NumericQuery.
"""
def __init__(self, field: str, pattern: str, fast: bool = True):
super().__init__(field, pattern, fast)
def __init__(self, field_name: str, pattern: str, fast: bool = True):
super().__init__(field_name, pattern, fast)
start, end = _parse_periods(pattern)
self.interval = DateInterval.from_periods(start, end)
def match(self, obj: Model) -> bool:
if self.field not in obj:
if self.field_name not in obj:
return False
timestamp = float(obj[self.field])
timestamp = float(obj[self.field_name])
date = datetime.fromtimestamp(timestamp)
return self.interval.contains(date)

View file

@ -152,7 +152,14 @@ def construct_query_part(
# Field queries get constructed according to the name of the field
# they are querying.
else:
out_query = query_class(key.lower(), pattern, key in model_cls._fields)
key = key.lower()
if key in model_cls.shared_db_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError if we try to query it in a join.
# Using an explicit table name resolves this.
key = f"{model_cls._table}.{key}"
out_query = query_class(key, pattern, key in model_cls.all_db_fields)
# Apply negation.
if negate:

View file

@ -14,6 +14,7 @@
"""The core data store and collection logic for beets.
"""
from __future__ import annotations
import os
import re
@ -32,6 +33,7 @@ from beets.dbcore import Results, types
from beets.util import (
MoveOperation,
bytestring_path,
cached_classproperty,
normpath,
samefile,
syspath,
@ -640,6 +642,22 @@ class Item(LibModel):
# Cached album object. Read-only.
__album = None
@cached_classproperty
def _relation(cls) -> type[Album]:
return Album
@cached_classproperty
def relation_join(cls) -> str:
"""Return the FROM clause which includes related albums.
We need to use a LEFT JOIN here, otherwise items that are not part of
an album (e.g. singletons) would be left out.
"""
return (
f"LEFT JOIN {cls._relation._table} "
f"ON {cls._table}.album_id = {cls._relation._table}.id"
)
@property
def _cached_album(self):
"""The Album object that this item belongs to, if any, or
@ -1240,6 +1258,22 @@ class Album(LibModel):
_format_config_key = "format_album"
@cached_classproperty
def _relation(cls) -> type[Item]:
return Item
@cached_classproperty
def relation_join(cls) -> str:
"""Return FROM clause which joins on related album items.
Use LEFT join to select all albums, including those that do not have
any items.
"""
return (
f"LEFT JOIN {cls._relation._table} "
f"ON {cls._table}.id = {cls._relation._table}.album_id"
)
@classmethod
def _getters(cls):
# In addition to plugin-provided computed fields, also expose

View file

@ -1059,3 +1059,20 @@ def par_map(transform: Callable, items: Iterable):
pool.map(transform, items)
pool.close()
pool.join()
class cached_classproperty: # noqa: N801
"""A decorator implementing a read-only property that is *lazy* in
the sense that the getter is only invoked once. Subsequent accesses
through *any* instance use the cached result.
"""
def __init__(self, getter):
self.getter = getter
self.cache = {}
def __get__(self, instance, owner):
if owner not in self.cache:
self.cache[owner] = self.getter(owner)
return self.cache[owner]

View file

@ -79,6 +79,11 @@ class LimitPlugin(BeetsPlugin):
n = 0
N = None
def __init__(self, *args, **kwargs) -> None:
"""Force the query to be slow so that 'value_match' is called."""
super().__init__(*args, **kwargs)
self.fast = False
@classmethod
def value_match(cls, pattern, value):
if cls.N is None:

View file

@ -6,6 +6,13 @@ Unreleased
Changelog goes here! Please add your entry to the bottom of one of the lists below!
New features:
* Ability to query albums with track db fields and vice-versa, for example
`beet list -a title:something` or `beet list artpath:cover`. Consequently
album queries involving `path` field have been sped up, like `beet list -a
path:/path/`.
Bug fixes:
* Improved naming of temporary files by separating the random part with the file extension.

View file

@ -17,7 +17,9 @@ This command::
$ beet list love
will show all tracks matching the query string ``love``. By default any unadorned word like this matches in a track's title, artist, album name, album artist, genre and comments. See below on how to search other fields.
will show all tracks matching the query string ``love``. By default any
unadorned word like this matches in a track's title, artist, album name, album
artist, genre and comments. See below on how to search other fields.
For example, this is what I might see when I run the command above::
@ -83,6 +85,14 @@ For multi-valued tags (such as ``artists`` or ``albumartists``), a regular
expression search must be used to search for a single value within the
multi-valued tag.
Note that you can filter albums by querying tracks fields and vice versa::
$ beet list -a title:love
and vice versa::
$ beet list art_path::love
Phrases
-------
@ -115,9 +125,9 @@ the field name's colon and before the expression::
$ beet list artist:=AIR
The first query is a simple substring one that returns tracks by Air, AIR, and
Air Supply. The second query returns tracks by Air and AIR, since both are a
Air Supply. The second query returns tracks by Air and AIR, since both are a
case-insensitive match for the entire expression, but does not return anything
by Air Supply. The third query, which requires a case-sensitive exact match,
by Air Supply. The third query, which requires a case-sensitive exact match,
returns tracks by AIR only.
Exact matches may be performed on phrases as well::
@ -358,7 +368,7 @@ result in lower-case values being placed after upper-case values, e.g.,
``Bar Qux foo``.
Note that when sorting by fields that are not present on all items (such as
flexible fields, or those defined by plugins) in *ascending* order, the items
flexible fields, or those defined by plugins) in *ascending* order, the items
that lack that particular field will be listed at the *beginning* of the list.
You can set the default sorting behavior with the :ref:`sort_item` and

View file

@ -31,7 +31,9 @@ show_contexts = true
min-version = 3.8
accept-encodings = utf-8
max-line-length = 88
docstring-convention = google
classmethod-decorators =
classmethod
cached_classproperty
# errors we ignore; see https://www.flake8rules.com/ for more info
ignore =
# pycodestyle errors

View file

@ -143,7 +143,7 @@ def _clear_weights():
"""Hack around the lazy descriptor used to cache weights for
Distance calculations.
"""
Distance.__dict__["_weights"].computed = False
Distance.__dict__["_weights"].cache = {}
class DistanceTest(_common.TestCase):

View file

@ -521,7 +521,7 @@ class PathQueryTest(_common.LibTestCase, TestHelper, AssertsMixin):
self.assert_items_matched(results, ["path item"])
results = self.lib.albums(q)
self.assert_albums_matched(results, [])
self.assert_albums_matched(results, ["path album"])
# FIXME: fails on windows
@unittest.skipIf(sys.platform == "win32", "win32")
@ -604,6 +604,9 @@ class PathQueryTest(_common.LibTestCase, TestHelper, AssertsMixin):
results = self.lib.items(q)
self.assert_items_matched(results, ["path item"])
results = self.lib.albums(q)
self.assert_albums_matched(results, ["path album"])
def test_path_album_regex(self):
q = "path::b"
results = self.lib.albums(q)
@ -1126,6 +1129,47 @@ class NotQueryTest(DummyDataTestCase):
pass
class RelatedQueriesTest(_common.TestCase, AssertsMixin):
"""Test album-level queries with track-level filters and vice-versa."""
def setUp(self):
super().setUp()
self.lib = beets.library.Library(":memory:")
albums = []
for album_idx in range(1, 3):
album_name = f"Album{album_idx}"
album_items = []
for item_idx in range(1, 3):
item = _common.item()
item.album = album_name
item.title = f"{album_name} Item{item_idx}"
self.lib.add(item)
album_items.append(item)
album = self.lib.add_album(album_items)
album.artpath = f"{album_name} Artpath"
album.catalognum = "ABC"
album.store()
albums.append(album)
self.album, self.another_album = albums
def test_get_albums_filter_by_track_field(self):
q = "title:Album1"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def test_get_items_filter_by_album_field(self):
q = "artpath::Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
def test_filter_by_common_field(self):
q = "catalognum:ABC Album1"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def suite():
return unittest.TestLoader().loadTestsFromName(__name__)