diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index b0c29e84c..43c044572 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -91,6 +91,100 @@ class FormattedMapping(Mapping): return value +class LazyConvertDict(object): + """Lazily convert types for attributes fetched from the database + """ + + def __init__(self, model_cls): + """Initialize the object empty + """ + self.data = {} + self.model_cls = model_cls + self._converted = {} + + def init(self, data): + """Set the base data that should be lazily converted + """ + self.data = data + + def _convert(self, key, value): + """Convert the attribute type according the the SQL type + """ + return self.model_cls._type(key).from_sql(value) + + def __setitem__(self, key, value): + """Set an attribute value, assume it's already converted + """ + self._converted[key] = value + + def __getitem__(self, key): + """Get an attribute value, converting the type on demand + if needed + """ + if key in self._converted: + return self._converted[key] + elif key in self.data: + value = self._convert(key, self.data[key]) + self._converted[key] = value + return value + + def __delitem__(self, key): + """Delete both converted and base data + """ + if key in self._converted: + del self._converted[key] + if key in self.data: + del self.data[key] + + def keys(self): + """Get a list of available field names for this object. + """ + return list(self._converted.keys()) + list(self.data.keys()) + + def copy(self): + """Create a copy of the object. + """ + new = self.__class__(self.model_cls) + new.data = self.data.copy() + new._converted = self._converted.copy() + return new + + # Act like a dictionary. + + def update(self, values): + """Assign all values in the given dict. + """ + for key, value in values.items(): + self[key] = value + + def items(self): + """Iterate over (key, value) pairs that this object contains. + Computed fields are not included. + """ + for key in self: + yield key, self[key] + + def get(self, key, default=None): + """Get the value for a given key or `default` if it does not + exist. + """ + if key in self: + return self[key] + else: + return default + + def __contains__(self, key): + """Determine whether `key` is an attribute on this object. + """ + return key in self.keys() + + def __iter__(self): + """Iterate over the available field names (excluding computed + fields). + """ + return iter(self.keys()) + + # Abstract base for model classes. class Model(object): @@ -180,8 +274,8 @@ class Model(object): """ self._db = db self._dirty = set() - self._values_fixed = {} - self._values_flex = {} + self._values_fixed = LazyConvertDict(self) + self._values_flex = LazyConvertDict(self) # Initial contents. self.update(values) @@ -195,10 +289,10 @@ class Model(object): ordinary construction are bypassed. """ obj = cls(db) - for key, value in fixed_values.items(): - obj._values_fixed[key] = cls._type(key).from_sql(value) - for key, value in flex_values.items(): - obj._values_flex[key] = cls._type(key).from_sql(value) + + obj._values_fixed.init(fixed_values) + obj._values_flex.init(flex_values) + return obj def __repr__(self): @@ -259,7 +353,10 @@ class Model(object): if key in getters: # Computed. return getters[key](self) elif key in self._fields: # Fixed. - return self._values_fixed.get(key, self._type(key).null) + if key in self._values_fixed: + return self._values_fixed[key] + else: + return self._type(key).null elif key in self._values_flex: # Flexible. return self._values_flex[key] else: @@ -439,8 +536,8 @@ class Model(object): self._check_db() stored_obj = self._db._get(type(self), self.id) assert stored_obj is not None, u"object {0} not in DB".format(self.id) - self._values_fixed = {} - self._values_flex = {} + self._values_fixed = LazyConvertDict(self) + self._values_flex = LazyConvertDict(self) self.update(dict(stored_obj)) self.clear_dirty()