Refactor model formatting

Remove all formatting related code from models. It now lives in the
`FormattedMapping` class. Only API change is from `model.formatted` to
`model.formatted()`.
This commit is contained in:
Thomas Scholtes 2014-08-25 14:24:39 +02:00
parent b5c9271baa
commit 0798af7774
5 changed files with 131 additions and 159 deletions

View file

@ -28,6 +28,50 @@ from .query import MatchQuery, build_sql
from .types import BASE_TYPE
class FormattedMapping(collections.Mapping):
"""A `dict`-like formatted view of a model.
The accessor `mapping[key]` returns the formated version of
`model[key]` as a unicode string.
If `for_path` is true, all path separators in the formatted values
are replaced.
"""
def __init__(self, model, for_path=False):
self.for_path = for_path
self.model = model
self.model_keys = model.keys(True)
def __getitem__(self, key):
if key in self.model_keys:
return self._get_formatted(self.model, key)
else:
raise KeyError(key)
def __iter__(self):
return iter(self.model_keys)
def __len__(self):
return len(self.model_keys)
def get(self, key, default=u''):
return super(FormattedMapping, self).get(key, default)
def _get_formatted(self, model, key):
value = model._type(key).format(model.get(key))
if isinstance(value, bytes):
value = value.decode('utf8', 'ignore')
if self.for_path:
sep_repl = beets.config['path_sep_replace'].get(unicode)
for sep in (os.path.sep, os.path.altsep):
if sep:
value = value.replace(sep, sep_repl)
return value
# Abstract base for model classes.
class Model(object):
@ -380,63 +424,24 @@ class Model(object):
# Formatting and templating.
@classmethod
def _format(cls, key, value, for_path=False):
"""Format a value as the given field for this model.
"""
# Format the value as a string according to its type.
value = cls._type(key).format(value)
_formatter = FormattedMapping
# Formatting must result in a string. To deal with
# Python2isms, implicitly convert ASCII strings.
assert isinstance(value, basestring), \
u'field formatter must produce strings'
if isinstance(value, bytes):
value = value.decode('utf8', 'ignore')
if for_path:
sep_repl = beets.config['path_sep_replace'].get(unicode)
for sep in (os.path.sep, os.path.altsep):
if sep:
value = value.replace(sep, sep_repl)
return value
def _get_formatted(self, key, for_path=False):
"""Get a field value formatted as a string (`unicode` object)
for display to the user. If `for_path` is true, then the value
will be sanitized for inclusion in a pathname (i.e., path
separators will be removed from the value).
"""
return self._format(key, self.get(key), for_path)
def _formatted_mapping(self, for_path=False):
def formatted(self, for_path=False):
"""Get a mapping containing all values on this object formatted
as human-readable strings.
as human-readable unicode strings.
"""
return FormattedMapping(self, for_path)
@property
def formatted(self):
"""A `dict`-like view containing formatted values.
"""
return self._formatted_mapping(False)
return self._formatter(self, for_path)
def evaluate_template(self, template, for_path=False):
"""Evaluate a template (a string or a `Template` object) using
the object's fields. If `for_path` is true, then no new path
separators will be added to the template.
"""
# Build value mapping.
mapping = self._formatted_mapping(for_path)
# Get template functions.
funcs = self._template_funcs()
# Perform substitution.
if isinstance(template, basestring):
template = Template(template)
return template.substitute(mapping, funcs)
return template.substitute(self.formatted(for_path),
self._template_funcs())
# Parsing.
@ -450,33 +455,6 @@ class Model(object):
return cls._type(key).parse(string)
class FormattedMapping(collections.Mapping):
"""A `dict`-like formatted view of a model.
The accessor ``mapping[key]`` returns the formated version of
``model[key]``. The formatting is handled by `model._format()`.
"""
# TODO Move all formatting logic here
# TODO Add caching
def __init__(self, model, for_path=False):
self.for_path = for_path
self.model = model
self.model_keys = model.keys(True)
def __getitem__(self, key):
if key in self.model_keys:
return self.model._get_formatted(key, self.for_path)
else:
raise KeyError(key)
def __iter__(self):
return iter(self.model_keys)
def __len__(self):
return len(self.model_keys)
# Database controller and supporting interfaces.
class Results(object):

View file

