synchronize access to mutable shared maps (GC-399)

We now properly synchronize access to _tx_stacks and _connections, which can be
concurrently mutated by different threads. This way I don't have to worry about
GIL semantics: DRF => SC!
This commit is contained in:
Adrian Sampson 2012-06-08 11:24:16 -07:00
parent 00c47b6811
commit 5d6e9b387a

View file

@ -22,6 +22,7 @@ import logging
import shlex
import unicodedata
import threading
import contextlib
from collections import defaultdict
from unidecode import unidecode
from beets.mediafile import MediaFile
@ -906,19 +907,12 @@ class Transaction(object):
def __init__(self, lib):
self.lib = lib
@property
def _stack(self):
"""Return the transaction stack that this transaction belongs
to. This is the associated library's stack for the current
thread ID. Transactions should never migrate across threads.
"""
return self.lib._tx_stacks[threading.current_thread().ident]
def __enter__(self):
"""Begin a transaction. This transaction may be created while
another is active in a different thread.
"""
self._stack.append(self)
with self.lib._tx_stack() as stack:
stack.append(self)
return self
def __exit__(self, exc_type, exc_value, traceback):
@ -926,8 +920,10 @@ class Transaction(object):
entered but not yet exited transaction. If it is the last active
transaction, the database updates are committed.
"""
assert self._stack.pop() is self
if not self._stack:
with self.lib._tx_stack() as stack:
assert stack.pop() is self
empty = not stack
if empty:
self.lib._connection().commit()
def query(self, statement, subvals=()):
@ -973,7 +969,11 @@ class Library(BaseLibrary):
self.timeout = timeout
self._connections = {}
self._tx_stacks = defaultdict(list)
# A lock to protect the _connections and _tx_stacks maps, which
# both map thread IDs to private resources.
self._shared_map_lock = threading.Lock()
# Set up database schema.
self._make_table('items', item_fields)
self._make_table('albums', album_fields)
@ -1025,19 +1025,30 @@ class Library(BaseLibrary):
One connection object is created per thread.
"""
thread_id = threading.current_thread().ident
if thread_id in self._connections:
return self._connections[thread_id]
else:
# Make a new connection.
conn = sqlite3.connect(self.path, self.timeout)
with self._shared_map_lock:
if thread_id in self._connections:
return self._connections[thread_id]
else:
# Make a new connection.
conn = sqlite3.connect(self.path, self.timeout)
# Access SELECT results like dictionaries.
conn.row_factory = sqlite3.Row
# Add the REGEXP function to SQLite queries.
conn.create_function("REGEXP", 2, _regexp)
# Access SELECT results like dictionaries.
conn.row_factory = sqlite3.Row
# Add the REGEXP function to SQLite queries.
conn.create_function("REGEXP", 2, _regexp)
self._connections[thread_id] = conn
return conn
self._connections[thread_id] = conn
return conn
@contextlib.contextmanager
def _tx_stack(self):
"""A context manager providing access to the current thread's
transaction stack. The context manager synchronizes access to
the stack map. Transactions should never migrate across threads.
"""
thread_id = threading.current_thread().ident
with self._shared_map_lock:
yield self._tx_stacks[thread_id]
def transaction(self):
"""Get a transaction object for interacting with the database.