diff --git a/test/test_dbcore.py b/test/test_dbcore.py index 653adf298..3f53dd888 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -23,7 +23,7 @@ from tempfile import mkstemp import pytest from beets import dbcore -from beets.dbcore.db import DBCustomFunctionError +from beets.dbcore.db import DBCustomFunctionError, Index from beets.library import LibModel from beets.test import _common from beets.util import cached_classproperty @@ -66,6 +66,7 @@ class ModelFixture1(LibModel): _sorts = { "some_sort": SortFixture, } + _indices = (Index("field_one_index", ("field_one",)),) @cached_classproperty def _types(cls): @@ -137,6 +138,7 @@ class AnotherModelFixture(ModelFixture1): "id": dbcore.types.PRIMARY_ID, "foo": dbcore.types.INTEGER, } + _indices = (Index("another_foo_index", ("foo",)),) class ModelFixture5(ModelFixture1): @@ -808,3 +810,71 @@ class TestException: with pytest.raises(DBCustomFunctionError): with db.transaction() as tx: tx.query("select * from test where plz_raise()") + + +class TestIndex: + @pytest.fixture(autouse=True) + def db(self): + """Set up an in-memory SQLite database.""" + db = DatabaseFixture1(":memory:") + yield db + db._connection().close() + + @pytest.fixture + def sample_index(self): + """Fixture for a sample Index object.""" + return Index(name="sample_index", columns=("field_one",)) + + def test_from_db(self, db, sample_index: Index): + """Test retrieving an index from the database.""" + with db.transaction() as tx: + sample_index.recreate(tx, "test") + retrieved = Index.from_db(tx, sample_index.name) + assert retrieved == sample_index + + @pytest.mark.parametrize( + "index1, index2, equality", + [ + ( + # Same + Index(name="sample_index", columns=("field_one",)), + Index(name="sample_index", columns=("field_one",)), + True, + ), + ( + # Multiple columns + Index(name="sample_index", columns=("f1", "f2")), + Index(name="sample_index", columns=("f1", "f2")), + True, + ), + ( + # Difference in name + Index(name="sample_indey", columns=("field_one",)), + Index(name="sample_index", columns=("field_one",)), + False, + ), + ( + # Difference in columns + Index(name="sample_indey", columns=("field_one",)), + Index(name="sample_index", columns=("field_two",)), + False, + ), + ( + # Difference in num columns + Index(name="sample_index", columns=("f1",)), + Index(name="sample_index", columns=("f1", "f2")), + False, + ), + ], + ) + def test_index_equality(self, index1: Index, index2: Index, equality: bool): + """Test the hashing and set behavior of the Index class.""" + + # Simple equality + assert (index1 == index2) == equality + + # Should be unique or not + index_set = {index1, index2} + assert len(index_set) == (1 if equality else 2) + +