diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index c554736fb..d0c02b146 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -17,6 +17,7 @@ from __future__ import annotations import contextlib +import json import os import re import sqlite3 @@ -60,8 +61,8 @@ else: DEBUG = bool(os.getenv("BEETS_DEBUG", False)) - -FlexAttrs = dict[str, str] +# convert data under 'json_str' type name to Python dictionary automatically +sqlite3.register_converter("json_str", json.loads) def print_query(sql, subvals=None): @@ -357,6 +358,47 @@ class Model(ABC, Generic[D]): """Fields in the related table.""" return cls._relation._fields.keys() - cls.shared_db_fields + @cached_classproperty + def table_with_flex_attrs(cls) -> str: + """Return a SQL for entity table which includes aggregated flexible attributes. + + The clause selects entity rows, flexible attributes rows and LEFT JOINs + them on entity id and 'entity_id' field respectively. + + 'json_group_object' aggregate function groups flexible attributes into a + single JSON object 'flex_attrs [json_str]'. The column name ending with + ' [json_str]' means that this column is converted to a Python dictionary + automatically (see 'register_converter' call at the top of this module). + + 'REPLACE' function handles absence of flexible attributes and replaces + some weird null JSON object (that SQLite gives us by default) with an + empty JSON object. + + Availability of the 'flex_attrs' means we can query flexible attributes + in the same manner we query other entity fields, see + `FieldQuery.field`. This way, we also remove the need for an + additional query to fetch them. + + Note: we use LEFT join to include entities without flexible attributes. + Note: we name this SELECT clause after the original entity table name + so that we can query it in the way like the original table. + """ + flex_attrs = "REPLACE(json_group_object(key, value), '{:null}', '{}')" + return f"""( + SELECT + *, + {flex_attrs} AS "flex_attrs [json_str]" + FROM {cls._table} LEFT JOIN ( + SELECT + entity_id, + key, + CAST(value AS text) AS value + FROM {cls._flex_table} + ) ON entity_id == {cls._table}.id + GROUP BY {cls._table}.id + ) {cls._table} + """ + @classmethod def _getters(cls: type[Model]): """Return a mapping from field names to getter functions.""" @@ -777,7 +819,6 @@ class Results(Generic[AnyModel]): model_class: type[AnyModel], rows: list[sqlite3.Row], db: D, - flex_rows, query: Query | None = None, sort=None, ): @@ -800,7 +841,6 @@ class Results(Generic[AnyModel]): self.db = db self.query = query self.sort = sort - self.flex_rows = flex_rows # We keep a queue of rows we haven't yet consumed for # materialization. We preserve the original total number of @@ -822,10 +862,6 @@ class Results(Generic[AnyModel]): a `Results` object a second time should be much faster than the first. """ - - # Index flexible attributes by the item ID, so we have easier access - flex_attrs = self._get_indexed_flex_attrs() - index = 0 # Position in the materialized objects. while index < len(self._objects) or self._rows: # Are there previously-materialized objects to produce? @@ -838,7 +874,7 @@ class Results(Generic[AnyModel]): else: while self._rows: row = self._rows.pop(0) - obj = self._make_model(row, flex_attrs.get(row["id"], {})) + obj = self._make_model(row) # If there is a slow-query predicate, ensurer that the # object passes it. if not self.query or self.query.match(obj): @@ -860,23 +896,10 @@ class Results(Generic[AnyModel]): # Objects are pre-sorted (i.e., by the database). return self._get_objects() - def _get_indexed_flex_attrs(self) -> dict[int, FlexAttrs]: - """Index flexible attributes by the entity id they belong to""" - flex_values: dict[int, FlexAttrs] = {} - for row in self.flex_rows: - if row["entity_id"] not in flex_values: - flex_values[row["entity_id"]] = {} - - flex_values[row["entity_id"]][row["key"]] = row["value"] - - return flex_values - - def _make_model( - self, row: sqlite3.Row, flex_values: FlexAttrs = {} - ) -> AnyModel: + def _make_model(self, row: sqlite3.Row) -> AnyModel: """Create a Model object for the given row""" - cols = dict(row) - values = {k: v for (k, v) in cols.items() if not k[:4] == "flex"} + values = dict(row) + flex_values = values.pop("flex_attrs") or {} # Construct the Python object obj = self.model_class._awaken(self.db, values, flex_values) @@ -1107,6 +1130,8 @@ class Database: # We have our own same-thread checks in _connection(), but need to # call conn.close() in _close() check_same_thread=False, + # enable type name "col [type]" conversion (`register_converter`) + detect_types=sqlite3.PARSE_COLNAMES, ) self.add_functions(conn) @@ -1262,11 +1287,11 @@ class Database: where, subvals = query.clause() order_by = sort.order_clause() - table = model_cls._table - _from = table + _from = model_cls.table_with_flex_attrs if query.field_names & model_cls.other_db_fields: _from += f" {model_cls.relation_join}" + table = model_cls._table # group by id to avoid duplicates when joining with the relation sql = ( f"SELECT {table}.* " @@ -1274,14 +1299,6 @@ class Database: f"WHERE {where or 1} " f"GROUP BY {table}.id" ) - # Fetch flexible attributes for items matching the main query. - # Doing the per-item filtering in python is faster than issuing - # one query per item to sqlite. - flex_sql = ( - "SELECT * " - f"FROM {model_cls._flex_table} " - f"WHERE entity_id IN (SELECT id FROM ({sql}))" - ) if order_by: # the sort field may exist in both 'items' and 'albums' tables @@ -1293,13 +1310,11 @@ class Database: with self.transaction() as tx: rows = tx.query(sql, subvals) - flex_rows = tx.query(flex_sql, subvals) return Results( model_cls, rows, self, - flex_rows, None if where else query, # Slow query component. sort if sort.is_slow() else None, # Slow sort component. )