diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index d01e8a5c3..edd611928 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -733,19 +733,26 @@ class Database(object): if thread_id in self._connections: return self._connections[thread_id] else: - # Make a new connection. The `sqlite3` module can't use - # bytestring paths here on Python 3, so we need to - # provide a `str` using `py3_path`. - conn = sqlite3.connect( - py3_path(self.path), timeout=self.timeout - ) - - # Access SELECT results like dictionaries. - conn.row_factory = sqlite3.Row - + conn = self._create_connection() self._connections[thread_id] = conn return conn + def _create_connection(self): + """Create a SQLite connection to the underlying database. Makes + a new connection every time. If you need to add custom functions + to each connection, override this method. + """ + # Make a new connection. The `sqlite3` module can't use + # bytestring paths here on Python 3, so we need to + # provide a `str` using `py3_path`. + conn = sqlite3.connect( + py3_path(self.path), timeout=self.timeout + ) + + # Access SELECT results like dictionaries. + conn.row_factory = sqlite3.Row + return conn + def _close(self): """Close the all connections to the underlying SQLite database from all threads. This does not render the database object diff --git a/beets/library.py b/beets/library.py index 32176b68d..e3ac1bd40 100644 --- a/beets/library.py +++ b/beets/library.py @@ -1237,14 +1237,17 @@ class Library(dbcore.Database): timeout = beets.config['timeout'].as_number() super(Library, self).__init__(path, timeout=timeout) - self._connection().create_function('bytelower', 1, _sqlite_bytelower) - self.directory = bytestring_path(normpath(directory)) self.path_formats = path_formats self.replacements = replacements self._memotable = {} # Used for template substitution performance. + def _create_connection(self): + conn = super(Library, self)._create_connection() + conn.create_function('bytelower', 1, _sqlite_bytelower) + return conn + # Adding objects to the database. def add(self, obj):