Revert "Make queries fast, filter all flexible attributes (#5240)"

This reverts commit 143b9202f3, reversing
changes made to 8508a57d77.
This commit is contained in:
Šarūnas Nejus 2024-06-19 21:51:44 +01:00
parent 143b9202f3
commit 2800a323a2
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
19 changed files with 355 additions and 696 deletions

View file

@ -39,7 +39,7 @@ from unidecode import unidecode
from beets import config, logging, plugins
from beets.autotag import mb
from beets.library import Item
from beets.util import as_string, cached_classproperty
from beets.util import as_string
log = logging.getLogger("beets")
@ -413,6 +413,23 @@ def string_dist(str1: Optional[str], str2: Optional[str]) -> float:
return base_dist + penalty
class LazyClassProperty:
"""A decorator implementing a read-only property that is *lazy* in
the sense that the getter is only invoked once. Subsequent accesses
through *any* instance use the cached result.
"""
def __init__(self, getter):
self.getter = getter
self.computed = False
def __get__(self, obj, owner):
if not self.computed:
self.value = self.getter(owner)
self.computed = True
return self.value
@total_ordering
class Distance:
"""Keeps track of multiple distance penalties. Provides a single
@ -424,7 +441,7 @@ class Distance:
self._penalties = {}
self.tracks: Dict[TrackInfo, Distance] = {}
@cached_classproperty
@LazyClassProperty
def _weights(cls) -> Dict[str, float]: # noqa: N805
"""A dictionary from keys to floating-point weights."""
weights_view = config["match"]["distance_weights"]

View file

@ -17,16 +17,14 @@
from __future__ import annotations
import contextlib
import json
import os
import re
import sqlite3
import sys
import threading
import time
from abc import ABC
from collections import defaultdict
from sqlite3 import Connection, sqlite_version
from sqlite3 import Connection
from types import TracebackType
from typing import (
Any,
@ -50,31 +48,22 @@ from typing import (
cast,
)
from packaging.version import Version
from rich import print
from rich_tables.generic import flexitable
from unidecode import unidecode
import beets
from beets.util import functemplate
from ..util import cached_classproperty, functemplate
from ..util.functemplate import Template
from . import types
from .query import FieldQuery, MatchQuery, NullSort, Query, Sort, TrueQuery
# convert data under 'json_str' type name to Python dictionary automatically
sqlite3.register_converter("json_str", json.loads)
DEBUG = bool(os.getenv("BEETS_DEBUG", False))
def print_query(sql, subvals=None):
"""If debugging, replace placeholders and print the query."""
if not DEBUG:
return
topr = sql
for val in subvals or []:
topr = topr.replace("?", str(val), 1)
print(flexitable({"sql": topr}), file=sys.stderr)
from .query import (
AndQuery,
FieldQuery,
MatchQuery,
NullSort,
Query,
Sort,
TrueQuery,
)
class DBAccessError(Exception):
@ -334,64 +323,6 @@ class Model(ABC):
to the database.
"""
@cached_classproperty
def _relation(cls) -> Type[Model]:
"""The model that this model is closely related to."""
return cls
@cached_classproperty
def relation_join(cls) -> str:
"""Return the join required to include the related table in the query.
This is intended to be used as a FROM clause in the SQL query.
"""
return ""
@cached_classproperty
def table_with_flex_attrs(cls) -> str:
"""Return a SQL for entity table which includes aggregated flexible attributes.
The clause selects entity rows, flexible attributes rows and LEFT JOINs
them on entity id and 'entity_id' field respectively.
'json_group_object' aggregate function groups flexible attributes into a
single JSON object 'flex_attrs [json_str]'. The column name ending with
' [json_str]' means that this column is converted to a Python dictionary
automatically (see 'register_converter' call at the top of this module).
'REPLACE' function handles absence of flexible attributes and replaces
some weird null JSON object (that SQLite gives us by default) with an
empty JSON object.
Availability of the 'flex_attrs' means we can query flexible attributes
in the same manner we query other entity fields, see
`FieldQuery.field`. This way, we also remove the need for an
additional query to fetch them.
Note: we use LEFT join to include entities without flexible attributes.
Note: we name this SELECT clause after the original entity table name
so that we can query it in the way like the original table.
"""
flex_attrs = "REPLACE(json_group_object(key, value), '{:null}', '{}')"
return f"""(
SELECT
*,
{flex_attrs} AS "flex_attrs [json_str]"
FROM {cls._table} LEFT JOIN (
SELECT
entity_id,
key,
CAST(value AS text) AS value
FROM {cls._flex_table}
) ON entity_id == {cls._table}.id
GROUP BY {cls._table}.id
) {cls._table}
"""
@cached_classproperty
def all_model_db_fields(cls) -> Set[str]:
return set()
@classmethod
def _getters(cls: Type["Model"]):
"""Return a mapping from field names to getter functions."""
@ -737,7 +668,7 @@ class Model(ABC):
def evaluate_template(
self,
template: Union[str, functemplate.Template],
template: Union[str, Template],
for_path: bool = False,
) -> str:
"""Evaluate a template (a string or a `Template` object) using
@ -768,6 +699,33 @@ class Model(ABC):
"""Set the object's key to a value represented by a string."""
self[key] = self._parse(key, string)
# Convenient queries.
@classmethod
def field_query(
cls,
field,
pattern,
query_cls: Type[FieldQuery] = MatchQuery,
) -> FieldQuery:
"""Get a `FieldQuery` for this model."""
return query_cls(field, pattern, field in cls._fields)
@classmethod
def all_fields_query(
cls: Type["Model"],
pats: Mapping,
query_cls: Type[FieldQuery] = MatchQuery,
):
"""Get a query that matches many fields with different patterns.
`pats` should be a mapping from field names to patterns. The
resulting query is a conjunction ("and") of per-field queries
for all of these field/pattern pairs.
"""
subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()]
return AndQuery(subqueries)
# Database controller and supporting interfaces.
@ -785,6 +743,8 @@ class Results(Generic[AnyModel]):
model_class: Type[AnyModel],
rows: List[Mapping],
db: "Database",
flex_rows,
query: Optional[Query] = None,
sort=None,
):
"""Create a result set that will construct objects of type
@ -794,7 +754,9 @@ class Results(Generic[AnyModel]):
constructed. `rows` is a query result: a list of mappings. The
new objects will be associated with the database `db`.
If `sort` is provided, it is used to sort the
If `query` is provided, it is used as a predicate to filter the
results for a "slow query" that cannot be evaluated by the
database directly. If `sort` is provided, it is used to sort the
full list of results before returning. This means it is a "slow
sort" and all objects must be built before returning the first
one.
@ -802,7 +764,9 @@ class Results(Generic[AnyModel]):
self.model_class = model_class
self.rows = rows
self.db = db
self.query = query
self.sort = sort
self.flex_rows = flex_rows
# We keep a queue of rows we haven't yet consumed for
# materialization. We preserve the original total number of
@ -824,6 +788,10 @@ class Results(Generic[AnyModel]):
a `Results` object a second time should be much faster than the
first.
"""
# Index flexible attributes by the item ID, so we have easier access
flex_attrs = self._get_indexed_flex_attrs()
index = 0 # Position in the materialized objects.
while index < len(self._objects) or self._rows:
# Are there previously-materialized objects to produce?
@ -836,11 +804,14 @@ class Results(Generic[AnyModel]):
else:
while self._rows:
row = self._rows.pop(0)
obj = self._make_model(row)
self._objects.append(obj)
index += 1
yield obj
break
obj = self._make_model(row, flex_attrs.get(row["id"], {}))
# If there is a slow-query predicate, ensurer that the
# object passes it.
if not self.query or self.query.match(obj):
self._objects.append(obj)
index += 1
yield obj
break
def __iter__(self) -> Iterator[AnyModel]:
"""Construct and generate Model objects for all matching
@ -855,10 +826,21 @@ class Results(Generic[AnyModel]):
# Objects are pre-sorted (i.e., by the database).
return self._get_objects()
def _make_model(self, row) -> AnyModel:
def _get_indexed_flex_attrs(self) -> Mapping:
"""Index flexible attributes by the entity id they belong to"""
flex_values: Dict[int, Dict[str, Any]] = {}
for row in self.flex_rows:
if row["entity_id"] not in flex_values:
flex_values[row["entity_id"]] = {}
flex_values[row["entity_id"]][row["key"]] = row["value"]
return flex_values
def _make_model(self, row, flex_values: Dict = {}) -> AnyModel:
"""Create a Model object for the given row"""
values = dict(row)
flex_values = values.pop("flex_attrs") or {}
cols = dict(row)
values = {k: v for (k, v) in cols.items() if not k[:4] == "flex"}
# Construct the Python object
obj = self.model_class._awaken(self.db, values, flex_values)
@ -869,8 +851,16 @@ class Results(Generic[AnyModel]):
if not self._rows:
# Fully materialized. Just count the objects.
return len(self._objects)
elif self.query:
# A slow query. Fall back to testing every object.
count = 0
for obj in self:
count += 1
return count
else:
# Just count the rows.
# A fast query. Just count the rows.
return self._row_count
def __nonzero__(self) -> bool:
@ -960,7 +950,6 @@ class Transaction:
"""Execute an SQL statement with substitution values and return
a list of rows from the database.
"""
print_query(statement, subvals)
cursor = self.db._connection().execute(statement, subvals)
return cursor.fetchall()
@ -969,7 +958,6 @@ class Transaction:
the row ID of the last affected row.
"""
try:
print_query(statement, subvals)
cursor = self.db._connection().execute(statement, subvals)
except sqlite3.OperationalError as e:
# In two specific cases, SQLite reports an error while accessing
@ -990,7 +978,6 @@ class Transaction:
"""Execute a string containing multiple SQL statements."""
# We don't know whether this mutates, but quite likely it does.
self._mutated = True
print_query(statements)
self.db._connection().executescript(statements)
@ -1079,8 +1066,6 @@ class Database:
# We have our own same-thread checks in _connection(), but need to
# call conn.close() in _close()
check_same_thread=False,
# enable type name "col [type]" conversion (`register_converter`)
detect_types=sqlite3.PARSE_COLNAMES,
)
self.add_functions(conn)
@ -1099,9 +1084,7 @@ class Database:
def regexp(value, pattern):
if isinstance(value, bytes):
value = value.decode()
return (
value is not None and re.search(pattern, str(value)) is not None
)
return re.search(pattern, str(value)) is not None
def bytelower(bytestring: Optional[AnyStr]) -> Optional[AnyStr]:
"""A custom ``bytelower`` sqlite function so we can compare
@ -1116,71 +1099,9 @@ class Database:
return bytestring
def json_patch(first: str, second: str) -> str:
"""Implementation of the 'json_patch' SQL function.
This function merges two JSON strings together.
"""
first_dict = json.loads(first)
second_dict = json.loads(second)
first_dict.update(second_dict)
return json.dumps(first_dict)
def json_extract(json_str: str, key: str) -> Optional[str]:
"""Simple implementation of the 'json_extract' SQLite function.
The original implementation in SQLite allows traversing objects of
any depth. Here, we only ever deal with a flat dictionary, thus
we can simplify the implementation to a single 'get' call.
"""
if json_str:
return json.loads(json_str).get(key.replace("$.", ""))
return None
class JSONGroupObject:
"""Implementation of the 'json_group_object' SQLite aggregate.
An aggregate function which accepts two values (key, val) and
groups all {key: val} pairs into a single object.
It is found in the json1 extension which is included in SQLite
by default since version 3.38.0 (2022-02-22). To ensure support
for older SQLite versions, we add our implementation.
Notably, it does not exist on Windows in Python 3.8.
Consider the following table
id key val
1 plays "10"
1 skips "20"
2 city "London"
SELECT id, group_to_json(key, val) GROUP BY id
1, '{"plays": "10", "skips": "20"}'
2, '{"city": "London"}'
"""
def __init__(self):
self.flex = {}
def step(self, field, value):
if field:
self.flex[field] = value
def finalize(self):
return json.dumps(self.flex)
conn.create_function("regexp", 2, regexp)
conn.create_function("unidecode", 1, unidecode)
conn.create_function("bytelower", 1, bytelower)
if Version(sqlite_version) < Version("3.38.0"):
# create 'json_group_object' for older SQLite versions that do
# not include the json1 extension by default
conn.create_aggregate("json_group_object", 2, JSONGroupObject)
conn.create_function("json_patch", 2, json_patch)
conn.create_function("json_extract", 2, json_extract)
def _close(self):
"""Close the all connections to the underlying SQLite database
@ -1302,42 +1223,34 @@ class Database:
where, subvals = query.clause()
order_by = sort.order_clause()
this_table = model_cls._table
select_fields = [f"{this_table}.*"]
_from = model_cls.table_with_flex_attrs
sql = ("SELECT * FROM {} WHERE {} {}").format(
model_cls._table,
where or "1",
f"ORDER BY {order_by}" if order_by else "",
)
required_fields = query.field_names
if required_fields - model_cls._fields.keys():
_from += f" {model_cls.relation_join}"
if required_fields - model_cls.all_model_db_fields:
# merge all flexible attribute into a single JSON field
select_fields.append(
f"""
json_patch(
COALESCE({this_table}."flex_attrs [json_str]", '{{}}'),
COALESCE({model_cls._relation._table}."flex_attrs [json_str]", '{{}}')
) AS all_flex_attrs
""" # noqa: E501
)
sql = f"SELECT {', '.join(select_fields)} FROM {_from} WHERE {where or 1} GROUP BY {this_table}.id" # noqa: E501
if order_by:
# the sort field may exist in both 'items' and 'albums' tables
# (when they are joined), causing ambiguous column OperationalError
# if we try to order directly.
# Since the join is required only for filtering, we can filter in
# a subquery and order the result, which returns unique fields.
sql = f"SELECT * FROM ({sql}) ORDER BY {order_by}"
# Fetch flexible attributes for items matching the main query.
# Doing the per-item filtering in python is faster than issuing
# one query per item to sqlite.
flex_sql = """
SELECT * FROM {} WHERE entity_id IN
(SELECT id FROM {} WHERE {});
""".format(
model_cls._flex_table,
model_cls._table,
where or "1",
)
with self.transaction() as tx:
rows = tx.query(sql, subvals)
flex_rows = tx.query(flex_sql, subvals)
return Results(
model_cls,
rows,
self,
flex_rows,
None if where else query, # Slow query component.
sort if sort.is_slow() else None, # Slow sort component.
)

