Simplify creating indices

This commit is contained in:
Šarūnas Nejus 2026-01-30 00:55:13 +00:00
parent f7ddcdeb59
commit e768f978b6
No known key found for this signature in database
2 changed files with 16 additions and 106 deletions

View file

@ -1071,7 +1071,7 @@ class Database:
for model_cls in self._models:
self._make_table(model_cls._table, model_cls._fields)
self._make_attribute_table(model_cls._flex_table)
self._migrate_indices(model_cls._table, model_cls._indices)
self._create_indices(model_cls._table, model_cls._indices)
# Primitive access control: connections and transactions.
@ -1249,24 +1249,18 @@ class Database:
ON {flex_table} (entity_id);
""")
def _migrate_indices(
def _create_indices(
self,
table: str,
indices: Sequence[Index],
):
"""Create or replace indices for the given table.
If the indices already exists and are up to date (i.e., the
index name and columns match), nothing is done. Otherwise, the
indices are created or replaced.
"""
"""Create indices for the given table if they don't exist."""
with self.transaction() as tx:
current = {
Index.from_db(tx, r[1])
for r in tx.query(f"PRAGMA index_list({table})")
}
for index in set(indices) - current:
index.recreate(tx, table)
for index in indices:
tx.script(
f"CREATE INDEX IF NOT EXISTS {index.name} "
f"ON {table} ({', '.join(index.columns)});"
)
# Querying.
@ -1346,29 +1340,3 @@ class Index(NamedTuple):
name: str
columns: tuple[str, ...]
def __hash__(self) -> int:
"""Unique hash for the index based on its name and columns."""
return hash((self.name, *self.columns))
def recreate(self, tx: Transaction, table: str) -> None:
"""Recreate the index in the database.
This is useful when the index has been changed and needs to be
updated.
"""
tx.script(f"""
DROP INDEX IF EXISTS {self.name};
CREATE INDEX {self.name} ON {table} ({", ".join(self.columns)})
""")
@classmethod
def from_db(cls, tx: Transaction, name: str) -> Index:
"""Create an Index object from the database if it exists.
The name has to exists in the database! Otherwise, an
Error will be raised.
"""
rows = tx.query(f"PRAGMA index_info({name})")
columns = tuple(row[2] for row in rows)
return cls(name, columns)

View file

@ -240,6 +240,14 @@ class MigrationTest(unittest.TestCase):
except sqlite3.OperationalError:
self.fail("select failed")
def test_index_creation(self):
"""Test that declared indices are created on database initialization."""
db = DatabaseFixture1(":memory:")
with db.transaction() as tx:
rows = tx.query("PRAGMA index_info(field_one_index)")
assert len(rows) > 0 # Index exists
db._connection().close()
class TransactionTest(unittest.TestCase):
def setUp(self):
@ -810,69 +818,3 @@ class TestException:
with pytest.raises(DBCustomFunctionError):
with db.transaction() as tx:
tx.query("select * from test where plz_raise()")
class TestIndex:
@pytest.fixture(autouse=True)
def db(self):
"""Set up an in-memory SQLite database."""
db = DatabaseFixture1(":memory:")
yield db
db._connection().close()
@pytest.fixture
def sample_index(self):
"""Fixture for a sample Index object."""
return Index(name="sample_index", columns=("field_one",))
def test_from_db(self, db, sample_index: Index):
"""Test retrieving an index from the database."""
with db.transaction() as tx:
sample_index.recreate(tx, "test")
retrieved = Index.from_db(tx, sample_index.name)
assert retrieved == sample_index
@pytest.mark.parametrize(
"index1, index2, equality",
[
(
# Same
Index(name="sample_index", columns=("field_one",)),
Index(name="sample_index", columns=("field_one",)),
True,
),
(
# Multiple columns
Index(name="sample_index", columns=("f1", "f2")),
Index(name="sample_index", columns=("f1", "f2")),
True,
),
(
# Difference in name
Index(name="sample_indey", columns=("field_one",)),
Index(name="sample_index", columns=("field_one",)),
False,
),
(
# Difference in columns
Index(name="sample_indey", columns=("field_one",)),
Index(name="sample_index", columns=("field_two",)),
False,
),
(
# Difference in num columns
Index(name="sample_index", columns=("f1",)),
Index(name="sample_index", columns=("f1", "f2")),
False,
),
],
)
def test_index_equality(self, index1: Index, index2: Index, equality: bool):
"""Test the hashing and set behavior of the Index class."""
# Simple equality
assert (index1 == index2) == equality
# Should be unique or not
index_set = {index1, index2}
assert len(index_set) == (1 if equality else 2)