diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 8cd89111e..b3a6c7dd8 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -64,6 +64,16 @@ class DBAccessError(Exception): """ +class DBCustomFunctionError(Exception): + """A sqlite function registered by beets failed.""" + + def __init__(self): + super().__init__( + "beets defined SQLite function failed; " + "see the other errors above for details" + ) + + class FormattedMapping(Mapping[str, str]): """A `dict`-like formatted view of a model. @@ -947,6 +957,12 @@ class Transaction: self._mutated = False self.db._db_lock.release() + if ( + isinstance(exc_value, sqlite3.OperationalError) + and exc_value.args[0] == "user-defined function raised exception" + ): + raise DBCustomFunctionError() + def query( self, statement: str, subvals: Sequence[SQLiteType] = () ) -> list[sqlite3.Row]: @@ -1007,6 +1023,13 @@ class Database: "sqlite3 must be compiled with multi-threading support" ) + # Print tracebacks for exceptions in user defined functions + # See also `self.add_functions` and `DBCustomFunctionError`. + # + # `if`: use feature detection because PyPy doesn't support this. + if hasattr(sqlite3, "enable_callback_tracebacks"): + sqlite3.enable_callback_tracebacks(True) + self.path = path self.timeout = timeout diff --git a/test/test_dbcore.py b/test/test_dbcore.py index b2ec2e968..d2c76d852 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -23,6 +23,7 @@ from tempfile import mkstemp import pytest from beets import dbcore +from beets.dbcore.db import DBCustomFunctionError from beets.library import LibModel from beets.test import _common from beets.util import cached_classproperty @@ -31,6 +32,13 @@ from beets.util import cached_classproperty # have multiple models with different numbers of fields. +@pytest.fixture +def db(model): + db = model(":memory:") + yield db + db._connection().close() + + class SortFixture(dbcore.query.FieldSort): pass @@ -784,3 +792,25 @@ class ResultsIteratorTest(unittest.TestCase): self.db._fetch(ModelFixture1, dbcore.query.FalseQuery()).get() is None ) + + +class TestException: + @pytest.mark.parametrize("model", [DatabaseFixture1]) + @pytest.mark.filterwarnings( + "ignore: .*plz_raise.*: pytest.PytestUnraisableExceptionWarning" + ) + @pytest.mark.filterwarnings( + "error: .*: pytest.PytestUnraisableExceptionWarning" + ) + def test_custom_function_error(self, db: DatabaseFixture1): + def plz_raise(): + raise Exception("i haz raized") + + db._connection().create_function("plz_raise", 0, plz_raise) + + with db.transaction() as tx: + tx.mutate("insert into test (field_one) values (1)") + + with pytest.raises(DBCustomFunctionError): + with db.transaction() as tx: + tx.query("select * from test where plz_raise()")