diff --git a/beets/dbcore/__init__.py b/beets/dbcore/__init__.py index fa20eb00d..0b5e700cb 100644 --- a/beets/dbcore/__init__.py +++ b/beets/dbcore/__init__.py @@ -16,7 +16,7 @@ Library. """ -from .db import Database, Model, Results +from .db import Database, Index, Model, Results from .query import ( AndQuery, FieldQuery, @@ -43,6 +43,7 @@ __all__ = [ "Query", "Results", "Type", + "Index", "parse_sorted_query", "query_from_strings", "sort_from_strings", diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index cc172d0d8..843bfeaff 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -35,7 +35,7 @@ from collections.abc import ( Sequence, ) from sqlite3 import Connection, sqlite_version_info -from typing import TYPE_CHECKING, Any, AnyStr, Generic +from typing import TYPE_CHECKING, Any, AnyStr, Generic, NamedTuple from typing_extensions import TypeVar # default value support from unidecode import unidecode @@ -306,6 +306,11 @@ class Model(ABC, Generic[D]): terms. """ + _indices: Sequence[Index] = () + """A sequence of `Index` objects that describe the indices to be + created for this table. + """ + @cached_classproperty def _types(cls) -> dict[str, types.Type]: """Optional types for non-fixed (flexible and computed) fields.""" @@ -1066,6 +1071,7 @@ class Database: for model_cls in self._models: self._make_table(model_cls._table, model_cls._fields) self._make_attribute_table(model_cls._flex_table) + self._migrate_indices(model_cls._table, model_cls._indices) # Primitive access control: connections and transactions. @@ -1243,6 +1249,25 @@ class Database: ON {flex_table} (entity_id); """) + def _migrate_indices( + self, + table: str, + indices: Sequence[Index], + ): + """Create or replace indices for the given table. + + If the indices already exists and are up to date (i.e., the + index name and columns match), nothing is done. Otherwise, the + indices are created or replaced. + """ + with self.transaction() as tx: + current = { + Index.from_db(tx, r[1]) + for r in tx.query(f"PRAGMA index_list({table})") + } + for index in set(indices) - current: + index.recreate(tx, table) + # Querying. def _fetch( @@ -1312,3 +1337,38 @@ class Database: exist. """ return self._fetch(model_cls, MatchQuery("id", id)).get() + + +class Index(NamedTuple): + """A helper class to represent the index + information in the database schema. + """ + + name: str + columns: tuple[str, ...] + + def __hash__(self) -> int: + """Unique hash for the index based on its name and columns.""" + return hash((self.name, *self.columns)) + + def recreate(self, tx: Transaction, table: str) -> None: + """Recreate the index in the database. + + This is useful when the index has been changed and needs to be + updated. + """ + tx.script(f""" + DROP INDEX IF EXISTS {self.name}; + CREATE INDEX {self.name} ON {table} ({", ".join(self.columns)}) + """) + + @classmethod + def from_db(cls, tx: Transaction, name: str) -> Index: + """Create an Index object from the database if it exists. + + The name has to exists in the database! Otherwise, an + Error will be raised. + """ + rows = tx.query(f"PRAGMA index_info({name})") + columns = tuple(row[2] for row in rows) + return cls(name, columns) diff --git a/beets/library/models.py b/beets/library/models.py index cbee2a411..f57b4201a 100644 --- a/beets/library/models.py +++ b/beets/library/models.py @@ -716,6 +716,7 @@ class Item(LibModel): "mtime": types.DATE, "added": types.DATE, } + _indices = (dbcore.Index("idx_item_album_id", ("album_id",)),) _search_fields = ( "artist", diff --git a/test/test_dbcore.py b/test/test_dbcore.py index 653adf298..06aceaec0 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -23,7 +23,7 @@ from tempfile import mkstemp import pytest from beets import dbcore -from beets.dbcore.db import DBCustomFunctionError +from beets.dbcore.db import DBCustomFunctionError, Index from beets.library import LibModel from beets.test import _common from beets.util import cached_classproperty @@ -66,6 +66,7 @@ class ModelFixture1(LibModel): _sorts = { "some_sort": SortFixture, } + _indices = (Index("field_one_index", ("field_one",)),) @cached_classproperty def _types(cls): @@ -137,6 +138,7 @@ class AnotherModelFixture(ModelFixture1): "id": dbcore.types.PRIMARY_ID, "foo": dbcore.types.INTEGER, } + _indices = (Index("another_foo_index", ("foo",)),) class ModelFixture5(ModelFixture1): @@ -808,3 +810,69 @@ class TestException: with pytest.raises(DBCustomFunctionError): with db.transaction() as tx: tx.query("select * from test where plz_raise()") + + +class TestIndex: + @pytest.fixture(autouse=True) + def db(self): + """Set up an in-memory SQLite database.""" + db = DatabaseFixture1(":memory:") + yield db + db._connection().close() + + @pytest.fixture + def sample_index(self): + """Fixture for a sample Index object.""" + return Index(name="sample_index", columns=("field_one",)) + + def test_from_db(self, db, sample_index: Index): + """Test retrieving an index from the database.""" + with db.transaction() as tx: + sample_index.recreate(tx, "test") + retrieved = Index.from_db(tx, sample_index.name) + assert retrieved == sample_index + + @pytest.mark.parametrize( + "index1, index2, equality", + [ + ( + # Same + Index(name="sample_index", columns=("field_one",)), + Index(name="sample_index", columns=("field_one",)), + True, + ), + ( + # Multiple columns + Index(name="sample_index", columns=("f1", "f2")), + Index(name="sample_index", columns=("f1", "f2")), + True, + ), + ( + # Difference in name + Index(name="sample_indey", columns=("field_one",)), + Index(name="sample_index", columns=("field_one",)), + False, + ), + ( + # Difference in columns + Index(name="sample_indey", columns=("field_one",)), + Index(name="sample_index", columns=("field_two",)), + False, + ), + ( + # Difference in num columns + Index(name="sample_index", columns=("f1",)), + Index(name="sample_index", columns=("f1", "f2")), + False, + ), + ], + ) + def test_index_equality(self, index1: Index, index2: Index, equality: bool): + """Test the hashing and set behavior of the Index class.""" + + # Simple equality + assert (index1 == index2) == equality + + # Should be unique or not + index_set = {index1, index2} + assert len(index_set) == (1 if equality else 2)