mirror of
https://github.com/beetbox/beets.git
synced 2026-02-15 20:03:25 +01:00
db: stricten from_id
This commit is contained in:
parent
bbedf5b3fb
commit
898f439d5e
8 changed files with 52 additions and 47 deletions
|
|
@ -38,7 +38,10 @@ from functools import cached_property
|
|||
from sqlite3 import Connection, sqlite_version_info
|
||||
from typing import TYPE_CHECKING, Any, AnyStr, Generic
|
||||
|
||||
from typing_extensions import TypeVar # default value support
|
||||
from typing_extensions import (
|
||||
Self,
|
||||
TypeVar, # default value support
|
||||
)
|
||||
from unidecode import unidecode
|
||||
|
||||
import beets
|
||||
|
|
@ -84,9 +87,14 @@ class DBCustomFunctionError(Exception):
|
|||
)
|
||||
|
||||
|
||||
class NotFoundError(LookupError):
|
||||
pass
|
||||
|
||||
|
||||
class FormattedMapping(Mapping[str, str]):
|
||||
"""A `dict`-like formatted view of a model.
|
||||
|
||||
|
||||
The accessor `mapping[key]` returns the formatted version of
|
||||
`model[key]` as a unicode string.
|
||||
|
||||
|
|
@ -365,6 +373,9 @@ class Model(ABC, Generic[D]):
|
|||
def db(self) -> D:
|
||||
return self._check_db()
|
||||
|
||||
def get_fresh_from_db(self) -> Self:
|
||||
return self.db.from_id(self.__class__, self.id)
|
||||
|
||||
@classmethod
|
||||
def _getters(cls: type[Model]):
|
||||
"""Return a mapping from field names to getter functions."""
|
||||
|
|
@ -652,11 +663,10 @@ class Model(ABC, Generic[D]):
|
|||
if not self._dirty and self.db.revision == self._revision:
|
||||
# Exit early
|
||||
return
|
||||
stored_obj = self.db.from_id(type(self), self.id)
|
||||
assert stored_obj is not None, f"object {self.id} not in DB"
|
||||
stored_obj = dict(self.get_fresh_from_db())
|
||||
self._values_fixed = LazyConvertDict(self)
|
||||
self._values_flex = LazyConvertDict(self)
|
||||
self.update(dict(stored_obj))
|
||||
self.update(stored_obj)
|
||||
self.clear_dirty()
|
||||
|
||||
def remove(self):
|
||||
|
|
@ -911,15 +921,14 @@ class Results(Generic[AnyModel]):
|
|||
except StopIteration:
|
||||
raise IndexError(f"result index {n} out of range")
|
||||
|
||||
def get(self) -> AnyModel | None:
|
||||
"""Return the first matching object, or None if no objects
|
||||
match.
|
||||
"""
|
||||
it = iter(self)
|
||||
def get(self) -> AnyModel:
|
||||
"""Return the first matching object."""
|
||||
try:
|
||||
return next(it)
|
||||
return next(iter(self))
|
||||
except StopIteration:
|
||||
return None
|
||||
raise NotFoundError(
|
||||
f"No matching {self.model_class.__name__} found"
|
||||
) from None
|
||||
|
||||
|
||||
class Transaction:
|
||||
|
|
@ -1306,12 +1315,6 @@ class Database:
|
|||
sort if sort.is_slow() else None, # Slow sort component.
|
||||
)
|
||||
|
||||
def from_id(
|
||||
self,
|
||||
model_cls: type[AnyModel],
|
||||
id,
|
||||
) -> AnyModel | None:
|
||||
"""Get a Model object by its id or None if the id does not
|
||||
exist.
|
||||
"""
|
||||
return self._fetch(model_cls, MatchQuery("id", id)).get()
|
||||
def from_id(self, model_cls: type[AnyModel], id_: int) -> AnyModel:
|
||||
"""Get a Model object by its id."""
|
||||
return self._fetch(model_cls, MatchQuery("id", id_)).get()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from contextlib import suppress
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import platformdirs
|
||||
|
|
@ -13,6 +14,7 @@ from .queries import PF_KEY_DEFAULT, parse_query_parts, parse_query_string
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from beets.dbcore import Results
|
||||
from beets.dbcore.db import AnyModel
|
||||
|
||||
|
||||
class Library(dbcore.Database):
|
||||
|
|
@ -125,24 +127,26 @@ class Library(dbcore.Database):
|
|||
return self._fetch(Item, query, sort or self.get_default_item_sort())
|
||||
|
||||
# Convenience accessors.
|
||||
def _try_get(self, model_cls: type[AnyModel], id_: int) -> AnyModel | None:
|
||||
with suppress(dbcore.db.NotFoundError):
|
||||
return self.from_id(model_cls, id_)
|
||||
|
||||
def get_item(self, id):
|
||||
return None
|
||||
|
||||
def get_item(self, id_: int) -> Item | None:
|
||||
"""Fetch a :class:`Item` by its ID.
|
||||
|
||||
Return `None` if no match is found.
|
||||
"""
|
||||
return self.from_id(Item, id)
|
||||
return self._try_get(Item, id_)
|
||||
|
||||
def get_album(self, item_or_id):
|
||||
def get_album(self, item_or_id: Item | int) -> Album | None:
|
||||
"""Given an album ID or an item associated with an album, return
|
||||
a :class:`Album` object for the album.
|
||||
|
||||
If no such album exists, return `None`.
|
||||
"""
|
||||
if isinstance(item_or_id, int):
|
||||
album_id = item_or_id
|
||||
else:
|
||||
album_id = item_or_id.album_id
|
||||
if album_id is None:
|
||||
return None
|
||||
return self.from_id(Album, album_id)
|
||||
album_id = (
|
||||
item_or_id if isinstance(item_or_id, int) else item_or_id.album_id
|
||||
)
|
||||
return self._try_get(Album, album_id) if album_id else None
|
||||
|
|
|
|||
|
|
@ -620,6 +620,8 @@ class Album(LibModel):
|
|||
class Item(LibModel):
|
||||
"""Represent a song or track."""
|
||||
|
||||
album_id: int | None
|
||||
|
||||
_table = "items"
|
||||
_flex_table = "item_attributes"
|
||||
_fields = {
|
||||
|
|
|
|||
|
|
@ -1073,7 +1073,7 @@ def show_model_changes(
|
|||
restrict the detection to. `always` indicates whether the object is
|
||||
always identified, regardless of whether any changes are present.
|
||||
"""
|
||||
old = old or new._db.from_id(type(new), new.id)
|
||||
old = old or new.get_fresh_from_db()
|
||||
|
||||
# Keep the formatted views around instead of re-creating them in each
|
||||
# iteration step
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import mpd
|
|||
|
||||
from beets import config, plugins, ui
|
||||
from beets.dbcore import types
|
||||
from beets.dbcore.db import NotFoundError
|
||||
from beets.dbcore.query import PathQuery
|
||||
from beets.util import displayable_path
|
||||
|
||||
|
|
@ -165,10 +166,9 @@ class MPDStats:
|
|||
def get_item(self, path):
|
||||
"""Return the beets item related to path."""
|
||||
query = PathQuery("path", path)
|
||||
item = self.lib.items(query).get()
|
||||
if item:
|
||||
return item
|
||||
else:
|
||||
try:
|
||||
return self.lib.items(query).get()
|
||||
except NotFoundError:
|
||||
self._log.info("item not found: {}", displayable_path(path))
|
||||
|
||||
def update_item(self, item, attribute, value=None, increment=None):
|
||||
|
|
|
|||
|
|
@ -344,10 +344,9 @@ def item_query(queries):
|
|||
@app.route("/item/path/<everything:path>")
|
||||
def item_at_path(path):
|
||||
query = PathQuery("path", path.encode("utf-8"))
|
||||
item = g.lib.items(query).get()
|
||||
if item:
|
||||
return flask.jsonify(_rep(item))
|
||||
else:
|
||||
try:
|
||||
return flask.jsonify(_rep(g.lib.items(query).get()))
|
||||
except beets.dbcore.db.NotFoundError:
|
||||
return flask.abort(404)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -782,10 +782,7 @@ class ResultsIteratorTest(unittest.TestCase):
|
|||
objs[100]
|
||||
|
||||
def test_no_results(self):
|
||||
assert (
|
||||
self.db._fetch(ModelFixture1, dbcore.query.FalseQuery()).get()
|
||||
is None
|
||||
)
|
||||
assert not self.db._fetch(ModelFixture1, dbcore.query.FalseQuery())
|
||||
|
||||
|
||||
class TestException:
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ class ImportSingletonTest(AutotagImportTestCase):
|
|||
}
|
||||
|
||||
# As-is item import.
|
||||
assert self.lib.albums().get() is None
|
||||
assert not self.lib.albums()
|
||||
self.importer.add_choice(importer.Action.ASIS)
|
||||
self.importer.run()
|
||||
|
||||
|
|
@ -444,7 +444,7 @@ class ImportTest(PathsMixin, AutotagImportTestCase):
|
|||
assert f"No files imported from {import_dir}" in logs
|
||||
|
||||
def test_asis_no_data_source(self):
|
||||
assert self.lib.items().get() is None
|
||||
assert not self.lib.items()
|
||||
|
||||
self.importer.add_choice(importer.Action.ASIS)
|
||||
self.importer.run()
|
||||
|
|
@ -467,7 +467,7 @@ class ImportTest(PathsMixin, AutotagImportTestCase):
|
|||
}
|
||||
|
||||
# As-is album import.
|
||||
assert self.lib.albums().get() is None
|
||||
assert not self.lib.albums()
|
||||
self.importer.add_choice(importer.Action.ASIS)
|
||||
self.importer.run()
|
||||
|
||||
|
|
@ -488,7 +488,7 @@ class ImportTest(PathsMixin, AutotagImportTestCase):
|
|||
album.remove()
|
||||
|
||||
# Autotagged.
|
||||
assert self.lib.albums().get() is None
|
||||
assert not self.lib.albums()
|
||||
self.importer.clear_choices()
|
||||
self.importer.add_choice(importer.Action.APPLY)
|
||||
self.importer.run()
|
||||
|
|
|
|||
Loading…
Reference in a new issue