diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 8640a5678..deb31ba71 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -26,14 +26,7 @@ import threading import time from abc import ABC from collections import defaultdict -from collections.abc import ( - Callable, - Generator, - Iterable, - Iterator, - Mapping, - Sequence, -) +from collections.abc import Mapping from functools import cached_property from sqlite3 import Connection, sqlite_version_info from typing import TYPE_CHECKING, Any, AnyStr, ClassVar, Generic, NamedTuple @@ -1088,11 +1081,14 @@ class Database: self._db_lock = threading.Lock() # Set up database schema. + self._ensure_migration_state_table() for model_cls in self._models: self._make_table(model_cls._table, model_cls._fields) self._make_attribute_table(model_cls._flex_table) self._create_indices(model_cls._table, model_cls._indices) + self._migrate() + # Primitive access control: connections and transactions. def _connection(self) -> Connection: @@ -1292,6 +1288,41 @@ class Database: f"ON {table} ({', '.join(index.columns)});" ) + # Generic migration state handling. + + def _ensure_migration_state_table(self) -> None: + with self.transaction() as tx: + tx.script(""" + CREATE TABLE IF NOT EXISTS migrations ( + name TEXT NOT NULL, + table_name TEXT NOT NULL, + PRIMARY KEY(name, table_name) + ); + """) + + def _migrate(self) -> None: + """Perform any necessary migration for the database.""" + + def migration_exists(self, name: str, table: str) -> bool: + """Return whether a named migration has been marked complete.""" + with self.transaction() as tx: + return tx.execute( + """ + SELECT EXISTS( + SELECT 1 FROM migrations WHERE name = ? AND table_name = ? + ) + """, + (name, table), + ).fetchone()[0] + + def record_migration(self, name: str, table: str) -> None: + """Set completion state for a named migration.""" + with self.transaction() as tx: + tx.mutate( + "INSERT INTO migrations(name, table_name) VALUES (?, ?)", + (name, table), + ) + # Querying. def _fetch(