dbcore: Add types for non-fixed fields

The base Type class now serves as the catch-all logic for untyped fields.
This commit is contained in:
Adrian Sampson 2014-05-25 16:23:15 -07:00
parent 429188e8e1
commit 394e4e45eb
4 changed files with 81 additions and 36 deletions

View file

@ -25,6 +25,7 @@ import collections
import beets
from beets.util.functemplate import Template
from .query import MatchQuery
from .types import Type
# Abstract base for model classes.
@ -65,7 +66,7 @@ class Model(object):
_fields = {}
"""A mapping indicating available "fixed" fields on this type. The
keys are field names and the values are Type objects.
keys are field names and the values are `Type` objects.
"""
_bytes_keys = ()
@ -78,6 +79,10 @@ class Model(object):
terms.
"""
_types = {}
"""Optional Types for non-fixed (i.e., flexible and computed) fields.
"""
@classmethod
def _getters(cls):
"""Return a mapping from field names to getter functions.
@ -120,7 +125,10 @@ class Model(object):
for key, value in fixed_values.items():
obj._values_fixed[key] = cls._fields[key].normalize(value)
if flex_values:
obj._values_flex.update(flex_values)
for key, value in flex_values.items():
if key in cls._types:
value = cls._types[key].normalize(value)
obj._values_flex[key] = value
return obj
def __repr__(self):
@ -147,6 +155,15 @@ class Model(object):
# Essential field accessors.
@classmethod
def _type(self, key):
"""Get the type of a field, a `Type` instance.
If the field has no explicit type, it is given the base `Type`,
which does no conversion.
"""
return self._fields.get(key) or self._types.get(key) or Type()
def __getitem__(self, key):
"""Get the value for a field. Raise a KeyError if the field is
not available.
@ -164,14 +181,15 @@ class Model(object):
def __setitem__(self, key, value):
"""Assign the value for a field.
"""
# Choose where to place the value. If the corresponding field
# has a type, filter the value.
# Choose where to place the value.
if key in self._fields:
source = self._values_fixed
value = self._fields[key].normalize(value)
else:
source = self._values_flex
# If the field has a type, filter the value.
value = self._type(key).normalize(value)
# Assign value and possibly mark as dirty.
old_value = source.get(key)
source[key] = value
@ -366,25 +384,15 @@ class Model(object):
def _format(cls, key, value, for_path=False):
"""Format a value as the given field for this model.
"""
# Format the value as a string according to its type, if any.
if key in cls._fields:
value = cls._fields[key].format(value)
# Formatting must result in a string. To deal with
# Python2isms, implicitly convert ASCII strings.
assert isinstance(value, basestring), \
u'field formatter must produce strings'
if isinstance(value, bytes):
value = value.decode('utf8', 'ignore')
# Format the value as a string according to its type.
value = cls._type(key).format(value)
elif not isinstance(value, unicode):
# Fallback formatter. Convert to unicode at all cost.
if value is None:
value = u''
elif isinstance(value, basestring):
if isinstance(value, bytes):
value = value.decode('utf8', 'ignore')
else:
value = unicode(value)
# Formatting must result in a string. To deal with
# Python2isms, implicitly convert ASCII strings.
assert isinstance(value, basestring), \
u'field formatter must produce strings'
if isinstance(value, bytes):
value = value.decode('utf8', 'ignore')
if for_path:
sep_repl = beets.config['path_sep_replace'].get(unicode)
@ -439,12 +447,7 @@ class Model(object):
if not isinstance(string, basestring):
raise TypeError("_parse() argument must be a string")
typ = cls._fields.get(key)
if typ:
return typ.parse(string)
else:
# Fall back to unparsed string.
return string
return cls._type(key).parse(string)
class FormattedMapping(collections.Mapping):

View file

@ -15,6 +15,7 @@
"""Parsing of strings into DBCore queries.
"""
import re
import itertools
from . import query
@ -83,8 +84,13 @@ def construct_query_part(model_cls, prefixes, query_part):
if not query_part:
return query.TrueQuery()
# Set up and parse the string.
query_classes = dict((k, t.query) for (k, t) in model_cls._fields.items())
# Get the query classes for each possible field.
query_classes = {}
for k, t in itertools.chain(model_cls._fields.items(),
model_cls._types.items()):
query_classes[k] = t.query
# Parse the string.
key, pattern, query_class = \
parse_query_part(query_part, query_classes, prefixes)

View file

@ -28,11 +28,11 @@ class Type(object):
field.
"""
sql = None
sql = u'TEXT'
"""The SQLite column type for the value.
"""
query = None
query = query.SubstringQuery
"""The `Query` subclass to be used when querying the field.
"""
@ -44,20 +44,28 @@ class Type(object):
"""Given a value of this type, produce a Unicode string
representing the value. This is used in template evaluation.
"""
raise NotImplementedError()
# Fallback formatter. Convert to Unicode at all cost.
if value is None:
return u''
elif isinstance(value, basestring):
if isinstance(value, bytes):
return value.decode('utf8', 'ignore')
else:
return value
else:
return unicode(value)
def parse(self, string):
"""Parse a (possibly human-written) string and return the
indicated value of this type.
"""
raise NotImplementedError()
return string
def normalize(self, value):
"""Given a value that will be assigned into a field of this
type, normalize the value to have the appropriate type. This
base implementation only reinterprets `None`.
"""
# TODO gradually remove the normalization of None.
if value is None:
return self.null
else:

