diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index d13f923f6..613430651 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -24,9 +24,10 @@ import sqlite3 import sys import threading import time -from abc import ABC +from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Mapping +from dataclasses import dataclass from functools import cached_property from sqlite3 import Connection, sqlite_version_info from typing import TYPE_CHECKING, Any, AnyStr, ClassVar, Generic, NamedTuple @@ -1029,6 +1030,29 @@ class Transaction: self.db._connection().executescript(statements) +@dataclass +class Migration(ABC): + db: Database + + @cached_classproperty + def flag_prefix(cls) -> str: + """Class name (except Migration) converted to snake case.""" + name = cls.__name__.removesuffix("Migration") # type: ignore[attr-defined] + return re.sub(r"(?<=[a-z])(?=[A-Z])", "_", name).lower() + + def migrate_table(self, table: str) -> None: + """Migrate a specific table.""" + migration_flag = f"{self.flag_prefix}_{table}" + + if not self.db.get_migration_state(migration_flag): + self._migrate_data(table) + self.db.set_migration_state(migration_flag, True) + + @abstractmethod + def _migrate_data(self, table: str) -> None: + """Migrate data for a specific table.""" + + class Database: """A container for Model objects that wraps an SQLite database as the backend. @@ -1038,6 +1062,10 @@ class Database: """The Model subclasses representing tables in this database. """ + _migrations: Sequence[tuple[type[Migration], Sequence[type[Model]]]] = () + """The Model subclasses representing tables in this database. + """ + supports_extensions = hasattr(sqlite3.Connection, "enable_load_extension") """Whether or not the current version of SQLite supports extensions""" @@ -1301,6 +1329,10 @@ class Database: def _migrate(self) -> None: """Perform any necessary migration for the database.""" + for migration_cls, model_classes in self._migrations: + migration = migration_cls(self) + for model_cls in model_classes: + migration.migrate_table(model_cls._table) def get_migration_state(self, name: str) -> bool: """Return whether a named migration has been marked complete."""