From b08db06c05435e23734d316cff8cd4778264426f Mon Sep 17 00:00:00 2001 From: Jack Wilsdon Date: Sat, 20 Apr 2019 20:43:24 +0100 Subject: [PATCH] Add load_extension method for loading SQLite extensions --- beets/dbcore/db.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 43c044572..97a4a7ce3 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -850,16 +850,21 @@ class Database(object): """A container for Model objects that wraps an SQLite database as the backend. """ + _models = () """The Model subclasses representing tables in this database. """ + supports_extensions = hasattr(sqlite3.Connection, 'enable_load_extension') + """Whether or not the current version of SQLite supports extensions""" + def __init__(self, path, timeout=5.0): self.path = path self.timeout = timeout self._connections = {} self._tx_stacks = defaultdict(list) + self._extensions = [] # A lock to protect the _connections and _tx_stacks maps, which # both map thread IDs to private resources. @@ -909,6 +914,13 @@ class Database(object): py3_path(self.path), timeout=self.timeout ) + if self.supports_extensions: + conn.enable_load_extension(True) + + # Load any extension that are already loaded for other connections. + for path in self._extensions: + conn.load_extension(path) + # Access SELECT results like dictionaries. conn.row_factory = sqlite3.Row return conn @@ -937,6 +949,18 @@ class Database(object): """ return Transaction(self) + def load_extension(self, path): + """Load an SQLite extension into all open connections.""" + if not self.supports_extensions: + raise ValueError( + 'this sqlite3 installation does not support extensions') + + self._extensions.append(path) + + # Load the extension into every open connection. + for conn in self._connections.values(): + conn.load_extension(path) + # Schema setup and migration. def _make_table(self, table, fields):