View file

@ -32,6 +32,9 @@ class TestModel1(dbcore.Model):
'id': dbcore.types.Id(),
'field_one': dbcore.types.Integer(),
}
_types = {
'some_float_field': dbcore.types.Float(),
}
@classmethod
def _getters(cls):
@ -242,6 +245,11 @@ class ModelTest(_common.TestCase):
model.foo = None
self.assertEqual(model.foo, None)
def test_normalization_for_typed_flex_fields(self):
model = TestModel1()
model.some_float_field = None
self.assertEqual(model.some_float_field, 0.0)
class FormatTest(_common.TestCase):
def test_format_fixed_field(self):
@ -268,6 +276,12 @@ class FormatTest(_common.TestCase):
value = model._get_formatted('other_field')
self.assertEqual(value, u'')
def test_format_typed_flex_field(self):
model = TestModel1()
model.some_float_field = 3.14159265358979
value = model._get_formatted('some_float_field')
self.assertEqual(value, u'3.1')
class FormattedMappingTest(_common.TestCase):
def test_keys_equal_model_keys(self):
@ -295,8 +309,14 @@ class FormattedMappingTest(_common.TestCase):
class ParseTest(_common.TestCase):
def test_parse_fixed_field(self):
value = TestModel1._parse('field_one', u'2')
self.assertIsInstance(value, int)
self.assertEqual(value, 2)
def test_parse_flex_field(self):
value = TestModel1._parse('some_float_field', u'2')
self.assertIsInstance(value, float)
self.assertEqual(value, 2.0)
def test_parse_untyped_field(self):
value = TestModel1._parse('field_nine', u'2')
self.assertEqual(value, u'2')
@ -383,6 +403,14 @@ class QueryFromStringsTest(_common.TestCase):
self.assertIsInstance(q.subqueries[0], dbcore.query.AnyFieldQuery)
self.assertIsInstance(q.subqueries[1], dbcore.query.SubstringQuery)
def test_parse_fixed_type_query(self):
q = self.qfs(['field_one:2..3'])
self.assertIsInstance(q.subqueries[0], dbcore.query.NumericQuery)
def test_parse_flex_type_query(self):
q = self.qfs(['some_float_field:2..3'])
self.assertIsInstance(q.subqueries[0], dbcore.query.NumericQuery)
def suite():
return unittest.TestLoader().loadTestsFromName(__name__)