View file

@ -21,7 +21,7 @@ import unicodedata
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from functools import reduce
from operator import mul, or_
from operator import mul
from typing import (
TYPE_CHECKING,
Any,
@ -33,7 +33,6 @@ from typing import (
Optional,
Pattern,
Sequence,
Set,
Tuple,
Type,
TypeVar,
@ -82,19 +81,17 @@ class InvalidQueryArgumentValueError(ParsingError):
class Query(ABC):
"""An abstract class representing a query into the database."""
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return set()
def clause(self) -> Tuple[Optional[str], Sequence[Any]]:
"""Generate an SQLite expression implementing the query.
Return (clause, subvals) where clause is a valid sqlite
WHERE clause implementing the query and subvals is a list of
items to be substituted for ?s in the clause.
The default implementation returns None, falling back to a slow query
using `match()`.
"""
raise NotImplementedError
return None, ()
@abstractmethod
def match(self, obj: Model):
@ -131,30 +128,20 @@ class FieldQuery(Query, Generic[P]):
same matching functionality in SQLite.
"""
def __init__(self, field_name: str, pattern: P, fast: bool = True):
self.table, _, self.field_name = field_name.rpartition(".")
def __init__(self, field: str, pattern: P, fast: bool = True):
self.field = field
self.pattern = pattern
self.fast = fast
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return {self.field_name}
def col_clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
return None, ()
@property
def field(self) -> str:
if not self.fast:
return f'json_extract(all_flex_attrs, "$.{self.field_name}")'
return (
f"{self.table}.{self.field_name}" if self.table else self.field_name
)
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
raise NotImplementedError
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.col_clause()
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
if self.fast:
return self.col_clause()
else:
# Matching a flexattr. This is a slow query.
return None, ()
@classmethod
def value_match(cls, pattern: P, value: Any):
@ -162,23 +149,23 @@ class FieldQuery(Query, Generic[P]):
raise NotImplementedError()
def match(self, obj: Model) -> bool:
return self.value_match(self.pattern, obj.get(self.field_name))
return self.value_match(self.pattern, obj.get(self.field))
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.field_name!r}, {self.pattern!r}, "
f"{self.__class__.__name__}({self.field!r}, {self.pattern!r}, "
f"fast={self.fast})"
)
def __eq__(self, other) -> bool:
return (
super().__eq__(other)
and self.field_name == other.field_name
and self.field == other.field
and self.pattern == other.pattern
)
def __hash__(self) -> int:
return hash((self.field_name, hash(self.pattern)))
return hash((self.field, hash(self.pattern)))
class MatchQuery(FieldQuery[AnySQLiteType]):
@ -202,10 +189,10 @@ class NoneQuery(FieldQuery[None]):
return self.field + " IS NULL", ()
def match(self, obj: Model) -> bool:
return obj.get(self.field_name) is None
return obj.get(self.field) is None
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.field_name!r}, {self.fast})"
return f"{self.__class__.__name__}({self.field!r}, {self.fast})"
class StringFieldQuery(FieldQuery[P]):
@ -276,7 +263,7 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
expression.
"""
def __init__(self, field_name: str, pattern: str, fast: bool = True):
def __init__(self, field: str, pattern: str, fast: bool = True):
pattern = self._normalize(pattern)
try:
pattern_re = re.compile(pattern)
@ -286,7 +273,7 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
pattern, "a regular expression", format(exc)
)
super().__init__(field_name, pattern_re, fast)
super().__init__(field, pattern_re, fast)
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return f" regexp({self.field}, ?)", [self.pattern.pattern]
@ -303,24 +290,14 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
return pattern.search(cls._normalize(value)) is not None
class NumericColumnQuery(MatchQuery[AnySQLiteType]):
"""A base class for queries that work with NUMERIC SQLite affinity."""
@property
def field(self) -> str:
"""Cast a flexible attribute column (string) to NUMERIC affinity."""
field = super().field
return field if self.fast else f"CAST({field} AS NUMERIC)"
class BooleanQuery(NumericColumnQuery[bool]):
class BooleanQuery(MatchQuery[int]):
"""Matches a boolean field. Pattern should either be a boolean or a
string reflecting a boolean.
"""
def __init__(
self,
field_name: str,
field: str,
pattern: bool,
fast: bool = True,
):
@ -329,7 +306,7 @@ class BooleanQuery(NumericColumnQuery[bool]):
pattern_int = int(pattern)
super().__init__(field_name, pattern_int, fast)
super().__init__(field, pattern_int, fast)
class BytesQuery(FieldQuery[bytes]):
@ -339,7 +316,7 @@ class BytesQuery(FieldQuery[bytes]):
`MatchQuery` when matching on BLOB values.
"""
def __init__(self, field_name: str, pattern: Union[bytes, str, memoryview]):
def __init__(self, field: str, pattern: Union[bytes, str, memoryview]):
# Use a buffer/memoryview representation of the pattern for SQLite
# matching. This instructs SQLite to treat the blob as binary
# rather than encoded Unicode.
@ -355,7 +332,7 @@ class BytesQuery(FieldQuery[bytes]):
else:
raise ValueError("pattern must be bytes, str, or memoryview")
super().__init__(field_name, bytes_pattern)
super().__init__(field, bytes_pattern)
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.field + " = ?", [self.buf_pattern]
@ -365,7 +342,7 @@ class BytesQuery(FieldQuery[bytes]):
return pattern == value
class NumericQuery(NumericColumnQuery[Union[int, float]]):
class NumericQuery(FieldQuery[str]):
"""Matches numeric fields. A syntax using Ruby-style range ellipses
(``..``) lets users specify one- or two-sided ranges. For example,
``year:2001..`` finds music released since the turn of the century.
@ -391,8 +368,8 @@ class NumericQuery(NumericColumnQuery[Union[int, float]]):
except ValueError:
raise InvalidQueryArgumentValueError(s, "an int or a float")
def __init__(self, field_name: str, pattern: str, fast: bool = True):
super().__init__(field_name, pattern, fast)
def __init__(self, field: str, pattern: str, fast: bool = True):
super().__init__(field, pattern, fast)
parts = pattern.split("..", 1)
if len(parts) == 1:
@ -407,9 +384,9 @@ class NumericQuery(NumericColumnQuery[Union[int, float]]):
self.rangemax = self._convert(parts[1])
def match(self, obj: Model) -> bool:
if self.field_name not in obj:
if self.field not in obj:
return False
value = obj[self.field_name]
value = obj[self.field]
if isinstance(value, str):
value = self._convert(value)
@ -442,7 +419,7 @@ class NumericQuery(NumericColumnQuery[Union[int, float]]):
class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]):
"""Query which matches values in the given set."""
field_name: str
field: str
pattern: Sequence[AnySQLiteType]
fast: bool = True
@ -452,7 +429,7 @@ class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]):
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
placeholders = ", ".join(["?"] * len(self.subvals))
return f"{self.field_name} IN ({placeholders})", self.subvals
return f"{self.field} IN ({placeholders})", self.subvals
@classmethod
def value_match(
@ -469,11 +446,6 @@ class CollectionQuery(Query):
def __init__(self, subqueries: Sequence = ()):
self.subqueries = subqueries
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return reduce(or_, (sq.field_names for sq in self.subqueries))
# Act like a sequence.
def __len__(self) -> int:
@ -491,7 +463,7 @@ class CollectionQuery(Query):
def clause_with_joiner(
self,
joiner: str,
) -> Tuple[str, Sequence[SQLiteType]]:
) -> Tuple[Optional[str], Sequence[SQLiteType]]:
"""Return a clause created by joining together the clauses of
all subqueries with the string joiner (padded by spaces).
"""
@ -499,6 +471,9 @@ class CollectionQuery(Query):
subvals = []
for subq in self.subqueries:
subq_clause, subq_subvals = subq.clause()
if not subq_clause:
# Fall back to slow query.
return None, ()
clause_parts.append("(" + subq_clause + ")")
subvals += subq_subvals
clause = (" " + joiner + " ").join(clause_parts)
@ -517,6 +492,45 @@ class CollectionQuery(Query):
return reduce(mul, map(hash, self.subqueries), 1)
class AnyFieldQuery(CollectionQuery):
"""A query that matches if a given FieldQuery subclass matches in
any field. The individual field query class is provided to the
constructor.
"""
def __init__(self, pattern, fields, cls: Type[FieldQuery]):
self.pattern = pattern
self.fields = fields
self.query_class = cls
subqueries = []
for field in self.fields:
subqueries.append(cls(field, pattern, True))
# TYPING ERROR
super().__init__(subqueries)
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
return self.clause_with_joiner("or")
def match(self, obj: Model) -> bool:
for subq in self.subqueries:
if subq.match(obj):
return True
return False
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.pattern!r}, {self.fields!r}, "
f"{self.query_class.__name__})"
)
def __eq__(self, other) -> bool:
return super().__eq__(other) and self.query_class == other.query_class
def __hash__(self) -> int:
return hash((self.pattern, tuple(self.fields), self.query_class))
class MutableCollectionQuery(CollectionQuery):
"""A collection query whose subqueries may be modified after the
query is initialized.
@ -534,7 +548,7 @@ class MutableCollectionQuery(CollectionQuery):
class AndQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
return self.clause_with_joiner("and")
def match(self, obj: Model) -> bool:
@ -544,7 +558,7 @@ class AndQuery(MutableCollectionQuery):
class OrQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
return self.clause_with_joiner("or")
def match(self, obj: Model) -> bool:
@ -559,14 +573,14 @@ class NotQuery(Query):
def __init__(self, subquery):
self.subquery = subquery
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return self.subquery.field_names
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
clause, subvals = self.subquery.clause()
return f"not ({clause})", subvals
if clause:
return f"not ({clause})", subvals
else:
# If there is no clause, there is nothing to negate. All the logic
# is handled by match() for slow queries.
return clause, subvals
def match(self, obj: Model) -> bool:
return not self.subquery.match(obj)
@ -773,7 +787,7 @@ class DateInterval:
return f"[{self.start}, {self.end})"
class DateQuery(NumericColumnQuery[int]):
class DateQuery(FieldQuery[str]):
"""Matches date fields stored as seconds since Unix epoch time.
Dates can be specified as ``year-month-day`` strings where only year
@ -783,15 +797,15 @@ class DateQuery(NumericColumnQuery[int]):
using an ellipsis interval syntax similar to that of NumericQuery.
"""
def __init__(self, field_name: str, pattern: str, fast: bool = True):
super().__init__(field_name, pattern, fast)
def __init__(self, field: str, pattern: str, fast: bool = True):
super().__init__(field, pattern, fast)
start, end = _parse_periods(pattern)
self.interval = DateInterval.from_periods(start, end)
def match(self, obj: Model) -> bool:
if self.field_name not in obj:
if self.field not in obj:
return False
timestamp = float(obj[self.field_name])
timestamp = float(obj[self.field])
date = datetime.fromtimestamp(timestamp)
return self.interval.contains(date)
@ -867,7 +881,7 @@ class Sort:
return sorted(items)
def is_slow(self) -> bool:
"""Indicate whether this sort is *slow*, meaning that it cannot
"""Indicate whether this query is *slow*, meaning that it cannot
be executed in SQL and must be executed in Python.
"""
return False

