diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 63a668601..6247ef41b 100644 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -20,6 +20,7 @@ from collections import defaultdict import threading import sqlite3 import contextlib +import collections import beets from beets.util.functemplate import Template @@ -425,7 +426,7 @@ class Model(object): return string -class FormattedMapping(object): +class FormattedMapping(collections.Mapping): """A `dict`-like formatted view of a model. The accessor ``mapping[key]`` returns the formated version of @@ -445,8 +446,11 @@ class FormattedMapping(object): else: raise KeyError(key) - def __contains__(self, key): - return key in self.model_keys + def __iter__(self): + return iter(self.model_keys) + + def __len__(self): + return len(self.model_keys) # Database controller and supporting interfaces. diff --git a/test/test_dbcore.py b/test/test_dbcore.py index bfcd1581a..20fa67510 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -269,6 +269,19 @@ class FormatTest(_common.TestCase): self.assertEqual(value, u'') +class FormattedMappingTest(_common.TestCase): + def test_keys_equal_model_keys(self): + model = TestModel1() + formatted = model._formatted_mapping() + self.assertEqual(set(model.keys(True)), set(formatted.keys())) + + def test_get_unset_field(self): + model = TestModel1() + formatted = model._formatted_mapping() + with self.assertRaises(KeyError): + formatted['other_field'] + + class ParseTest(_common.TestCase): def test_parse_fixed_field(self): value = TestModel1._parse('field_one', u'2')