diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index a2d7b3513..6d09eb006 100644 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -25,6 +25,7 @@ import collections import beets from beets.util.functemplate import Template from .query import MatchQuery +from .types import Type # Abstract base for model classes. @@ -65,7 +66,7 @@ class Model(object): _fields = {} """A mapping indicating available "fixed" fields on this type. The - keys are field names and the values are Type objects. + keys are field names and the values are `Type` objects. """ _bytes_keys = () @@ -78,6 +79,10 @@ class Model(object): terms. """ + _types = {} + """Optional Types for non-fixed (i.e., flexible and computed) fields. + """ + @classmethod def _getters(cls): """Return a mapping from field names to getter functions. @@ -120,7 +125,10 @@ class Model(object): for key, value in fixed_values.items(): obj._values_fixed[key] = cls._fields[key].normalize(value) if flex_values: - obj._values_flex.update(flex_values) + for key, value in flex_values.items(): + if key in cls._types: + value = cls._types[key].normalize(value) + obj._values_flex[key] = value return obj def __repr__(self): @@ -147,6 +155,15 @@ class Model(object): # Essential field accessors. + @classmethod + def _type(self, key): + """Get the type of a field, a `Type` instance. + + If the field has no explicit type, it is given the base `Type`, + which does no conversion. + """ + return self._fields.get(key) or self._types.get(key) or Type() + def __getitem__(self, key): """Get the value for a field. Raise a KeyError if the field is not available. @@ -164,14 +181,15 @@ class Model(object): def __setitem__(self, key, value): """Assign the value for a field. """ - # Choose where to place the value. If the corresponding field - # has a type, filter the value. + # Choose where to place the value. if key in self._fields: source = self._values_fixed - value = self._fields[key].normalize(value) else: source = self._values_flex + # If the field has a type, filter the value. + value = self._type(key).normalize(value) + # Assign value and possibly mark as dirty. old_value = source.get(key) source[key] = value @@ -366,25 +384,15 @@ class Model(object): def _format(cls, key, value, for_path=False): """Format a value as the given field for this model. """ - # Format the value as a string according to its type, if any. - if key in cls._fields: - value = cls._fields[key].format(value) - # Formatting must result in a string. To deal with - # Python2isms, implicitly convert ASCII strings. - assert isinstance(value, basestring), \ - u'field formatter must produce strings' - if isinstance(value, bytes): - value = value.decode('utf8', 'ignore') + # Format the value as a string according to its type. + value = cls._type(key).format(value) - elif not isinstance(value, unicode): - # Fallback formatter. Convert to unicode at all cost. - if value is None: - value = u'' - elif isinstance(value, basestring): - if isinstance(value, bytes): - value = value.decode('utf8', 'ignore') - else: - value = unicode(value) + # Formatting must result in a string. To deal with + # Python2isms, implicitly convert ASCII strings. + assert isinstance(value, basestring), \ + u'field formatter must produce strings' + if isinstance(value, bytes): + value = value.decode('utf8', 'ignore') if for_path: sep_repl = beets.config['path_sep_replace'].get(unicode) @@ -439,12 +447,7 @@ class Model(object): if not isinstance(string, basestring): raise TypeError("_parse() argument must be a string") - typ = cls._fields.get(key) - if typ: - return typ.parse(string) - else: - # Fall back to unparsed string. - return string + return cls._type(key).parse(string) class FormattedMapping(collections.Mapping): diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index 9c0116d36..a767b56d1 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -15,6 +15,7 @@ """Parsing of strings into DBCore queries. """ import re +import itertools from . import query @@ -83,8 +84,13 @@ def construct_query_part(model_cls, prefixes, query_part): if not query_part: return query.TrueQuery() - # Set up and parse the string. - query_classes = dict((k, t.query) for (k, t) in model_cls._fields.items()) + # Get the query classes for each possible field. + query_classes = {} + for k, t in itertools.chain(model_cls._fields.items(), + model_cls._types.items()): + query_classes[k] = t.query + + # Parse the string. key, pattern, query_class = \ parse_query_part(query_part, query_classes, prefixes) diff --git a/beets/dbcore/types.py b/beets/dbcore/types.py index d6bdeb0e9..fea70a5a3 100644 --- a/beets/dbcore/types.py +++ b/beets/dbcore/types.py @@ -28,11 +28,11 @@ class Type(object): field. """ - sql = None + sql = u'TEXT' """The SQLite column type for the value. """ - query = None + query = query.SubstringQuery """The `Query` subclass to be used when querying the field. """ @@ -44,20 +44,28 @@ class Type(object): """Given a value of this type, produce a Unicode string representing the value. This is used in template evaluation. """ - raise NotImplementedError() + # Fallback formatter. Convert to Unicode at all cost. + if value is None: + return u'' + elif isinstance(value, basestring): + if isinstance(value, bytes): + return value.decode('utf8', 'ignore') + else: + return value + else: + return unicode(value) def parse(self, string): """Parse a (possibly human-written) string and return the indicated value of this type. """ - raise NotImplementedError() + return string def normalize(self, value): """Given a value that will be assigned into a field of this type, normalize the value to have the appropriate type. This base implementation only reinterprets `None`. """ - # TODO gradually remove the normalization of None. if value is None: return self.null else: diff --git a/test/test_dbcore.py b/test/test_dbcore.py index 6b10df700..768533590 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -32,6 +32,9 @@ class TestModel1(dbcore.Model): 'id': dbcore.types.Id(), 'field_one': dbcore.types.Integer(), } + _types = { + 'some_float_field': dbcore.types.Float(), + } @classmethod def _getters(cls): @@ -242,6 +245,11 @@ class ModelTest(_common.TestCase): model.foo = None self.assertEqual(model.foo, None) + def test_normalization_for_typed_flex_fields(self): + model = TestModel1() + model.some_float_field = None + self.assertEqual(model.some_float_field, 0.0) + class FormatTest(_common.TestCase): def test_format_fixed_field(self): @@ -268,6 +276,12 @@ class FormatTest(_common.TestCase): value = model._get_formatted('other_field') self.assertEqual(value, u'') + def test_format_typed_flex_field(self): + model = TestModel1() + model.some_float_field = 3.14159265358979 + value = model._get_formatted('some_float_field') + self.assertEqual(value, u'3.1') + class FormattedMappingTest(_common.TestCase): def test_keys_equal_model_keys(self): @@ -295,8 +309,14 @@ class FormattedMappingTest(_common.TestCase): class ParseTest(_common.TestCase): def test_parse_fixed_field(self): value = TestModel1._parse('field_one', u'2') + self.assertIsInstance(value, int) self.assertEqual(value, 2) + def test_parse_flex_field(self): + value = TestModel1._parse('some_float_field', u'2') + self.assertIsInstance(value, float) + self.assertEqual(value, 2.0) + def test_parse_untyped_field(self): value = TestModel1._parse('field_nine', u'2') self.assertEqual(value, u'2') @@ -383,6 +403,14 @@ class QueryFromStringsTest(_common.TestCase): self.assertIsInstance(q.subqueries[0], dbcore.query.AnyFieldQuery) self.assertIsInstance(q.subqueries[1], dbcore.query.SubstringQuery) + def test_parse_fixed_type_query(self): + q = self.qfs(['field_one:2..3']) + self.assertIsInstance(q.subqueries[0], dbcore.query.NumericQuery) + + def test_parse_flex_type_query(self): + q = self.qfs(['some_float_field:2..3']) + self.assertIsInstance(q.subqueries[0], dbcore.query.NumericQuery) + def suite(): return unittest.TestLoader().loadTestsFromName(__name__)