diff --git a/beets/ui/__init__.py b/beets/ui/__init__.py index e0c1bb486..60e201448 100644 --- a/beets/ui/__init__.py +++ b/beets/ui/__init__.py @@ -1078,7 +1078,9 @@ def _field_diff(field, old, old_fmt, new, new_fmt): return f"{oldstr} -> {newstr}" -def show_model_changes(new, old=None, fields=None, always=False): +def show_model_changes( + new, old=None, fields=None, always=False, print_obj: bool = True +): """Given a Model object, print a list of changes from its pristine version stored in the database. Return a boolean indicating whether any changes were found. @@ -1117,7 +1119,7 @@ def show_model_changes(new, old=None, fields=None, always=False): ) # Print changes. - if changes or always: + if print_obj and (changes or always): print_(format(old)) if changes: print_("\n".join(changes)) diff --git a/beetsplug/lastgenre/__init__.py b/beetsplug/lastgenre/__init__.py index 1da5ecde4..902cef9ef 100644 --- a/beetsplug/lastgenre/__init__.py +++ b/beetsplug/lastgenre/__init__.py @@ -22,10 +22,13 @@ The scraper script used is available here: https://gist.github.com/1241307 """ +from __future__ import annotations + import os import traceback +from functools import singledispatchmethod from pathlib import Path -from typing import Union +from typing import TYPE_CHECKING, Union import pylast import yaml @@ -34,6 +37,9 @@ from beets import config, library, plugins, ui from beets.library import Album, Item from beets.util import plurality, unique_list +if TYPE_CHECKING: + from beets.library import LibModel + LASTFM = pylast.LastFMNetwork(api_key=plugins.LASTFM_KEY) PYLAST_EXCEPTIONS = ( @@ -101,6 +107,7 @@ class LastGenrePlugin(plugins.BeetsPlugin): "prefer_specific": False, "title_case": True, "extended_debug": False, + "pretend": False, } ) self.setup() @@ -321,7 +328,7 @@ class LastGenrePlugin(plugins.BeetsPlugin): return self.config["separator"].as_str().join(formatted) - def _get_existing_genres(self, obj: Union[Album, Item]) -> list[str]: + def _get_existing_genres(self, obj: LibModel) -> list[str]: """Return a list of genres for this Item or Album. Empty string genres are removed.""" separator = self.config["separator"].get() @@ -342,9 +349,7 @@ class LastGenrePlugin(plugins.BeetsPlugin): combined = old + new return self._resolve_genres(combined) - def _get_genre( - self, obj: Union[Album, Item] - ) -> tuple[Union[str, None], ...]: + def _get_genre(self, obj: LibModel) -> tuple[Union[str, None], ...]: """Get the final genre string for an Album or Item object. `self.sources` specifies allowed genre sources. Starting with the first @@ -459,6 +464,39 @@ class LastGenrePlugin(plugins.BeetsPlugin): # Beets plugin hooks and CLI. + def _fetch_and_log_genre(self, obj: LibModel) -> None: + """Fetch genre and log it.""" + self._log.info(str(obj)) + obj.genre, label = self._get_genre(obj) + self._log.debug("Resolved ({}): {}", label, obj.genre) + + ui.show_model_changes(obj, fields=["genre"], print_obj=False) + + @singledispatchmethod + def _process(self, obj: LibModel, write: bool) -> None: + """Process an object, dispatching to the appropriate method.""" + raise NotImplementedError + + @_process.register + def _process_track(self, obj: Item, write: bool) -> None: + """Process a single track/item.""" + self._fetch_and_log_genre(obj) + if not self.config["pretend"]: + obj.try_sync(write=write, move=False) + + @_process.register + def _process_album(self, obj: Album, write: bool) -> None: + """Process an entire album.""" + self._fetch_and_log_genre(obj) + if "track" in self.sources: + for item in obj.items(): + self._process(item, write) + + if not self.config["pretend"]: + obj.try_sync( + write=write, move=False, inherit="track" not in self.sources + ) + def commands(self): lastgenre_cmd = ui.Subcommand("lastgenre", help="fetch genres") lastgenre_cmd.parser.add_option( @@ -526,101 +564,17 @@ class LastGenrePlugin(plugins.BeetsPlugin): lastgenre_cmd.parser.set_defaults(album=True) def lastgenre_func(lib, opts, args): - write = ui.should_write() - pretend = getattr(opts, "pretend", False) self.config.set_args(opts) - if opts.album: - # Fetch genres for whole albums - for album in lib.albums(args): - album_genre, src = self._get_genre(album) - prefix = "Pretend: " if pretend else "" - self._log.info( - '{}genre for album "{.album}" ({}): {}', - prefix, - album, - src, - album_genre, - ) - if not pretend: - album.genre = album_genre - if "track" in self.sources: - album.store(inherit=False) - else: - album.store() - - for item in album.items(): - # If we're using track-level sources, also look up each - # track on the album. - if "track" in self.sources: - item_genre, src = self._get_genre(item) - self._log.info( - '{}genre for track "{.title}" ({}): {}', - prefix, - item, - src, - item_genre, - ) - if not pretend: - item.genre = item_genre - item.store() - - if write and not pretend: - item.try_write() - else: - # Just query singletons, i.e. items that are not part of - # an album - for item in lib.items(args): - item_genre, src = self._get_genre(item) - prefix = "Pretend: " if pretend else "" - self._log.info( - '{}genre for track "{0.title}" ({1}): {}', - prefix, - item, - src, - item_genre, - ) - if not pretend: - item.genre = item_genre - item.store() - if write and not pretend: - item.try_write() + method = lib.albums if opts.album else lib.items + for obj in method(args): + self._process(obj, write=ui.should_write()) lastgenre_cmd.func = lastgenre_func return [lastgenre_cmd] def imported(self, session, task): - """Event hook called when an import task finishes.""" - if task.is_album: - album = task.album - album.genre, src = self._get_genre(album) - self._log.debug( - 'genre for album "{0.album}" ({1}): {0.genre}', album, src - ) - - # If we're using track-level sources, store the album genre only, - # then also look up individual track genres. - if "track" in self.sources: - album.store(inherit=False) - for item in album.items(): - item.genre, src = self._get_genre(item) - self._log.debug( - 'genre for track "{0.title}" ({1}): {0.genre}', - item, - src, - ) - item.store() - # Store the album genre and inherit to tracks. - else: - album.store() - - else: - item = task.item - item.genre, src = self._get_genre(item) - self._log.debug( - 'genre for track "{0.title}" ({1}): {0.genre}', item, src - ) - item.store() + self._process(task.album if task.is_album else task.item, write=False) def _tags_for(self, obj, min_weight=None): """Core genre identification routine. diff --git a/test/plugins/test_lastgenre.py b/test/plugins/test_lastgenre.py index d6df42f97..12ff30f8e 100644 --- a/test/plugins/test_lastgenre.py +++ b/test/plugins/test_lastgenre.py @@ -19,11 +19,13 @@ from unittest.mock import Mock, patch import pytest from beets.test import _common -from beets.test.helper import BeetsTestCase +from beets.test.helper import PluginTestCase from beetsplug import lastgenre -class LastGenrePluginTest(BeetsTestCase): +class LastGenrePluginTest(PluginTestCase): + plugin = "lastgenre" + def setUp(self): super().setUp() self.plugin = lastgenre.LastGenrePlugin() @@ -131,6 +133,11 @@ class LastGenrePluginTest(BeetsTestCase): "math rock", ] + @patch("beets.ui.should_write", Mock(return_value=True)) + @patch( + "beetsplug.lastgenre.LastGenrePlugin._get_genre", + Mock(return_value=("Mock Genre", "mock stage")), + ) def test_pretend_option_skips_library_updates(self): item = self.create_item( album="Pretend Album", @@ -141,32 +148,17 @@ class LastGenrePluginTest(BeetsTestCase): ) album = self.lib.add_album([item]) - command = self.plugin.commands()[0] - opts, args = command.parser.parse_args(["--pretend"]) - - with patch.object(lastgenre.ui, "should_write", return_value=True): - with patch.object( - self.plugin, - "_get_genre", - return_value=("Mock Genre", "mock stage"), - ) as mock_get_genre: - with patch.object(self.plugin._log, "info") as log_info: - # Mock try_write to verify it's never called in pretend mode - with patch.object(item, "try_write") as mock_try_write: - command.func(self.lib, opts, args) - - mock_get_genre.assert_called_once() - - assert any( - call.args[1] == "Pretend: " for call in log_info.call_args_list - ) + def unexpected_store(*_, **__): + raise AssertionError("Unexpected store call") # Verify that try_write was never called (file operations skipped) - mock_try_write.assert_not_called() + with patch("beetsplug.lastgenre.Item.store", unexpected_store): + output = self.run_with_output("lastgenre", "--pretend") - stored_album = self.lib.get_album(album.id) - assert stored_album.genre == "Original Genre" - assert stored_album.items()[0].genre == "Original Genre" + assert "Mock Genre" in output + album.load() + assert album.genre == "Original Genre" + assert album.items()[0].genre == "Original Genre" def test_no_duplicate(self): """Remove duplicated genres."""