diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 46b47a2e1..409ecc9af 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -56,10 +56,11 @@ class FormattedMapping(Mapping): are replaced. """ - def __init__(self, model, for_path=False): + def __init__(self, model, for_path=False, compute_keys=True): self.for_path = for_path self.model = model - self.model_keys = model.keys(True) + if compute_keys: + self.model_keys = model.keys(True) def __getitem__(self, key): if key in self.model_keys: @@ -257,6 +258,11 @@ class Model(object): value is the same as the old value (e.g., `o.f = o.f`). """ + _revision = -1 + """A revision number from when the model was loaded from or written + to the database. + """ + @classmethod def _getters(cls): """Return a mapping from field names to getter functions. @@ -309,9 +315,11 @@ class Model(object): def clear_dirty(self): """Mark all fields as *clean* (i.e., not needing to be stored to - the database). + the database). Also update the revision. """ self._dirty = set() + if self._db: + self._revision = self._db.revision def _check_db(self, need_id=True): """Ensure that this object is associated with a database row: it @@ -351,9 +359,9 @@ class Model(object): """ return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT - def __getitem__(self, key): - """Get the value for a field. Raise a KeyError if the field is - not available. + def _get(self, key, default=None, raise_=False): + """Get the value for a field, or `default`. Alternatively, + raise a KeyError if the field is not available. """ getters = self._getters() if key in getters: # Computed. @@ -365,8 +373,18 @@ class Model(object): return self._type(key).null elif key in self._values_flex: # Flexible. return self._values_flex[key] - else: + elif raise_: raise KeyError(key) + else: + return default + + get = _get + + def __getitem__(self, key): + """Get the value for a field. Raise a KeyError if the field is + not available. + """ + return self._get(key, raise_=True) def _setitem(self, key, value): """Assign the value for a field, return whether new and old value @@ -441,19 +459,10 @@ class Model(object): for key in self: yield key, self[key] - def get(self, key, default=None): - """Get the value for a given key or `default` if it does not - exist. - """ - if key in self: - return self[key] - else: - return default - def __contains__(self, key): """Determine whether `key` is an attribute on this object. """ - return key in self.keys(True) + return key in self.keys(computed=True) def __iter__(self): """Iterate over the available field names (excluding computed @@ -538,8 +547,14 @@ class Model(object): def load(self): """Refresh the object's metadata from the library database. + + If check_revision is true, the database is only queried loaded when a + transaction has been committed since the item was last loaded. """ self._check_db() + if not self._dirty and self._db.revision == self._revision: + # Exit early + return stored_obj = self._db._get(type(self), self.id) assert stored_obj is not None, u"object {0} not in DB".format(self.id) self._values_fixed = LazyConvertDict(self) @@ -794,6 +809,12 @@ class Transaction(object): """A context manager for safe, concurrent access to the database. All SQL commands should be executed through a transaction. """ + + _mutated = False + """A flag storing whether a mutation has been executed in the + current transaction. + """ + def __init__(self, db): self.db = db @@ -815,12 +836,15 @@ class Transaction(object): entered but not yet exited transaction. If it is the last active transaction, the database updates are committed. """ + # Beware of races; currently secured by db._db_lock + self.db.revision += self._mutated with self.db._tx_stack() as stack: assert stack.pop() is self empty = not stack if empty: # Ending a "root" transaction. End the SQLite transaction. self.db._connection().commit() + self._mutated = False self.db._db_lock.release() def query(self, statement, subvals=()): @@ -836,7 +860,6 @@ class Transaction(object): """ try: cursor = self.db._connection().execute(statement, subvals) - return cursor.lastrowid except sqlite3.OperationalError as e: # In two specific cases, SQLite reports an error while accessing # the underlying database file. We surface these exceptions as @@ -846,9 +869,14 @@ class Transaction(object): raise DBAccessError(e.args[0]) else: raise + else: + self._mutated = True + return cursor.lastrowid def script(self, statements): """Execute a string containing multiple SQL statements.""" + # We don't know whether this mutates, but quite likely it does. + self._mutated = True self.db._connection().executescript(statements) @@ -864,6 +892,11 @@ class Database(object): supports_extensions = hasattr(sqlite3.Connection, 'enable_load_extension') """Whether or not the current version of SQLite supports extensions""" + revision = 0 + """The current revision of the database. To be increased whenever + data is written in a transaction. + """ + def __init__(self, path, timeout=5.0): self.path = path self.timeout = timeout diff --git a/beets/importer.py b/beets/importer.py index 3220b260f..c5701ff30 100644 --- a/beets/importer.py +++ b/beets/importer.py @@ -786,7 +786,7 @@ class ImportTask(BaseImportTask): if (not dup_item.album_id or dup_item.album_id in replaced_album_ids): continue - replaced_album = dup_item.get_album() + replaced_album = dup_item._cached_album if replaced_album: replaced_album_ids.add(dup_item.album_id) self.replaced_albums[replaced_album.path] = replaced_album diff --git a/beets/library.py b/beets/library.py index 78552bb61..3ba79d069 100644 --- a/beets/library.py +++ b/beets/library.py @@ -375,7 +375,11 @@ class FormattedItemMapping(dbcore.db.FormattedMapping): """ def __init__(self, item, for_path=False): - super(FormattedItemMapping, self).__init__(item, for_path) + # We treat album and item keys specially here, + # so exclude transitive album keys from the model's keys. + super(FormattedItemMapping, self).__init__(item, for_path, + compute_keys=False) + self.model_keys = item.keys(computed=True, with_album=False) self.item = item @lazy_property @@ -386,15 +390,15 @@ class FormattedItemMapping(dbcore.db.FormattedMapping): def album_keys(self): album_keys = [] if self.album: - for key in self.album.keys(True): + for key in self.album.keys(computed=True): if key in Album.item_keys \ or key not in self.item._fields.keys(): album_keys.append(key) return album_keys - @lazy_property + @property def album(self): - return self.item.get_album() + return self.item._cached_album def _get(self, key): """Get the value for a key, either from the album or the item. @@ -545,6 +549,29 @@ class Item(LibModel): _format_config_key = 'format_item' + __album = None + """Cached album object. Read-only.""" + + @property + def _cached_album(self): + """The Album object that this item belongs to, if any, or + None if the item is a singleton or is not associated with a + library. + The instance is cached and refreshed on access. + + DO NOT MODIFY! + If you want a copy to modify, use :meth:`get_album`. + """ + if not self.__album and self._db: + self.__album = self._db.get_album(self) + elif self.__album: + self.__album.load() + return self.__album + + @_cached_album.setter + def _cached_album(self, album): + self.__album = album + @classmethod def _getters(cls): getters = plugins.item_field_getters() @@ -571,12 +598,45 @@ class Item(LibModel): value = bytestring_path(value) elif isinstance(value, BLOB_TYPE): value = bytes(value) + elif key == 'album_id': + self._cached_album = None changed = super(Item, self)._setitem(key, value) if changed and key in MediaFile.fields(): self.mtime = 0 # Reset mtime on dirty. + def __getitem__(self, key): + """Get the value for a field, falling back to the album if + necessary. Raise a KeyError if the field is not available. + """ + try: + return super(Item, self).__getitem__(key) + except KeyError: + if self._cached_album: + return self._cached_album[key] + raise + + def keys(self, computed=False, with_album=True): + """Get a list of available field names. `with_album` + controls whether the album's fields are included. + """ + keys = super(Item, self).keys(computed=computed) + if with_album and self._cached_album: + keys += self._cached_album.keys(computed=computed) + return keys + + def get(self, key, default=None, with_album=True): + """Get the value for a given key or `default` if it does not + exist. Set `with_album` to false to skip album fallback. + """ + try: + return self._get(key, default, raise_=with_album) + except KeyError: + if self._cached_album: + return self._cached_album.get(key, default) + return default + def update(self, values): """Set all key/value pairs in the mapping. If mtime is specified, it is not reset (as it might otherwise be). diff --git a/beets/ui/__init__.py b/beets/ui/__init__.py index 28879a731..362e8752a 100644 --- a/beets/ui/__init__.py +++ b/beets/ui/__init__.py @@ -1155,8 +1155,13 @@ def _setup(options, lib=None): plugins.send("library_opened", lib=lib) # Add types and queries defined by plugins. - library.Item._types.update(plugins.types(library.Item)) - library.Album._types.update(plugins.types(library.Album)) + plugin_types_album = plugins.types(library.Album) + library.Album._types.update(plugin_types_album) + item_types = plugin_types_album.copy() + item_types.update(library.Item._types) + item_types.update(plugins.types(library.Item)) + library.Item._types = item_types + library.Item._queries.update(plugins.named_queries(library.Item)) library.Album._queries.update(plugins.named_queries(library.Album)) diff --git a/beetsplug/convert.py b/beetsplug/convert.py index 275703e97..45a571d38 100644 --- a/beetsplug/convert.py +++ b/beetsplug/convert.py @@ -358,7 +358,7 @@ class ConvertPlugin(BeetsPlugin): item.store() # Store new path and audio data. if self.config['embed'] and not linked: - album = item.get_album() + album = item._cached_album if album and album.artpath: self._log.debug(u'embedding album art from {}', util.displayable_path(album.artpath)) diff --git a/docs/changelog.rst b/docs/changelog.rst index b9020621b..df286a10a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -178,6 +178,11 @@ New features: * :doc:`/plugins/replaygain` now does its analysis in parallel when using the ``command`` or ``ffmpeg`` backends. :bug:`3478` +* Fields in queries now fall back to an item's album and check its fields too. + Notably, this allows querying items by an album flex attribute, also in path + configuration. + Thanks to :user:`FichteFoll`. + :bug:`2797` :bug:`2988` * Removes usage of the bs1770gain replaygain backend. Thanks to :user:`SamuelCook`. * Added ``trackdisambig`` which stores the recording disambiguation from @@ -344,6 +349,12 @@ For plugin developers: :bug:`3355` * The autotag hooks have been modified such that they now take 'bpm', 'musical_key' and a per-track based 'genre' as attributes. +* Item (and attribute) access on an item now falls back to the album's + attributes as well. If you specifically want to access an item's attributes, + use ``Item.get(key, with_album=False)``. :bug:`2988` +* ``Item.keys`` also has a ``with_album`` argument now, defaulting to ``True``. +* A ``revision`` attribute has been added to ``Database``. It is increased on + every transaction that mutates it. :bug:`2988` For packagers: diff --git a/test/test_dbcore.py b/test/test_dbcore.py index 0d40896da..1dd2284c6 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -225,6 +225,31 @@ class MigrationTest(unittest.TestCase): self.fail("select failed") +class TransactionTest(unittest.TestCase): + def setUp(self): + self.db = DatabaseFixture1(':memory:') + + def tearDown(self): + self.db._connection().close() + + def test_mutate_increase_revision(self): + old_rev = self.db.revision + with self.db.transaction() as tx: + tx.mutate( + 'INSERT INTO {0} ' + '(field_one) ' + 'VALUES (?);'.format(ModelFixture1._table), + (111,), + ) + self.assertGreater(self.db.revision, old_rev) + + def test_query_no_increase_revision(self): + old_rev = self.db.revision + with self.db.transaction() as tx: + tx.query('PRAGMA table_info(%s)' % ModelFixture1._table) + self.assertEqual(self.db.revision, old_rev) + + class ModelTest(unittest.TestCase): def setUp(self): self.db = DatabaseFixture1(':memory:') @@ -246,6 +271,30 @@ class ModelTest(unittest.TestCase): row = self.db._connection().execute('select * from test').fetchone() self.assertEqual(row['field_one'], 123) + def test_revision(self): + old_rev = self.db.revision + model = ModelFixture1() + model.add(self.db) + model.store() + self.assertEqual(model._revision, self.db.revision) + self.assertGreater(self.db.revision, old_rev) + + mid_rev = self.db.revision + model2 = ModelFixture1() + model2.add(self.db) + model2.store() + self.assertGreater(model2._revision, mid_rev) + self.assertGreater(self.db.revision, model._revision) + + # revision changed, so the model should be re-loaded + model.load() + self.assertEqual(model._revision, self.db.revision) + + # revision did not change, so no reload + mod2_old_rev = model2._revision + model2.load() + self.assertEqual(model2._revision, mod2_old_rev) + def test_retrieve_by_id(self): model = ModelFixture1() model.add(self.db) diff --git a/test/test_ipfs.py b/test/test_ipfs.py index d670bfc25..2fe89e7e5 100644 --- a/test/test_ipfs.py +++ b/test/test_ipfs.py @@ -49,7 +49,7 @@ class IPFSPluginTest(unittest.TestCase, TestHelper): want_item = test_album.items()[2] for check_item in added_album.items(): try: - if check_item.ipfs: + if check_item.get('ipfs', with_album=False): ipfs_item = os.path.basename(want_item.path).decode( _fsencoding(), ) @@ -57,7 +57,8 @@ class IPFSPluginTest(unittest.TestCase, TestHelper): ipfs_item) want_path = bytestring_path(want_path) self.assertEqual(check_item.path, want_path) - self.assertEqual(check_item.ipfs, want_item.ipfs) + self.assertEqual(check_item.get('ipfs', with_album=False), + want_item.ipfs) self.assertEqual(check_item.title, want_item.title) found = True except AttributeError: diff --git a/test/test_library.py b/test/test_library.py index 4e3be878c..51171b1f8 100644 --- a/test/test_library.py +++ b/test/test_library.py @@ -132,6 +132,21 @@ class GetSetTest(_common.TestCase): def test_invalid_field_raises_attributeerror(self): self.assertRaises(AttributeError, getattr, self.i, u'xyzzy') + def test_album_fallback(self): + # integration test of item-album fallback + lib = beets.library.Library(':memory:') + i = item(lib) + album = lib.add_album([i]) + album['flex'] = u'foo' + album.store() + + self.assertTrue('flex' in i) + self.assertFalse('flex' in i.keys(with_album=False)) + self.assertEqual(i['flex'], u'foo') + self.assertEqual(i.get('flex'), u'foo') + self.assertEqual(i.get('flex', with_album=False), None) + self.assertEqual(i.get('flexx'), None) + class DestinationTest(_common.TestCase): def setUp(self): @@ -491,6 +506,24 @@ class DestinationTest(_common.TestCase): dest = self.i.destination() self.assertEqual(dest[-2:], b'XX') + def test_album_field_query(self): + self.lib.directory = b'one' + self.lib.path_formats = [(u'default', u'two'), + (u'flex:foo', u'three')] + album = self.lib.add_album([self.i]) + self.assertEqual(self.i.destination(), np('one/two')) + album['flex'] = u'foo' + album.store() + self.assertEqual(self.i.destination(), np('one/three')) + + def test_album_field_in_template(self): + self.lib.directory = b'one' + self.lib.path_formats = [(u'default', u'$flex/two')] + album = self.lib.add_album([self.i]) + album['flex'] = u'foo' + album.store() + self.assertEqual(self.i.destination(), np('one/foo/two')) + class ItemFormattedMappingTest(_common.LibTestCase): def test_formatted_item_value(self): diff --git a/test/test_query.py b/test/test_query.py index f88a12c92..4017ff44b 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -109,7 +109,7 @@ class DummyDataTestCase(_common.TestCase, AssertsMixin): items[2].comp = False for item in items: self.lib.add(item) - self.lib.add_album(items[:2]) + self.album = self.lib.add_album(items[:2]) def assert_items_matched_all(self, results): self.assert_items_matched(results, [ @@ -300,6 +300,17 @@ class GetTest(DummyDataTestCase): results = self.lib.items(q) self.assertFalse(results) + def test_album_field_fallback(self): + self.album['albumflex'] = u'foo' + self.album.store() + + q = u'albumflex:foo' + results = self.lib.items(q) + self.assert_items_matched(results, [ + u'foo bar', + u'baz qux', + ]) + def test_invalid_query(self): with self.assertRaises(InvalidQueryArgumentValueError) as raised: dbcore.query.NumericQuery('year', u'199a')