Merge pull request #2988 from FichteForks/pr/item-album-fallback

Add fallback for item access to album's attributes
This commit is contained in:
Adrian Sampson 2021-03-07 09:29:37 -05:00 committed by GitHub
commit 3e8261393d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 232 additions and 29 deletions

View file

@ -56,10 +56,11 @@ class FormattedMapping(Mapping):
are replaced.
"""
def __init__(self, model, for_path=False):
def __init__(self, model, for_path=False, compute_keys=True):
self.for_path = for_path
self.model = model
self.model_keys = model.keys(True)
if compute_keys:
self.model_keys = model.keys(True)
def __getitem__(self, key):
if key in self.model_keys:
@ -257,6 +258,11 @@ class Model(object):
value is the same as the old value (e.g., `o.f = o.f`).
"""
_revision = -1
"""A revision number from when the model was loaded from or written
to the database.
"""
@classmethod
def _getters(cls):
"""Return a mapping from field names to getter functions.
@ -309,9 +315,11 @@ class Model(object):
def clear_dirty(self):
"""Mark all fields as *clean* (i.e., not needing to be stored to
the database).
the database). Also update the revision.
"""
self._dirty = set()
if self._db:
self._revision = self._db.revision
def _check_db(self, need_id=True):
"""Ensure that this object is associated with a database row: it
@ -351,9 +359,9 @@ class Model(object):
"""
return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT
def __getitem__(self, key):
"""Get the value for a field. Raise a KeyError if the field is
not available.
def _get(self, key, default=None, raise_=False):
"""Get the value for a field, or `default`. Alternatively,
raise a KeyError if the field is not available.
"""
getters = self._getters()
if key in getters: # Computed.
@ -365,8 +373,18 @@ class Model(object):
return self._type(key).null
elif key in self._values_flex: # Flexible.
return self._values_flex[key]
else:
elif raise_:
raise KeyError(key)
else:
return default
get = _get
def __getitem__(self, key):
"""Get the value for a field. Raise a KeyError if the field is
not available.
"""
return self._get(key, raise_=True)
def _setitem(self, key, value):
"""Assign the value for a field, return whether new and old value
@ -441,19 +459,10 @@ class Model(object):
for key in self:
yield key, self[key]
def get(self, key, default=None):
"""Get the value for a given key or `default` if it does not
exist.
"""
if key in self:
return self[key]
else:
return default
def __contains__(self, key):
"""Determine whether `key` is an attribute on this object.
"""
return key in self.keys(True)
return key in self.keys(computed=True)
def __iter__(self):
"""Iterate over the available field names (excluding computed
@ -538,8 +547,14 @@ class Model(object):
def load(self):
"""Refresh the object's metadata from the library database.
If check_revision is true, the database is only queried loaded when a
transaction has been committed since the item was last loaded.
"""
self._check_db()
if not self._dirty and self._db.revision == self._revision:
# Exit early
return
stored_obj = self._db._get(type(self), self.id)
assert stored_obj is not None, u"object {0} not in DB".format(self.id)
self._values_fixed = LazyConvertDict(self)
@ -794,6 +809,12 @@ class Transaction(object):
"""A context manager for safe, concurrent access to the database.
All SQL commands should be executed through a transaction.
"""
_mutated = False
"""A flag storing whether a mutation has been executed in the
current transaction.
"""
def __init__(self, db):
self.db = db
@ -815,12 +836,15 @@ class Transaction(object):
entered but not yet exited transaction. If it is the last active
transaction, the database updates are committed.
"""
# Beware of races; currently secured by db._db_lock
self.db.revision += self._mutated
with self.db._tx_stack() as stack:
assert stack.pop() is self
empty = not stack
if empty:
# Ending a "root" transaction. End the SQLite transaction.
self.db._connection().commit()
self._mutated = False
self.db._db_lock.release()
def query(self, statement, subvals=()):
@ -836,7 +860,6 @@ class Transaction(object):
"""
try:
cursor = self.db._connection().execute(statement, subvals)
return cursor.lastrowid
except sqlite3.OperationalError as e:
# In two specific cases, SQLite reports an error while accessing
# the underlying database file. We surface these exceptions as
@ -846,9 +869,14 @@ class Transaction(object):
raise DBAccessError(e.args[0])
else:
raise
else:
self._mutated = True
return cursor.lastrowid
def script(self, statements):
"""Execute a string containing multiple SQL statements."""
# We don't know whether this mutates, but quite likely it does.
self._mutated = True
self.db._connection().executescript(statements)
@ -864,6 +892,11 @@ class Database(object):
supports_extensions = hasattr(sqlite3.Connection, 'enable_load_extension')
"""Whether or not the current version of SQLite supports extensions"""
revision = 0
"""The current revision of the database. To be increased whenever
data is written in a transaction.
"""
def __init__(self, path, timeout=5.0):
self.path = path
self.timeout = timeout

View file

@ -786,7 +786,7 @@ class ImportTask(BaseImportTask):
if (not dup_item.album_id or
dup_item.album_id in replaced_album_ids):
continue
replaced_album = dup_item.get_album()
replaced_album = dup_item._cached_album
if replaced_album:
replaced_album_ids.add(dup_item.album_id)
self.replaced_albums[replaced_album.path] = replaced_album

View file

@ -375,7 +375,11 @@ class FormattedItemMapping(dbcore.db.FormattedMapping):
"""
def __init__(self, item, for_path=False):
super(FormattedItemMapping, self).__init__(item, for_path)
# We treat album and item keys specially here,
# so exclude transitive album keys from the model's keys.
super(FormattedItemMapping, self).__init__(item, for_path,
compute_keys=False)
self.model_keys = item.keys(computed=True, with_album=False)
self.item = item
@lazy_property
@ -386,15 +390,15 @@ class FormattedItemMapping(dbcore.db.FormattedMapping):
def album_keys(self):
album_keys = []
if self.album:
for key in self.album.keys(True):
for key in self.album.keys(computed=True):
if key in Album.item_keys \
or key not in self.item._fields.keys():
album_keys.append(key)
return album_keys
@lazy_property
@property
def album(self):
return self.item.get_album()
return self.item._cached_album
def _get(self, key):
"""Get the value for a key, either from the album or the item.
@ -545,6 +549,29 @@ class Item(LibModel):
_format_config_key = 'format_item'
__album = None
"""Cached album object. Read-only."""
@property
def _cached_album(self):
"""The Album object that this item belongs to, if any, or
None if the item is a singleton or is not associated with a
library.
The instance is cached and refreshed on access.
DO NOT MODIFY!
If you want a copy to modify, use :meth:`get_album`.
"""
if not self.__album and self._db:
self.__album = self._db.get_album(self)
elif self.__album:
self.__album.load()
return self.__album
@_cached_album.setter
def _cached_album(self, album):
self.__album = album
@classmethod
def _getters(cls):
getters = plugins.item_field_getters()
@ -571,12 +598,45 @@ class Item(LibModel):
value = bytestring_path(value)
elif isinstance(value, BLOB_TYPE):
value = bytes(value)
elif key == 'album_id':
self._cached_album = None
changed = super(Item, self)._setitem(key, value)
if changed and key in MediaFile.fields():
self.mtime = 0 # Reset mtime on dirty.
def __getitem__(self, key):
"""Get the value for a field, falling back to the album if
necessary. Raise a KeyError if the field is not available.
"""
try:
return super(Item, self).__getitem__(key)
except KeyError:
if self._cached_album:
return self._cached_album[key]
raise
def keys(self, computed=False, with_album=True):
"""Get a list of available field names. `with_album`
controls whether the album's fields are included.
"""
keys = super(Item, self).keys(computed=computed)
if with_album and self._cached_album:
keys += self._cached_album.keys(computed=computed)
return keys
def get(self, key, default=None, with_album=True):
"""Get the value for a given key or `default` if it does not
exist. Set `with_album` to false to skip album fallback.
"""
try:
return self._get(key, default, raise_=with_album)
except KeyError:
if self._cached_album:
return self._cached_album.get(key, default)
return default
def update(self, values):
"""Set all key/value pairs in the mapping. If mtime is
specified, it is not reset (as it might otherwise be).

View file

@ -1155,8 +1155,13 @@ def _setup(options, lib=None):
plugins.send("library_opened", lib=lib)
# Add types and queries defined by plugins.
library.Item._types.update(plugins.types(library.Item))
library.Album._types.update(plugins.types(library.Album))
plugin_types_album = plugins.types(library.Album)
library.Album._types.update(plugin_types_album)
item_types = plugin_types_album.copy()
item_types.update(library.Item._types)
item_types.update(plugins.types(library.Item))
library.Item._types = item_types
library.Item._queries.update(plugins.named_queries(library.Item))
library.Album._queries.update(plugins.named_queries(library.Album))

View file

@ -358,7 +358,7 @@ class ConvertPlugin(BeetsPlugin):
item.store() # Store new path and audio data.
if self.config['embed'] and not linked:
album = item.get_album()
album = item._cached_album
if album and album.artpath:
self._log.debug(u'embedding album art from {}',
util.displayable_path(album.artpath))

View file

@ -178,6 +178,11 @@ New features:
* :doc:`/plugins/replaygain` now does its analysis in parallel when using
the ``command`` or ``ffmpeg`` backends.
:bug:`3478`
* Fields in queries now fall back to an item's album and check its fields too.
Notably, this allows querying items by an album flex attribute, also in path
configuration.
Thanks to :user:`FichteFoll`.
:bug:`2797` :bug:`2988`
* Removes usage of the bs1770gain replaygain backend.
Thanks to :user:`SamuelCook`.
* Added ``trackdisambig`` which stores the recording disambiguation from
@ -344,6 +349,12 @@ For plugin developers:
:bug:`3355`
* The autotag hooks have been modified such that they now take 'bpm',
'musical_key' and a per-track based 'genre' as attributes.
* Item (and attribute) access on an item now falls back to the album's
attributes as well. If you specifically want to access an item's attributes,
use ``Item.get(key, with_album=False)``. :bug:`2988`
* ``Item.keys`` also has a ``with_album`` argument now, defaulting to ``True``.
* A ``revision`` attribute has been added to ``Database``. It is increased on
every transaction that mutates it. :bug:`2988`
For packagers:

View file

@ -225,6 +225,31 @@ class MigrationTest(unittest.TestCase):
self.fail("select failed")
class TransactionTest(unittest.TestCase):
def setUp(self):
self.db = DatabaseFixture1(':memory:')
def tearDown(self):
self.db._connection().close()
def test_mutate_increase_revision(self):
old_rev = self.db.revision
with self.db.transaction() as tx:
tx.mutate(
'INSERT INTO {0} '
'(field_one) '
'VALUES (?);'.format(ModelFixture1._table),
(111,),
)
self.assertGreater(self.db.revision, old_rev)
def test_query_no_increase_revision(self):
old_rev = self.db.revision
with self.db.transaction() as tx:
tx.query('PRAGMA table_info(%s)' % ModelFixture1._table)
self.assertEqual(self.db.revision, old_rev)
class ModelTest(unittest.TestCase):
def setUp(self):
self.db = DatabaseFixture1(':memory:')
@ -246,6 +271,30 @@ class ModelTest(unittest.TestCase):
row = self.db._connection().execute('select * from test').fetchone()
self.assertEqual(row['field_one'], 123)
def test_revision(self):
old_rev = self.db.revision
model = ModelFixture1()
model.add(self.db)
model.store()
self.assertEqual(model._revision, self.db.revision)
self.assertGreater(self.db.revision, old_rev)
mid_rev = self.db.revision
model2 = ModelFixture1()
model2.add(self.db)
model2.store()
self.assertGreater(model2._revision, mid_rev)
self.assertGreater(self.db.revision, model._revision)
# revision changed, so the model should be re-loaded
model.load()
self.assertEqual(model._revision, self.db.revision)
# revision did not change, so no reload
mod2_old_rev = model2._revision
model2.load()
self.assertEqual(model2._revision, mod2_old_rev)
def test_retrieve_by_id(self):
model = ModelFixture1()
model.add(self.db)

View file

@ -49,7 +49,7 @@ class IPFSPluginTest(unittest.TestCase, TestHelper):
want_item = test_album.items()[2]
for check_item in added_album.items():
try:
if check_item.ipfs:
if check_item.get('ipfs', with_album=False):
ipfs_item = os.path.basename(want_item.path).decode(
_fsencoding(),
)
@ -57,7 +57,8 @@ class IPFSPluginTest(unittest.TestCase, TestHelper):
ipfs_item)
want_path = bytestring_path(want_path)
self.assertEqual(check_item.path, want_path)
self.assertEqual(check_item.ipfs, want_item.ipfs)
self.assertEqual(check_item.get('ipfs', with_album=False),
want_item.ipfs)
self.assertEqual(check_item.title, want_item.title)
found = True
except AttributeError:

View file

@ -132,6 +132,21 @@ class GetSetTest(_common.TestCase):
def test_invalid_field_raises_attributeerror(self):
self.assertRaises(AttributeError, getattr, self.i, u'xyzzy')
def test_album_fallback(self):
# integration test of item-album fallback
lib = beets.library.Library(':memory:')
i = item(lib)
album = lib.add_album([i])
album['flex'] = u'foo'
album.store()
self.assertTrue('flex' in i)
self.assertFalse('flex' in i.keys(with_album=False))
self.assertEqual(i['flex'], u'foo')
self.assertEqual(i.get('flex'), u'foo')
self.assertEqual(i.get('flex', with_album=False), None)
self.assertEqual(i.get('flexx'), None)
class DestinationTest(_common.TestCase):
def setUp(self):
@ -491,6 +506,24 @@ class DestinationTest(_common.TestCase):
dest = self.i.destination()
self.assertEqual(dest[-2:], b'XX')
def test_album_field_query(self):
self.lib.directory = b'one'
self.lib.path_formats = [(u'default', u'two'),
(u'flex:foo', u'three')]
album = self.lib.add_album([self.i])
self.assertEqual(self.i.destination(), np('one/two'))
album['flex'] = u'foo'
album.store()
self.assertEqual(self.i.destination(), np('one/three'))
def test_album_field_in_template(self):
self.lib.directory = b'one'
self.lib.path_formats = [(u'default', u'$flex/two')]
album = self.lib.add_album([self.i])
album['flex'] = u'foo'
album.store()
self.assertEqual(self.i.destination(), np('one/foo/two'))
class ItemFormattedMappingTest(_common.LibTestCase):
def test_formatted_item_value(self):

View file

@ -109,7 +109,7 @@ class DummyDataTestCase(_common.TestCase, AssertsMixin):
items[2].comp = False
for item in items:
self.lib.add(item)
self.lib.add_album(items[:2])
self.album = self.lib.add_album(items[:2])
def assert_items_matched_all(self, results):
self.assert_items_matched(results, [
@ -300,6 +300,17 @@ class GetTest(DummyDataTestCase):
results = self.lib.items(q)
self.assertFalse(results)
def test_album_field_fallback(self):
self.album['albumflex'] = u'foo'
self.album.store()
q = u'albumflex:foo'
results = self.lib.items(q)
self.assert_items_matched(results, [
u'foo bar',
u'baz qux',
])
def test_invalid_query(self):
with self.assertRaises(InvalidQueryArgumentValueError) as raised:
dbcore.query.NumericQuery('year', u'199a')