mirror of
https://github.com/beetbox/beets.git
synced 2026-02-18 05:17:31 +01:00
Add generic Migration implementation
This commit is contained in:
parent
a1bed02a58
commit
79167f1df7
1 changed files with 33 additions and 1 deletions
|
|
@ -24,9 +24,10 @@ import sqlite3
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from sqlite3 import Connection, sqlite_version_info
|
||||
from typing import TYPE_CHECKING, Any, AnyStr, ClassVar, Generic, NamedTuple
|
||||
|
|
@ -1029,6 +1030,29 @@ class Transaction:
|
|||
self.db._connection().executescript(statements)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Migration(ABC):
|
||||
db: Database
|
||||
|
||||
@cached_classproperty
|
||||
def flag_prefix(cls) -> str:
|
||||
"""Class name (except Migration) converted to snake case."""
|
||||
name = cls.__name__.removesuffix("Migration") # type: ignore[attr-defined]
|
||||
return re.sub(r"(?<=[a-z])(?=[A-Z])", "_", name).lower()
|
||||
|
||||
def migrate_table(self, table: str) -> None:
|
||||
"""Migrate a specific table."""
|
||||
migration_flag = f"{self.flag_prefix}_{table}"
|
||||
|
||||
if not self.db.get_migration_state(migration_flag):
|
||||
self._migrate_data(table)
|
||||
self.db.set_migration_state(migration_flag, True)
|
||||
|
||||
@abstractmethod
|
||||
def _migrate_data(self, table: str) -> None:
|
||||
"""Migrate data for a specific table."""
|
||||
|
||||
|
||||
class Database:
|
||||
"""A container for Model objects that wraps an SQLite database as
|
||||
the backend.
|
||||
|
|
@ -1038,6 +1062,10 @@ class Database:
|
|||
"""The Model subclasses representing tables in this database.
|
||||
"""
|
||||
|
||||
_migrations: Sequence[tuple[type[Migration], Sequence[type[Model]]]] = ()
|
||||
"""The Model subclasses representing tables in this database.
|
||||
"""
|
||||
|
||||
supports_extensions = hasattr(sqlite3.Connection, "enable_load_extension")
|
||||
"""Whether or not the current version of SQLite supports extensions"""
|
||||
|
||||
|
|
@ -1301,6 +1329,10 @@ class Database:
|
|||
|
||||
def _migrate(self) -> None:
|
||||
"""Perform any necessary migration for the database."""
|
||||
for migration_cls, model_classes in self._migrations:
|
||||
migration = migration_cls(self)
|
||||
for model_cls in model_classes:
|
||||
migration.migrate_table(model_cls._table)
|
||||
|
||||
def get_migration_state(self, name: str) -> bool:
|
||||
"""Return whether a named migration has been marked complete."""
|
||||
|
|
|
|||
Loading…
Reference in a new issue