mirror of
https://github.com/beetbox/beets.git
synced 2026-01-17 21:52:34 +01:00
Merge pull request #2988 from FichteForks/pr/item-album-fallback
Add fallback for item access to album's attributes
This commit is contained in:
commit
3e8261393d
10 changed files with 232 additions and 29 deletions
|
|
@ -56,10 +56,11 @@ class FormattedMapping(Mapping):
|
|||
are replaced.
|
||||
"""
|
||||
|
||||
def __init__(self, model, for_path=False):
|
||||
def __init__(self, model, for_path=False, compute_keys=True):
|
||||
self.for_path = for_path
|
||||
self.model = model
|
||||
self.model_keys = model.keys(True)
|
||||
if compute_keys:
|
||||
self.model_keys = model.keys(True)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self.model_keys:
|
||||
|
|
@ -257,6 +258,11 @@ class Model(object):
|
|||
value is the same as the old value (e.g., `o.f = o.f`).
|
||||
"""
|
||||
|
||||
_revision = -1
|
||||
"""A revision number from when the model was loaded from or written
|
||||
to the database.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _getters(cls):
|
||||
"""Return a mapping from field names to getter functions.
|
||||
|
|
@ -309,9 +315,11 @@ class Model(object):
|
|||
|
||||
def clear_dirty(self):
|
||||
"""Mark all fields as *clean* (i.e., not needing to be stored to
|
||||
the database).
|
||||
the database). Also update the revision.
|
||||
"""
|
||||
self._dirty = set()
|
||||
if self._db:
|
||||
self._revision = self._db.revision
|
||||
|
||||
def _check_db(self, need_id=True):
|
||||
"""Ensure that this object is associated with a database row: it
|
||||
|
|
@ -351,9 +359,9 @@ class Model(object):
|
|||
"""
|
||||
return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get the value for a field. Raise a KeyError if the field is
|
||||
not available.
|
||||
def _get(self, key, default=None, raise_=False):
|
||||
"""Get the value for a field, or `default`. Alternatively,
|
||||
raise a KeyError if the field is not available.
|
||||
"""
|
||||
getters = self._getters()
|
||||
if key in getters: # Computed.
|
||||
|
|
@ -365,8 +373,18 @@ class Model(object):
|
|||
return self._type(key).null
|
||||
elif key in self._values_flex: # Flexible.
|
||||
return self._values_flex[key]
|
||||
else:
|
||||
elif raise_:
|
||||
raise KeyError(key)
|
||||
else:
|
||||
return default
|
||||
|
||||
get = _get
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get the value for a field. Raise a KeyError if the field is
|
||||
not available.
|
||||
"""
|
||||
return self._get(key, raise_=True)
|
||||
|
||||
def _setitem(self, key, value):
|
||||
"""Assign the value for a field, return whether new and old value
|
||||
|
|
@ -441,19 +459,10 @@ class Model(object):
|
|||
for key in self:
|
||||
yield key, self[key]
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""Get the value for a given key or `default` if it does not
|
||||
exist.
|
||||
"""
|
||||
if key in self:
|
||||
return self[key]
|
||||
else:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
"""Determine whether `key` is an attribute on this object.
|
||||
"""
|
||||
return key in self.keys(True)
|
||||
return key in self.keys(computed=True)
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over the available field names (excluding computed
|
||||
|
|
@ -538,8 +547,14 @@ class Model(object):
|
|||
|
||||
def load(self):
|
||||
"""Refresh the object's metadata from the library database.
|
||||
|
||||
If check_revision is true, the database is only queried loaded when a
|
||||
transaction has been committed since the item was last loaded.
|
||||
"""
|
||||
self._check_db()
|
||||
if not self._dirty and self._db.revision == self._revision:
|
||||
# Exit early
|
||||
return
|
||||
stored_obj = self._db._get(type(self), self.id)
|
||||
assert stored_obj is not None, u"object {0} not in DB".format(self.id)
|
||||
self._values_fixed = LazyConvertDict(self)
|
||||
|
|
@ -794,6 +809,12 @@ class Transaction(object):
|
|||
"""A context manager for safe, concurrent access to the database.
|
||||
All SQL commands should be executed through a transaction.
|
||||
"""
|
||||
|
||||
_mutated = False
|
||||
"""A flag storing whether a mutation has been executed in the
|
||||
current transaction.
|
||||
"""
|
||||
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
|
||||
|
|
@ -815,12 +836,15 @@ class Transaction(object):
|
|||
entered but not yet exited transaction. If it is the last active
|
||||
transaction, the database updates are committed.
|
||||
"""
|
||||
# Beware of races; currently secured by db._db_lock
|
||||
self.db.revision += self._mutated
|
||||
with self.db._tx_stack() as stack:
|
||||
assert stack.pop() is self
|
||||
empty = not stack
|
||||
if empty:
|
||||
# Ending a "root" transaction. End the SQLite transaction.
|
||||
self.db._connection().commit()
|
||||
self._mutated = False
|
||||
self.db._db_lock.release()
|
||||
|
||||
def query(self, statement, subvals=()):
|
||||
|
|
@ -836,7 +860,6 @@ class Transaction(object):
|
|||
"""
|
||||
try:
|
||||
cursor = self.db._connection().execute(statement, subvals)
|
||||
return cursor.lastrowid
|
||||
except sqlite3.OperationalError as e:
|
||||
# In two specific cases, SQLite reports an error while accessing
|
||||
# the underlying database file. We surface these exceptions as
|
||||
|
|
@ -846,9 +869,14 @@ class Transaction(object):
|
|||
raise DBAccessError(e.args[0])
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
self._mutated = True
|
||||
return cursor.lastrowid
|
||||
|
||||
def script(self, statements):
|
||||
"""Execute a string containing multiple SQL statements."""
|
||||
# We don't know whether this mutates, but quite likely it does.
|
||||
self._mutated = True
|
||||
self.db._connection().executescript(statements)
|
||||
|
||||
|
||||
|
|
@ -864,6 +892,11 @@ class Database(object):
|
|||
supports_extensions = hasattr(sqlite3.Connection, 'enable_load_extension')
|
||||
"""Whether or not the current version of SQLite supports extensions"""
|
||||
|
||||
revision = 0
|
||||
"""The current revision of the database. To be increased whenever
|
||||
data is written in a transaction.
|
||||
"""
|
||||
|
||||
def __init__(self, path, timeout=5.0):
|
||||
self.path = path
|
||||
self.timeout = timeout
|
||||
|
|
|
|||
|
|
@ -786,7 +786,7 @@ class ImportTask(BaseImportTask):
|
|||
if (not dup_item.album_id or
|
||||
dup_item.album_id in replaced_album_ids):
|
||||
continue
|
||||
replaced_album = dup_item.get_album()
|
||||
replaced_album = dup_item._cached_album
|
||||
if replaced_album:
|
||||
replaced_album_ids.add(dup_item.album_id)
|
||||
self.replaced_albums[replaced_album.path] = replaced_album
|
||||
|
|
|
|||
|
|
@ -375,7 +375,11 @@ class FormattedItemMapping(dbcore.db.FormattedMapping):
|
|||
"""
|
||||
|
||||
def __init__(self, item, for_path=False):
|
||||
super(FormattedItemMapping, self).__init__(item, for_path)
|
||||
# We treat album and item keys specially here,
|
||||
# so exclude transitive album keys from the model's keys.
|
||||
super(FormattedItemMapping, self).__init__(item, for_path,
|
||||
compute_keys=False)
|
||||
self.model_keys = item.keys(computed=True, with_album=False)
|
||||
self.item = item
|
||||
|
||||
@lazy_property
|
||||
|
|
@ -386,15 +390,15 @@ class FormattedItemMapping(dbcore.db.FormattedMapping):
|
|||
def album_keys(self):
|
||||
album_keys = []
|
||||
if self.album:
|
||||
for key in self.album.keys(True):
|
||||
for key in self.album.keys(computed=True):
|
||||
if key in Album.item_keys \
|
||||
or key not in self.item._fields.keys():
|
||||
album_keys.append(key)
|
||||
return album_keys
|
||||
|
||||
@lazy_property
|
||||
@property
|
||||
def album(self):
|
||||
return self.item.get_album()
|
||||
return self.item._cached_album
|
||||
|
||||
def _get(self, key):
|
||||
"""Get the value for a key, either from the album or the item.
|
||||
|
|
@ -545,6 +549,29 @@ class Item(LibModel):
|
|||
|
||||
_format_config_key = 'format_item'
|
||||
|
||||
__album = None
|
||||
"""Cached album object. Read-only."""
|
||||
|
||||
@property
|
||||
def _cached_album(self):
|
||||
"""The Album object that this item belongs to, if any, or
|
||||
None if the item is a singleton or is not associated with a
|
||||
library.
|
||||
The instance is cached and refreshed on access.
|
||||
|
||||
DO NOT MODIFY!
|
||||
If you want a copy to modify, use :meth:`get_album`.
|
||||
"""
|
||||
if not self.__album and self._db:
|
||||
self.__album = self._db.get_album(self)
|
||||
elif self.__album:
|
||||
self.__album.load()
|
||||
return self.__album
|
||||
|
||||
@_cached_album.setter
|
||||
def _cached_album(self, album):
|
||||
self.__album = album
|
||||
|
||||
@classmethod
|
||||
def _getters(cls):
|
||||
getters = plugins.item_field_getters()
|
||||
|
|
@ -571,12 +598,45 @@ class Item(LibModel):
|
|||
value = bytestring_path(value)
|
||||
elif isinstance(value, BLOB_TYPE):
|
||||
value = bytes(value)
|
||||
elif key == 'album_id':
|
||||
self._cached_album = None
|
||||
|
||||
changed = super(Item, self)._setitem(key, value)
|
||||
|
||||
if changed and key in MediaFile.fields():
|
||||
self.mtime = 0 # Reset mtime on dirty.
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get the value for a field, falling back to the album if
|
||||
necessary. Raise a KeyError if the field is not available.
|
||||
"""
|
||||
try:
|
||||
return super(Item, self).__getitem__(key)
|
||||
except KeyError:
|
||||
if self._cached_album:
|
||||
return self._cached_album[key]
|
||||
raise
|
||||
|
||||
def keys(self, computed=False, with_album=True):
|
||||
"""Get a list of available field names. `with_album`
|
||||
controls whether the album's fields are included.
|
||||
"""
|
||||
keys = super(Item, self).keys(computed=computed)
|
||||
if with_album and self._cached_album:
|
||||
keys += self._cached_album.keys(computed=computed)
|
||||
return keys
|
||||
|
||||
def get(self, key, default=None, with_album=True):
|
||||
"""Get the value for a given key or `default` if it does not
|
||||
exist. Set `with_album` to false to skip album fallback.
|
||||
"""
|
||||
try:
|
||||
return self._get(key, default, raise_=with_album)
|
||||
except KeyError:
|
||||
if self._cached_album:
|
||||
return self._cached_album.get(key, default)
|
||||
return default
|
||||
|
||||
def update(self, values):
|
||||
"""Set all key/value pairs in the mapping. If mtime is
|
||||
specified, it is not reset (as it might otherwise be).
|
||||
|
|
|
|||
|
|
@ -1155,8 +1155,13 @@ def _setup(options, lib=None):
|
|||
plugins.send("library_opened", lib=lib)
|
||||
|
||||
# Add types and queries defined by plugins.
|
||||
library.Item._types.update(plugins.types(library.Item))
|
||||
library.Album._types.update(plugins.types(library.Album))
|
||||
plugin_types_album = plugins.types(library.Album)
|
||||
library.Album._types.update(plugin_types_album)
|
||||
item_types = plugin_types_album.copy()
|
||||
item_types.update(library.Item._types)
|
||||
item_types.update(plugins.types(library.Item))
|
||||
library.Item._types = item_types
|
||||
|
||||
library.Item._queries.update(plugins.named_queries(library.Item))
|
||||
library.Album._queries.update(plugins.named_queries(library.Album))
|
||||
|
||||
|
|
|
|||
|
|
@ -358,7 +358,7 @@ class ConvertPlugin(BeetsPlugin):
|
|||
item.store() # Store new path and audio data.
|
||||
|
||||
if self.config['embed'] and not linked:
|
||||
album = item.get_album()
|
||||
album = item._cached_album
|
||||
if album and album.artpath:
|
||||
self._log.debug(u'embedding album art from {}',
|
||||
util.displayable_path(album.artpath))
|
||||
|
|
|
|||
|
|
@ -178,6 +178,11 @@ New features:
|
|||
* :doc:`/plugins/replaygain` now does its analysis in parallel when using
|
||||
the ``command`` or ``ffmpeg`` backends.
|
||||
:bug:`3478`
|
||||
* Fields in queries now fall back to an item's album and check its fields too.
|
||||
Notably, this allows querying items by an album flex attribute, also in path
|
||||
configuration.
|
||||
Thanks to :user:`FichteFoll`.
|
||||
:bug:`2797` :bug:`2988`
|
||||
* Removes usage of the bs1770gain replaygain backend.
|
||||
Thanks to :user:`SamuelCook`.
|
||||
* Added ``trackdisambig`` which stores the recording disambiguation from
|
||||
|
|
@ -344,6 +349,12 @@ For plugin developers:
|
|||
:bug:`3355`
|
||||
* The autotag hooks have been modified such that they now take 'bpm',
|
||||
'musical_key' and a per-track based 'genre' as attributes.
|
||||
* Item (and attribute) access on an item now falls back to the album's
|
||||
attributes as well. If you specifically want to access an item's attributes,
|
||||
use ``Item.get(key, with_album=False)``. :bug:`2988`
|
||||
* ``Item.keys`` also has a ``with_album`` argument now, defaulting to ``True``.
|
||||
* A ``revision`` attribute has been added to ``Database``. It is increased on
|
||||
every transaction that mutates it. :bug:`2988`
|
||||
|
||||
For packagers:
|
||||
|
||||
|
|
|
|||
|
|
@ -225,6 +225,31 @@ class MigrationTest(unittest.TestCase):
|
|||
self.fail("select failed")
|
||||
|
||||
|
||||
class TransactionTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.db = DatabaseFixture1(':memory:')
|
||||
|
||||
def tearDown(self):
|
||||
self.db._connection().close()
|
||||
|
||||
def test_mutate_increase_revision(self):
|
||||
old_rev = self.db.revision
|
||||
with self.db.transaction() as tx:
|
||||
tx.mutate(
|
||||
'INSERT INTO {0} '
|
||||
'(field_one) '
|
||||
'VALUES (?);'.format(ModelFixture1._table),
|
||||
(111,),
|
||||
)
|
||||
self.assertGreater(self.db.revision, old_rev)
|
||||
|
||||
def test_query_no_increase_revision(self):
|
||||
old_rev = self.db.revision
|
||||
with self.db.transaction() as tx:
|
||||
tx.query('PRAGMA table_info(%s)' % ModelFixture1._table)
|
||||
self.assertEqual(self.db.revision, old_rev)
|
||||
|
||||
|
||||
class ModelTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.db = DatabaseFixture1(':memory:')
|
||||
|
|
@ -246,6 +271,30 @@ class ModelTest(unittest.TestCase):
|
|||
row = self.db._connection().execute('select * from test').fetchone()
|
||||
self.assertEqual(row['field_one'], 123)
|
||||
|
||||
def test_revision(self):
|
||||
old_rev = self.db.revision
|
||||
model = ModelFixture1()
|
||||
model.add(self.db)
|
||||
model.store()
|
||||
self.assertEqual(model._revision, self.db.revision)
|
||||
self.assertGreater(self.db.revision, old_rev)
|
||||
|
||||
mid_rev = self.db.revision
|
||||
model2 = ModelFixture1()
|
||||
model2.add(self.db)
|
||||
model2.store()
|
||||
self.assertGreater(model2._revision, mid_rev)
|
||||
self.assertGreater(self.db.revision, model._revision)
|
||||
|
||||
# revision changed, so the model should be re-loaded
|
||||
model.load()
|
||||
self.assertEqual(model._revision, self.db.revision)
|
||||
|
||||
# revision did not change, so no reload
|
||||
mod2_old_rev = model2._revision
|
||||
model2.load()
|
||||
self.assertEqual(model2._revision, mod2_old_rev)
|
||||
|
||||
def test_retrieve_by_id(self):
|
||||
model = ModelFixture1()
|
||||
model.add(self.db)
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class IPFSPluginTest(unittest.TestCase, TestHelper):
|
|||
want_item = test_album.items()[2]
|
||||
for check_item in added_album.items():
|
||||
try:
|
||||
if check_item.ipfs:
|
||||
if check_item.get('ipfs', with_album=False):
|
||||
ipfs_item = os.path.basename(want_item.path).decode(
|
||||
_fsencoding(),
|
||||
)
|
||||
|
|
@ -57,7 +57,8 @@ class IPFSPluginTest(unittest.TestCase, TestHelper):
|
|||
ipfs_item)
|
||||
want_path = bytestring_path(want_path)
|
||||
self.assertEqual(check_item.path, want_path)
|
||||
self.assertEqual(check_item.ipfs, want_item.ipfs)
|
||||
self.assertEqual(check_item.get('ipfs', with_album=False),
|
||||
want_item.ipfs)
|
||||
self.assertEqual(check_item.title, want_item.title)
|
||||
found = True
|
||||
except AttributeError:
|
||||
|
|
|
|||
|
|
@ -132,6 +132,21 @@ class GetSetTest(_common.TestCase):
|
|||
def test_invalid_field_raises_attributeerror(self):
|
||||
self.assertRaises(AttributeError, getattr, self.i, u'xyzzy')
|
||||
|
||||
def test_album_fallback(self):
|
||||
# integration test of item-album fallback
|
||||
lib = beets.library.Library(':memory:')
|
||||
i = item(lib)
|
||||
album = lib.add_album([i])
|
||||
album['flex'] = u'foo'
|
||||
album.store()
|
||||
|
||||
self.assertTrue('flex' in i)
|
||||
self.assertFalse('flex' in i.keys(with_album=False))
|
||||
self.assertEqual(i['flex'], u'foo')
|
||||
self.assertEqual(i.get('flex'), u'foo')
|
||||
self.assertEqual(i.get('flex', with_album=False), None)
|
||||
self.assertEqual(i.get('flexx'), None)
|
||||
|
||||
|
||||
class DestinationTest(_common.TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -491,6 +506,24 @@ class DestinationTest(_common.TestCase):
|
|||
dest = self.i.destination()
|
||||
self.assertEqual(dest[-2:], b'XX')
|
||||
|
||||
def test_album_field_query(self):
|
||||
self.lib.directory = b'one'
|
||||
self.lib.path_formats = [(u'default', u'two'),
|
||||
(u'flex:foo', u'three')]
|
||||
album = self.lib.add_album([self.i])
|
||||
self.assertEqual(self.i.destination(), np('one/two'))
|
||||
album['flex'] = u'foo'
|
||||
album.store()
|
||||
self.assertEqual(self.i.destination(), np('one/three'))
|
||||
|
||||
def test_album_field_in_template(self):
|
||||
self.lib.directory = b'one'
|
||||
self.lib.path_formats = [(u'default', u'$flex/two')]
|
||||
album = self.lib.add_album([self.i])
|
||||
album['flex'] = u'foo'
|
||||
album.store()
|
||||
self.assertEqual(self.i.destination(), np('one/foo/two'))
|
||||
|
||||
|
||||
class ItemFormattedMappingTest(_common.LibTestCase):
|
||||
def test_formatted_item_value(self):
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class DummyDataTestCase(_common.TestCase, AssertsMixin):
|
|||
items[2].comp = False
|
||||
for item in items:
|
||||
self.lib.add(item)
|
||||
self.lib.add_album(items[:2])
|
||||
self.album = self.lib.add_album(items[:2])
|
||||
|
||||
def assert_items_matched_all(self, results):
|
||||
self.assert_items_matched(results, [
|
||||
|
|
@ -300,6 +300,17 @@ class GetTest(DummyDataTestCase):
|
|||
results = self.lib.items(q)
|
||||
self.assertFalse(results)
|
||||
|
||||
def test_album_field_fallback(self):
|
||||
self.album['albumflex'] = u'foo'
|
||||
self.album.store()
|
||||
|
||||
q = u'albumflex:foo'
|
||||
results = self.lib.items(q)
|
||||
self.assert_items_matched(results, [
|
||||
u'foo bar',
|
||||
u'baz qux',
|
||||
])
|
||||
|
||||
def test_invalid_query(self):
|
||||
with self.assertRaises(InvalidQueryArgumentValueError) as raised:
|
||||
dbcore.query.NumericQuery('year', u'199a')
|
||||
|
|
|
|||
Loading…
Reference in a new issue