View file

@ -16,23 +16,11 @@
import itertools
import re
from typing import (
TYPE_CHECKING,
Collection,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
)
from typing import Collection, Dict, List, Optional, Sequence, Tuple, Type
from . import Model, query
from .query import Sort
if TYPE_CHECKING:
from ..library import LibModel
PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
r"(-|\^)?" # Negation prefixes.
@ -116,7 +104,7 @@ def parse_query_part(
def construct_query_part(
model_cls: Type["LibModel"],
model_cls: Type[Model],
prefixes: Dict,
query_part: str,
) -> query.Query:
@ -151,14 +139,20 @@ def construct_query_part(
query_part, query_classes, prefixes
)
# If there's no key (field name) specified, this is a "match
# anything" query.
if key is None:
# If there's no key (field name) specified, this is a "match anything"
# query.
out_query = model_cls.any_field_query(query_class, pattern)
# The query type matches a specific field, but none was
# specified. So we use a version of the query that matches
# any field.
out_query = query.AnyFieldQuery(
pattern, model_cls._search_fields, query_class
)
# Field queries get constructed according to the name of the field
# they are querying.
else:
# Field queries get constructed according to the name of the field
# they are querying.
out_query = model_cls.field_query(key.lower(), pattern, query_class)
out_query = query_class(key.lower(), pattern, key in model_cls._fields)
# Apply negation.
if negate:

View file

@ -708,7 +708,7 @@ class ImportTask(BaseImportTask):
# use a temporary Album object to generate any computed fields.
tmp_album = library.Album(lib, **info)
keys = config["import"]["duplicate_keys"]["album"].as_str_seq()
dup_query = library.Album.match_all_query(
dup_query = library.Album.all_fields_query(
{key: tmp_album.get(key) for key in keys}
)
@ -1019,7 +1019,7 @@ class SingletonImportTask(ImportTask):
# temporary `Item` object to generate any computed fields.
tmp_item = library.Item(lib, **info)
keys = config["import"]["duplicate_keys"]["item"].as_str_seq()
dup_query = library.Item.match_all_query(
dup_query = library.Album.all_fields_query(
{key: tmp_item.get(key) for key in keys}
)

View file

@ -14,7 +14,6 @@
"""The core data store and collection logic for beets.
"""
from __future__ import annotations
import os
import re
@ -24,7 +23,6 @@ import sys
import time
import unicodedata
from functools import cached_property
from typing import Mapping, Set, Type
from mediafile import MediaFile, UnreadableFileError
@ -34,7 +32,6 @@ from beets.dbcore import Results, types
from beets.util import (
MoveOperation,
bytestring_path,
cached_classproperty,
normpath,
samefile,
syspath,
@ -389,18 +386,6 @@ class LibModel(dbcore.Model):
# Config key that specifies how an instance should be formatted.
_format_config_key: str
@cached_classproperty
def all_model_db_fields(cls) -> Set[str]:
return cls._fields.keys() | cls._relation._fields.keys()
@cached_classproperty
def shared_model_db_fields(cls) -> Set[str]:
return cls._fields.keys() & cls._relation._fields.keys()
@cached_classproperty
def writable_fields(cls) -> Set[str]:
return MediaFile.fields() & cls._relation._fields.keys()
def _template_funcs(self):
funcs = DefaultTemplateFunctions(self, self._db).functions()
funcs.update(plugins.template_funcs())
@ -430,61 +415,6 @@ class LibModel(dbcore.Model):
def __bytes__(self):
return self.__str__().encode("utf-8")
# Convenient queries.
@classmethod
def field_query(
cls, field: str, pattern: str, query_cls: Type[dbcore.FieldQuery]
) -> dbcore.Query:
"""Get a `FieldQuery` for this model."""
fast = field in cls.all_model_db_fields
if field in cls.shared_model_db_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError if we try to use it in a query.
# Using an explicit table name resolves this.
field = f"{cls._table}.{field}"
return query_cls(field, pattern, fast)
@classmethod
def any_field_query(
cls, query_class: Type[dbcore.FieldQuery], pattern: str
) -> dbcore.OrQuery:
return dbcore.OrQuery(
[
cls.field_query(f, pattern, query_class)
for f in cls._search_fields
]
)
@classmethod
def any_writable_field_query(
cls, query_class: Type[dbcore.FieldQuery], pattern: str
) -> dbcore.OrQuery:
return dbcore.OrQuery(
[
cls.field_query(f, pattern, query_class)
for f in cls.writable_fields
]
)
@classmethod
def match_all_query(
cls, pattern_by_field: Mapping[str, str]
) -> dbcore.AndQuery:
"""Get a query that matches many fields with different patterns.
`pattern_by_field` should be a mapping from field names to patterns.
The resulting query is a conjunction ("and") of per-field queries
for all of these field/pattern pairs.
"""
return dbcore.AndQuery(
[
cls.field_query(f, p, dbcore.MatchQuery)
for f, p in pattern_by_field.items()
]
)
class FormattedItemMapping(dbcore.db.FormattedMapping):
"""Add lookup for album-level fields.
@ -710,22 +640,6 @@ class Item(LibModel):
# Cached album object. Read-only.
__album = None
@cached_classproperty
def _relation(cls) -> type[Album]:
return Album
@cached_classproperty
def relation_join(cls) -> str:
"""Return the FROM clause which includes related albums.
We need to use a LEFT JOIN here, otherwise items that are not part of
an album (e.g. singletons) would be left out.
"""
return (
f"LEFT JOIN {cls._relation.table_with_flex_attrs}"
f" ON {cls._table}.album_id = {cls._relation._table}.id"
)
@property
def _cached_album(self):
"""The Album object that this item belongs to, if any, or
@ -1326,22 +1240,6 @@ class Album(LibModel):
_format_config_key = "format_album"
@cached_classproperty
def _relation(cls) -> type[Item]:
return Item
@cached_classproperty
def relation_join(cls) -> str:
"""Return FROM clause which joins on related album items.
Here we can use INNER JOIN (which is more performant than LEFT JOIN),
since we only want to see albums that have at least one Item in them.
"""
return (
f"INNER JOIN {cls._relation.table_with_flex_attrs}"
f" ON {cls._table}.id = {cls._relation._table}.album_id"
)
@classmethod
def _getters(cls):
# In addition to plugin-provided computed fields, also expose
@ -2030,10 +1928,9 @@ class DefaultTemplateFunctions:
subqueries.extend(initial_subqueries)
for key in keys:
value = db_item.get(key, "")
subqueries.append(
db_item.field_query(key, value, dbcore.MatchQuery)
)
# Use slow queries for flexible attributes.
fast = key in item_keys
subqueries.append(dbcore.MatchQuery(key, value, fast))
query = dbcore.AndQuery(subqueries)
ambigous_items = (
self.lib.items(query)

View file

@ -1055,20 +1055,3 @@ def par_map(transform: Callable, items: Iterable):
pool.map(transform, items)
pool.close()
pool.join()
class cached_classproperty: # noqa: N801
"""A decorator implementing a read-only property that is *lazy* in
the sense that the getter is only invoked once. Subsequent accesses
through *any* instance use the cached result.
"""
def __init__(self, getter):
self.getter = getter
self.cache = {}
def __get__(self, instance, owner):
if owner not in self.cache:
self.cache[owner] = self.getter(owner)
return self.cache[owner]

View file

@ -180,9 +180,8 @@ class AURADocument:
converter = self.get_attribute_converter(beets_attr)
value = converter(value)
# Add exact match query to list
queries.append(
self.model_cls.field_query(beets_attr, value, MatchQuery)
)
# Use a slow query so it works with all fields
queries.append(MatchQuery(beets_attr, value, fast=False))
# NOTE: AURA doesn't officially support multiple queries
return AndQuery(queries)

View file

@ -29,6 +29,8 @@ import traceback
from string import Template
from typing import List
from mediafile import MediaFile
import beets
import beets.ui
from beets import dbcore, vfs
@ -91,6 +93,8 @@ SUBSYSTEMS = [
"partition",
]
ITEM_KEYS_WRITABLE = set(MediaFile.fields()).intersection(Item._fields.keys())
# Gstreamer import error.
class NoGstreamerError(Exception):
@ -1397,7 +1401,7 @@ class Server(BaseServer):
return test_tag, key
raise BPDError(ERROR_UNKNOWN, "no such tagtype")
def _metadata_query(self, query_type, kv, allow_any_query: bool = False):
def _metadata_query(self, query_type, any_query_type, kv):
"""Helper function returns a query object that will find items
according to the library query type provided and the key-value
pairs specified. The any_query_type is used for queries of
@ -1409,9 +1413,11 @@ class Server(BaseServer):
it = iter(kv)
for tag, value in zip(it, it):
if tag.lower() == "any":
if allow_any_query:
if any_query_type:
queries.append(
Item.any_writable_field_query(query_type, value)
any_query_type(
value, ITEM_KEYS_WRITABLE, query_type
)
)
else:
raise BPDError(ERROR_UNKNOWN, "no such tagtype")
@ -1425,14 +1431,14 @@ class Server(BaseServer):
def cmd_search(self, conn, *kv):
"""Perform a substring match for items."""
query = self._metadata_query(
dbcore.query.SubstringQuery, kv, allow_any_query=True
dbcore.query.SubstringQuery, dbcore.query.AnyFieldQuery, kv
)
for item in self.lib.items(query):
yield self._item_info(item)
def cmd_find(self, conn, *kv):
"""Perform an exact match for items."""
query = self._metadata_query(dbcore.query.MatchQuery, kv)
query = self._metadata_query(dbcore.query.MatchQuery, None, kv)
for item in self.lib.items(query):
yield self._item_info(item)
@ -1452,7 +1458,7 @@ class Server(BaseServer):
raise BPDError(ERROR_ARG, 'should be "Album" for 3 arguments')
elif len(kv) % 2 != 0:
raise BPDError(ERROR_ARG, "Incorrect number of filter arguments")
query = self._metadata_query(dbcore.query.MatchQuery, kv)
query = self._metadata_query(dbcore.query.MatchQuery, None, kv)
clause, subvals = query.clause()
statement = (

View file

@ -6,16 +6,6 @@ Unreleased
Changelog goes here! Please add your entry to the bottom of one of the lists below!
New features:
* Ability to query albums with track-level (and vice-versa) **db** or
**flexible** field queries, for example `beet list -a title:something`, `beet
list artpath:cover`.
* Queries have been made faster, and their speed is constant regardless of
their complexity or the type of queried fields. Notably, album queries for
the `path` field and those that involve flexible attributes have seen the
most significant speedup.
Bug fixes:
* Improved naming of temporary files by separating the random part with the file extension.

View file

@ -17,9 +17,7 @@ This command::
$ beet list love
will show all tracks matching the query string ``love``. By default any
unadorned word like this matches in a track's title, artist, album name, album
artist, genre and comments. See below on how to search other fields.
will show all tracks matching the query string ``love``. By default any unadorned word like this matches in a track's title, artist, album name, album artist, genre and comments. See below on how to search other fields.
For example, this is what I might see when I run the command above::
@ -85,15 +83,6 @@ For multi-valued tags (such as ``artists`` or ``albumartists``), a regular
expression search must be used to search for a single value within the
multi-valued tag.
Note that you can filter albums by querying their tracks fields, including
flexible attributes::
$ beet list -a title:love
and vice versa::
$ beet list art_path::love
Phrases
-------
@ -126,9 +115,9 @@ the field name's colon and before the expression::
$ beet list artist:=AIR
The first query is a simple substring one that returns tracks by Air, AIR, and
Air Supply. The second query returns tracks by Air and AIR, since both are a
Air Supply. The second query returns tracks by Air and AIR, since both are a
case-insensitive match for the entire expression, but does not return anything
by Air Supply. The third query, which requires a case-sensitive exact match,
by Air Supply. The third query, which requires a case-sensitive exact match,
returns tracks by AIR only.
Exact matches may be performed on phrases as well::
@ -369,7 +358,7 @@ result in lower-case values being placed after upper-case values, e.g.,
``Bar Qux foo``.
Note that when sorting by fields that are not present on all items (such as
flexible fields, or those defined by plugins) in *ascending* order, the items
flexible fields, or those defined by plugins) in *ascending* order, the items
that lack that particular field will be listed at the *beginning* of the list.
You can set the default sorting behavior with the :ref:`sort_item` and

115
poetry.lock generated
View file

@ -685,17 +685,6 @@ files = [
[package.dependencies]
Flask = ">=0.9"
[[package]]
name = "funcy"
version = "2.0"
description = "A fancy and practical functional tools"
optional = false
python-versions = "*"
files = [
{file = "funcy-2.0-py2.py3-none-any.whl", hash = "sha256:53df23c8bb1651b12f095df764bfb057935d49537a56de211b098f4c79614bb0"},
{file = "funcy-2.0.tar.gz", hash = "sha256:3963315d59d41c6f30c04bc910e10ab50a3ac4a225868bfa96feed133df075cb"},
]
[[package]]
name = "h11"
version = "0.14.0"
@ -1181,30 +1170,6 @@ html5 = ["html5lib"]
htmlsoup = ["BeautifulSoup4"]
source = ["Cython (>=3.0.10)"]
[[package]]
name = "markdown-it-py"
version = "3.0.0"
description = "Python port of markdown-it. Markdown parsing, done right!"
optional = false
python-versions = ">=3.8"
files = [
{file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
{file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
]
[package.dependencies]
mdurl = ">=0.1,<1.0"
[package.extras]
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
code-style = ["pre-commit (>=3.0,<4.0)"]
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
linkify = ["linkify-it-py (>=1,<3)"]
plugins = ["mdit-py-plugins"]
profiling = ["gprof2dot"]
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
[[package]]
name = "markupsafe"
version = "2.1.5"
@ -1285,17 +1250,6 @@ files = [
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
]
[[package]]
name = "mdurl"
version = "0.1.2"
description = "Markdown URL utilities"
optional = false
python-versions = ">=3.7"
files = [
{file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
]
[[package]]
name = "mediafile"
version = "0.12.0"
@ -1330,17 +1284,6 @@ build = ["blurb", "twine", "wheel"]
docs = ["sphinx"]
test = ["pytest", "pytest-cov"]
[[package]]
name = "multimethod"
version = "1.10"
description = "Multiple argument dispatching."
optional = false
python-versions = ">=3.8"
files = [
{file = "multimethod-1.10-py3-none-any.whl", hash = "sha256:afd84da9c3d0445c84f827e4d63ad42d17c6d29b122427c6dee9032ac2d2a0d4"},
{file = "multimethod-1.10.tar.gz", hash = "sha256:daa45af3fe257f73abb69673fd54ddeaf31df0eb7363ad6e1251b7c9b192d8c5"},
]
[[package]]
name = "multivolumefile"
version = "0.2.3"
@ -2369,47 +2312,6 @@ urllib3 = ">=1.25.10,<3.0"
[package.extras]
tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "tomli", "tomli-w", "types-PyYAML", "types-requests"]
[[package]]
name = "rich"
version = "13.7.1"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"},
{file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"},
]
[package.dependencies]
markdown-it-py = ">=2.2.0"
pygments = ">=2.13.0,<3.0.0"
typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""}
[package.extras]
jupyter = ["ipywidgets (>=7.5.1,<9)"]
[[package]]
name = "rich-tables"
version = "0.5.1"
description = "Ready-made rich tables for various purposes"
optional = false
python-versions = "<4,>=3.8"
files = [
{file = "rich_tables-0.5.1-py3-none-any.whl", hash = "sha256:26980f9881a44cd5a530f634c17fa4bed40875ee962127bbdafec9c237589b8d"},
{file = "rich_tables-0.5.1.tar.gz", hash = "sha256:7cc9887f380d773aa0e2da05256970bcbb61bc40445193f32a1f7e167e77a971"},
]
[package.dependencies]
funcy = ">=2.0"
multimethod = "*"
platformdirs = ">=4.2.0"
rich = ">=12.3.0"
sqlparse = ">=0.4.4"
typing-extensions = ">=4.7.1"
[package.extras]
hue = ["rgbxy (>=0.5)"]
[[package]]
name = "six"
version = "1.16.0"
@ -2600,21 +2502,6 @@ files = [
lint = ["docutils-stubs", "flake8", "mypy"]
test = ["pytest"]
[[package]]
name = "sqlparse"
version = "0.5.0"
description = "A non-validating SQL parser."
optional = false
python-versions = ">=3.8"
files = [
{file = "sqlparse-0.5.0-py3-none-any.whl", hash = "sha256:c204494cd97479d0e39f28c93d46c0b2d5959c7b9ab904762ea6c7af211c8663"},
{file = "sqlparse-0.5.0.tar.gz", hash = "sha256:714d0a4932c059d16189f58ef5411ec2287a4360f17cdd0edd2d09d4c5087c93"},
]
[package.extras]
dev = ["build", "hatch"]
doc = ["sphinx"]
[[package]]
name = "texttable"
version = "1.7.0"
@ -2833,4 +2720,4 @@ web = ["flask", "flask-cors"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8,<4"
content-hash = "0de3f4cf9e0fc7ace1de5e9c3aa859cb2b5b2a42d0a58e4b1d96a4dc251bde07"
content-hash = "740281ee3ddba4c6015eab9cfc24bb947e8816e3b7f5a6bebeb39ff2413d7ac3"

View file

@ -41,7 +41,6 @@ mediafile = ">=0.12.0"
munkres = ">=1.0.0"
musicbrainzngs = ">=0.4"
pyyaml = "*"
rich-tables = ">=0.5.1"
typing_extensions = "*"
unidecode = ">=1.3.6"
beautifulsoup4 = { version = "*", optional = true }

View file

@ -31,9 +31,7 @@ show_contexts = true
min-version = 3.8
accept-encodings = utf-8
max-line-length = 88
classmethod-decorators =
classmethod
cached_classproperty
docstring-convention = google
# errors we ignore; see https://www.flake8rules.com/ for more info
ignore =
# pycodestyle errors

View file

@ -15,8 +15,6 @@
import unittest
import pytest
from beets.test.helper import TestHelper
@ -81,17 +79,11 @@ class LimitPluginTest(unittest.TestCase, TestHelper):
)
self.assertEqual(result.count("\n"), self.num_limit)
@pytest.mark.xfail(
reason="Will be restored together with removal of slow sorts"
)
def test_prefix(self):
"""Returns the expected number with the query prefix."""
result = self.lib.items(self.num_limit_prefix)
self.assertEqual(len(result), self.num_limit)
@pytest.mark.xfail(
reason="Will be restored together with removal of slow sorts"
)
def test_prefix_when_correctly_ordered(self):
"""Returns the expected number with the query prefix and filter when
the prefix portion (correctly) appears last."""
@ -99,9 +91,6 @@ class LimitPluginTest(unittest.TestCase, TestHelper):
result = self.lib.items(correct_order)
self.assertEqual(len(result), self.num_limit)
@pytest.mark.xfail(
reason="Will be restored together with removal of slow sorts"
)
def test_prefix_when_incorrectly_ordred(self):
"""Returns no results with the query prefix and filter when the prefix
portion (incorrectly) appears first."""

View file

@ -5,7 +5,6 @@ import os.path
import platform
import shutil
import unittest
from pathlib import Path
from beets import logging
from beets.library import Album, Item
@ -30,38 +29,36 @@ class WebPluginTest(_common.LibTestCase):
# Add library elements. Note that self.lib.add overrides any "id=<n>"
# and assigns the next free id number.
# The following adds will create items #1, #2 and #3
base_path = Path(self.path_prefix + os.sep)
album2_item1 = Item(
title="title",
path=str(base_path / "path_1"),
album_id=2,
artist="AAA Singers",
path1 = (
self.path_prefix + os.sep + os.path.join(b"path_1").decode("utf-8")
)
album1_item = Item(
title="another title",
path=str(base_path / "somewhere" / "a"),
artist="AAA Singers",
self.lib.add(
Item(title="title", path=path1, album_id=2, artist="AAA Singers")
)
album2_item2 = Item(
title="and a third",
testattr="ABC",
path=str(base_path / "somewhere" / "abc"),
album_id=2,
path2 = (
self.path_prefix
+ os.sep
+ os.path.join(b"somewhere", b"a").decode("utf-8")
)
self.lib.add(
Item(title="another title", path=path2, artist="AAA Singers")
)
path3 = (
self.path_prefix
+ os.sep
+ os.path.join(b"somewhere", b"abc").decode("utf-8")
)
self.lib.add(
Item(title="and a third", testattr="ABC", path=path3, album_id=2)
)
self.lib.add(album2_item1)
self.lib.add(album1_item)
self.lib.add(album2_item2)
# The following adds will create albums #1 and #2
album1 = self.lib.add_album([album1_item])
album1.album = "album"
album1.albumtest = "xyz"
album1.store()
album2 = self.lib.add_album([album2_item1, album2_item2])
album2.album = "other album"
album2.artpath = str(base_path / "somewhere2" / "art_path_2")
album2.store()
self.lib.add(Album(album="album", albumtest="xyz"))
path4 = (
self.path_prefix
+ os.sep
+ os.path.join(b"somewhere2", b"art_path_2").decode("utf-8")
)
self.lib.add(Album(album="other album", artpath=path4))
web.app.config["TESTING"] = True
web.app.config["lib"] = self.lib

View file

@ -143,7 +143,7 @@ def _clear_weights():
"""Hack around the lazy descriptor used to cache weights for
Distance calculations.
"""
Distance.__dict__["_weights"].cache = {}
Distance.__dict__["_weights"].computed = False
class DistanceTest(_common.TestCase):

View file

@ -21,7 +21,6 @@ import unittest
from tempfile import mkstemp
from beets import dbcore
from beets.library import LibModel
from beets.test import _common
# Fixture: concrete database and model classes. For migration tests, we
@ -43,7 +42,7 @@ class QueryFixture(dbcore.query.FieldQuery):
return True
class ModelFixture1(LibModel):
class ModelFixture1(dbcore.Model):
_table = "test"
_flex_table = "testflex"
_fields = {
@ -590,7 +589,7 @@ class QueryFromStringsTest(unittest.TestCase):
q = self.qfs(["foo", "bar:baz"])
self.assertIsInstance(q, dbcore.query.AndQuery)
self.assertEqual(len(q.subqueries), 2)
self.assertIsInstance(q.subqueries[0], dbcore.query.OrQuery)
self.assertIsInstance(q.subqueries[0], dbcore.query.AnyFieldQuery)
self.assertIsInstance(q.subqueries[1], dbcore.query.SubstringQuery)
def test_parse_fixed_type_query(self):

View file

@ -48,6 +48,40 @@ class TestHelper(helper.TestHelper):
self.assertNotIn(item.id, result_ids)
class AnyFieldQueryTest(_common.LibTestCase):
def test_no_restriction(self):
q = dbcore.query.AnyFieldQuery(
"title",
beets.library.Item._fields.keys(),
dbcore.query.SubstringQuery,
)
self.assertEqual(self.lib.items(q).get().title, "the title")
def test_restriction_completeness(self):
q = dbcore.query.AnyFieldQuery(
"title", ["title"], dbcore.query.SubstringQuery
)
self.assertEqual(self.lib.items(q).get().title, "the title")
def test_restriction_soundness(self):
q = dbcore.query.AnyFieldQuery(
"title", ["artist"], dbcore.query.SubstringQuery
)
self.assertIsNone(self.lib.items(q).get())
def test_eq(self):
q1 = dbcore.query.AnyFieldQuery(
"foo", ["bar"], dbcore.query.SubstringQuery
)
q2 = dbcore.query.AnyFieldQuery(
"foo", ["bar"], dbcore.query.SubstringQuery
)
self.assertEqual(q1, q2)
q2.query_class = None
self.assertNotEqual(q1, q2)
class AssertsMixin:
def assert_items_matched(self, results, titles):
self.assertEqual({i.title for i in results}, set(titles))
@ -487,7 +521,7 @@ class PathQueryTest(_common.LibTestCase, TestHelper, AssertsMixin):
self.assert_items_matched(results, ["path item"])
results = self.lib.albums(q)
self.assert_albums_matched(results, ["path album"])
self.assert_albums_matched(results, [])
# FIXME: fails on windows
@unittest.skipIf(sys.platform == "win32", "win32")
@ -570,9 +604,6 @@ class PathQueryTest(_common.LibTestCase, TestHelper, AssertsMixin):
results = self.lib.items(q)
self.assert_items_matched(results, ["path item"])
results = self.lib.albums(q)
self.assert_albums_matched(results, ["path album"])
def test_path_album_regex(self):
q = "path::b"
results = self.lib.albums(q)
@ -823,17 +854,17 @@ class NoneQueryTest(unittest.TestCase, TestHelper):
def test_match_slow(self):
item = self.add_item()
matched = self.lib.items(NoneQuery("rg_track_peak"))
matched = self.lib.items(NoneQuery("rg_track_peak", fast=False))
self.assertInResult(item, matched)
def test_match_slow_after_set_none(self):
item = self.add_item(rg_track_gain=0)
matched = self.lib.items(NoneQuery("rg_track_gain"))
matched = self.lib.items(NoneQuery("rg_track_gain", fast=False))
self.assertNotInResult(item, matched)
item["rg_track_gain"] = None
item.store()
matched = self.lib.items(NoneQuery("rg_track_gain"))
matched = self.lib.items(NoneQuery("rg_track_gain", fast=False))
self.assertInResult(item, matched)
@ -947,6 +978,14 @@ class NotQueryTest(DummyDataTestCase):
self.assert_items_matched(not_results, ["foo bar", "beets 4 eva"])
self.assertNegationProperties(q)
def test_type_anyfield(self):
q = dbcore.query.AnyFieldQuery(
"foo", ["title", "artist", "album"], dbcore.query.SubstringQuery
)
not_results = self.lib.items(dbcore.query.NotQuery(q))
self.assert_items_matched(not_results, ["baz qux"])
self.assertNegationProperties(q)
def test_type_boolean(self):
q = dbcore.query.BooleanQuery("comp", True)
not_results = self.lib.items(dbcore.query.NotQuery(q))
@ -1055,87 +1094,36 @@ class NotQueryTest(DummyDataTestCase):
results = self.lib.items(q)
self.assert_items_matched(results, ["baz qux"])
def test_fast_vs_slow(self):
"""Test that the results are the same regardless of the `fast` flag
for negated `FieldQuery`s.
class RelatedQueriesTest(_common.TestCase, AssertsMixin):
"""Test album-level queries with track-level filters and vice-versa."""
TODO: investigate NoneQuery(fast=False), as it is raising
AttributeError: type object 'NoneQuery' has no attribute 'field'
at NoneQuery.match() (due to being @classmethod, and no self?)
"""
classes = [
(dbcore.query.DateQuery, ["added", "2001-01-01"]),
(dbcore.query.MatchQuery, ["artist", "one"]),
# (dbcore.query.NoneQuery, ['rg_track_gain']),
(dbcore.query.NumericQuery, ["year", "2002"]),
(dbcore.query.StringFieldQuery, ["year", "2001"]),
(dbcore.query.RegexpQuery, ["album", "^.a"]),
(dbcore.query.SubstringQuery, ["title", "x"]),
]
def setUp(self):
super().setUp()
self.lib = beets.library.Library(":memory:")
for klass, args in classes:
q_fast = dbcore.query.NotQuery(klass(*(args + [True])))
q_slow = dbcore.query.NotQuery(klass(*(args + [False])))
albums = []
for album_idx in range(1, 3):
album_name = f"Album{album_idx}"
album_items = []
for item_idx in range(1, 3):
item = _common.item()
item.album = album_name
title = f"{album_name} Item{item_idx}"
item.title = title
item.item_flex1 = f"{title} Flex1"
item.item_flex2 = f"{title} Flex2"
self.lib.add(item)
album_items.append(item)
album = self.lib.add_album(album_items)
album.artpath = f"{album_name} Artpath"
album.catalognum = "ABC"
album.album_flex = f"{album_name} Flex"
album.store()
albums.append(album)
self.album, self.another_album = albums
def test_get_albums_filter_by_track_field(self):
q = "title:Album1"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def test_get_items_filter_by_album_field(self):
q = "artpath::Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
def test_filter_albums_by_common_field(self):
# title:Album1 ensures that the items table is joined for the query
q = "title:Album1 catalognum:ABC"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def test_filter_items_by_common_field(self):
# artpath::A ensures that the albums table is joined for the query
q = "artpath::A Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
def test_get_items_filter_by_track_flex(self):
q = "item_flex1:Item1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"])
def test_get_albums_filter_by_album_flex(self):
q = "album_flex:Album1"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def test_get_albums_filter_by_track_flex(self):
q = "item_flex1:Album1"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def test_get_items_filter_by_album_flex(self):
q = "album_flex:Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
def test_filter_by_flex(self):
q = "item_flex1:'Item1 Flex1'"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"])
def test_filter_by_many_flex(self):
q = "item_flex1:'Item1 Flex1' item_flex2:Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1"])
try:
self.assertEqual(
[i.title for i in self.lib.items(q_fast)],
[i.title for i in self.lib.items(q_slow)],
)
except NotImplementedError:
# ignore classes that do not provide `fast` implementation
pass
def suite():