Improve model changes colour display and db / field diff typing (#6240)

## 1. **Refactored UI diffs**

Using the following command:

```sh
beet modify turn page data_source= 'play_count!' hello=hi comp=1 mb_albumid=https://bandcamp
```

### Before
<img width="613" height="260" alt="before"
src="https://github.com/user-attachments/assets/785c4b73-69e4-4c60-b4dd-d114ee3170a1"
/>

* New field additions have been shown in red
* No difference in formatting between
  - field _removal_ (`field!`)
  - and it being reset to an empty string (`field=`)
 
### After
<img width="640" height="256" alt="after"
src="https://github.com/user-attachments/assets/89b036d7-a074-494b-a5e1-b1bf5100d454"
/>


* Now, the field name is colored in red or green whenever it's added or
removed

## 2. Small improvements in `Model` types:
* Added `NotFoundError` and `Model.get_fresh_from_db` for those cases
where `Database._get` must return a non-optional instance of `Model`
   * Added cached `Model.db` property to dedupe `Model._check_db` calls.

## 3. Added a global autouse model-level `config` fixture.
This commit is contained in:
Šarūnas Nejus 2025-12-27 14:35:43 +00:00 committed by GitHub
commit 21e6a1f757
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 161 additions and 93 deletions

View file

@ -34,10 +34,14 @@ from collections.abc import (
Mapping,
Sequence,
)
from functools import cached_property
from sqlite3 import Connection, sqlite_version_info
from typing import TYPE_CHECKING, Any, AnyStr, Generic
from typing_extensions import TypeVar # default value support
from typing_extensions import (
Self,
TypeVar, # default value support
)
from unidecode import unidecode
import beets
@ -83,6 +87,10 @@ class DBCustomFunctionError(Exception):
)
class NotFoundError(LookupError):
pass
class FormattedMapping(Mapping[str, str]):
"""A `dict`-like formatted view of a model.
@ -97,6 +105,8 @@ class FormattedMapping(Mapping[str, str]):
are replaced.
"""
model: Model
ALL_KEYS = "*"
def __init__(
@ -360,6 +370,22 @@ class Model(ABC, Generic[D]):
"""Fields in the related table."""
return cls._relation._fields.keys() - cls.shared_db_fields
@cached_property
def db(self) -> D:
"""Get the database associated with this object.
This validates that the database is attached and the object has an id.
"""
return self._check_db()
def get_fresh_from_db(self) -> Self:
"""Load this object from the database."""
model_cls = self.__class__
if obj := self.db._get(model_cls, self.id):
return obj
raise NotFoundError(f"No matching {model_cls.__name__} found") from None
@classmethod
def _getters(cls: type[Model]):
"""Return a mapping from field names to getter functions."""
@ -599,7 +625,6 @@ class Model(ABC, Generic[D]):
"""
if fields is None:
fields = self._fields
db = self._check_db()
# Build assignments for query.
assignments = []
@ -611,7 +636,7 @@ class Model(ABC, Generic[D]):
value = self._type(key).to_sql(self[key])
subvars.append(value)
with db.transaction() as tx:
with self.db.transaction() as tx:
# Main table update.
if assignments:
query = f"UPDATE {self._table} SET {','.join(assignments)} WHERE id=?"
@ -645,21 +670,16 @@ class Model(ABC, Generic[D]):
If check_revision is true, the database is only queried loaded when a
transaction has been committed since the item was last loaded.
"""
db = self._check_db()
if not self._dirty and db.revision == self._revision:
if not self._dirty and self.db.revision == self._revision:
# Exit early
return
stored_obj = db._get(type(self), self.id)
assert stored_obj is not None, f"object {self.id} not in DB"
self._values_fixed = LazyConvertDict(self)
self._values_flex = LazyConvertDict(self)
self.update(dict(stored_obj))
self.__dict__.update(self.get_fresh_from_db().__dict__)
self.clear_dirty()
def remove(self):
"""Remove the object's associated rows from the database."""
db = self._check_db()
with db.transaction() as tx:
with self.db.transaction() as tx:
tx.mutate(f"DELETE FROM {self._table} WHERE id=?", (self.id,))
tx.mutate(
f"DELETE FROM {self._flex_table} WHERE entity_id=?", (self.id,)
@ -675,7 +695,7 @@ class Model(ABC, Generic[D]):
"""
if db:
self._db = db
db = self._check_db(False)
db = self._check_db(need_id=False)
with db.transaction() as tx:
new_id = tx.mutate(f"INSERT INTO {self._table} DEFAULT VALUES")
@ -696,7 +716,7 @@ class Model(ABC, Generic[D]):
self,
included_keys: str = _formatter.ALL_KEYS,
for_path: bool = False,
):
) -> FormattedMapping:
"""Get a mapping containing all values on this object formatted
as human-readable unicode strings.
"""
@ -740,9 +760,9 @@ class Model(ABC, Generic[D]):
Remove the database connection as sqlite connections are not
picklable.
"""
state = self.__dict__.copy()
state["_db"] = None
return state
return {
k: v for k, v in self.__dict__.items() if k not in {"_db", "db"}
}
# Database controller and supporting interfaces.
@ -1303,12 +1323,6 @@ class Database:
sort if sort.is_slow() else None, # Slow sort component.
)
def _get(
self,
model_cls: type[AnyModel],
id,
) -> AnyModel | None:
"""Get a Model object by its id or None if the id does not
exist.
"""
return self._fetch(model_cls, MatchQuery("id", id)).get()
def _get(self, model_cls: type[AnyModel], id_: int) -> AnyModel | None:
"""Get a Model object by its id or None if the id does not exist."""
return self._fetch(model_cls, MatchQuery("id", id_)).get()

View file

@ -125,24 +125,20 @@ class Library(dbcore.Database):
return self._fetch(Item, query, sort or self.get_default_item_sort())
# Convenience accessors.
def get_item(self, id):
def get_item(self, id_: int) -> Item | None:
"""Fetch a :class:`Item` by its ID.
Return `None` if no match is found.
"""
return self._get(Item, id)
return self._get(Item, id_)
def get_album(self, item_or_id):
def get_album(self, item_or_id: Item | int) -> Album | None:
"""Given an album ID or an item associated with an album, return
a :class:`Album` object for the album.
If no such album exists, return `None`.
"""
if isinstance(item_or_id, int):
album_id = item_or_id
else:
album_id = item_or_id.album_id
if album_id is None:
return None
return self._get(Album, album_id)
album_id = (
item_or_id if isinstance(item_or_id, int) else item_or_id.album_id
)
return self._get(Album, album_id) if album_id else None

View file

@ -620,6 +620,8 @@ class Album(LibModel):
class Item(LibModel):
"""Represent a song or track."""
album_id: int | None
_table = "items"
_flex_table = "item_attributes"
_fields = {
@ -1143,7 +1145,6 @@ class Item(LibModel):
If `store` is `False` however, the item won't be stored and it will
have to be manually stored after invoking this method.
"""
self._check_db()
dest = self.destination(basedir=basedir)
# Create necessary ancestry for the move.
@ -1183,9 +1184,8 @@ class Item(LibModel):
is true, returns just the fragment of the path underneath the library
base directory.
"""
db = self._check_db()
basedir = basedir or db.directory
path_formats = path_formats or db.path_formats
basedir = basedir or self.db.directory
path_formats = path_formats or self.db.path_formats
# Use a path format based on a query, falling back on the
# default.
@ -1224,7 +1224,7 @@ class Item(LibModel):
)
lib_path_str, fallback = util.legalize_path(
subpath, db.replacements, self.filepath.suffix
subpath, self.db.replacements, self.filepath.suffix
)
if fallback:
# Print an error message if legalization fell back to

View file

@ -43,7 +43,10 @@ from beets.util.deprecation import deprecate_for_maintainers
from beets.util.functemplate import template
if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Callable, Iterable
from beets.dbcore.db import FormattedMapping
# On Windows platforms, use colorama to support "ANSI" terminal colors.
if sys.platform == "win32":
@ -1028,42 +1031,47 @@ def print_newline_layout(
FLOAT_EPSILON = 0.01
def _field_diff(field, old, old_fmt, new, new_fmt):
def _field_diff(
field: str, old: FormattedMapping, new: FormattedMapping
) -> str | None:
"""Given two Model objects and their formatted views, format their values
for `field` and highlight changes among them. Return a human-readable
string. If the value has not changed, return None instead.
"""
oldval = old.get(field)
newval = new.get(field)
# If no change, abort.
if (
if (oldval := old.model.get(field)) == (newval := new.model.get(field)) or (
isinstance(oldval, float)
and isinstance(newval, float)
and abs(oldval - newval) < FLOAT_EPSILON
):
return None
elif oldval == newval:
return None
# Get formatted values for output.
oldstr = old_fmt.get(field, "")
newstr = new_fmt.get(field, "")
oldstr, newstr = old.get(field, ""), new.get(field, "")
if field not in new:
return colorize("text_diff_removed", f"{field}: {oldstr}")
if field not in old:
return colorize("text_diff_added", f"{field}: {newstr}")
# For strings, highlight changes. For others, colorize the whole
# thing.
if isinstance(oldval, str):
oldstr, newstr = colordiff(oldval, newstr)
oldstr, newstr = colordiff(oldstr, newstr)
else:
oldstr = colorize("text_diff_removed", oldstr)
newstr = colorize("text_diff_added", newstr)
return f"{oldstr} -> {newstr}"
return f"{field}: {oldstr} -> {newstr}"
def show_model_changes(
new, old=None, fields=None, always=False, print_obj: bool = True
):
new: library.LibModel,
old: library.LibModel | None = None,
fields: Iterable[str] | None = None,
always: bool = False,
print_obj: bool = True,
) -> bool:
"""Given a Model object, print a list of changes from its pristine
version stored in the database. Return a boolean indicating whether
any changes were found.
@ -1073,7 +1081,7 @@ def show_model_changes(
restrict the detection to. `always` indicates whether the object is
always identified, regardless of whether any changes are present.
"""
old = old or new._db._get(type(new), new.id)
old = old or new.get_fresh_from_db()
# Keep the formatted views around instead of re-creating them in each
# iteration step
@ -1081,31 +1089,21 @@ def show_model_changes(
new_fmt = new.formatted()
# Build up lines showing changed fields.
changes = []
for field in old:
# Subset of the fields. Never show mtime.
if field == "mtime" or (fields and field not in fields):
continue
diff_fields = (set(old) | set(new)) - {"mtime"}
if allowed_fields := set(fields or {}):
diff_fields &= allowed_fields
# Detect and show difference for this field.
line = _field_diff(field, old, old_fmt, new, new_fmt)
if line:
changes.append(f" {field}: {line}")
# New fields.
for field in set(new) - set(old):
if fields and field not in fields:
continue
changes.append(
f" {field}: {colorize('text_highlight', new_fmt[field])}"
)
changes = [
d
for f in sorted(diff_fields)
if (d := _field_diff(f, old_fmt, new_fmt))
]
# Print changes.
if print_obj and (changes or always):
print_(format(old))
if changes:
print_("\n".join(changes))
print_(textwrap.indent("\n".join(changes), " "))
return bool(changes)

View file

@ -70,6 +70,8 @@ Bug fixes:
- When using :doc:`plugins/fromfilename` together with :doc:`plugins/edit`,
temporary tags extracted from filenames are no longer lost when discarding or
cancelling an edit session during import. :bug:`6104`
- :ref:`update-cmd` :doc:`plugins/edit` fix display formatting of field changes
to clearly show added and removed flexible fields.
For plugin developers:

View file

@ -322,6 +322,7 @@ ignore = [
[tool.ruff.lint.per-file-ignores]
"beets/**" = ["PT"]
"test/test_util.py" = ["E501"]
"test/ui/test_field_diff.py" = ["E501"]
[tool.ruff.lint.isort]
split-on-trailing-comma = false

View file

@ -12,15 +12,13 @@ from beets.autotag.distance import (
from beets.library import Item
from beets.metadata_plugins import MetadataSourcePlugin, get_penalty
from beets.plugins import BeetsPlugin
from beets.test.helper import ConfigMixin
_p = pytest.param
class TestDistance:
@pytest.fixture(autouse=True, scope="class")
def setup_config(self):
config = ConfigMixin().config
def setup_config(self, config):
config["match"]["distance_weights"]["data_source"] = 2.0
config["match"]["distance_weights"]["album"] = 4.0
config["match"]["distance_weights"]["medium"] = 2.0

View file

@ -5,6 +5,7 @@ import pytest
from beets.autotag.distance import Distance
from beets.dbcore.query import Query
from beets.test.helper import ConfigMixin
from beets.util import cached_classproperty
@ -53,3 +54,9 @@ def pytest_assertrepr_compare(op, left, right):
@pytest.fixture(autouse=True)
def clear_cached_classproperty():
cached_classproperty.cache.clear()
@pytest.fixture(scope="module")
def config():
"""Provide a fresh beets configuration for a module, when requested."""
return ConfigMixin().config

View file

@ -8,7 +8,7 @@ from beets.autotag import AlbumMatch
from beets.autotag.distance import Distance
from beets.autotag.hooks import AlbumInfo, TrackInfo
from beets.library import Item
from beets.test.helper import ConfigMixin, PluginMixin
from beets.test.helper import PluginMixin
from beetsplug._typing import JSONDict
from beetsplug.mbpseudo import (
_STATUS_PSEUDO,
@ -52,14 +52,7 @@ def pseudo_release_info() -> AlbumInfo:
)
@pytest.fixture(scope="module", autouse=True)
def config():
config = ConfigMixin().config
with pytest.MonkeyPatch.context() as m:
m.setattr("beetsplug.mbpseudo.config", config)
yield config
@pytest.mark.usefixtures("config")
class TestPseudoAlbumInfo:
def test_album_id_always_from_pseudo(
self, official_release_info: AlbumInfo, pseudo_release_info: AlbumInfo

View file

@ -19,18 +19,18 @@ import pytest
from beets import autotag, config
from beets.autotag import AlbumInfo, TrackInfo, correct_list_fields, match
from beets.library import Item
from beets.test.helper import BeetsTestCase, ConfigMixin
from beets.test.helper import BeetsTestCase
class TestAssignment(ConfigMixin):
class TestAssignment:
A = "one"
B = "two"
C = "three"
@pytest.fixture(autouse=True)
def _setup_config(self):
self.config["match"]["track_length_grace"] = 10
self.config["match"]["track_length_max"] = 30
def config(self, config):
config["match"]["track_length_grace"] = 10
config["match"]["track_length_max"] = 30
@pytest.mark.parametrize(
# 'expected' is a tuple of expected (mapping, extra_items, extra_tracks)

View file

@ -0,0 +1,59 @@
import pytest
from beets.library import Item
from beets.ui import _field_diff
p = pytest.param
class TestFieldDiff:
@pytest.fixture(autouse=True)
def configure_color(self, config, color):
config["ui"]["color"] = color
@pytest.fixture(autouse=True)
def patch_colorize(self, monkeypatch):
"""Patch to return a deterministic string format instead of ANSI codes."""
monkeypatch.setattr(
"beets.ui.colorize",
lambda color_name, text: f"[{color_name}]{text}[/]",
)
@staticmethod
def diff_fmt(old, new):
return f"[text_diff_removed]{old}[/] -> [text_diff_added]{new}[/]"
@pytest.mark.parametrize(
"old_data, new_data, field, expected_diff",
[
p({"title": "foo"}, {"title": "foo"}, "title", None, id="no_change"),
p({"bpm": 120.0}, {"bpm": 120.005}, "bpm", None, id="float_close_enough"),
p({"bpm": 120.0}, {"bpm": 121.0}, "bpm", f"bpm: {diff_fmt('120', '121')}", id="float_changed"),
p({"title": "foo"}, {"title": "bar"}, "title", f"title: {diff_fmt('foo', 'bar')}", id="string_full_replace"),
p({"title": "prefix foo"}, {"title": "prefix bar"}, "title", "title: prefix [text_diff_removed]foo[/] -> prefix [text_diff_added]bar[/]", id="string_partial_change"),
p({"year": 2000}, {"year": 2001}, "year", f"year: {diff_fmt('2000', '2001')}", id="int_changed"),
p({}, {"genre": "Rock"}, "genre", "genre: -> [text_diff_added]Rock[/]", id="field_added"),
p({"genre": "Rock"}, {}, "genre", "genre: [text_diff_removed]Rock[/] -> ", id="field_removed"),
p({"track": 1}, {"track": 2}, "track", f"track: {diff_fmt('01', '02')}", id="formatted_value_changed"),
p({"mb_trackid": None}, {"mb_trackid": "1234"}, "mb_trackid", "mb_trackid: -> [text_diff_added]1234[/]", id="none_to_value"),
p({}, {"new_flex": "foo"}, "new_flex", "[text_diff_added]new_flex: foo[/]", id="flex_field_added"),
p({"old_flex": "foo"}, {}, "old_flex", "[text_diff_removed]old_flex: foo[/]", id="flex_field_removed"),
],
) # fmt: skip
@pytest.mark.parametrize("color", [True], ids=["color_enabled"])
def test_field_diff_colors(self, old_data, new_data, field, expected_diff):
old_item = Item(**old_data)
new_item = Item(**new_data)
diff = _field_diff(field, old_item.formatted(), new_item.formatted())
assert diff == expected_diff
@pytest.mark.parametrize("color", [False], ids=["color_disabled"])
def test_field_diff_no_color(self):
old_item = Item(title="foo")
new_item = Item(title="bar")
diff = _field_diff("title", old_item.formatted(), new_item.formatted())
assert diff == "title: foo -> bar"