From aeec7b886793104d08027056f4ce7e5dfd1b757a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sun, 8 Feb 2026 21:24:00 +0000 Subject: [PATCH] Add generic Migration implementation --- beets/dbcore/db.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index deb31ba71..4b0ba4f15 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -24,9 +24,10 @@ import sqlite3 import sys import threading import time -from abc import ABC +from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Mapping +from dataclasses import dataclass from functools import cached_property from sqlite3 import Connection, sqlite_version_info from typing import TYPE_CHECKING, Any, AnyStr, ClassVar, Generic, NamedTuple @@ -1029,6 +1030,27 @@ class Transaction: self.db._connection().executescript(statements) +@dataclass +class Migration(ABC): + db: Database + + @cached_classproperty + def name(cls) -> str: + """Class name (except Migration) converted to snake case.""" + name = cls.__name__.removesuffix("Migration") # type: ignore[attr-defined] + return re.sub(r"(?<=[a-z])(?=[A-Z])", "_", name).lower() + + def migrate_table(self, table: str) -> None: + """Migrate a specific table.""" + if not self.db.migration_exists(self.name, table): + self._migrate_data(table) + self.db.record_migration(self.name, table) + + @abstractmethod + def _migrate_data(self, table: str) -> None: + """Migrate data for a specific table.""" + + class Database: """A container for Model objects that wraps an SQLite database as the backend. @@ -1038,6 +1060,9 @@ class Database: """The Model subclasses representing tables in this database. """ + _migrations: Sequence[tuple[type[Migration], Sequence[type[Model]]]] = () + """Migrations that are to be performed for the configured models.""" + supports_extensions = hasattr(sqlite3.Connection, "enable_load_extension") """Whether or not the current version of SQLite supports extensions""" @@ -1302,6 +1327,10 @@ class Database: def _migrate(self) -> None: """Perform any necessary migration for the database.""" + for migration_cls, model_classes in self._migrations: + migration = migration_cls(self) + for model_cls in model_classes: + migration.migrate_table(model_cls._table) def migration_exists(self, name: str, table: str) -> bool: """Return whether a named migration has been marked complete."""