diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index cbdaf5a7f..a5f29e0ed 100644 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -152,8 +152,15 @@ class Model(object): def __setitem__(self, key, value): """Assign the value for a field. """ - source = self._values_fixed if key in self._fields \ - else self._values_flex + # Choose where to place the value. If the corresponding field + # has a type, filter the value. + if key in self._fields: + source = self._values_fixed + value = self._fields[key].normalize(value) + else: + source = self._values_flex + + # Assign value and possibly mark as dirty. old_value = source.get(key) source[key] = value if old_value != value: diff --git a/beets/dbcore/types.py b/beets/dbcore/types.py index 165c0b601..ccd260c0d 100644 --- a/beets/dbcore/types.py +++ b/beets/dbcore/types.py @@ -24,8 +24,8 @@ from beets.util import str2bool class Type(object): """An object encapsulating the type of a model field. Includes - information about how to store the value in the database, query, - format, and parse a given field. + information about how to store, query, format, and parse a given + field. """ sql = None @@ -36,6 +36,10 @@ class Type(object): """The `Query` subclass to be used when querying the field. """ + null = None + """The value to be exposed when the underlying value is None. + """ + def format(self, value): """Given a value of this type, produce a Unicode string representing the value. This is used in template evaluation. @@ -48,6 +52,16 @@ class Type(object): """ raise NotImplementedError() + def normalize(self, value): + """Given a value that will be assigned into a field of this + type, normalize the value to have the appropriate type. This + base implementation only reinterprets `None`. + """ + if value is None: + return self.null + else: + return value + # Reusable types. @@ -58,6 +72,7 @@ class Integer(Type): """ sql = u'INTEGER' query = query.NumericQuery + null = 0 def format(self, value): return unicode(value or 0) @@ -93,9 +108,14 @@ class ScaledInt(Integer): class Id(Integer): - """An integer used as the row key for a SQLite table. + """An integer used as the row id or a foreign key in a SQLite table. + This type is nullable: None values are not translated to zero. """ - sql = u'INTEGER PRIMARY KEY' + null = None + + def __init__(self, primary=True): + if primary: + self.sql = u'INTEGER PRIMARY KEY' class Float(Type): @@ -103,6 +123,7 @@ class Float(Type): """ sql = u'REAL' query = query.NumericQuery + null = 0.0 def format(self, value): return u'{0:.1f}'.format(value or 0.0) @@ -119,6 +140,7 @@ class String(Type): """ sql = u'TEXT' query = query.SubstringQuery + null = u'' def format(self, value): return unicode(value) if value else u'' @@ -132,6 +154,7 @@ class Boolean(Type): """ sql = u'INTEGER' query = query.BooleanQuery + null = False def format(self, value): return unicode(bool(value)) diff --git a/beets/library.py b/beets/library.py index fe4d9e0b5..349f8459a 100644 --- a/beets/library.py +++ b/beets/library.py @@ -79,6 +79,7 @@ class SingletonQuery(dbcore.Query): class DateType(types.Type): sql = u'REAL' query = dbcore.query.DateQuery + null = 0.0 def format(self, value): return time.strftime(beets.config['time_format'].get(unicode), @@ -95,7 +96,7 @@ class DateType(types.Type): try: return float(string) except ValueError: - return 0.0 + return self.null class PathType(types.Type): @@ -122,9 +123,9 @@ class PathType(types.Type): # - Is the field writable? # - Does the field reflect an attribute of a MediaFile? ITEM_FIELDS = [ - ('id', types.Id(), False, False), + ('id', types.Id(True), False, False), ('path', PathType(), False, False), - ('album_id', types.Integer(), False, False), + ('album_id', types.Id(False), False, False), ('title', types.String(), True, True), ('artist', types.String(), True, True), @@ -192,9 +193,9 @@ ITEM_KEYS = [f[0] for f in ITEM_FIELDS] # The third entry in each tuple indicates whether the field reflects an # identically-named field in the items table. ALBUM_FIELDS = [ - ('id', types.Id(), False), - ('artpath', PathType(), False), - ('added', DateType(), True), + ('id', types.Id(True), False), + ('artpath', PathType(), False), + ('added', DateType(), True), ('albumartist', types.String(), True), ('albumartist_sort', types.String(), True), diff --git a/test/test_dbcore.py b/test/test_dbcore.py index 85fb94f87..2f4af7472 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -219,6 +219,16 @@ class ModelTest(_common.TestCase): with self.assertRaises(KeyError): del model['field_one'] + def test_null_value_normalization_by_type(self): + model = TestModel1() + model.field_one = None + self.assertEqual(model.field_one, 0) + + def test_null_value_stays_none_for_untyped_field(self): + model = TestModel1() + model.foo = None + self.assertEqual(model.foo, None) + class FormatTest(_common.TestCase): def test_format_fixed_field(self):