diff --git a/beets/attachments.py b/beets/attachments.py index f88216034..a49bcbf8c 100644 --- a/beets/attachments.py +++ b/beets/attachments.py @@ -21,6 +21,17 @@ from beets import dbcore from beets.dbcore.query import Query, AndQuery +def ref_type(entity): + # FIXME prevents circular dependency + from beets.library import Item, Album + if isinstance(entity, Item): + return 'item' + elif isinstance(entity, Album): + return 'album' + else: + raise ValueError('{} must be a Item or Album'.format(entity)) + + class Attachment(dbcore.db.Model): """Represents an attachment in the database. @@ -62,15 +73,7 @@ class Attachment(dbcore.db.Model): """Set the `ref` and `ref_type` properties so that `self.entity == entity`. """ - # FIXME prevents circular dependency - from beets.library import Item, Album - if isinstance(entity, Item): - self.ref_type = 'item' - elif isinstance(entity, Album): - self.ref_type = 'album' - else: - raise ValueError('{} must be a Item or Album'.format(entity)) - + self.ref_type = ref_type(entity) if not entity.id: raise ValueError('{} must have an id', format(entity)) self.ref = entity.id @@ -297,14 +300,27 @@ class AttachmentCommand(ArgumentParser): pass +class AttachmentRefQuery(Query): + + def __init__(self, entity): + self.entity = entity + + def clause(self): + return ('(ref = ? AND ref_type = ?)', + (self.entity.id, ref_type(self.entity))) + + def match(self, attachment): + return attachment.entity == self.entity + + class AttachmentEntityQuery(Query): - def __init__(self, query): - self.query = query + def __init__(self, entity_query): + self.query = entity_query def match(self, attachment): if self.query is not None: - return self.query.match(attachment.entity) + return self.query.match(attachment.query) else: return True @@ -314,5 +330,4 @@ class LibModelMixin(object): """ def attachments(self): - # TODO implement - raise NotImplementedError + return self._db._fetch(Attachment, AttachmentRefQuery(self)) diff --git a/beets/library.py b/beets/library.py index 7b46c1f77..79f912a02 100644 --- a/beets/library.py +++ b/beets/library.py @@ -30,6 +30,7 @@ from beets.util.functemplate import Template from beets import dbcore from beets.dbcore import types import beets +from beets import attachments from beets.attachments import Attachment log = logging.getLogger('beets') @@ -193,7 +194,7 @@ class WriteError(FileOperationError): # Item and Album model classes. -class LibModel(dbcore.Model): +class LibModel(dbcore.Model, attachments.LibModelMixin): """Shared concrete functionality for Items and Albums. """ _bytes_keys = ('path', 'artpath') diff --git a/test/test_attachments.py b/test/test_attachments.py index 1d8045f30..62069228d 100644 --- a/test/test_attachments.py +++ b/test/test_attachments.py @@ -16,7 +16,7 @@ from _common import unittest from beets.attachments import AttachmentFactory -from beets.library import Library, Album +from beets.library import Library, Album, Item class AttachmentFactoryTest(unittest.TestCase): @@ -58,6 +58,35 @@ class AttachmentFactoryTest(unittest.TestCase): self.assertEqual(attachment.type, 'atype') +class EntityAttachmentsTest(unittest.TestCase): + + def setUp(self): + self.lib = Library(':memory:') + self.factory = AttachmentFactory(self.lib) + + def test_all_item_attachments(self): + item = Item() + item.add(self.lib) + + attachment = self.factory.create('/path/to/attachment', + 'coverart', item) + attachment.add() + + self.assertItemsEqual(map(lambda a: a.id, item.attachments()), + [attachment.id]) + + def test_all_album_attachments(self): + album = Album() + album.add(self.lib) + + attachment = self.factory.create('/path/to/attachment', + 'coverart', album) + attachment.add() + + self.assertItemsEqual(map(lambda a: a.id, album.attachments()), + [attachment.id]) + + def suite(): return unittest.TestLoader().loadTestsFromName(__name__)