diff --git a/beets/autotag/__init__.py b/beets/autotag/__init__.py index 4cc4ff30a..feeefbf28 100644 --- a/beets/autotag/__init__.py +++ b/beets/autotag/__init__.py @@ -167,7 +167,6 @@ def correct_list_fields(m: LibModel) -> None: setattr(m, single_field, list_val[0]) ensure_first_value("albumtype", "albumtypes") - ensure_first_value("genre", "genres") if hasattr(m, "mb_artistids"): ensure_first_value("mb_artistid", "mb_artistids") diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 902eae634..75ce9d3e7 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -30,7 +30,16 @@ from collections.abc import Mapping from dataclasses import dataclass from functools import cached_property from sqlite3 import Connection, sqlite_version_info -from typing import TYPE_CHECKING, Any, AnyStr, ClassVar, Generic, NamedTuple +from typing import ( + TYPE_CHECKING, + Any, + AnyStr, + ClassVar, + Generic, + Literal, + NamedTuple, + TypedDict, +) from typing_extensions import ( Self, @@ -1063,17 +1072,22 @@ class Migration(ABC): name = cls.__name__.removesuffix("Migration") # type: ignore[attr-defined] return re.sub(r"(?<=[a-z])(?=[A-Z])", "_", name).lower() - def migrate_table(self, table: str) -> None: + def migrate_table(self, table: str, *args, **kwargs) -> None: """Migrate a specific table.""" if not self.db.migration_exists(self.name, table): - self._migrate_data(table) + self._migrate_data(table, *args, **kwargs) self.db.record_migration(self.name, table) @abstractmethod - def _migrate_data(self, table: str) -> None: + def _migrate_data(self, table: str, current_fields: set[str]) -> None: """Migrate data for a specific table.""" +class TableInfo(TypedDict): + columns: set[str] + migrations: set[str] + + class Database: """A container for Model objects that wraps an SQLite database as the backend. @@ -1138,6 +1152,32 @@ class Database: self._migrate() + @cached_property + def db_tables(self) -> dict[str, TableInfo]: + column_queries = [ + f""" + SELECT '{m._table}' AS table_name, 'columns' AS source, name + FROM pragma_table_info('{m._table}') + """ + for m in self._models + ] + with self.transaction() as tx: + rows = tx.query(f""" + {" UNION ALL ".join(column_queries)} + UNION ALL + SELECT table_name, 'migrations' AS source, name FROM migrations + """) + + tables_data: dict[str, TableInfo] = defaultdict( + lambda: TableInfo(columns=set(), migrations=set()) + ) + + source: Literal["columns", "migrations"] + for table_name, source, name in rows: + tables_data[table_name][source].add(name) + + return tables_data + # Primitive access control: connections and transactions. def _connection(self) -> Connection: @@ -1278,32 +1318,22 @@ class Database: """Set up the schema of the database. `fields` is a mapping from field names to `Type`s. Columns are added if necessary. """ - # Get current schema. - with self.transaction() as tx: - rows = tx.query(f"PRAGMA table_info({table})") - current_fields = {row[1] for row in rows} - - field_names = set(fields.keys()) - if current_fields.issuperset(field_names): - # Table exists and has all the required columns. - return - - if not current_fields: + if table not in self.db_tables: # No table exists. columns = [] for name, typ in fields.items(): columns.append(f"{name} {typ.sql}") setup_sql = f"CREATE TABLE {table} ({', '.join(columns)});\n" - + self.db_tables[table]["columns"] = set(fields) else: # Table exists does not match the field set. setup_sql = "" + current_fields = self.db_tables[table]["columns"] for name, typ in fields.items(): - if name in current_fields: - continue - setup_sql += ( - f"ALTER TABLE {table} ADD COLUMN {name} {typ.sql};\n" - ) + if name not in current_fields: + setup_sql += ( + f"ALTER TABLE {table} ADD COLUMN {name} {typ.sql};\n" + ) with self.transaction() as tx: tx.script(setup_sql) @@ -1354,19 +1384,12 @@ class Database: for migration_cls, model_classes in self._migrations: migration = migration_cls(self) for model_cls in model_classes: - migration.migrate_table(model_cls._table) + table = model_cls._table + migration.migrate_table(table, self.db_tables[table]["columns"]) def migration_exists(self, name: str, table: str) -> bool: """Return whether a named migration has been marked complete.""" - with self.transaction() as tx: - return tx.execute( - """ - SELECT EXISTS( - SELECT 1 FROM migrations WHERE name = ? AND table_name = ? - ) - """, - (name, table), - ).fetchone()[0] + return name in self.db_tables[table]["migrations"] def record_migration(self, name: str, table: str) -> None: """Set completion state for a named migration.""" diff --git a/beets/library/migrations.py b/beets/library/migrations.py index e2fa80f63..16f4c6761 100644 --- a/beets/library/migrations.py +++ b/beets/library/migrations.py @@ -57,8 +57,11 @@ class MultiGenreFieldMigration(Migration): return genre - def _migrate_data(self, table: str) -> None: + def _migrate_data(self, table: str, current_fields: set[str]) -> None: """Migrate legacy genre values to the multi-value genres field.""" + if "genre" not in current_fields: + # No legacy genre field, so nothing to migrate. + return with self.db.transaction() as tx, self.with_factory(GenreRow): rows: list[GenreRow] = tx.query( # type: ignore[assignment] diff --git a/beets/library/models.py b/beets/library/models.py index 1f01581c2..eba2fb618 100644 --- a/beets/library/models.py +++ b/beets/library/models.py @@ -241,7 +241,6 @@ class Album(LibModel): "albumartists_sort": types.MULTI_VALUE_DSV, "albumartists_credit": types.MULTI_VALUE_DSV, "album": types.STRING, - "genre": types.STRING, "genres": types.MULTI_VALUE_DSV, "style": types.STRING, "discogs_albumid": types.INTEGER, @@ -277,7 +276,7 @@ class Album(LibModel): "original_day": types.PaddedInt(2), } - _search_fields = ("album", "albumartist", "genre") + _search_fields = ("album", "albumartist", "genres") @cached_classproperty def _types(cls) -> dict[str, types.Type]: @@ -298,7 +297,6 @@ class Album(LibModel): "albumartist_credit", "albumartists_credit", "album", - "genre", "genres", "style", "discogs_albumid", @@ -652,7 +650,6 @@ class Item(LibModel): "albumartists_sort": types.MULTI_VALUE_DSV, "albumartist_credit": types.STRING, "albumartists_credit": types.MULTI_VALUE_DSV, - "genre": types.STRING, "genres": types.MULTI_VALUE_DSV, "style": types.STRING, "discogs_albumid": types.INTEGER, @@ -735,7 +732,7 @@ class Item(LibModel): "comments", "album", "albumartist", - "genre", + "genres", ) # Set of item fields that are backed by `MediaFile` fields. diff --git a/test/library/test_migrations.py b/test/library/test_migrations.py index dba0d8718..2c0dece8b 100644 --- a/test/library/test_migrations.py +++ b/test/library/test_migrations.py @@ -1,5 +1,6 @@ import pytest +from beets.dbcore import types from beets.library.migrations import MultiGenreFieldMigration from beets.library.models import Album, Item from beets.test.helper import TestHelper @@ -10,6 +11,19 @@ class TestMultiGenreFieldMigration: def helper(self, monkeypatch): # do not apply migrations upon library initialization monkeypatch.setattr("beets.library.library.Library._migrations", ()) + # add genre field to both models to make sure this column is created + monkeypatch.setattr( + "beets.library.models.Item._fields", + {**Item._fields, "genre": types.STRING}, + ) + monkeypatch.setattr( + "beets.library.models.Album._fields", + {**Album._fields, "genre": types.STRING}, + ) + monkeypatch.setattr( + "beets.library.models.Album.item_keys", + {*Album.item_keys, "genre"}, + ) helper = TestHelper() helper.setup_beets() @@ -52,5 +66,7 @@ class TestMultiGenreFieldMigration: unmigrated_album.load() assert unmigrated_album.genres == ["Album Rock", "Alternative"] + # remove cached initial db tables data + del helper.lib.db_tables assert helper.lib.migration_exists("multi_genre_field", "items") assert helper.lib.migration_exists("multi_genre_field", "albums") diff --git a/test/test_autotag.py b/test/test_autotag.py index e6a122ae2..119ca15e8 100644 --- a/test/test_autotag.py +++ b/test/test_autotag.py @@ -475,71 +475,3 @@ def test_correct_list_fields( single_val, list_val = item[single_field], item[list_field] assert (not single_val and not list_val) or single_val == list_val[0] - - -# Tests for multi-value genres functionality -class TestGenreSync: - """Test the genre/genres field synchronization.""" - - def test_genres_list_to_genre_first(self): - """Genres list sets genre to first item.""" - item = Item(genres=["Rock", "Alternative", "Indie"]) - correct_list_fields(item) - - assert item.genre == "Rock" - assert item.genres == ["Rock", "Alternative", "Indie"] - - def test_genre_string_to_genres_list(self): - """Genre string becomes first item in genres list.""" - item = Item(genre="Rock") - correct_list_fields(item) - - assert item.genre == "Rock" - assert item.genres == ["Rock"] - - def test_genre_and_genres_both_present(self): - """When both genre and genres exist, genre becomes first in list.""" - item = Item(genre="Jazz", genres=["Rock", "Alternative"]) - correct_list_fields(item) - - # genre should be prepended to genres list (deduplicated) - assert item.genre == "Jazz" - assert item.genres == ["Jazz", "Rock", "Alternative"] - - def test_empty_genre(self): - """Empty genre field.""" - item = Item(genre="") - correct_list_fields(item) - - assert item.genre == "" - assert item.genres == [] - - def test_empty_genres(self): - """Empty genres list.""" - item = Item(genres=[]) - correct_list_fields(item) - - assert item.genre == "" - assert item.genres == [] - - def test_none_values(self): - """Handle None values in genre/genres fields without errors.""" - # Test with None genre - item = Item(genre=None, genres=["Rock"]) - correct_list_fields(item) - assert item.genres == ["Rock"] - assert item.genre == "Rock" - - # Test with None genres - item = Item(genre="Jazz", genres=None) - correct_list_fields(item) - assert item.genre == "Jazz" - assert item.genres == ["Jazz"] - - def test_none_both(self): - """Handle None in both genre and genres.""" - item = Item(genre=None, genres=None) - correct_list_fields(item) - - assert item.genres == [] - assert item.genre == ""