@ -216,6 +216,54 @@ class LibModel(dbcore.Model):
plugins.send('database_change', lib=self._db)
class FormattedItemMapping(dbcore.db.FormattedMapping):
"""Add lookup for album level fields.
"""
def __init__(self, item, for_path=False):
super(FormattedItemMapping, self).__init__(item, for_path)
self.album = item.get_album()
self.album_keys = []
if self.album:
for key in self.album.keys(True):
if key in Album.item_keys or key not in item._fields.keys():
self.album_keys.append(key)
self.all_keys = set(self.model_keys).union(self.album_keys)
def _get(self, key):
"""Get the value for a key, either from the album or the item.
Raise a KeyError for invalid keys.
"""
if key in self.album_keys:
return self._get_formatted(self.album, key)
elif key in self.model_keys:
return self._get_formatted(self.model, key)
else:
raise KeyError(key)
def __getitem__(self, key):
"""Get the value for a key. Certain unset values are remapped.
"""
value = self._get(key)
# `artist` and `albumartist` fields fall back to one another.
# This is helpful in path formats when the album artist is unset
# on as-is imports.
if key == 'artist' and not value:
return self._get('albumartist')
elif key == 'albumartist' and not value:
return self._get('artist')
else:
return value
def __iter__(self):
return iter(self.all_keys)
def __len__(self):
return len(self.all_keys)
class Item(LibModel):
_table = 'items'
_flex_table = 'item_attributes'
@ -296,6 +344,8 @@ class Item(LibModel):
`write`.
"""
_formatter = FormattedItemMapping
@classmethod
def _getters(cls):
getters = plugins.item_field_getters()
@ -523,12 +573,6 @@ class Item(LibModel):
# Templating.
def _formatted_mapping(self, for_path=False):
"""Get a mapping containing string-formatted values from either
this item or the associated album, if any.
"""
return FormattedItemMapping(self, for_path)
def destination(self, fragment=False, basedir=None, platform=None,
path_formats=None):
"""Returns the path in the library directory designated for the
@ -604,56 +648,6 @@ class Item(LibModel):
return normpath(os.path.join(basedir, subpath))
class FormattedItemMapping(dbcore.db.FormattedMapping):
"""A `dict`-like formatted view of an item that inherits album fields.
The accessor ``mapping[key]`` returns the formated version of either
``item[key]`` or ``album[key]``. Here `album` is the album
associated to `item` if it exists.
"""
def __init__(self, item, for_path=False):
super(FormattedItemMapping, self).__init__(item, for_path)
self.album = item.get_album()
self.album_keys = []
if self.album:
for key in self.album.keys(True):
if key in Album.item_keys or key not in item._fields.keys():
self.album_keys.append(key)
self.all_keys = set(self.model_keys).union(self.album_keys)
def _get(self, key):
"""Get the value for a key, either from the album or the item.
Raise a KeyError for invalid keys.
"""
if key in self.album_keys:
return self.album._get_formatted(key, self.for_path)
elif key in self.model_keys:
return self.model._get_formatted(key, self.for_path)
else:
raise KeyError(key)
def __getitem__(self, key):
"""Get the value for a key. Certain unset values are remapped.
"""
value = self._get(key)
# `artist` and `albumartist` fields fall back to one another.
# This is helpful in path formats when the album artist is unset
# on as-is imports.
if key == 'artist' and not value:
return self._get('albumartist')
elif key == 'albumartist' and not value:
return self._get('artist')
else:
return value
def __iter__(self):
return iter(self.all_keys)
def __len__(self):
return len(self.all_keys)
class Album(LibModel):
"""Provides access to information about albums stored in a
library. Reflects the library's "albums" table, including album
@ -1219,7 +1213,7 @@ class DefaultTemplateFunctions(object):
return res
# Flatten disambiguation value into a string.
disam_value = album._get_formatted(disambiguator, True)
disam_value = album.formatted(True).get(disambiguator)
res = u' [{0}]'.format(disam_value)
self.lib._memotable[memokey] = res
return res

View file

@ -576,8 +576,8 @@ def _field_diff(field, old, new):
return None
# Get formatted values for output.
oldstr = old.formatted.get(field, u'')
newstr = new.formatted.get(field, u'')
oldstr = old.formatted().get(field, u'')
newstr = new.formatted().get(field, u'')
# For strings, highlight changes. For others, colorize the whole
# thing.
@ -620,7 +620,7 @@ def show_model_changes(new, old=None, fields=None, always=False):
changes.append(u' {0}: {1}'.format(
field,
colorize('red', new.formatted[field])
colorize('red', new.formatted()[field])
))
# Print changes.

View file

@ -255,54 +255,54 @@ class FormatTest(_common.TestCase):
def test_format_fixed_field(self):
model = TestModel1()
model.field_one = u'caf\xe9'
value = model._get_formatted('field_one')
value = model.formatted().get('field_one')
self.assertEqual(value, u'caf\xe9')
def test_format_flex_field(self):
model = TestModel1()
model.other_field = u'caf\xe9'
value = model._get_formatted('other_field')
value = model.formatted().get('other_field')
self.assertEqual(value, u'caf\xe9')
def test_format_flex_field_bytes(self):
model = TestModel1()
model.other_field = u'caf\xe9'.encode('utf8')
value = model._get_formatted('other_field')
value = model.formatted().get('other_field')
self.assertTrue(isinstance(value, unicode))
self.assertEqual(value, u'caf\xe9')
def test_format_unset_field(self):
model = TestModel1()
value = model._get_formatted('other_field')
value = model.formatted().get('other_field')
self.assertEqual(value, u'')
def test_format_typed_flex_field(self):
model = TestModel1()
model.some_float_field = 3.14159265358979
value = model._get_formatted('some_float_field')
value = model.formatted().get('some_float_field')
self.assertEqual(value, u'3.1')
class FormattedMappingTest(_common.TestCase):
def test_keys_equal_model_keys(self):
model = TestModel1()
formatted = model._formatted_mapping()
formatted = model.formatted()
self.assertEqual(set(model.keys(True)), set(formatted.keys()))
def test_get_unset_field(self):
model = TestModel1()
formatted = model._formatted_mapping()
formatted = model.formatted()
with self.assertRaises(KeyError):
formatted['other_field']
def test_get_method_with_none_default(self):
def test_get_method_with_default(self):
model = TestModel1()
formatted = model._formatted_mapping()
self.assertIsNone(formatted.get('other_field'))
formatted = model.formatted()
self.assertEqual(formatted.get('other_field'), u'')
def test_get_method_with_specified_default(self):
model = TestModel1()
formatted = model._formatted_mapping()
formatted = model.formatted()
self.assertEqual(formatted.get('other_field', 'default'), 'default')

View file

@ -340,37 +340,37 @@ class DestinationTest(_common.TestCase):
with _common.platform_posix():
name = os.path.join('a', 'b')
self.i.title = name
newname = self.i._get_formatted('title')
newname = self.i.formatted().get('title')
self.assertEqual(name, newname)
def test_get_formatted_pads_with_zero(self):
with _common.platform_posix():
self.i.track = 1
name = self.i._get_formatted('track')
name = self.i.formatted().get('track')
self.assertTrue(name.startswith('0'))
def test_get_formatted_uses_kbps_bitrate(self):
with _common.platform_posix():
self.i.bitrate = 12345
val = self.i._get_formatted('bitrate')
val = self.i.formatted().get('bitrate')
self.assertEqual(val, u'12kbps')
def test_get_formatted_uses_khz_samplerate(self):
with _common.platform_posix():
self.i.samplerate = 12345
val = self.i._get_formatted('samplerate')
val = self.i.formatted().get('samplerate')
self.assertEqual(val, u'12kHz')
def test_get_formatted_datetime(self):
with _common.platform_posix():
self.i.added = 1368302461.210265
val = self.i._get_formatted('added')
val = self.i.formatted().get('added')
self.assertTrue(val.startswith('2013'))
def test_get_formatted_none(self):
with _common.platform_posix():
self.i.some_other_field = None
val = self.i._get_formatted('some_other_field')
val = self.i.formatted().get('some_other_field')
self.assertEqual(val, u'')
def test_artist_falls_back_to_albumartist(self):
@ -462,20 +462,20 @@ class DestinationTest(_common.TestCase):
class ItemFormattedMappingTest(_common.LibTestCase):
def test_formatted_item_value(self):
formatted = self.i._formatted_mapping()
formatted = self.i.formatted()
self.assertEqual(formatted['artist'], 'the artist')
def test_get_unset_field(self):
formatted = self.i._formatted_mapping()
formatted = self.i.formatted()
with self.assertRaises(KeyError):
formatted['other_field']
def test_get_method_with_none_default(self):
formatted = self.i._formatted_mapping()
self.assertIsNone(formatted.get('other_field'))
def test_get_method_with_default(self):
formatted = self.i.formatted()
self.assertEqual(formatted.get('other_field'), u'')
def test_get_method_with_specified_default(self):
formatted = self.i._formatted_mapping()
formatted = self.i.formatted()
self.assertEqual(formatted.get('other_field', 'default'), 'default')
def test_album_field_overrides_item_field(self):
@ -487,23 +487,23 @@ class ItemFormattedMappingTest(_common.LibTestCase):
self.i.store()
# Ensure the album takes precedence.
formatted = self.i._formatted_mapping()
formatted = self.i.formatted()
self.assertEqual(formatted['album'], 'foo')
def test_artist_falls_back_to_albumartist(self):
self.i.artist = ''
formatted = self.i._formatted_mapping()
formatted = self.i.formatted()
self.assertEqual(formatted['artist'], 'the album artist')
def test_albumartist_falls_back_to_artist(self):
self.i.albumartist = ''
formatted = self.i._formatted_mapping()
formatted = self.i.formatted()
self.assertEqual(formatted['albumartist'], 'the artist')
def test_both_artist_and_albumartist_empty(self):
self.i.artist = ''
self.i.albumartist = ''
formatted = self.i._formatted_mapping()
formatted = self.i.formatted()
self.assertEqual(formatted['albumartist'], '')