This commit is contained in:
Sebastian Mohr 2025-12-04 21:05:50 +00:00 committed by GitHub
commit 7bdca84f45
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 133 additions and 3 deletions

View file

@ -16,7 +16,7 @@
Library.
"""
from .db import Database, Model, Results
from .db import Database, Index, Model, Results
from .query import (
AndQuery,
FieldQuery,
@ -43,6 +43,7 @@ __all__ = [
"Query",
"Results",
"Type",
"Index",
"parse_sorted_query",
"query_from_strings",
"sort_from_strings",

View file

@ -35,7 +35,7 @@ from collections.abc import (
Sequence,
)
from sqlite3 import Connection, sqlite_version_info
from typing import TYPE_CHECKING, Any, AnyStr, Generic
from typing import TYPE_CHECKING, Any, AnyStr, Generic, NamedTuple
from typing_extensions import TypeVar # default value support
from unidecode import unidecode
@ -306,6 +306,11 @@ class Model(ABC, Generic[D]):
terms.
"""
_indices: Sequence[Index] = ()
"""A sequence of `Index` objects that describe the indices to be
created for this table.
"""
@cached_classproperty
def _types(cls) -> dict[str, types.Type]:
"""Optional types for non-fixed (flexible and computed) fields."""
@ -1066,6 +1071,7 @@ class Database:
for model_cls in self._models:
self._make_table(model_cls._table, model_cls._fields)
self._make_attribute_table(model_cls._flex_table)
self._migrate_indices(model_cls._table, model_cls._indices)
# Primitive access control: connections and transactions.
@ -1243,6 +1249,25 @@ class Database:
ON {flex_table} (entity_id);
""")
def _migrate_indices(
self,
table: str,
indices: Sequence[Index],
):
"""Create or replace indices for the given table.
If the indices already exists and are up to date (i.e., the
index name and columns match), nothing is done. Otherwise, the
indices are created or replaced.
"""
with self.transaction() as tx:
current = {
Index.from_db(tx, r[1])
for r in tx.query(f"PRAGMA index_list({table})")
}
for index in set(indices) - current:
index.recreate(tx, table)
# Querying.
def _fetch(
@ -1312,3 +1337,38 @@ class Database:
exist.
"""
return self._fetch(model_cls, MatchQuery("id", id)).get()
class Index(NamedTuple):
"""A helper class to represent the index
information in the database schema.
"""
name: str
columns: tuple[str, ...]
def __hash__(self) -> int:
"""Unique hash for the index based on its name and columns."""
return hash((self.name, *self.columns))
def recreate(self, tx: Transaction, table: str) -> None:
"""Recreate the index in the database.
This is useful when the index has been changed and needs to be
updated.
"""
tx.script(f"""
DROP INDEX IF EXISTS {self.name};
CREATE INDEX {self.name} ON {table} ({", ".join(self.columns)})
""")
@classmethod
def from_db(cls, tx: Transaction, name: str) -> Index:
"""Create an Index object from the database if it exists.
The name has to exists in the database! Otherwise, an
Error will be raised.
"""
rows = tx.query(f"PRAGMA index_info({name})")
columns = tuple(row[2] for row in rows)
return cls(name, columns)

View file

@ -716,6 +716,7 @@ class Item(LibModel):
"mtime": types.DATE,
"added": types.DATE,
}
_indices = (dbcore.Index("idx_item_album_id", ("album_id",)),)
_search_fields = (
"artist",

View file

@ -23,7 +23,7 @@ from tempfile import mkstemp
import pytest
from beets import dbcore
from beets.dbcore.db import DBCustomFunctionError
from beets.dbcore.db import DBCustomFunctionError, Index
from beets.library import LibModel
from beets.test import _common
from beets.util import cached_classproperty
@ -66,6 +66,7 @@ class ModelFixture1(LibModel):
_sorts = {
"some_sort": SortFixture,
}
_indices = (Index("field_one_index", ("field_one",)),)
@cached_classproperty
def _types(cls):
@ -137,6 +138,7 @@ class AnotherModelFixture(ModelFixture1):
"id": dbcore.types.PRIMARY_ID,
"foo": dbcore.types.INTEGER,
}
_indices = (Index("another_foo_index", ("foo",)),)
class ModelFixture5(ModelFixture1):
@ -808,3 +810,69 @@ class TestException:
with pytest.raises(DBCustomFunctionError):
with db.transaction() as tx:
tx.query("select * from test where plz_raise()")
class TestIndex:
@pytest.fixture(autouse=True)
def db(self):
"""Set up an in-memory SQLite database."""
db = DatabaseFixture1(":memory:")
yield db
db._connection().close()
@pytest.fixture
def sample_index(self):
"""Fixture for a sample Index object."""
return Index(name="sample_index", columns=("field_one",))
def test_from_db(self, db, sample_index: Index):
"""Test retrieving an index from the database."""
with db.transaction() as tx:
sample_index.recreate(tx, "test")
retrieved = Index.from_db(tx, sample_index.name)
assert retrieved == sample_index
@pytest.mark.parametrize(
"index1, index2, equality",
[
(
# Same
Index(name="sample_index", columns=("field_one",)),
Index(name="sample_index", columns=("field_one",)),
True,
),
(
# Multiple columns
Index(name="sample_index", columns=("f1", "f2")),
Index(name="sample_index", columns=("f1", "f2")),
True,
),
(
# Difference in name
Index(name="sample_indey", columns=("field_one",)),
Index(name="sample_index", columns=("field_one",)),
False,
),
(
# Difference in columns
Index(name="sample_indey", columns=("field_one",)),
Index(name="sample_index", columns=("field_two",)),
False,
),
(
# Difference in num columns
Index(name="sample_index", columns=("f1",)),
Index(name="sample_index", columns=("f1", "f2")),
False,
),
],
)
def test_index_equality(self, index1: Index, index2: Index, equality: bool):
"""Test the hashing and set behavior of the Index class."""
# Simple equality
assert (index1 == index2) == equality
# Should be unique or not
index_set = {index1, index2}
assert len(index_set) == (1 if equality else 2)