From 0798af7774472a3e9bdbcdd1a000f7db6b7afcc1 Mon Sep 17 00:00:00 2001 From: Thomas Scholtes Date: Mon, 25 Aug 2014 14:24:39 +0200 Subject: [PATCH] Refactor model formatting Remove all formatting related code from models. It now lives in the `FormattedMapping` class. Only API change is from `model.formatted` to `model.formatted()`. --- beets/dbcore/db.py | 122 ++++++++++++++++++------------------------- beets/library.py | 108 ++++++++++++++++++-------------------- beets/ui/__init__.py | 6 +-- test/test_dbcore.py | 22 ++++---- test/test_library.py | 32 ++++++------ 5 files changed, 131 insertions(+), 159 deletions(-) diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 5172ad523..756448720 100644 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -28,6 +28,50 @@ from .query import MatchQuery, build_sql from .types import BASE_TYPE +class FormattedMapping(collections.Mapping): + """A `dict`-like formatted view of a model. + + The accessor `mapping[key]` returns the formated version of + `model[key]` as a unicode string. + + If `for_path` is true, all path separators in the formatted values + are replaced. + """ + + def __init__(self, model, for_path=False): + self.for_path = for_path + self.model = model + self.model_keys = model.keys(True) + + def __getitem__(self, key): + if key in self.model_keys: + return self._get_formatted(self.model, key) + else: + raise KeyError(key) + + def __iter__(self): + return iter(self.model_keys) + + def __len__(self): + return len(self.model_keys) + + def get(self, key, default=u''): + return super(FormattedMapping, self).get(key, default) + + def _get_formatted(self, model, key): + value = model._type(key).format(model.get(key)) + if isinstance(value, bytes): + value = value.decode('utf8', 'ignore') + + if self.for_path: + sep_repl = beets.config['path_sep_replace'].get(unicode) + for sep in (os.path.sep, os.path.altsep): + if sep: + value = value.replace(sep, sep_repl) + + return value + + # Abstract base for model classes. class Model(object): @@ -380,63 +424,24 @@ class Model(object): # Formatting and templating. - @classmethod - def _format(cls, key, value, for_path=False): - """Format a value as the given field for this model. - """ - # Format the value as a string according to its type. - value = cls._type(key).format(value) + _formatter = FormattedMapping - # Formatting must result in a string. To deal with - # Python2isms, implicitly convert ASCII strings. - assert isinstance(value, basestring), \ - u'field formatter must produce strings' - if isinstance(value, bytes): - value = value.decode('utf8', 'ignore') - - if for_path: - sep_repl = beets.config['path_sep_replace'].get(unicode) - for sep in (os.path.sep, os.path.altsep): - if sep: - value = value.replace(sep, sep_repl) - - return value - - def _get_formatted(self, key, for_path=False): - """Get a field value formatted as a string (`unicode` object) - for display to the user. If `for_path` is true, then the value - will be sanitized for inclusion in a pathname (i.e., path - separators will be removed from the value). - """ - return self._format(key, self.get(key), for_path) - - def _formatted_mapping(self, for_path=False): + def formatted(self, for_path=False): """Get a mapping containing all values on this object formatted - as human-readable strings. + as human-readable unicode strings. """ - return FormattedMapping(self, for_path) - - @property - def formatted(self): - """A `dict`-like view containing formatted values. - """ - return self._formatted_mapping(False) + return self._formatter(self, for_path) def evaluate_template(self, template, for_path=False): """Evaluate a template (a string or a `Template` object) using the object's fields. If `for_path` is true, then no new path separators will be added to the template. """ - # Build value mapping. - mapping = self._formatted_mapping(for_path) - - # Get template functions. - funcs = self._template_funcs() - # Perform substitution. if isinstance(template, basestring): template = Template(template) - return template.substitute(mapping, funcs) + return template.substitute(self.formatted(for_path), + self._template_funcs()) # Parsing. @@ -450,33 +455,6 @@ class Model(object): return cls._type(key).parse(string) -class FormattedMapping(collections.Mapping): - """A `dict`-like formatted view of a model. - - The accessor ``mapping[key]`` returns the formated version of - ``model[key]``. The formatting is handled by `model._format()`. - """ - # TODO Move all formatting logic here - # TODO Add caching - - def __init__(self, model, for_path=False): - self.for_path = for_path - self.model = model - self.model_keys = model.keys(True) - - def __getitem__(self, key): - if key in self.model_keys: - return self.model._get_formatted(key, self.for_path) - else: - raise KeyError(key) - - def __iter__(self): - return iter(self.model_keys) - - def __len__(self): - return len(self.model_keys) - - # Database controller and supporting interfaces. class Results(object): diff --git a/beets/library.py b/beets/library.py index 456f51f7b..ee9184bb0 100644 --- a/beets/library.py +++ b/beets/library.py @@ -216,6 +216,54 @@ class LibModel(dbcore.Model): plugins.send('database_change', lib=self._db) +class FormattedItemMapping(dbcore.db.FormattedMapping): + """Add lookup for album level fields. + """ + + def __init__(self, item, for_path=False): + super(FormattedItemMapping, self).__init__(item, for_path) + self.album = item.get_album() + self.album_keys = [] + if self.album: + for key in self.album.keys(True): + if key in Album.item_keys or key not in item._fields.keys(): + self.album_keys.append(key) + self.all_keys = set(self.model_keys).union(self.album_keys) + + def _get(self, key): + """Get the value for a key, either from the album or the item. + Raise a KeyError for invalid keys. + """ + if key in self.album_keys: + return self._get_formatted(self.album, key) + elif key in self.model_keys: + return self._get_formatted(self.model, key) + else: + raise KeyError(key) + + def __getitem__(self, key): + """Get the value for a key. Certain unset values are remapped. + """ + value = self._get(key) + + # `artist` and `albumartist` fields fall back to one another. + # This is helpful in path formats when the album artist is unset + # on as-is imports. + if key == 'artist' and not value: + return self._get('albumartist') + elif key == 'albumartist' and not value: + return self._get('artist') + else: + return value + + def __iter__(self): + return iter(self.all_keys) + + def __len__(self): + return len(self.all_keys) + + + class Item(LibModel): _table = 'items' _flex_table = 'item_attributes' @@ -296,6 +344,8 @@ class Item(LibModel): `write`. """ + _formatter = FormattedItemMapping + @classmethod def _getters(cls): getters = plugins.item_field_getters() @@ -523,12 +573,6 @@ class Item(LibModel): # Templating. - def _formatted_mapping(self, for_path=False): - """Get a mapping containing string-formatted values from either - this item or the associated album, if any. - """ - return FormattedItemMapping(self, for_path) - def destination(self, fragment=False, basedir=None, platform=None, path_formats=None): """Returns the path in the library directory designated for the @@ -604,56 +648,6 @@ class Item(LibModel): return normpath(os.path.join(basedir, subpath)) -class FormattedItemMapping(dbcore.db.FormattedMapping): - """A `dict`-like formatted view of an item that inherits album fields. - - The accessor ``mapping[key]`` returns the formated version of either - ``item[key]`` or ``album[key]``. Here `album` is the album - associated to `item` if it exists. - """ - def __init__(self, item, for_path=False): - super(FormattedItemMapping, self).__init__(item, for_path) - self.album = item.get_album() - self.album_keys = [] - if self.album: - for key in self.album.keys(True): - if key in Album.item_keys or key not in item._fields.keys(): - self.album_keys.append(key) - self.all_keys = set(self.model_keys).union(self.album_keys) - - def _get(self, key): - """Get the value for a key, either from the album or the item. - Raise a KeyError for invalid keys. - """ - if key in self.album_keys: - return self.album._get_formatted(key, self.for_path) - elif key in self.model_keys: - return self.model._get_formatted(key, self.for_path) - else: - raise KeyError(key) - - def __getitem__(self, key): - """Get the value for a key. Certain unset values are remapped. - """ - value = self._get(key) - - # `artist` and `albumartist` fields fall back to one another. - # This is helpful in path formats when the album artist is unset - # on as-is imports. - if key == 'artist' and not value: - return self._get('albumartist') - elif key == 'albumartist' and not value: - return self._get('artist') - else: - return value - - def __iter__(self): - return iter(self.all_keys) - - def __len__(self): - return len(self.all_keys) - - class Album(LibModel): """Provides access to information about albums stored in a library. Reflects the library's "albums" table, including album @@ -1219,7 +1213,7 @@ class DefaultTemplateFunctions(object): return res # Flatten disambiguation value into a string. - disam_value = album._get_formatted(disambiguator, True) + disam_value = album.formatted(True).get(disambiguator) res = u' [{0}]'.format(disam_value) self.lib._memotable[memokey] = res return res diff --git a/beets/ui/__init__.py b/beets/ui/__init__.py index 1e64cd89c..829089200 100644 --- a/beets/ui/__init__.py +++ b/beets/ui/__init__.py @@ -576,8 +576,8 @@ def _field_diff(field, old, new): return None # Get formatted values for output. - oldstr = old.formatted.get(field, u'') - newstr = new.formatted.get(field, u'') + oldstr = old.formatted().get(field, u'') + newstr = new.formatted().get(field, u'') # For strings, highlight changes. For others, colorize the whole # thing. @@ -620,7 +620,7 @@ def show_model_changes(new, old=None, fields=None, always=False): changes.append(u' {0}: {1}'.format( field, - colorize('red', new.formatted[field]) + colorize('red', new.formatted()[field]) )) # Print changes. diff --git a/test/test_dbcore.py b/test/test_dbcore.py index e55bd84db..d6e92a1fb 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -255,54 +255,54 @@ class FormatTest(_common.TestCase): def test_format_fixed_field(self): model = TestModel1() model.field_one = u'caf\xe9' - value = model._get_formatted('field_one') + value = model.formatted().get('field_one') self.assertEqual(value, u'caf\xe9') def test_format_flex_field(self): model = TestModel1() model.other_field = u'caf\xe9' - value = model._get_formatted('other_field') + value = model.formatted().get('other_field') self.assertEqual(value, u'caf\xe9') def test_format_flex_field_bytes(self): model = TestModel1() model.other_field = u'caf\xe9'.encode('utf8') - value = model._get_formatted('other_field') + value = model.formatted().get('other_field') self.assertTrue(isinstance(value, unicode)) self.assertEqual(value, u'caf\xe9') def test_format_unset_field(self): model = TestModel1() - value = model._get_formatted('other_field') + value = model.formatted().get('other_field') self.assertEqual(value, u'') def test_format_typed_flex_field(self): model = TestModel1() model.some_float_field = 3.14159265358979 - value = model._get_formatted('some_float_field') + value = model.formatted().get('some_float_field') self.assertEqual(value, u'3.1') class FormattedMappingTest(_common.TestCase): def test_keys_equal_model_keys(self): model = TestModel1() - formatted = model._formatted_mapping() + formatted = model.formatted() self.assertEqual(set(model.keys(True)), set(formatted.keys())) def test_get_unset_field(self): model = TestModel1() - formatted = model._formatted_mapping() + formatted = model.formatted() with self.assertRaises(KeyError): formatted['other_field'] - def test_get_method_with_none_default(self): + def test_get_method_with_default(self): model = TestModel1() - formatted = model._formatted_mapping() - self.assertIsNone(formatted.get('other_field')) + formatted = model.formatted() + self.assertEqual(formatted.get('other_field'), u'') def test_get_method_with_specified_default(self): model = TestModel1() - formatted = model._formatted_mapping() + formatted = model.formatted() self.assertEqual(formatted.get('other_field', 'default'), 'default') diff --git a/test/test_library.py b/test/test_library.py index aac58b9e9..7a29ea5db 100644 --- a/test/test_library.py +++ b/test/test_library.py @@ -340,37 +340,37 @@ class DestinationTest(_common.TestCase): with _common.platform_posix(): name = os.path.join('a', 'b') self.i.title = name - newname = self.i._get_formatted('title') + newname = self.i.formatted().get('title') self.assertEqual(name, newname) def test_get_formatted_pads_with_zero(self): with _common.platform_posix(): self.i.track = 1 - name = self.i._get_formatted('track') + name = self.i.formatted().get('track') self.assertTrue(name.startswith('0')) def test_get_formatted_uses_kbps_bitrate(self): with _common.platform_posix(): self.i.bitrate = 12345 - val = self.i._get_formatted('bitrate') + val = self.i.formatted().get('bitrate') self.assertEqual(val, u'12kbps') def test_get_formatted_uses_khz_samplerate(self): with _common.platform_posix(): self.i.samplerate = 12345 - val = self.i._get_formatted('samplerate') + val = self.i.formatted().get('samplerate') self.assertEqual(val, u'12kHz') def test_get_formatted_datetime(self): with _common.platform_posix(): self.i.added = 1368302461.210265 - val = self.i._get_formatted('added') + val = self.i.formatted().get('added') self.assertTrue(val.startswith('2013')) def test_get_formatted_none(self): with _common.platform_posix(): self.i.some_other_field = None - val = self.i._get_formatted('some_other_field') + val = self.i.formatted().get('some_other_field') self.assertEqual(val, u'') def test_artist_falls_back_to_albumartist(self): @@ -462,20 +462,20 @@ class DestinationTest(_common.TestCase): class ItemFormattedMappingTest(_common.LibTestCase): def test_formatted_item_value(self): - formatted = self.i._formatted_mapping() + formatted = self.i.formatted() self.assertEqual(formatted['artist'], 'the artist') def test_get_unset_field(self): - formatted = self.i._formatted_mapping() + formatted = self.i.formatted() with self.assertRaises(KeyError): formatted['other_field'] - def test_get_method_with_none_default(self): - formatted = self.i._formatted_mapping() - self.assertIsNone(formatted.get('other_field')) + def test_get_method_with_default(self): + formatted = self.i.formatted() + self.assertEqual(formatted.get('other_field'), u'') def test_get_method_with_specified_default(self): - formatted = self.i._formatted_mapping() + formatted = self.i.formatted() self.assertEqual(formatted.get('other_field', 'default'), 'default') def test_album_field_overrides_item_field(self): @@ -487,23 +487,23 @@ class ItemFormattedMappingTest(_common.LibTestCase): self.i.store() # Ensure the album takes precedence. - formatted = self.i._formatted_mapping() + formatted = self.i.formatted() self.assertEqual(formatted['album'], 'foo') def test_artist_falls_back_to_albumartist(self): self.i.artist = '' - formatted = self.i._formatted_mapping() + formatted = self.i.formatted() self.assertEqual(formatted['artist'], 'the album artist') def test_albumartist_falls_back_to_artist(self): self.i.albumartist = '' - formatted = self.i._formatted_mapping() + formatted = self.i.formatted() self.assertEqual(formatted['albumartist'], 'the artist') def test_both_artist_and_albumartist_empty(self): self.i.artist = '' self.i.albumartist = '' - formatted = self.i._formatted_mapping() + formatted = self.i.formatted() self.assertEqual(formatted['albumartist'], '')