db: stricten from_id

This commit is contained in:
Šarūnas Nejus 2025-12-26 09:35:47 +00:00
parent bbedf5b3fb
commit 898f439d5e
No known key found for this signature in database
8 changed files with 52 additions and 47 deletions

View file

@ -38,7 +38,10 @@ from functools import cached_property
from sqlite3 import Connection, sqlite_version_info
from typing import TYPE_CHECKING, Any, AnyStr, Generic
from typing_extensions import TypeVar # default value support
from typing_extensions import (
Self,
TypeVar, # default value support
)
from unidecode import unidecode
import beets
@ -84,9 +87,14 @@ class DBCustomFunctionError(Exception):
)
class NotFoundError(LookupError):
pass
class FormattedMapping(Mapping[str, str]):
"""A `dict`-like formatted view of a model.
The accessor `mapping[key]` returns the formatted version of
`model[key]` as a unicode string.
@ -365,6 +373,9 @@ class Model(ABC, Generic[D]):
def db(self) -> D:
return self._check_db()
def get_fresh_from_db(self) -> Self:
return self.db.from_id(self.__class__, self.id)
@classmethod
def _getters(cls: type[Model]):
"""Return a mapping from field names to getter functions."""
@ -652,11 +663,10 @@ class Model(ABC, Generic[D]):
if not self._dirty and self.db.revision == self._revision:
# Exit early
return
stored_obj = self.db.from_id(type(self), self.id)
assert stored_obj is not None, f"object {self.id} not in DB"
stored_obj = dict(self.get_fresh_from_db())
self._values_fixed = LazyConvertDict(self)
self._values_flex = LazyConvertDict(self)
self.update(dict(stored_obj))
self.update(stored_obj)
self.clear_dirty()
def remove(self):
@ -911,15 +921,14 @@ class Results(Generic[AnyModel]):
except StopIteration:
raise IndexError(f"result index {n} out of range")
def get(self) -> AnyModel | None:
"""Return the first matching object, or None if no objects
match.
"""
it = iter(self)
def get(self) -> AnyModel:
"""Return the first matching object."""
try:
return next(it)
return next(iter(self))
except StopIteration:
return None
raise NotFoundError(
f"No matching {self.model_class.__name__} found"
) from None
class Transaction:
@ -1306,12 +1315,6 @@ class Database:
sort if sort.is_slow() else None, # Slow sort component.
)
def from_id(
self,
model_cls: type[AnyModel],
id,
) -> AnyModel | None:
"""Get a Model object by its id or None if the id does not
exist.
"""
return self._fetch(model_cls, MatchQuery("id", id)).get()
def from_id(self, model_cls: type[AnyModel], id_: int) -> AnyModel:
"""Get a Model object by its id."""
return self._fetch(model_cls, MatchQuery("id", id_)).get()

View file

@ -1,5 +1,6 @@
from __future__ import annotations
from contextlib import suppress
from typing import TYPE_CHECKING
import platformdirs
@ -13,6 +14,7 @@ from .queries import PF_KEY_DEFAULT, parse_query_parts, parse_query_string
if TYPE_CHECKING:
from beets.dbcore import Results
from beets.dbcore.db import AnyModel
class Library(dbcore.Database):
@ -125,24 +127,26 @@ class Library(dbcore.Database):
return self._fetch(Item, query, sort or self.get_default_item_sort())
# Convenience accessors.
def _try_get(self, model_cls: type[AnyModel], id_: int) -> AnyModel | None:
with suppress(dbcore.db.NotFoundError):
return self.from_id(model_cls, id_)
def get_item(self, id):
return None
def get_item(self, id_: int) -> Item | None:
"""Fetch a :class:`Item` by its ID.
Return `None` if no match is found.
"""
return self.from_id(Item, id)
return self._try_get(Item, id_)
def get_album(self, item_or_id):
def get_album(self, item_or_id: Item | int) -> Album | None:
"""Given an album ID or an item associated with an album, return
a :class:`Album` object for the album.
If no such album exists, return `None`.
"""
if isinstance(item_or_id, int):
album_id = item_or_id
else:
album_id = item_or_id.album_id
if album_id is None:
return None
return self.from_id(Album, album_id)
album_id = (
item_or_id if isinstance(item_or_id, int) else item_or_id.album_id
)
return self._try_get(Album, album_id) if album_id else None

View file

@ -620,6 +620,8 @@ class Album(LibModel):
class Item(LibModel):
"""Represent a song or track."""
album_id: int | None
_table = "items"
_flex_table = "item_attributes"
_fields = {

View file

@ -1073,7 +1073,7 @@ def show_model_changes(
restrict the detection to. `always` indicates whether the object is
always identified, regardless of whether any changes are present.
"""
old = old or new._db.from_id(type(new), new.id)
old = old or new.get_fresh_from_db()
# Keep the formatted views around instead of re-creating them in each
# iteration step

View file

@ -20,6 +20,7 @@ import mpd
from beets import config, plugins, ui
from beets.dbcore import types
from beets.dbcore.db import NotFoundError
from beets.dbcore.query import PathQuery
from beets.util import displayable_path
@ -165,10 +166,9 @@ class MPDStats:
def get_item(self, path):
"""Return the beets item related to path."""
query = PathQuery("path", path)
item = self.lib.items(query).get()
if item:
return item
else:
try:
return self.lib.items(query).get()
except NotFoundError:
self._log.info("item not found: {}", displayable_path(path))
def update_item(self, item, attribute, value=None, increment=None):

View file

@ -344,10 +344,9 @@ def item_query(queries):
@app.route("/item/path/<everything:path>")
def item_at_path(path):
query = PathQuery("path", path.encode("utf-8"))
item = g.lib.items(query).get()
if item:
return flask.jsonify(_rep(item))
else:
try:
return flask.jsonify(_rep(g.lib.items(query).get()))
except beets.dbcore.db.NotFoundError:
return flask.abort(404)

View file

@ -782,10 +782,7 @@ class ResultsIteratorTest(unittest.TestCase):
objs[100]
def test_no_results(self):
assert (
self.db._fetch(ModelFixture1, dbcore.query.FalseQuery()).get()
is None
)
assert not self.db._fetch(ModelFixture1, dbcore.query.FalseQuery())
class TestException:

View file

@ -305,7 +305,7 @@ class ImportSingletonTest(AutotagImportTestCase):
}
# As-is item import.
assert self.lib.albums().get() is None
assert not self.lib.albums()
self.importer.add_choice(importer.Action.ASIS)
self.importer.run()
@ -444,7 +444,7 @@ class ImportTest(PathsMixin, AutotagImportTestCase):
assert f"No files imported from {import_dir}" in logs
def test_asis_no_data_source(self):
assert self.lib.items().get() is None
assert not self.lib.items()
self.importer.add_choice(importer.Action.ASIS)
self.importer.run()
@ -467,7 +467,7 @@ class ImportTest(PathsMixin, AutotagImportTestCase):
}
# As-is album import.
assert self.lib.albums().get() is None
assert not self.lib.albums()
self.importer.add_choice(importer.Action.ASIS)
self.importer.run()
@ -488,7 +488,7 @@ class ImportTest(PathsMixin, AutotagImportTestCase):
album.remove()
# Autotagged.
assert self.lib.albums().get() is None
assert not self.lib.albums()
self.importer.clear_choices()
self.importer.add_choice(importer.Action.APPLY)
self.importer.run()