mirror of
https://github.com/beetbox/beets.git
synced 2026-01-13 03:34:31 +01:00
dbcore: Add types for non-fixed fields
The base Type class now serves as the catch-all logic for untyped fields.
This commit is contained in:
parent
429188e8e1
commit
394e4e45eb
4 changed files with 81 additions and 36 deletions
|
|
@ -25,6 +25,7 @@ import collections
|
|||
import beets
|
||||
from beets.util.functemplate import Template
|
||||
from .query import MatchQuery
|
||||
from .types import Type
|
||||
|
||||
|
||||
# Abstract base for model classes.
|
||||
|
|
@ -65,7 +66,7 @@ class Model(object):
|
|||
|
||||
_fields = {}
|
||||
"""A mapping indicating available "fixed" fields on this type. The
|
||||
keys are field names and the values are Type objects.
|
||||
keys are field names and the values are `Type` objects.
|
||||
"""
|
||||
|
||||
_bytes_keys = ()
|
||||
|
|
@ -78,6 +79,10 @@ class Model(object):
|
|||
terms.
|
||||
"""
|
||||
|
||||
_types = {}
|
||||
"""Optional Types for non-fixed (i.e., flexible and computed) fields.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _getters(cls):
|
||||
"""Return a mapping from field names to getter functions.
|
||||
|
|
@ -120,7 +125,10 @@ class Model(object):
|
|||
for key, value in fixed_values.items():
|
||||
obj._values_fixed[key] = cls._fields[key].normalize(value)
|
||||
if flex_values:
|
||||
obj._values_flex.update(flex_values)
|
||||
for key, value in flex_values.items():
|
||||
if key in cls._types:
|
||||
value = cls._types[key].normalize(value)
|
||||
obj._values_flex[key] = value
|
||||
return obj
|
||||
|
||||
def __repr__(self):
|
||||
|
|
@ -147,6 +155,15 @@ class Model(object):
|
|||
|
||||
# Essential field accessors.
|
||||
|
||||
@classmethod
|
||||
def _type(self, key):
|
||||
"""Get the type of a field, a `Type` instance.
|
||||
|
||||
If the field has no explicit type, it is given the base `Type`,
|
||||
which does no conversion.
|
||||
"""
|
||||
return self._fields.get(key) or self._types.get(key) or Type()
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get the value for a field. Raise a KeyError if the field is
|
||||
not available.
|
||||
|
|
@ -164,14 +181,15 @@ class Model(object):
|
|||
def __setitem__(self, key, value):
|
||||
"""Assign the value for a field.
|
||||
"""
|
||||
# Choose where to place the value. If the corresponding field
|
||||
# has a type, filter the value.
|
||||
# Choose where to place the value.
|
||||
if key in self._fields:
|
||||
source = self._values_fixed
|
||||
value = self._fields[key].normalize(value)
|
||||
else:
|
||||
source = self._values_flex
|
||||
|
||||
# If the field has a type, filter the value.
|
||||
value = self._type(key).normalize(value)
|
||||
|
||||
# Assign value and possibly mark as dirty.
|
||||
old_value = source.get(key)
|
||||
source[key] = value
|
||||
|
|
@ -366,25 +384,15 @@ class Model(object):
|
|||
def _format(cls, key, value, for_path=False):
|
||||
"""Format a value as the given field for this model.
|
||||
"""
|
||||
# Format the value as a string according to its type, if any.
|
||||
if key in cls._fields:
|
||||
value = cls._fields[key].format(value)
|
||||
# Formatting must result in a string. To deal with
|
||||
# Python2isms, implicitly convert ASCII strings.
|
||||
assert isinstance(value, basestring), \
|
||||
u'field formatter must produce strings'
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf8', 'ignore')
|
||||
# Format the value as a string according to its type.
|
||||
value = cls._type(key).format(value)
|
||||
|
||||
elif not isinstance(value, unicode):
|
||||
# Fallback formatter. Convert to unicode at all cost.
|
||||
if value is None:
|
||||
value = u''
|
||||
elif isinstance(value, basestring):
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf8', 'ignore')
|
||||
else:
|
||||
value = unicode(value)
|
||||
# Formatting must result in a string. To deal with
|
||||
# Python2isms, implicitly convert ASCII strings.
|
||||
assert isinstance(value, basestring), \
|
||||
u'field formatter must produce strings'
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf8', 'ignore')
|
||||
|
||||
if for_path:
|
||||
sep_repl = beets.config['path_sep_replace'].get(unicode)
|
||||
|
|
@ -439,12 +447,7 @@ class Model(object):
|
|||
if not isinstance(string, basestring):
|
||||
raise TypeError("_parse() argument must be a string")
|
||||
|
||||
typ = cls._fields.get(key)
|
||||
if typ:
|
||||
return typ.parse(string)
|
||||
else:
|
||||
# Fall back to unparsed string.
|
||||
return string
|
||||
return cls._type(key).parse(string)
|
||||
|
||||
|
||||
class FormattedMapping(collections.Mapping):
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
"""Parsing of strings into DBCore queries.
|
||||
"""
|
||||
import re
|
||||
import itertools
|
||||
from . import query
|
||||
|
||||
|
||||
|
|
@ -83,8 +84,13 @@ def construct_query_part(model_cls, prefixes, query_part):
|
|||
if not query_part:
|
||||
return query.TrueQuery()
|
||||
|
||||
# Set up and parse the string.
|
||||
query_classes = dict((k, t.query) for (k, t) in model_cls._fields.items())
|
||||
# Get the query classes for each possible field.
|
||||
query_classes = {}
|
||||
for k, t in itertools.chain(model_cls._fields.items(),
|
||||
model_cls._types.items()):
|
||||
query_classes[k] = t.query
|
||||
|
||||
# Parse the string.
|
||||
key, pattern, query_class = \
|
||||
parse_query_part(query_part, query_classes, prefixes)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,11 +28,11 @@ class Type(object):
|
|||
field.
|
||||
"""
|
||||
|
||||
sql = None
|
||||
sql = u'TEXT'
|
||||
"""The SQLite column type for the value.
|
||||
"""
|
||||
|
||||
query = None
|
||||
query = query.SubstringQuery
|
||||
"""The `Query` subclass to be used when querying the field.
|
||||
"""
|
||||
|
||||
|
|
@ -44,20 +44,28 @@ class Type(object):
|
|||
"""Given a value of this type, produce a Unicode string
|
||||
representing the value. This is used in template evaluation.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
# Fallback formatter. Convert to Unicode at all cost.
|
||||
if value is None:
|
||||
return u''
|
||||
elif isinstance(value, basestring):
|
||||
if isinstance(value, bytes):
|
||||
return value.decode('utf8', 'ignore')
|
||||
else:
|
||||
return value
|
||||
else:
|
||||
return unicode(value)
|
||||
|
||||
def parse(self, string):
|
||||
"""Parse a (possibly human-written) string and return the
|
||||
indicated value of this type.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
return string
|
||||
|
||||
def normalize(self, value):
|
||||
"""Given a value that will be assigned into a field of this
|
||||
type, normalize the value to have the appropriate type. This
|
||||
base implementation only reinterprets `None`.
|
||||
"""
|
||||
# TODO gradually remove the normalization of None.
|
||||
if value is None:
|
||||
return self.null
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -32,6 +32,9 @@ class TestModel1(dbcore.Model):
|
|||
'id': dbcore.types.Id(),
|
||||
'field_one': dbcore.types.Integer(),
|
||||
}
|
||||
_types = {
|
||||
'some_float_field': dbcore.types.Float(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _getters(cls):
|
||||
|
|
@ -242,6 +245,11 @@ class ModelTest(_common.TestCase):
|
|||
model.foo = None
|
||||
self.assertEqual(model.foo, None)
|
||||
|
||||
def test_normalization_for_typed_flex_fields(self):
|
||||
model = TestModel1()
|
||||
model.some_float_field = None
|
||||
self.assertEqual(model.some_float_field, 0.0)
|
||||
|
||||
|
||||
class FormatTest(_common.TestCase):
|
||||
def test_format_fixed_field(self):
|
||||
|
|
@ -268,6 +276,12 @@ class FormatTest(_common.TestCase):
|
|||
value = model._get_formatted('other_field')
|
||||
self.assertEqual(value, u'')
|
||||
|
||||
def test_format_typed_flex_field(self):
|
||||
model = TestModel1()
|
||||
model.some_float_field = 3.14159265358979
|
||||
value = model._get_formatted('some_float_field')
|
||||
self.assertEqual(value, u'3.1')
|
||||
|
||||
|
||||
class FormattedMappingTest(_common.TestCase):
|
||||
def test_keys_equal_model_keys(self):
|
||||
|
|
@ -295,8 +309,14 @@ class FormattedMappingTest(_common.TestCase):
|
|||
class ParseTest(_common.TestCase):
|
||||
def test_parse_fixed_field(self):
|
||||
value = TestModel1._parse('field_one', u'2')
|
||||
self.assertIsInstance(value, int)
|
||||
self.assertEqual(value, 2)
|
||||
|
||||
def test_parse_flex_field(self):
|
||||
value = TestModel1._parse('some_float_field', u'2')
|
||||
self.assertIsInstance(value, float)
|
||||
self.assertEqual(value, 2.0)
|
||||
|
||||
def test_parse_untyped_field(self):
|
||||
value = TestModel1._parse('field_nine', u'2')
|
||||
self.assertEqual(value, u'2')
|
||||
|
|
@ -383,6 +403,14 @@ class QueryFromStringsTest(_common.TestCase):
|
|||
self.assertIsInstance(q.subqueries[0], dbcore.query.AnyFieldQuery)
|
||||
self.assertIsInstance(q.subqueries[1], dbcore.query.SubstringQuery)
|
||||
|
||||
def test_parse_fixed_type_query(self):
|
||||
q = self.qfs(['field_one:2..3'])
|
||||
self.assertIsInstance(q.subqueries[0], dbcore.query.NumericQuery)
|
||||
|
||||
def test_parse_flex_type_query(self):
|
||||
q = self.qfs(['some_float_field:2..3'])
|
||||
self.assertIsInstance(q.subqueries[0], dbcore.query.NumericQuery)
|
||||
|
||||
|
||||
def suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
|
|
|||
Loading…
Reference in a new issue