mirror of
https://github.com/beetbox/beets.git
synced 2026-02-13 10:51:59 +01:00
Refactor model formatting
Remove all formatting related code from models. It now lives in the `FormattedMapping` class. Only API change is from `model.formatted` to `model.formatted()`.
This commit is contained in:
parent
b5c9271baa
commit
0798af7774
5 changed files with 131 additions and 159 deletions
|
|
@ -28,6 +28,50 @@ from .query import MatchQuery, build_sql
|
|||
from .types import BASE_TYPE
|
||||
|
||||
|
||||
class FormattedMapping(collections.Mapping):
|
||||
"""A `dict`-like formatted view of a model.
|
||||
|
||||
The accessor `mapping[key]` returns the formated version of
|
||||
`model[key]` as a unicode string.
|
||||
|
||||
If `for_path` is true, all path separators in the formatted values
|
||||
are replaced.
|
||||
"""
|
||||
|
||||
def __init__(self, model, for_path=False):
|
||||
self.for_path = for_path
|
||||
self.model = model
|
||||
self.model_keys = model.keys(True)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self.model_keys:
|
||||
return self._get_formatted(self.model, key)
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.model_keys)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.model_keys)
|
||||
|
||||
def get(self, key, default=u''):
|
||||
return super(FormattedMapping, self).get(key, default)
|
||||
|
||||
def _get_formatted(self, model, key):
|
||||
value = model._type(key).format(model.get(key))
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf8', 'ignore')
|
||||
|
||||
if self.for_path:
|
||||
sep_repl = beets.config['path_sep_replace'].get(unicode)
|
||||
for sep in (os.path.sep, os.path.altsep):
|
||||
if sep:
|
||||
value = value.replace(sep, sep_repl)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
# Abstract base for model classes.
|
||||
|
||||
class Model(object):
|
||||
|
|
@ -380,63 +424,24 @@ class Model(object):
|
|||
|
||||
# Formatting and templating.
|
||||
|
||||
@classmethod
|
||||
def _format(cls, key, value, for_path=False):
|
||||
"""Format a value as the given field for this model.
|
||||
"""
|
||||
# Format the value as a string according to its type.
|
||||
value = cls._type(key).format(value)
|
||||
_formatter = FormattedMapping
|
||||
|
||||
# Formatting must result in a string. To deal with
|
||||
# Python2isms, implicitly convert ASCII strings.
|
||||
assert isinstance(value, basestring), \
|
||||
u'field formatter must produce strings'
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf8', 'ignore')
|
||||
|
||||
if for_path:
|
||||
sep_repl = beets.config['path_sep_replace'].get(unicode)
|
||||
for sep in (os.path.sep, os.path.altsep):
|
||||
if sep:
|
||||
value = value.replace(sep, sep_repl)
|
||||
|
||||
return value
|
||||
|
||||
def _get_formatted(self, key, for_path=False):
|
||||
"""Get a field value formatted as a string (`unicode` object)
|
||||
for display to the user. If `for_path` is true, then the value
|
||||
will be sanitized for inclusion in a pathname (i.e., path
|
||||
separators will be removed from the value).
|
||||
"""
|
||||
return self._format(key, self.get(key), for_path)
|
||||
|
||||
def _formatted_mapping(self, for_path=False):
|
||||
def formatted(self, for_path=False):
|
||||
"""Get a mapping containing all values on this object formatted
|
||||
as human-readable strings.
|
||||
as human-readable unicode strings.
|
||||
"""
|
||||
return FormattedMapping(self, for_path)
|
||||
|
||||
@property
|
||||
def formatted(self):
|
||||
"""A `dict`-like view containing formatted values.
|
||||
"""
|
||||
return self._formatted_mapping(False)
|
||||
return self._formatter(self, for_path)
|
||||
|
||||
def evaluate_template(self, template, for_path=False):
|
||||
"""Evaluate a template (a string or a `Template` object) using
|
||||
the object's fields. If `for_path` is true, then no new path
|
||||
separators will be added to the template.
|
||||
"""
|
||||
# Build value mapping.
|
||||
mapping = self._formatted_mapping(for_path)
|
||||
|
||||
# Get template functions.
|
||||
funcs = self._template_funcs()
|
||||
|
||||
# Perform substitution.
|
||||
if isinstance(template, basestring):
|
||||
template = Template(template)
|
||||
return template.substitute(mapping, funcs)
|
||||
return template.substitute(self.formatted(for_path),
|
||||
self._template_funcs())
|
||||
|
||||
# Parsing.
|
||||
|
||||
|
|
@ -450,33 +455,6 @@ class Model(object):
|
|||
return cls._type(key).parse(string)
|
||||
|
||||
|
||||
class FormattedMapping(collections.Mapping):
|
||||
"""A `dict`-like formatted view of a model.
|
||||
|
||||
The accessor ``mapping[key]`` returns the formated version of
|
||||
``model[key]``. The formatting is handled by `model._format()`.
|
||||
"""
|
||||
# TODO Move all formatting logic here
|
||||
# TODO Add caching
|
||||
|
||||
def __init__(self, model, for_path=False):
|
||||
self.for_path = for_path
|
||||
self.model = model
|
||||
self.model_keys = model.keys(True)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self.model_keys:
|
||||
return self.model._get_formatted(key, self.for_path)
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.model_keys)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.model_keys)
|
||||
|
||||
|
||||
# Database controller and supporting interfaces.
|
||||
|
||||
class Results(object):
|
||||
|
|
|
|||
108
beets/library.py
108
beets/library.py
|
|
@ -216,6 +216,54 @@ class LibModel(dbcore.Model):
|
|||
plugins.send('database_change', lib=self._db)
|
||||
|
||||
|
||||
class FormattedItemMapping(dbcore.db.FormattedMapping):
|
||||
"""Add lookup for album level fields.
|
||||
"""
|
||||
|
||||
def __init__(self, item, for_path=False):
|
||||
super(FormattedItemMapping, self).__init__(item, for_path)
|
||||
self.album = item.get_album()
|
||||
self.album_keys = []
|
||||
if self.album:
|
||||
for key in self.album.keys(True):
|
||||
if key in Album.item_keys or key not in item._fields.keys():
|
||||
self.album_keys.append(key)
|
||||
self.all_keys = set(self.model_keys).union(self.album_keys)
|
||||
|
||||
def _get(self, key):
|
||||
"""Get the value for a key, either from the album or the item.
|
||||
Raise a KeyError for invalid keys.
|
||||
"""
|
||||
if key in self.album_keys:
|
||||
return self._get_formatted(self.album, key)
|
||||
elif key in self.model_keys:
|
||||
return self._get_formatted(self.model, key)
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get the value for a key. Certain unset values are remapped.
|
||||
"""
|
||||
value = self._get(key)
|
||||
|
||||
# `artist` and `albumartist` fields fall back to one another.
|
||||
# This is helpful in path formats when the album artist is unset
|
||||
# on as-is imports.
|
||||
if key == 'artist' and not value:
|
||||
return self._get('albumartist')
|
||||
elif key == 'albumartist' and not value:
|
||||
return self._get('artist')
|
||||
else:
|
||||
return value
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.all_keys)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.all_keys)
|
||||
|
||||
|
||||
|
||||
class Item(LibModel):
|
||||
_table = 'items'
|
||||
_flex_table = 'item_attributes'
|
||||
|
|
@ -296,6 +344,8 @@ class Item(LibModel):
|
|||
`write`.
|
||||
"""
|
||||
|
||||
_formatter = FormattedItemMapping
|
||||
|
||||
@classmethod
|
||||
def _getters(cls):
|
||||
getters = plugins.item_field_getters()
|
||||
|
|
@ -523,12 +573,6 @@ class Item(LibModel):
|
|||
|
||||
# Templating.
|
||||
|
||||
def _formatted_mapping(self, for_path=False):
|
||||
"""Get a mapping containing string-formatted values from either
|
||||
this item or the associated album, if any.
|
||||
"""
|
||||
return FormattedItemMapping(self, for_path)
|
||||
|
||||
def destination(self, fragment=False, basedir=None, platform=None,
|
||||
path_formats=None):
|
||||
"""Returns the path in the library directory designated for the
|
||||
|
|
@ -604,56 +648,6 @@ class Item(LibModel):
|
|||
return normpath(os.path.join(basedir, subpath))
|
||||
|
||||
|
||||
class FormattedItemMapping(dbcore.db.FormattedMapping):
|
||||
"""A `dict`-like formatted view of an item that inherits album fields.
|
||||
|
||||
The accessor ``mapping[key]`` returns the formated version of either
|
||||
``item[key]`` or ``album[key]``. Here `album` is the album
|
||||
associated to `item` if it exists.
|
||||
"""
|
||||
def __init__(self, item, for_path=False):
|
||||
super(FormattedItemMapping, self).__init__(item, for_path)
|
||||
self.album = item.get_album()
|
||||
self.album_keys = []
|
||||
if self.album:
|
||||
for key in self.album.keys(True):
|
||||
if key in Album.item_keys or key not in item._fields.keys():
|
||||
self.album_keys.append(key)
|
||||
self.all_keys = set(self.model_keys).union(self.album_keys)
|
||||
|
||||
def _get(self, key):
|
||||
"""Get the value for a key, either from the album or the item.
|
||||
Raise a KeyError for invalid keys.
|
||||
"""
|
||||
if key in self.album_keys:
|
||||
return self.album._get_formatted(key, self.for_path)
|
||||
elif key in self.model_keys:
|
||||
return self.model._get_formatted(key, self.for_path)
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get the value for a key. Certain unset values are remapped.
|
||||
"""
|
||||
value = self._get(key)
|
||||
|
||||
# `artist` and `albumartist` fields fall back to one another.
|
||||
# This is helpful in path formats when the album artist is unset
|
||||
# on as-is imports.
|
||||
if key == 'artist' and not value:
|
||||
return self._get('albumartist')
|
||||
elif key == 'albumartist' and not value:
|
||||
return self._get('artist')
|
||||
else:
|
||||
return value
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.all_keys)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.all_keys)
|
||||
|
||||
|
||||
class Album(LibModel):
|
||||
"""Provides access to information about albums stored in a
|
||||
library. Reflects the library's "albums" table, including album
|
||||
|
|
@ -1219,7 +1213,7 @@ class DefaultTemplateFunctions(object):
|
|||
return res
|
||||
|
||||
# Flatten disambiguation value into a string.
|
||||
disam_value = album._get_formatted(disambiguator, True)
|
||||
disam_value = album.formatted(True).get(disambiguator)
|
||||
res = u' [{0}]'.format(disam_value)
|
||||
self.lib._memotable[memokey] = res
|
||||
return res
|
||||
|
|
|
|||
|
|
@ -576,8 +576,8 @@ def _field_diff(field, old, new):
|
|||
return None
|
||||
|
||||
# Get formatted values for output.
|
||||
oldstr = old.formatted.get(field, u'')
|
||||
newstr = new.formatted.get(field, u'')
|
||||
oldstr = old.formatted().get(field, u'')
|
||||
newstr = new.formatted().get(field, u'')
|
||||
|
||||
# For strings, highlight changes. For others, colorize the whole
|
||||
# thing.
|
||||
|
|
@ -620,7 +620,7 @@ def show_model_changes(new, old=None, fields=None, always=False):
|
|||
|
||||
changes.append(u' {0}: {1}'.format(
|
||||
field,
|
||||
colorize('red', new.formatted[field])
|
||||
colorize('red', new.formatted()[field])
|
||||
))
|
||||
|
||||
# Print changes.
|
||||
|
|
|
|||
|
|
@ -255,54 +255,54 @@ class FormatTest(_common.TestCase):
|
|||
def test_format_fixed_field(self):
|
||||
model = TestModel1()
|
||||
model.field_one = u'caf\xe9'
|
||||
value = model._get_formatted('field_one')
|
||||
value = model.formatted().get('field_one')
|
||||
self.assertEqual(value, u'caf\xe9')
|
||||
|
||||
def test_format_flex_field(self):
|
||||
model = TestModel1()
|
||||
model.other_field = u'caf\xe9'
|
||||
value = model._get_formatted('other_field')
|
||||
value = model.formatted().get('other_field')
|
||||
self.assertEqual(value, u'caf\xe9')
|
||||
|
||||
def test_format_flex_field_bytes(self):
|
||||
model = TestModel1()
|
||||
model.other_field = u'caf\xe9'.encode('utf8')
|
||||
value = model._get_formatted('other_field')
|
||||
value = model.formatted().get('other_field')
|
||||
self.assertTrue(isinstance(value, unicode))
|
||||
self.assertEqual(value, u'caf\xe9')
|
||||
|
||||
def test_format_unset_field(self):
|
||||
model = TestModel1()
|
||||
value = model._get_formatted('other_field')
|
||||
value = model.formatted().get('other_field')
|
||||
self.assertEqual(value, u'')
|
||||
|
||||
def test_format_typed_flex_field(self):
|
||||
model = TestModel1()
|
||||
model.some_float_field = 3.14159265358979
|
||||
value = model._get_formatted('some_float_field')
|
||||
value = model.formatted().get('some_float_field')
|
||||
self.assertEqual(value, u'3.1')
|
||||
|
||||
|
||||
class FormattedMappingTest(_common.TestCase):
|
||||
def test_keys_equal_model_keys(self):
|
||||
model = TestModel1()
|
||||
formatted = model._formatted_mapping()
|
||||
formatted = model.formatted()
|
||||
self.assertEqual(set(model.keys(True)), set(formatted.keys()))
|
||||
|
||||
def test_get_unset_field(self):
|
||||
model = TestModel1()
|
||||
formatted = model._formatted_mapping()
|
||||
formatted = model.formatted()
|
||||
with self.assertRaises(KeyError):
|
||||
formatted['other_field']
|
||||
|
||||
def test_get_method_with_none_default(self):
|
||||
def test_get_method_with_default(self):
|
||||
model = TestModel1()
|
||||
formatted = model._formatted_mapping()
|
||||
self.assertIsNone(formatted.get('other_field'))
|
||||
formatted = model.formatted()
|
||||
self.assertEqual(formatted.get('other_field'), u'')
|
||||
|
||||
def test_get_method_with_specified_default(self):
|
||||
model = TestModel1()
|
||||
formatted = model._formatted_mapping()
|
||||
formatted = model.formatted()
|
||||
self.assertEqual(formatted.get('other_field', 'default'), 'default')
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -340,37 +340,37 @@ class DestinationTest(_common.TestCase):
|
|||
with _common.platform_posix():
|
||||
name = os.path.join('a', 'b')
|
||||
self.i.title = name
|
||||
newname = self.i._get_formatted('title')
|
||||
newname = self.i.formatted().get('title')
|
||||
self.assertEqual(name, newname)
|
||||
|
||||
def test_get_formatted_pads_with_zero(self):
|
||||
with _common.platform_posix():
|
||||
self.i.track = 1
|
||||
name = self.i._get_formatted('track')
|
||||
name = self.i.formatted().get('track')
|
||||
self.assertTrue(name.startswith('0'))
|
||||
|
||||
def test_get_formatted_uses_kbps_bitrate(self):
|
||||
with _common.platform_posix():
|
||||
self.i.bitrate = 12345
|
||||
val = self.i._get_formatted('bitrate')
|
||||
val = self.i.formatted().get('bitrate')
|
||||
self.assertEqual(val, u'12kbps')
|
||||
|
||||
def test_get_formatted_uses_khz_samplerate(self):
|
||||
with _common.platform_posix():
|
||||
self.i.samplerate = 12345
|
||||
val = self.i._get_formatted('samplerate')
|
||||
val = self.i.formatted().get('samplerate')
|
||||
self.assertEqual(val, u'12kHz')
|
||||
|
||||
def test_get_formatted_datetime(self):
|
||||
with _common.platform_posix():
|
||||
self.i.added = 1368302461.210265
|
||||
val = self.i._get_formatted('added')
|
||||
val = self.i.formatted().get('added')
|
||||
self.assertTrue(val.startswith('2013'))
|
||||
|
||||
def test_get_formatted_none(self):
|
||||
with _common.platform_posix():
|
||||
self.i.some_other_field = None
|
||||
val = self.i._get_formatted('some_other_field')
|
||||
val = self.i.formatted().get('some_other_field')
|
||||
self.assertEqual(val, u'')
|
||||
|
||||
def test_artist_falls_back_to_albumartist(self):
|
||||
|
|
@ -462,20 +462,20 @@ class DestinationTest(_common.TestCase):
|
|||
|
||||
class ItemFormattedMappingTest(_common.LibTestCase):
|
||||
def test_formatted_item_value(self):
|
||||
formatted = self.i._formatted_mapping()
|
||||
formatted = self.i.formatted()
|
||||
self.assertEqual(formatted['artist'], 'the artist')
|
||||
|
||||
def test_get_unset_field(self):
|
||||
formatted = self.i._formatted_mapping()
|
||||
formatted = self.i.formatted()
|
||||
with self.assertRaises(KeyError):
|
||||
formatted['other_field']
|
||||
|
||||
def test_get_method_with_none_default(self):
|
||||
formatted = self.i._formatted_mapping()
|
||||
self.assertIsNone(formatted.get('other_field'))
|
||||
def test_get_method_with_default(self):
|
||||
formatted = self.i.formatted()
|
||||
self.assertEqual(formatted.get('other_field'), u'')
|
||||
|
||||
def test_get_method_with_specified_default(self):
|
||||
formatted = self.i._formatted_mapping()
|
||||
formatted = self.i.formatted()
|
||||
self.assertEqual(formatted.get('other_field', 'default'), 'default')
|
||||
|
||||
def test_album_field_overrides_item_field(self):
|
||||
|
|
@ -487,23 +487,23 @@ class ItemFormattedMappingTest(_common.LibTestCase):
|
|||
self.i.store()
|
||||
|
||||
# Ensure the album takes precedence.
|
||||
formatted = self.i._formatted_mapping()
|
||||
formatted = self.i.formatted()
|
||||
self.assertEqual(formatted['album'], 'foo')
|
||||
|
||||
def test_artist_falls_back_to_albumartist(self):
|
||||
self.i.artist = ''
|
||||
formatted = self.i._formatted_mapping()
|
||||
formatted = self.i.formatted()
|
||||
self.assertEqual(formatted['artist'], 'the album artist')
|
||||
|
||||
def test_albumartist_falls_back_to_artist(self):
|
||||
self.i.albumartist = ''
|
||||
formatted = self.i._formatted_mapping()
|
||||
formatted = self.i.formatted()
|
||||
self.assertEqual(formatted['albumartist'], 'the artist')
|
||||
|
||||
def test_both_artist_and_albumartist_empty(self):
|
||||
self.i.artist = ''
|
||||
self.i.albumartist = ''
|
||||
formatted = self.i._formatted_mapping()
|
||||
formatted = self.i.formatted()
|
||||
self.assertEqual(formatted['albumartist'], '')
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue