From 898f439d5e2ba82087a6102cd655389edeebb26b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Fri, 26 Dec 2025 09:35:47 +0000 Subject: [PATCH] db: stricten from_id --- beets/dbcore/db.py | 43 +++++++++++++++++++++------------------ beets/library/library.py | 24 +++++++++++++--------- beets/library/models.py | 2 ++ beets/ui/__init__.py | 2 +- beetsplug/mpdstats.py | 8 ++++---- beetsplug/web/__init__.py | 7 +++---- test/test_dbcore.py | 5 +---- test/test_importer.py | 8 ++++---- 8 files changed, 52 insertions(+), 47 deletions(-) diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 1c7cf47da..a0cbd17fd 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -38,7 +38,10 @@ from functools import cached_property from sqlite3 import Connection, sqlite_version_info from typing import TYPE_CHECKING, Any, AnyStr, Generic -from typing_extensions import TypeVar # default value support +from typing_extensions import ( + Self, + TypeVar, # default value support +) from unidecode import unidecode import beets @@ -84,9 +87,14 @@ class DBCustomFunctionError(Exception): ) +class NotFoundError(LookupError): + pass + + class FormattedMapping(Mapping[str, str]): """A `dict`-like formatted view of a model. + The accessor `mapping[key]` returns the formatted version of `model[key]` as a unicode string. @@ -365,6 +373,9 @@ class Model(ABC, Generic[D]): def db(self) -> D: return self._check_db() + def get_fresh_from_db(self) -> Self: + return self.db.from_id(self.__class__, self.id) + @classmethod def _getters(cls: type[Model]): """Return a mapping from field names to getter functions.""" @@ -652,11 +663,10 @@ class Model(ABC, Generic[D]): if not self._dirty and self.db.revision == self._revision: # Exit early return - stored_obj = self.db.from_id(type(self), self.id) - assert stored_obj is not None, f"object {self.id} not in DB" + stored_obj = dict(self.get_fresh_from_db()) self._values_fixed = LazyConvertDict(self) self._values_flex = LazyConvertDict(self) - self.update(dict(stored_obj)) + self.update(stored_obj) self.clear_dirty() def remove(self): @@ -911,15 +921,14 @@ class Results(Generic[AnyModel]): except StopIteration: raise IndexError(f"result index {n} out of range") - def get(self) -> AnyModel | None: - """Return the first matching object, or None if no objects - match. - """ - it = iter(self) + def get(self) -> AnyModel: + """Return the first matching object.""" try: - return next(it) + return next(iter(self)) except StopIteration: - return None + raise NotFoundError( + f"No matching {self.model_class.__name__} found" + ) from None class Transaction: @@ -1306,12 +1315,6 @@ class Database: sort if sort.is_slow() else None, # Slow sort component. ) - def from_id( - self, - model_cls: type[AnyModel], - id, - ) -> AnyModel | None: - """Get a Model object by its id or None if the id does not - exist. - """ - return self._fetch(model_cls, MatchQuery("id", id)).get() + def from_id(self, model_cls: type[AnyModel], id_: int) -> AnyModel: + """Get a Model object by its id.""" + return self._fetch(model_cls, MatchQuery("id", id_)).get() diff --git a/beets/library/library.py b/beets/library/library.py index 0f3b6adbd..eae33b4b6 100644 --- a/beets/library/library.py +++ b/beets/library/library.py @@ -1,5 +1,6 @@ from __future__ import annotations +from contextlib import suppress from typing import TYPE_CHECKING import platformdirs @@ -13,6 +14,7 @@ from .queries import PF_KEY_DEFAULT, parse_query_parts, parse_query_string if TYPE_CHECKING: from beets.dbcore import Results + from beets.dbcore.db import AnyModel class Library(dbcore.Database): @@ -125,24 +127,26 @@ class Library(dbcore.Database): return self._fetch(Item, query, sort or self.get_default_item_sort()) # Convenience accessors. + def _try_get(self, model_cls: type[AnyModel], id_: int) -> AnyModel | None: + with suppress(dbcore.db.NotFoundError): + return self.from_id(model_cls, id_) - def get_item(self, id): + return None + + def get_item(self, id_: int) -> Item | None: """Fetch a :class:`Item` by its ID. Return `None` if no match is found. """ - return self.from_id(Item, id) + return self._try_get(Item, id_) - def get_album(self, item_or_id): + def get_album(self, item_or_id: Item | int) -> Album | None: """Given an album ID or an item associated with an album, return a :class:`Album` object for the album. If no such album exists, return `None`. """ - if isinstance(item_or_id, int): - album_id = item_or_id - else: - album_id = item_or_id.album_id - if album_id is None: - return None - return self.from_id(Album, album_id) + album_id = ( + item_or_id if isinstance(item_or_id, int) else item_or_id.album_id + ) + return self._try_get(Album, album_id) if album_id else None diff --git a/beets/library/models.py b/beets/library/models.py index 76618d929..9609989bc 100644 --- a/beets/library/models.py +++ b/beets/library/models.py @@ -620,6 +620,8 @@ class Album(LibModel): class Item(LibModel): """Represent a song or track.""" + album_id: int | None + _table = "items" _flex_table = "item_attributes" _fields = { diff --git a/beets/ui/__init__.py b/beets/ui/__init__.py index 664531359..cbe0fb109 100644 --- a/beets/ui/__init__.py +++ b/beets/ui/__init__.py @@ -1073,7 +1073,7 @@ def show_model_changes( restrict the detection to. `always` indicates whether the object is always identified, regardless of whether any changes are present. """ - old = old or new._db.from_id(type(new), new.id) + old = old or new.get_fresh_from_db() # Keep the formatted views around instead of re-creating them in each # iteration step diff --git a/beetsplug/mpdstats.py b/beetsplug/mpdstats.py index 0a3e1de02..25786d9bf 100644 --- a/beetsplug/mpdstats.py +++ b/beetsplug/mpdstats.py @@ -20,6 +20,7 @@ import mpd from beets import config, plugins, ui from beets.dbcore import types +from beets.dbcore.db import NotFoundError from beets.dbcore.query import PathQuery from beets.util import displayable_path @@ -165,10 +166,9 @@ class MPDStats: def get_item(self, path): """Return the beets item related to path.""" query = PathQuery("path", path) - item = self.lib.items(query).get() - if item: - return item - else: + try: + return self.lib.items(query).get() + except NotFoundError: self._log.info("item not found: {}", displayable_path(path)) def update_item(self, item, attribute, value=None, increment=None): diff --git a/beetsplug/web/__init__.py b/beetsplug/web/__init__.py index 28bc20152..9cf97aaf8 100644 --- a/beetsplug/web/__init__.py +++ b/beetsplug/web/__init__.py @@ -344,10 +344,9 @@ def item_query(queries): @app.route("/item/path/") def item_at_path(path): query = PathQuery("path", path.encode("utf-8")) - item = g.lib.items(query).get() - if item: - return flask.jsonify(_rep(item)) - else: + try: + return flask.jsonify(_rep(g.lib.items(query).get())) + except beets.dbcore.db.NotFoundError: return flask.abort(404) diff --git a/test/test_dbcore.py b/test/test_dbcore.py index ff208c13e..f8f7abbff 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -782,10 +782,7 @@ class ResultsIteratorTest(unittest.TestCase): objs[100] def test_no_results(self): - assert ( - self.db._fetch(ModelFixture1, dbcore.query.FalseQuery()).get() - is None - ) + assert not self.db._fetch(ModelFixture1, dbcore.query.FalseQuery()) class TestException: diff --git a/test/test_importer.py b/test/test_importer.py index c1768df3e..713812030 100644 --- a/test/test_importer.py +++ b/test/test_importer.py @@ -305,7 +305,7 @@ class ImportSingletonTest(AutotagImportTestCase): } # As-is item import. - assert self.lib.albums().get() is None + assert not self.lib.albums() self.importer.add_choice(importer.Action.ASIS) self.importer.run() @@ -444,7 +444,7 @@ class ImportTest(PathsMixin, AutotagImportTestCase): assert f"No files imported from {import_dir}" in logs def test_asis_no_data_source(self): - assert self.lib.items().get() is None + assert not self.lib.items() self.importer.add_choice(importer.Action.ASIS) self.importer.run() @@ -467,7 +467,7 @@ class ImportTest(PathsMixin, AutotagImportTestCase): } # As-is album import. - assert self.lib.albums().get() is None + assert not self.lib.albums() self.importer.add_choice(importer.Action.ASIS) self.importer.run() @@ -488,7 +488,7 @@ class ImportTest(PathsMixin, AutotagImportTestCase): album.remove() # Autotagged. - assert self.lib.albums().get() is None + assert not self.lib.albums() self.importer.clear_choices() self.importer.add_choice(importer.Action.APPLY) self.importer.run()