diff --git a/beets/autotag/hooks.py b/beets/autotag/hooks.py index efd71da9b..c1dabdb09 100644 --- a/beets/autotag/hooks.py +++ b/beets/autotag/hooks.py @@ -39,7 +39,7 @@ from unidecode import unidecode from beets import config, logging, plugins from beets.autotag import mb from beets.library import Item -from beets.util import as_string, cached_classproperty +from beets.util import as_string log = logging.getLogger("beets") @@ -413,6 +413,23 @@ def string_dist(str1: Optional[str], str2: Optional[str]) -> float: return base_dist + penalty +class LazyClassProperty: + """A decorator implementing a read-only property that is *lazy* in + the sense that the getter is only invoked once. Subsequent accesses + through *any* instance use the cached result. + """ + + def __init__(self, getter): + self.getter = getter + self.computed = False + + def __get__(self, obj, owner): + if not self.computed: + self.value = self.getter(owner) + self.computed = True + return self.value + + @total_ordering class Distance: """Keeps track of multiple distance penalties. Provides a single @@ -424,7 +441,7 @@ class Distance: self._penalties = {} self.tracks: Dict[TrackInfo, Distance] = {} - @cached_classproperty + @LazyClassProperty def _weights(cls) -> Dict[str, float]: # noqa: N805 """A dictionary from keys to floating-point weights.""" weights_view = config["match"]["distance_weights"] diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 8b7f3f35c..4f7665e37 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -17,16 +17,14 @@ from __future__ import annotations import contextlib -import json import os import re import sqlite3 -import sys import threading import time from abc import ABC from collections import defaultdict -from sqlite3 import Connection, sqlite_version +from sqlite3 import Connection from types import TracebackType from typing import ( Any, @@ -50,31 +48,22 @@ from typing import ( cast, ) -from packaging.version import Version -from rich import print -from rich_tables.generic import flexitable from unidecode import unidecode import beets +from beets.util import functemplate -from ..util import cached_classproperty, functemplate +from ..util.functemplate import Template from . import types -from .query import FieldQuery, MatchQuery, NullSort, Query, Sort, TrueQuery - -# convert data under 'json_str' type name to Python dictionary automatically -sqlite3.register_converter("json_str", json.loads) - -DEBUG = bool(os.getenv("BEETS_DEBUG", False)) - - -def print_query(sql, subvals=None): - """If debugging, replace placeholders and print the query.""" - if not DEBUG: - return - topr = sql - for val in subvals or []: - topr = topr.replace("?", str(val), 1) - print(flexitable({"sql": topr}), file=sys.stderr) +from .query import ( + AndQuery, + FieldQuery, + MatchQuery, + NullSort, + Query, + Sort, + TrueQuery, +) class DBAccessError(Exception): @@ -334,64 +323,6 @@ class Model(ABC): to the database. """ - @cached_classproperty - def _relation(cls) -> Type[Model]: - """The model that this model is closely related to.""" - return cls - - @cached_classproperty - def relation_join(cls) -> str: - """Return the join required to include the related table in the query. - - This is intended to be used as a FROM clause in the SQL query. - """ - return "" - - @cached_classproperty - def table_with_flex_attrs(cls) -> str: - """Return a SQL for entity table which includes aggregated flexible attributes. - - The clause selects entity rows, flexible attributes rows and LEFT JOINs - them on entity id and 'entity_id' field respectively. - - 'json_group_object' aggregate function groups flexible attributes into a - single JSON object 'flex_attrs [json_str]'. The column name ending with - ' [json_str]' means that this column is converted to a Python dictionary - automatically (see 'register_converter' call at the top of this module). - - 'REPLACE' function handles absence of flexible attributes and replaces - some weird null JSON object (that SQLite gives us by default) with an - empty JSON object. - - Availability of the 'flex_attrs' means we can query flexible attributes - in the same manner we query other entity fields, see - `FieldQuery.field`. This way, we also remove the need for an - additional query to fetch them. - - Note: we use LEFT join to include entities without flexible attributes. - Note: we name this SELECT clause after the original entity table name - so that we can query it in the way like the original table. - """ - flex_attrs = "REPLACE(json_group_object(key, value), '{:null}', '{}')" - return f"""( - SELECT - *, - {flex_attrs} AS "flex_attrs [json_str]" - FROM {cls._table} LEFT JOIN ( - SELECT - entity_id, - key, - CAST(value AS text) AS value - FROM {cls._flex_table} - ) ON entity_id == {cls._table}.id - GROUP BY {cls._table}.id - ) {cls._table} - """ - - @cached_classproperty - def all_model_db_fields(cls) -> Set[str]: - return set() - @classmethod def _getters(cls: Type["Model"]): """Return a mapping from field names to getter functions.""" @@ -737,7 +668,7 @@ class Model(ABC): def evaluate_template( self, - template: Union[str, functemplate.Template], + template: Union[str, Template], for_path: bool = False, ) -> str: """Evaluate a template (a string or a `Template` object) using @@ -768,6 +699,33 @@ class Model(ABC): """Set the object's key to a value represented by a string.""" self[key] = self._parse(key, string) + # Convenient queries. + + @classmethod + def field_query( + cls, + field, + pattern, + query_cls: Type[FieldQuery] = MatchQuery, + ) -> FieldQuery: + """Get a `FieldQuery` for this model.""" + return query_cls(field, pattern, field in cls._fields) + + @classmethod + def all_fields_query( + cls: Type["Model"], + pats: Mapping, + query_cls: Type[FieldQuery] = MatchQuery, + ): + """Get a query that matches many fields with different patterns. + + `pats` should be a mapping from field names to patterns. The + resulting query is a conjunction ("and") of per-field queries + for all of these field/pattern pairs. + """ + subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()] + return AndQuery(subqueries) + # Database controller and supporting interfaces. @@ -785,6 +743,8 @@ class Results(Generic[AnyModel]): model_class: Type[AnyModel], rows: List[Mapping], db: "Database", + flex_rows, + query: Optional[Query] = None, sort=None, ): """Create a result set that will construct objects of type @@ -794,7 +754,9 @@ class Results(Generic[AnyModel]): constructed. `rows` is a query result: a list of mappings. The new objects will be associated with the database `db`. - If `sort` is provided, it is used to sort the + If `query` is provided, it is used as a predicate to filter the + results for a "slow query" that cannot be evaluated by the + database directly. If `sort` is provided, it is used to sort the full list of results before returning. This means it is a "slow sort" and all objects must be built before returning the first one. @@ -802,7 +764,9 @@ class Results(Generic[AnyModel]): self.model_class = model_class self.rows = rows self.db = db + self.query = query self.sort = sort + self.flex_rows = flex_rows # We keep a queue of rows we haven't yet consumed for # materialization. We preserve the original total number of @@ -824,6 +788,10 @@ class Results(Generic[AnyModel]): a `Results` object a second time should be much faster than the first. """ + + # Index flexible attributes by the item ID, so we have easier access + flex_attrs = self._get_indexed_flex_attrs() + index = 0 # Position in the materialized objects. while index < len(self._objects) or self._rows: # Are there previously-materialized objects to produce? @@ -836,11 +804,14 @@ class Results(Generic[AnyModel]): else: while self._rows: row = self._rows.pop(0) - obj = self._make_model(row) - self._objects.append(obj) - index += 1 - yield obj - break + obj = self._make_model(row, flex_attrs.get(row["id"], {})) + # If there is a slow-query predicate, ensurer that the + # object passes it. + if not self.query or self.query.match(obj): + self._objects.append(obj) + index += 1 + yield obj + break def __iter__(self) -> Iterator[AnyModel]: """Construct and generate Model objects for all matching @@ -855,10 +826,21 @@ class Results(Generic[AnyModel]): # Objects are pre-sorted (i.e., by the database). return self._get_objects() - def _make_model(self, row) -> AnyModel: + def _get_indexed_flex_attrs(self) -> Mapping: + """Index flexible attributes by the entity id they belong to""" + flex_values: Dict[int, Dict[str, Any]] = {} + for row in self.flex_rows: + if row["entity_id"] not in flex_values: + flex_values[row["entity_id"]] = {} + + flex_values[row["entity_id"]][row["key"]] = row["value"] + + return flex_values + + def _make_model(self, row, flex_values: Dict = {}) -> AnyModel: """Create a Model object for the given row""" - values = dict(row) - flex_values = values.pop("flex_attrs") or {} + cols = dict(row) + values = {k: v for (k, v) in cols.items() if not k[:4] == "flex"} # Construct the Python object obj = self.model_class._awaken(self.db, values, flex_values) @@ -869,8 +851,16 @@ class Results(Generic[AnyModel]): if not self._rows: # Fully materialized. Just count the objects. return len(self._objects) + + elif self.query: + # A slow query. Fall back to testing every object. + count = 0 + for obj in self: + count += 1 + return count + else: - # Just count the rows. + # A fast query. Just count the rows. return self._row_count def __nonzero__(self) -> bool: @@ -960,7 +950,6 @@ class Transaction: """Execute an SQL statement with substitution values and return a list of rows from the database. """ - print_query(statement, subvals) cursor = self.db._connection().execute(statement, subvals) return cursor.fetchall() @@ -969,7 +958,6 @@ class Transaction: the row ID of the last affected row. """ try: - print_query(statement, subvals) cursor = self.db._connection().execute(statement, subvals) except sqlite3.OperationalError as e: # In two specific cases, SQLite reports an error while accessing @@ -990,7 +978,6 @@ class Transaction: """Execute a string containing multiple SQL statements.""" # We don't know whether this mutates, but quite likely it does. self._mutated = True - print_query(statements) self.db._connection().executescript(statements) @@ -1079,8 +1066,6 @@ class Database: # We have our own same-thread checks in _connection(), but need to # call conn.close() in _close() check_same_thread=False, - # enable type name "col [type]" conversion (`register_converter`) - detect_types=sqlite3.PARSE_COLNAMES, ) self.add_functions(conn) @@ -1099,9 +1084,7 @@ class Database: def regexp(value, pattern): if isinstance(value, bytes): value = value.decode() - return ( - value is not None and re.search(pattern, str(value)) is not None - ) + return re.search(pattern, str(value)) is not None def bytelower(bytestring: Optional[AnyStr]) -> Optional[AnyStr]: """A custom ``bytelower`` sqlite function so we can compare @@ -1116,71 +1099,9 @@ class Database: return bytestring - def json_patch(first: str, second: str) -> str: - """Implementation of the 'json_patch' SQL function. - - This function merges two JSON strings together. - """ - first_dict = json.loads(first) - second_dict = json.loads(second) - first_dict.update(second_dict) - return json.dumps(first_dict) - - def json_extract(json_str: str, key: str) -> Optional[str]: - """Simple implementation of the 'json_extract' SQLite function. - - The original implementation in SQLite allows traversing objects of - any depth. Here, we only ever deal with a flat dictionary, thus - we can simplify the implementation to a single 'get' call. - """ - if json_str: - return json.loads(json_str).get(key.replace("$.", "")) - - return None - - class JSONGroupObject: - """Implementation of the 'json_group_object' SQLite aggregate. - - An aggregate function which accepts two values (key, val) and - groups all {key: val} pairs into a single object. - - It is found in the json1 extension which is included in SQLite - by default since version 3.38.0 (2022-02-22). To ensure support - for older SQLite versions, we add our implementation. - - Notably, it does not exist on Windows in Python 3.8. - - Consider the following table - - id key val - 1 plays "10" - 1 skips "20" - 2 city "London" - - SELECT id, group_to_json(key, val) GROUP BY id - 1, '{"plays": "10", "skips": "20"}' - 2, '{"city": "London"}' - """ - - def __init__(self): - self.flex = {} - - def step(self, field, value): - if field: - self.flex[field] = value - - def finalize(self): - return json.dumps(self.flex) - conn.create_function("regexp", 2, regexp) conn.create_function("unidecode", 1, unidecode) conn.create_function("bytelower", 1, bytelower) - if Version(sqlite_version) < Version("3.38.0"): - # create 'json_group_object' for older SQLite versions that do - # not include the json1 extension by default - conn.create_aggregate("json_group_object", 2, JSONGroupObject) - conn.create_function("json_patch", 2, json_patch) - conn.create_function("json_extract", 2, json_extract) def _close(self): """Close the all connections to the underlying SQLite database @@ -1302,42 +1223,34 @@ class Database: where, subvals = query.clause() order_by = sort.order_clause() - this_table = model_cls._table - select_fields = [f"{this_table}.*"] - _from = model_cls.table_with_flex_attrs + sql = ("SELECT * FROM {} WHERE {} {}").format( + model_cls._table, + where or "1", + f"ORDER BY {order_by}" if order_by else "", + ) - required_fields = query.field_names - if required_fields - model_cls._fields.keys(): - _from += f" {model_cls.relation_join}" - - if required_fields - model_cls.all_model_db_fields: - # merge all flexible attribute into a single JSON field - select_fields.append( - f""" - json_patch( - COALESCE({this_table}."flex_attrs [json_str]", '{{}}'), - COALESCE({model_cls._relation._table}."flex_attrs [json_str]", '{{}}') - ) AS all_flex_attrs - """ # noqa: E501 - ) - - sql = f"SELECT {', '.join(select_fields)} FROM {_from} WHERE {where or 1} GROUP BY {this_table}.id" # noqa: E501 - - if order_by: - # the sort field may exist in both 'items' and 'albums' tables - # (when they are joined), causing ambiguous column OperationalError - # if we try to order directly. - # Since the join is required only for filtering, we can filter in - # a subquery and order the result, which returns unique fields. - sql = f"SELECT * FROM ({sql}) ORDER BY {order_by}" + # Fetch flexible attributes for items matching the main query. + # Doing the per-item filtering in python is faster than issuing + # one query per item to sqlite. + flex_sql = """ + SELECT * FROM {} WHERE entity_id IN + (SELECT id FROM {} WHERE {}); + """.format( + model_cls._flex_table, + model_cls._table, + where or "1", + ) with self.transaction() as tx: rows = tx.query(sql, subvals) + flex_rows = tx.query(flex_sql, subvals) return Results( model_cls, rows, self, + flex_rows, + None if where else query, # Slow query component. sort if sort.is_slow() else None, # Slow sort component. ) diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 2282c7815..2e1385ca2 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -21,7 +21,7 @@ import unicodedata from abc import ABC, abstractmethod from datetime import datetime, timedelta from functools import reduce -from operator import mul, or_ +from operator import mul from typing import ( TYPE_CHECKING, Any, @@ -33,7 +33,6 @@ from typing import ( Optional, Pattern, Sequence, - Set, Tuple, Type, TypeVar, @@ -82,19 +81,17 @@ class InvalidQueryArgumentValueError(ParsingError): class Query(ABC): """An abstract class representing a query into the database.""" - @property - def field_names(self) -> Set[str]: - """Return a set with field names that this query operates on.""" - return set() - def clause(self) -> Tuple[Optional[str], Sequence[Any]]: """Generate an SQLite expression implementing the query. Return (clause, subvals) where clause is a valid sqlite WHERE clause implementing the query and subvals is a list of items to be substituted for ?s in the clause. + + The default implementation returns None, falling back to a slow query + using `match()`. """ - raise NotImplementedError + return None, () @abstractmethod def match(self, obj: Model): @@ -131,30 +128,20 @@ class FieldQuery(Query, Generic[P]): same matching functionality in SQLite. """ - def __init__(self, field_name: str, pattern: P, fast: bool = True): - self.table, _, self.field_name = field_name.rpartition(".") + def __init__(self, field: str, pattern: P, fast: bool = True): + self.field = field self.pattern = pattern self.fast = fast - @property - def field_names(self) -> Set[str]: - """Return a set with field names that this query operates on.""" - return {self.field_name} + def col_clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: + return None, () - @property - def field(self) -> str: - if not self.fast: - return f'json_extract(all_flex_attrs, "$.{self.field_name}")' - - return ( - f"{self.table}.{self.field_name}" if self.table else self.field_name - ) - - def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: - raise NotImplementedError - - def clause(self) -> Tuple[str, Sequence[SQLiteType]]: - return self.col_clause() + def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: + if self.fast: + return self.col_clause() + else: + # Matching a flexattr. This is a slow query. + return None, () @classmethod def value_match(cls, pattern: P, value: Any): @@ -162,23 +149,23 @@ class FieldQuery(Query, Generic[P]): raise NotImplementedError() def match(self, obj: Model) -> bool: - return self.value_match(self.pattern, obj.get(self.field_name)) + return self.value_match(self.pattern, obj.get(self.field)) def __repr__(self) -> str: return ( - f"{self.__class__.__name__}({self.field_name!r}, {self.pattern!r}, " + f"{self.__class__.__name__}({self.field!r}, {self.pattern!r}, " f"fast={self.fast})" ) def __eq__(self, other) -> bool: return ( super().__eq__(other) - and self.field_name == other.field_name + and self.field == other.field and self.pattern == other.pattern ) def __hash__(self) -> int: - return hash((self.field_name, hash(self.pattern))) + return hash((self.field, hash(self.pattern))) class MatchQuery(FieldQuery[AnySQLiteType]): @@ -202,10 +189,10 @@ class NoneQuery(FieldQuery[None]): return self.field + " IS NULL", () def match(self, obj: Model) -> bool: - return obj.get(self.field_name) is None + return obj.get(self.field) is None def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.field_name!r}, {self.fast})" + return f"{self.__class__.__name__}({self.field!r}, {self.fast})" class StringFieldQuery(FieldQuery[P]): @@ -276,7 +263,7 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]): expression. """ - def __init__(self, field_name: str, pattern: str, fast: bool = True): + def __init__(self, field: str, pattern: str, fast: bool = True): pattern = self._normalize(pattern) try: pattern_re = re.compile(pattern) @@ -286,7 +273,7 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]): pattern, "a regular expression", format(exc) ) - super().__init__(field_name, pattern_re, fast) + super().__init__(field, pattern_re, fast) def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: return f" regexp({self.field}, ?)", [self.pattern.pattern] @@ -303,24 +290,14 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]): return pattern.search(cls._normalize(value)) is not None -class NumericColumnQuery(MatchQuery[AnySQLiteType]): - """A base class for queries that work with NUMERIC SQLite affinity.""" - - @property - def field(self) -> str: - """Cast a flexible attribute column (string) to NUMERIC affinity.""" - field = super().field - return field if self.fast else f"CAST({field} AS NUMERIC)" - - -class BooleanQuery(NumericColumnQuery[bool]): +class BooleanQuery(MatchQuery[int]): """Matches a boolean field. Pattern should either be a boolean or a string reflecting a boolean. """ def __init__( self, - field_name: str, + field: str, pattern: bool, fast: bool = True, ): @@ -329,7 +306,7 @@ class BooleanQuery(NumericColumnQuery[bool]): pattern_int = int(pattern) - super().__init__(field_name, pattern_int, fast) + super().__init__(field, pattern_int, fast) class BytesQuery(FieldQuery[bytes]): @@ -339,7 +316,7 @@ class BytesQuery(FieldQuery[bytes]): `MatchQuery` when matching on BLOB values. """ - def __init__(self, field_name: str, pattern: Union[bytes, str, memoryview]): + def __init__(self, field: str, pattern: Union[bytes, str, memoryview]): # Use a buffer/memoryview representation of the pattern for SQLite # matching. This instructs SQLite to treat the blob as binary # rather than encoded Unicode. @@ -355,7 +332,7 @@ class BytesQuery(FieldQuery[bytes]): else: raise ValueError("pattern must be bytes, str, or memoryview") - super().__init__(field_name, bytes_pattern) + super().__init__(field, bytes_pattern) def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: return self.field + " = ?", [self.buf_pattern] @@ -365,7 +342,7 @@ class BytesQuery(FieldQuery[bytes]): return pattern == value -class NumericQuery(NumericColumnQuery[Union[int, float]]): +class NumericQuery(FieldQuery[str]): """Matches numeric fields. A syntax using Ruby-style range ellipses (``..``) lets users specify one- or two-sided ranges. For example, ``year:2001..`` finds music released since the turn of the century. @@ -391,8 +368,8 @@ class NumericQuery(NumericColumnQuery[Union[int, float]]): except ValueError: raise InvalidQueryArgumentValueError(s, "an int or a float") - def __init__(self, field_name: str, pattern: str, fast: bool = True): - super().__init__(field_name, pattern, fast) + def __init__(self, field: str, pattern: str, fast: bool = True): + super().__init__(field, pattern, fast) parts = pattern.split("..", 1) if len(parts) == 1: @@ -407,9 +384,9 @@ class NumericQuery(NumericColumnQuery[Union[int, float]]): self.rangemax = self._convert(parts[1]) def match(self, obj: Model) -> bool: - if self.field_name not in obj: + if self.field not in obj: return False - value = obj[self.field_name] + value = obj[self.field] if isinstance(value, str): value = self._convert(value) @@ -442,7 +419,7 @@ class NumericQuery(NumericColumnQuery[Union[int, float]]): class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]): """Query which matches values in the given set.""" - field_name: str + field: str pattern: Sequence[AnySQLiteType] fast: bool = True @@ -452,7 +429,7 @@ class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]): def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: placeholders = ", ".join(["?"] * len(self.subvals)) - return f"{self.field_name} IN ({placeholders})", self.subvals + return f"{self.field} IN ({placeholders})", self.subvals @classmethod def value_match( @@ -469,11 +446,6 @@ class CollectionQuery(Query): def __init__(self, subqueries: Sequence = ()): self.subqueries = subqueries - @property - def field_names(self) -> Set[str]: - """Return a set with field names that this query operates on.""" - return reduce(or_, (sq.field_names for sq in self.subqueries)) - # Act like a sequence. def __len__(self) -> int: @@ -491,7 +463,7 @@ class CollectionQuery(Query): def clause_with_joiner( self, joiner: str, - ) -> Tuple[str, Sequence[SQLiteType]]: + ) -> Tuple[Optional[str], Sequence[SQLiteType]]: """Return a clause created by joining together the clauses of all subqueries with the string joiner (padded by spaces). """ @@ -499,6 +471,9 @@ class CollectionQuery(Query): subvals = [] for subq in self.subqueries: subq_clause, subq_subvals = subq.clause() + if not subq_clause: + # Fall back to slow query. + return None, () clause_parts.append("(" + subq_clause + ")") subvals += subq_subvals clause = (" " + joiner + " ").join(clause_parts) @@ -517,6 +492,45 @@ class CollectionQuery(Query): return reduce(mul, map(hash, self.subqueries), 1) +class AnyFieldQuery(CollectionQuery): + """A query that matches if a given FieldQuery subclass matches in + any field. The individual field query class is provided to the + constructor. + """ + + def __init__(self, pattern, fields, cls: Type[FieldQuery]): + self.pattern = pattern + self.fields = fields + self.query_class = cls + + subqueries = [] + for field in self.fields: + subqueries.append(cls(field, pattern, True)) + # TYPING ERROR + super().__init__(subqueries) + + def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: + return self.clause_with_joiner("or") + + def match(self, obj: Model) -> bool: + for subq in self.subqueries: + if subq.match(obj): + return True + return False + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.pattern!r}, {self.fields!r}, " + f"{self.query_class.__name__})" + ) + + def __eq__(self, other) -> bool: + return super().__eq__(other) and self.query_class == other.query_class + + def __hash__(self) -> int: + return hash((self.pattern, tuple(self.fields), self.query_class)) + + class MutableCollectionQuery(CollectionQuery): """A collection query whose subqueries may be modified after the query is initialized. @@ -534,7 +548,7 @@ class MutableCollectionQuery(CollectionQuery): class AndQuery(MutableCollectionQuery): """A conjunction of a list of other queries.""" - def clause(self) -> Tuple[str, Sequence[SQLiteType]]: + def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: return self.clause_with_joiner("and") def match(self, obj: Model) -> bool: @@ -544,7 +558,7 @@ class AndQuery(MutableCollectionQuery): class OrQuery(MutableCollectionQuery): """A conjunction of a list of other queries.""" - def clause(self) -> Tuple[str, Sequence[SQLiteType]]: + def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: return self.clause_with_joiner("or") def match(self, obj: Model) -> bool: @@ -559,14 +573,14 @@ class NotQuery(Query): def __init__(self, subquery): self.subquery = subquery - @property - def field_names(self) -> Set[str]: - """Return a set with field names that this query operates on.""" - return self.subquery.field_names - - def clause(self) -> Tuple[str, Sequence[SQLiteType]]: + def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: clause, subvals = self.subquery.clause() - return f"not ({clause})", subvals + if clause: + return f"not ({clause})", subvals + else: + # If there is no clause, there is nothing to negate. All the logic + # is handled by match() for slow queries. + return clause, subvals def match(self, obj: Model) -> bool: return not self.subquery.match(obj) @@ -773,7 +787,7 @@ class DateInterval: return f"[{self.start}, {self.end})" -class DateQuery(NumericColumnQuery[int]): +class DateQuery(FieldQuery[str]): """Matches date fields stored as seconds since Unix epoch time. Dates can be specified as ``year-month-day`` strings where only year @@ -783,15 +797,15 @@ class DateQuery(NumericColumnQuery[int]): using an ellipsis interval syntax similar to that of NumericQuery. """ - def __init__(self, field_name: str, pattern: str, fast: bool = True): - super().__init__(field_name, pattern, fast) + def __init__(self, field: str, pattern: str, fast: bool = True): + super().__init__(field, pattern, fast) start, end = _parse_periods(pattern) self.interval = DateInterval.from_periods(start, end) def match(self, obj: Model) -> bool: - if self.field_name not in obj: + if self.field not in obj: return False - timestamp = float(obj[self.field_name]) + timestamp = float(obj[self.field]) date = datetime.fromtimestamp(timestamp) return self.interval.contains(date) @@ -867,7 +881,7 @@ class Sort: return sorted(items) def is_slow(self) -> bool: - """Indicate whether this sort is *slow*, meaning that it cannot + """Indicate whether this query is *slow*, meaning that it cannot be executed in SQL and must be executed in Python. """ return False diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index 1cd373027..ea6f16922 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -16,23 +16,11 @@ import itertools import re -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - List, - Optional, - Sequence, - Tuple, - Type, -) +from typing import Collection, Dict, List, Optional, Sequence, Tuple, Type from . import Model, query from .query import Sort -if TYPE_CHECKING: - from ..library import LibModel - PARSE_QUERY_PART_REGEX = re.compile( # Non-capturing optional segment for the keyword. r"(-|\^)?" # Negation prefixes. @@ -116,7 +104,7 @@ def parse_query_part( def construct_query_part( - model_cls: Type["LibModel"], + model_cls: Type[Model], prefixes: Dict, query_part: str, ) -> query.Query: @@ -151,14 +139,20 @@ def construct_query_part( query_part, query_classes, prefixes ) + # If there's no key (field name) specified, this is a "match + # anything" query. if key is None: - # If there's no key (field name) specified, this is a "match anything" - # query. - out_query = model_cls.any_field_query(query_class, pattern) + # The query type matches a specific field, but none was + # specified. So we use a version of the query that matches + # any field. + out_query = query.AnyFieldQuery( + pattern, model_cls._search_fields, query_class + ) + + # Field queries get constructed according to the name of the field + # they are querying. else: - # Field queries get constructed according to the name of the field - # they are querying. - out_query = model_cls.field_query(key.lower(), pattern, query_class) + out_query = query_class(key.lower(), pattern, key in model_cls._fields) # Apply negation. if negate: diff --git a/beets/importer.py b/beets/importer.py index be37c8f58..f6517b515 100644 --- a/beets/importer.py +++ b/beets/importer.py @@ -708,7 +708,7 @@ class ImportTask(BaseImportTask): # use a temporary Album object to generate any computed fields. tmp_album = library.Album(lib, **info) keys = config["import"]["duplicate_keys"]["album"].as_str_seq() - dup_query = library.Album.match_all_query( + dup_query = library.Album.all_fields_query( {key: tmp_album.get(key) for key in keys} ) @@ -1019,7 +1019,7 @@ class SingletonImportTask(ImportTask): # temporary `Item` object to generate any computed fields. tmp_item = library.Item(lib, **info) keys = config["import"]["duplicate_keys"]["item"].as_str_seq() - dup_query = library.Item.match_all_query( + dup_query = library.Album.all_fields_query( {key: tmp_item.get(key) for key in keys} ) diff --git a/beets/library.py b/beets/library.py index 2e0003dd4..68789bf84 100644 --- a/beets/library.py +++ b/beets/library.py @@ -14,7 +14,6 @@ """The core data store and collection logic for beets. """ -from __future__ import annotations import os import re @@ -24,7 +23,6 @@ import sys import time import unicodedata from functools import cached_property -from typing import Mapping, Set, Type from mediafile import MediaFile, UnreadableFileError @@ -34,7 +32,6 @@ from beets.dbcore import Results, types from beets.util import ( MoveOperation, bytestring_path, - cached_classproperty, normpath, samefile, syspath, @@ -389,18 +386,6 @@ class LibModel(dbcore.Model): # Config key that specifies how an instance should be formatted. _format_config_key: str - @cached_classproperty - def all_model_db_fields(cls) -> Set[str]: - return cls._fields.keys() | cls._relation._fields.keys() - - @cached_classproperty - def shared_model_db_fields(cls) -> Set[str]: - return cls._fields.keys() & cls._relation._fields.keys() - - @cached_classproperty - def writable_fields(cls) -> Set[str]: - return MediaFile.fields() & cls._relation._fields.keys() - def _template_funcs(self): funcs = DefaultTemplateFunctions(self, self._db).functions() funcs.update(plugins.template_funcs()) @@ -430,61 +415,6 @@ class LibModel(dbcore.Model): def __bytes__(self): return self.__str__().encode("utf-8") - # Convenient queries. - - @classmethod - def field_query( - cls, field: str, pattern: str, query_cls: Type[dbcore.FieldQuery] - ) -> dbcore.Query: - """Get a `FieldQuery` for this model.""" - fast = field in cls.all_model_db_fields - if field in cls.shared_model_db_fields: - # This field exists in both tables, so SQLite will encounter - # an OperationalError if we try to use it in a query. - # Using an explicit table name resolves this. - field = f"{cls._table}.{field}" - - return query_cls(field, pattern, fast) - - @classmethod - def any_field_query( - cls, query_class: Type[dbcore.FieldQuery], pattern: str - ) -> dbcore.OrQuery: - return dbcore.OrQuery( - [ - cls.field_query(f, pattern, query_class) - for f in cls._search_fields - ] - ) - - @classmethod - def any_writable_field_query( - cls, query_class: Type[dbcore.FieldQuery], pattern: str - ) -> dbcore.OrQuery: - return dbcore.OrQuery( - [ - cls.field_query(f, pattern, query_class) - for f in cls.writable_fields - ] - ) - - @classmethod - def match_all_query( - cls, pattern_by_field: Mapping[str, str] - ) -> dbcore.AndQuery: - """Get a query that matches many fields with different patterns. - - `pattern_by_field` should be a mapping from field names to patterns. - The resulting query is a conjunction ("and") of per-field queries - for all of these field/pattern pairs. - """ - return dbcore.AndQuery( - [ - cls.field_query(f, p, dbcore.MatchQuery) - for f, p in pattern_by_field.items() - ] - ) - class FormattedItemMapping(dbcore.db.FormattedMapping): """Add lookup for album-level fields. @@ -710,22 +640,6 @@ class Item(LibModel): # Cached album object. Read-only. __album = None - @cached_classproperty - def _relation(cls) -> type[Album]: - return Album - - @cached_classproperty - def relation_join(cls) -> str: - """Return the FROM clause which includes related albums. - - We need to use a LEFT JOIN here, otherwise items that are not part of - an album (e.g. singletons) would be left out. - """ - return ( - f"LEFT JOIN {cls._relation.table_with_flex_attrs}" - f" ON {cls._table}.album_id = {cls._relation._table}.id" - ) - @property def _cached_album(self): """The Album object that this item belongs to, if any, or @@ -1326,22 +1240,6 @@ class Album(LibModel): _format_config_key = "format_album" - @cached_classproperty - def _relation(cls) -> type[Item]: - return Item - - @cached_classproperty - def relation_join(cls) -> str: - """Return FROM clause which joins on related album items. - - Here we can use INNER JOIN (which is more performant than LEFT JOIN), - since we only want to see albums that have at least one Item in them. - """ - return ( - f"INNER JOIN {cls._relation.table_with_flex_attrs}" - f" ON {cls._table}.id = {cls._relation._table}.album_id" - ) - @classmethod def _getters(cls): # In addition to plugin-provided computed fields, also expose @@ -2030,10 +1928,9 @@ class DefaultTemplateFunctions: subqueries.extend(initial_subqueries) for key in keys: value = db_item.get(key, "") - subqueries.append( - db_item.field_query(key, value, dbcore.MatchQuery) - ) - + # Use slow queries for flexible attributes. + fast = key in item_keys + subqueries.append(dbcore.MatchQuery(key, value, fast)) query = dbcore.AndQuery(subqueries) ambigous_items = ( self.lib.items(query) diff --git a/beets/util/__init__.py b/beets/util/__init__.py index 6b5b984a0..8cbbda8f9 100644 --- a/beets/util/__init__.py +++ b/beets/util/__init__.py @@ -1055,20 +1055,3 @@ def par_map(transform: Callable, items: Iterable): pool.map(transform, items) pool.close() pool.join() - - -class cached_classproperty: # noqa: N801 - """A decorator implementing a read-only property that is *lazy* in - the sense that the getter is only invoked once. Subsequent accesses - through *any* instance use the cached result. - """ - - def __init__(self, getter): - self.getter = getter - self.cache = {} - - def __get__(self, instance, owner): - if owner not in self.cache: - self.cache[owner] = self.getter(owner) - - return self.cache[owner] diff --git a/beetsplug/aura.py b/beetsplug/aura.py index 39c88b4d1..35c3e919f 100644 --- a/beetsplug/aura.py +++ b/beetsplug/aura.py @@ -180,9 +180,8 @@ class AURADocument: converter = self.get_attribute_converter(beets_attr) value = converter(value) # Add exact match query to list - queries.append( - self.model_cls.field_query(beets_attr, value, MatchQuery) - ) + # Use a slow query so it works with all fields + queries.append(MatchQuery(beets_attr, value, fast=False)) # NOTE: AURA doesn't officially support multiple queries return AndQuery(queries) diff --git a/beetsplug/bpd/__init__.py b/beetsplug/bpd/__init__.py index 4171d02d1..a4cb4d291 100644 --- a/beetsplug/bpd/__init__.py +++ b/beetsplug/bpd/__init__.py @@ -29,6 +29,8 @@ import traceback from string import Template from typing import List +from mediafile import MediaFile + import beets import beets.ui from beets import dbcore, vfs @@ -91,6 +93,8 @@ SUBSYSTEMS = [ "partition", ] +ITEM_KEYS_WRITABLE = set(MediaFile.fields()).intersection(Item._fields.keys()) + # Gstreamer import error. class NoGstreamerError(Exception): @@ -1397,7 +1401,7 @@ class Server(BaseServer): return test_tag, key raise BPDError(ERROR_UNKNOWN, "no such tagtype") - def _metadata_query(self, query_type, kv, allow_any_query: bool = False): + def _metadata_query(self, query_type, any_query_type, kv): """Helper function returns a query object that will find items according to the library query type provided and the key-value pairs specified. The any_query_type is used for queries of @@ -1409,9 +1413,11 @@ class Server(BaseServer): it = iter(kv) for tag, value in zip(it, it): if tag.lower() == "any": - if allow_any_query: + if any_query_type: queries.append( - Item.any_writable_field_query(query_type, value) + any_query_type( + value, ITEM_KEYS_WRITABLE, query_type + ) ) else: raise BPDError(ERROR_UNKNOWN, "no such tagtype") @@ -1425,14 +1431,14 @@ class Server(BaseServer): def cmd_search(self, conn, *kv): """Perform a substring match for items.""" query = self._metadata_query( - dbcore.query.SubstringQuery, kv, allow_any_query=True + dbcore.query.SubstringQuery, dbcore.query.AnyFieldQuery, kv ) for item in self.lib.items(query): yield self._item_info(item) def cmd_find(self, conn, *kv): """Perform an exact match for items.""" - query = self._metadata_query(dbcore.query.MatchQuery, kv) + query = self._metadata_query(dbcore.query.MatchQuery, None, kv) for item in self.lib.items(query): yield self._item_info(item) @@ -1452,7 +1458,7 @@ class Server(BaseServer): raise BPDError(ERROR_ARG, 'should be "Album" for 3 arguments') elif len(kv) % 2 != 0: raise BPDError(ERROR_ARG, "Incorrect number of filter arguments") - query = self._metadata_query(dbcore.query.MatchQuery, kv) + query = self._metadata_query(dbcore.query.MatchQuery, None, kv) clause, subvals = query.clause() statement = ( diff --git a/docs/changelog.rst b/docs/changelog.rst index bb9e5f740..3725e4993 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,16 +6,6 @@ Unreleased Changelog goes here! Please add your entry to the bottom of one of the lists below! -New features: - -* Ability to query albums with track-level (and vice-versa) **db** or - **flexible** field queries, for example `beet list -a title:something`, `beet - list artpath:cover`. -* Queries have been made faster, and their speed is constant regardless of - their complexity or the type of queried fields. Notably, album queries for - the `path` field and those that involve flexible attributes have seen the - most significant speedup. - Bug fixes: * Improved naming of temporary files by separating the random part with the file extension. diff --git a/docs/reference/query.rst b/docs/reference/query.rst index 3cc994431..2bed2ed68 100644 --- a/docs/reference/query.rst +++ b/docs/reference/query.rst @@ -17,9 +17,7 @@ This command:: $ beet list love -will show all tracks matching the query string ``love``. By default any -unadorned word like this matches in a track's title, artist, album name, album -artist, genre and comments. See below on how to search other fields. +will show all tracks matching the query string ``love``. By default any unadorned word like this matches in a track's title, artist, album name, album artist, genre and comments. See below on how to search other fields. For example, this is what I might see when I run the command above:: @@ -85,15 +83,6 @@ For multi-valued tags (such as ``artists`` or ``albumartists``), a regular expression search must be used to search for a single value within the multi-valued tag. -Note that you can filter albums by querying their tracks fields, including -flexible attributes:: - - $ beet list -a title:love - -and vice versa:: - - $ beet list art_path::love - Phrases ------- @@ -126,9 +115,9 @@ the field name's colon and before the expression:: $ beet list artist:=AIR The first query is a simple substring one that returns tracks by Air, AIR, and -Air Supply. The second query returns tracks by Air and AIR, since both are a +Air Supply. The second query returns tracks by Air and AIR, since both are a case-insensitive match for the entire expression, but does not return anything -by Air Supply. The third query, which requires a case-sensitive exact match, +by Air Supply. The third query, which requires a case-sensitive exact match, returns tracks by AIR only. Exact matches may be performed on phrases as well:: @@ -369,7 +358,7 @@ result in lower-case values being placed after upper-case values, e.g., ``Bar Qux foo``. Note that when sorting by fields that are not present on all items (such as -flexible fields, or those defined by plugins) in *ascending* order, the items +flexible fields, or those defined by plugins) in *ascending* order, the items that lack that particular field will be listed at the *beginning* of the list. You can set the default sorting behavior with the :ref:`sort_item` and diff --git a/poetry.lock b/poetry.lock index 3ff525c9c..1e3b4cd1d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -685,17 +685,6 @@ files = [ [package.dependencies] Flask = ">=0.9" -[[package]] -name = "funcy" -version = "2.0" -description = "A fancy and practical functional tools" -optional = false -python-versions = "*" -files = [ - {file = "funcy-2.0-py2.py3-none-any.whl", hash = "sha256:53df23c8bb1651b12f095df764bfb057935d49537a56de211b098f4c79614bb0"}, - {file = "funcy-2.0.tar.gz", hash = "sha256:3963315d59d41c6f30c04bc910e10ab50a3ac4a225868bfa96feed133df075cb"}, -] - [[package]] name = "h11" version = "0.14.0" @@ -1181,30 +1170,6 @@ html5 = ["html5lib"] htmlsoup = ["BeautifulSoup4"] source = ["Cython (>=3.0.10)"] -[[package]] -name = "markdown-it-py" -version = "3.0.0" -description = "Python port of markdown-it. Markdown parsing, done right!" -optional = false -python-versions = ">=3.8" -files = [ - {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, - {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, -] - -[package.dependencies] -mdurl = ">=0.1,<1.0" - -[package.extras] -benchmarking = ["psutil", "pytest", "pytest-benchmark"] -code-style = ["pre-commit (>=3.0,<4.0)"] -compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] -linkify = ["linkify-it-py (>=1,<3)"] -plugins = ["mdit-py-plugins"] -profiling = ["gprof2dot"] -rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] -testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] - [[package]] name = "markupsafe" version = "2.1.5" @@ -1285,17 +1250,6 @@ files = [ {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, ] -[[package]] -name = "mdurl" -version = "0.1.2" -description = "Markdown URL utilities" -optional = false -python-versions = ">=3.7" -files = [ - {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, - {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, -] - [[package]] name = "mediafile" version = "0.12.0" @@ -1330,17 +1284,6 @@ build = ["blurb", "twine", "wheel"] docs = ["sphinx"] test = ["pytest", "pytest-cov"] -[[package]] -name = "multimethod" -version = "1.10" -description = "Multiple argument dispatching." -optional = false -python-versions = ">=3.8" -files = [ - {file = "multimethod-1.10-py3-none-any.whl", hash = "sha256:afd84da9c3d0445c84f827e4d63ad42d17c6d29b122427c6dee9032ac2d2a0d4"}, - {file = "multimethod-1.10.tar.gz", hash = "sha256:daa45af3fe257f73abb69673fd54ddeaf31df0eb7363ad6e1251b7c9b192d8c5"}, -] - [[package]] name = "multivolumefile" version = "0.2.3" @@ -2369,47 +2312,6 @@ urllib3 = ">=1.25.10,<3.0" [package.extras] tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "tomli", "tomli-w", "types-PyYAML", "types-requests"] -[[package]] -name = "rich" -version = "13.7.1" -description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" -optional = false -python-versions = ">=3.7.0" -files = [ - {file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"}, - {file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"}, -] - -[package.dependencies] -markdown-it-py = ">=2.2.0" -pygments = ">=2.13.0,<3.0.0" -typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""} - -[package.extras] -jupyter = ["ipywidgets (>=7.5.1,<9)"] - -[[package]] -name = "rich-tables" -version = "0.5.1" -description = "Ready-made rich tables for various purposes" -optional = false -python-versions = "<4,>=3.8" -files = [ - {file = "rich_tables-0.5.1-py3-none-any.whl", hash = "sha256:26980f9881a44cd5a530f634c17fa4bed40875ee962127bbdafec9c237589b8d"}, - {file = "rich_tables-0.5.1.tar.gz", hash = "sha256:7cc9887f380d773aa0e2da05256970bcbb61bc40445193f32a1f7e167e77a971"}, -] - -[package.dependencies] -funcy = ">=2.0" -multimethod = "*" -platformdirs = ">=4.2.0" -rich = ">=12.3.0" -sqlparse = ">=0.4.4" -typing-extensions = ">=4.7.1" - -[package.extras] -hue = ["rgbxy (>=0.5)"] - [[package]] name = "six" version = "1.16.0" @@ -2600,21 +2502,6 @@ files = [ lint = ["docutils-stubs", "flake8", "mypy"] test = ["pytest"] -[[package]] -name = "sqlparse" -version = "0.5.0" -description = "A non-validating SQL parser." -optional = false -python-versions = ">=3.8" -files = [ - {file = "sqlparse-0.5.0-py3-none-any.whl", hash = "sha256:c204494cd97479d0e39f28c93d46c0b2d5959c7b9ab904762ea6c7af211c8663"}, - {file = "sqlparse-0.5.0.tar.gz", hash = "sha256:714d0a4932c059d16189f58ef5411ec2287a4360f17cdd0edd2d09d4c5087c93"}, -] - -[package.extras] -dev = ["build", "hatch"] -doc = ["sphinx"] - [[package]] name = "texttable" version = "1.7.0" @@ -2833,4 +2720,4 @@ web = ["flask", "flask-cors"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4" -content-hash = "0de3f4cf9e0fc7ace1de5e9c3aa859cb2b5b2a42d0a58e4b1d96a4dc251bde07" +content-hash = "740281ee3ddba4c6015eab9cfc24bb947e8816e3b7f5a6bebeb39ff2413d7ac3" diff --git a/pyproject.toml b/pyproject.toml index b565fb543..3ef11ac14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,6 @@ mediafile = ">=0.12.0" munkres = ">=1.0.0" musicbrainzngs = ">=0.4" pyyaml = "*" -rich-tables = ">=0.5.1" typing_extensions = "*" unidecode = ">=1.3.6" beautifulsoup4 = { version = "*", optional = true } diff --git a/setup.cfg b/setup.cfg index b918bdb1d..b91c84bda 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,9 +31,7 @@ show_contexts = true min-version = 3.8 accept-encodings = utf-8 max-line-length = 88 -classmethod-decorators = - classmethod - cached_classproperty +docstring-convention = google # errors we ignore; see https://www.flake8rules.com/ for more info ignore = # pycodestyle errors diff --git a/test/plugins/test_limit.py b/test/plugins/test_limit.py index 5a7308fa1..0ed6c9202 100644 --- a/test/plugins/test_limit.py +++ b/test/plugins/test_limit.py @@ -15,8 +15,6 @@ import unittest -import pytest - from beets.test.helper import TestHelper @@ -81,17 +79,11 @@ class LimitPluginTest(unittest.TestCase, TestHelper): ) self.assertEqual(result.count("\n"), self.num_limit) - @pytest.mark.xfail( - reason="Will be restored together with removal of slow sorts" - ) def test_prefix(self): """Returns the expected number with the query prefix.""" result = self.lib.items(self.num_limit_prefix) self.assertEqual(len(result), self.num_limit) - @pytest.mark.xfail( - reason="Will be restored together with removal of slow sorts" - ) def test_prefix_when_correctly_ordered(self): """Returns the expected number with the query prefix and filter when the prefix portion (correctly) appears last.""" @@ -99,9 +91,6 @@ class LimitPluginTest(unittest.TestCase, TestHelper): result = self.lib.items(correct_order) self.assertEqual(len(result), self.num_limit) - @pytest.mark.xfail( - reason="Will be restored together with removal of slow sorts" - ) def test_prefix_when_incorrectly_ordred(self): """Returns no results with the query prefix and filter when the prefix portion (incorrectly) appears first.""" diff --git a/test/plugins/test_web.py b/test/plugins/test_web.py index 7dfee8321..afd1ed706 100644 --- a/test/plugins/test_web.py +++ b/test/plugins/test_web.py @@ -5,7 +5,6 @@ import os.path import platform import shutil import unittest -from pathlib import Path from beets import logging from beets.library import Album, Item @@ -30,38 +29,36 @@ class WebPluginTest(_common.LibTestCase): # Add library elements. Note that self.lib.add overrides any "id=" # and assigns the next free id number. # The following adds will create items #1, #2 and #3 - base_path = Path(self.path_prefix + os.sep) - album2_item1 = Item( - title="title", - path=str(base_path / "path_1"), - album_id=2, - artist="AAA Singers", + path1 = ( + self.path_prefix + os.sep + os.path.join(b"path_1").decode("utf-8") ) - album1_item = Item( - title="another title", - path=str(base_path / "somewhere" / "a"), - artist="AAA Singers", + self.lib.add( + Item(title="title", path=path1, album_id=2, artist="AAA Singers") ) - album2_item2 = Item( - title="and a third", - testattr="ABC", - path=str(base_path / "somewhere" / "abc"), - album_id=2, + path2 = ( + self.path_prefix + + os.sep + + os.path.join(b"somewhere", b"a").decode("utf-8") + ) + self.lib.add( + Item(title="another title", path=path2, artist="AAA Singers") + ) + path3 = ( + self.path_prefix + + os.sep + + os.path.join(b"somewhere", b"abc").decode("utf-8") + ) + self.lib.add( + Item(title="and a third", testattr="ABC", path=path3, album_id=2) ) - self.lib.add(album2_item1) - self.lib.add(album1_item) - self.lib.add(album2_item2) - # The following adds will create albums #1 and #2 - album1 = self.lib.add_album([album1_item]) - album1.album = "album" - album1.albumtest = "xyz" - album1.store() - - album2 = self.lib.add_album([album2_item1, album2_item2]) - album2.album = "other album" - album2.artpath = str(base_path / "somewhere2" / "art_path_2") - album2.store() + self.lib.add(Album(album="album", albumtest="xyz")) + path4 = ( + self.path_prefix + + os.sep + + os.path.join(b"somewhere2", b"art_path_2").decode("utf-8") + ) + self.lib.add(Album(album="other album", artpath=path4)) web.app.config["TESTING"] = True web.app.config["lib"] = self.lib diff --git a/test/test_autotag.py b/test/test_autotag.py index e9b44458c..868138411 100644 --- a/test/test_autotag.py +++ b/test/test_autotag.py @@ -143,7 +143,7 @@ def _clear_weights(): """Hack around the lazy descriptor used to cache weights for Distance calculations. """ - Distance.__dict__["_weights"].cache = {} + Distance.__dict__["_weights"].computed = False class DistanceTest(_common.TestCase): diff --git a/test/test_dbcore.py b/test/test_dbcore.py index 3feacfd7b..763601b7f 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -21,7 +21,6 @@ import unittest from tempfile import mkstemp from beets import dbcore -from beets.library import LibModel from beets.test import _common # Fixture: concrete database and model classes. For migration tests, we @@ -43,7 +42,7 @@ class QueryFixture(dbcore.query.FieldQuery): return True -class ModelFixture1(LibModel): +class ModelFixture1(dbcore.Model): _table = "test" _flex_table = "testflex" _fields = { @@ -590,7 +589,7 @@ class QueryFromStringsTest(unittest.TestCase): q = self.qfs(["foo", "bar:baz"]) self.assertIsInstance(q, dbcore.query.AndQuery) self.assertEqual(len(q.subqueries), 2) - self.assertIsInstance(q.subqueries[0], dbcore.query.OrQuery) + self.assertIsInstance(q.subqueries[0], dbcore.query.AnyFieldQuery) self.assertIsInstance(q.subqueries[1], dbcore.query.SubstringQuery) def test_parse_fixed_type_query(self): diff --git a/test/test_query.py b/test/test_query.py index 47195ce03..b710da13b 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -48,6 +48,40 @@ class TestHelper(helper.TestHelper): self.assertNotIn(item.id, result_ids) +class AnyFieldQueryTest(_common.LibTestCase): + def test_no_restriction(self): + q = dbcore.query.AnyFieldQuery( + "title", + beets.library.Item._fields.keys(), + dbcore.query.SubstringQuery, + ) + self.assertEqual(self.lib.items(q).get().title, "the title") + + def test_restriction_completeness(self): + q = dbcore.query.AnyFieldQuery( + "title", ["title"], dbcore.query.SubstringQuery + ) + self.assertEqual(self.lib.items(q).get().title, "the title") + + def test_restriction_soundness(self): + q = dbcore.query.AnyFieldQuery( + "title", ["artist"], dbcore.query.SubstringQuery + ) + self.assertIsNone(self.lib.items(q).get()) + + def test_eq(self): + q1 = dbcore.query.AnyFieldQuery( + "foo", ["bar"], dbcore.query.SubstringQuery + ) + q2 = dbcore.query.AnyFieldQuery( + "foo", ["bar"], dbcore.query.SubstringQuery + ) + self.assertEqual(q1, q2) + + q2.query_class = None + self.assertNotEqual(q1, q2) + + class AssertsMixin: def assert_items_matched(self, results, titles): self.assertEqual({i.title for i in results}, set(titles)) @@ -487,7 +521,7 @@ class PathQueryTest(_common.LibTestCase, TestHelper, AssertsMixin): self.assert_items_matched(results, ["path item"]) results = self.lib.albums(q) - self.assert_albums_matched(results, ["path album"]) + self.assert_albums_matched(results, []) # FIXME: fails on windows @unittest.skipIf(sys.platform == "win32", "win32") @@ -570,9 +604,6 @@ class PathQueryTest(_common.LibTestCase, TestHelper, AssertsMixin): results = self.lib.items(q) self.assert_items_matched(results, ["path item"]) - results = self.lib.albums(q) - self.assert_albums_matched(results, ["path album"]) - def test_path_album_regex(self): q = "path::b" results = self.lib.albums(q) @@ -823,17 +854,17 @@ class NoneQueryTest(unittest.TestCase, TestHelper): def test_match_slow(self): item = self.add_item() - matched = self.lib.items(NoneQuery("rg_track_peak")) + matched = self.lib.items(NoneQuery("rg_track_peak", fast=False)) self.assertInResult(item, matched) def test_match_slow_after_set_none(self): item = self.add_item(rg_track_gain=0) - matched = self.lib.items(NoneQuery("rg_track_gain")) + matched = self.lib.items(NoneQuery("rg_track_gain", fast=False)) self.assertNotInResult(item, matched) item["rg_track_gain"] = None item.store() - matched = self.lib.items(NoneQuery("rg_track_gain")) + matched = self.lib.items(NoneQuery("rg_track_gain", fast=False)) self.assertInResult(item, matched) @@ -947,6 +978,14 @@ class NotQueryTest(DummyDataTestCase): self.assert_items_matched(not_results, ["foo bar", "beets 4 eva"]) self.assertNegationProperties(q) + def test_type_anyfield(self): + q = dbcore.query.AnyFieldQuery( + "foo", ["title", "artist", "album"], dbcore.query.SubstringQuery + ) + not_results = self.lib.items(dbcore.query.NotQuery(q)) + self.assert_items_matched(not_results, ["baz qux"]) + self.assertNegationProperties(q) + def test_type_boolean(self): q = dbcore.query.BooleanQuery("comp", True) not_results = self.lib.items(dbcore.query.NotQuery(q)) @@ -1055,87 +1094,36 @@ class NotQueryTest(DummyDataTestCase): results = self.lib.items(q) self.assert_items_matched(results, ["baz qux"]) + def test_fast_vs_slow(self): + """Test that the results are the same regardless of the `fast` flag + for negated `FieldQuery`s. -class RelatedQueriesTest(_common.TestCase, AssertsMixin): - """Test album-level queries with track-level filters and vice-versa.""" + TODO: investigate NoneQuery(fast=False), as it is raising + AttributeError: type object 'NoneQuery' has no attribute 'field' + at NoneQuery.match() (due to being @classmethod, and no self?) + """ + classes = [ + (dbcore.query.DateQuery, ["added", "2001-01-01"]), + (dbcore.query.MatchQuery, ["artist", "one"]), + # (dbcore.query.NoneQuery, ['rg_track_gain']), + (dbcore.query.NumericQuery, ["year", "2002"]), + (dbcore.query.StringFieldQuery, ["year", "2001"]), + (dbcore.query.RegexpQuery, ["album", "^.a"]), + (dbcore.query.SubstringQuery, ["title", "x"]), + ] - def setUp(self): - super().setUp() - self.lib = beets.library.Library(":memory:") + for klass, args in classes: + q_fast = dbcore.query.NotQuery(klass(*(args + [True]))) + q_slow = dbcore.query.NotQuery(klass(*(args + [False]))) - albums = [] - for album_idx in range(1, 3): - album_name = f"Album{album_idx}" - album_items = [] - for item_idx in range(1, 3): - item = _common.item() - item.album = album_name - title = f"{album_name} Item{item_idx}" - item.title = title - item.item_flex1 = f"{title} Flex1" - item.item_flex2 = f"{title} Flex2" - self.lib.add(item) - album_items.append(item) - album = self.lib.add_album(album_items) - album.artpath = f"{album_name} Artpath" - album.catalognum = "ABC" - album.album_flex = f"{album_name} Flex" - album.store() - albums.append(album) - - self.album, self.another_album = albums - - def test_get_albums_filter_by_track_field(self): - q = "title:Album1" - results = self.lib.albums(q) - self.assert_albums_matched(results, ["Album1"]) - - def test_get_items_filter_by_album_field(self): - q = "artpath::Album1" - results = self.lib.items(q) - self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"]) - - def test_filter_albums_by_common_field(self): - # title:Album1 ensures that the items table is joined for the query - q = "title:Album1 catalognum:ABC" - results = self.lib.albums(q) - self.assert_albums_matched(results, ["Album1"]) - - def test_filter_items_by_common_field(self): - # artpath::A ensures that the albums table is joined for the query - q = "artpath::A Album1" - results = self.lib.items(q) - self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"]) - - def test_get_items_filter_by_track_flex(self): - q = "item_flex1:Item1" - results = self.lib.items(q) - self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"]) - - def test_get_albums_filter_by_album_flex(self): - q = "album_flex:Album1" - results = self.lib.albums(q) - self.assert_albums_matched(results, ["Album1"]) - - def test_get_albums_filter_by_track_flex(self): - q = "item_flex1:Album1" - results = self.lib.albums(q) - self.assert_albums_matched(results, ["Album1"]) - - def test_get_items_filter_by_album_flex(self): - q = "album_flex:Album1" - results = self.lib.items(q) - self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"]) - - def test_filter_by_flex(self): - q = "item_flex1:'Item1 Flex1'" - results = self.lib.items(q) - self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"]) - - def test_filter_by_many_flex(self): - q = "item_flex1:'Item1 Flex1' item_flex2:Album1" - results = self.lib.items(q) - self.assert_items_matched(results, ["Album1 Item1"]) + try: + self.assertEqual( + [i.title for i in self.lib.items(q_fast)], + [i.title for i in self.lib.items(q_slow)], + ) + except NotImplementedError: + # ignore classes that do not provide `fast` implementation + pass def suite():