diff --git a/beets/mediafile.py b/beets/mediafile.py index 12281f1f5..73e17d65e 100644 --- a/beets/mediafile.py +++ b/beets/mediafile.py @@ -94,7 +94,7 @@ def _safe_cast(out_type, val): if not isinstance(val, basestring): val = unicode(val) # Get a number from the front of the string. - val = re.match('[0-9]*', val.strip()).group(0) + val = re.match(r'[0-9]*', val.strip()).group(0) if not val: return 0 else: @@ -116,6 +116,20 @@ def _safe_cast(out_type, val): else: return unicode(val) + elif out_type == float: + if val is None: + return 0.0 + elif isinstance(val, int) or isinstance(val, float): + return float(val) + else: + if not isinstance(val, basestring): + val = unicode(val) + val = re.match(r'[\+-]?[0-9\.]*', val.strip()).group(0) + if not val: + return 0.0 + else: + return float(val) + else: return val @@ -612,6 +626,30 @@ class ImageField(object): base64.b64encode(pic.write()) ] +class FloatValueField(MediaField): + """A field that stores a floating-point number as a string.""" + def __init__(self, places=2, suffix=None, **kwargs): + """Make a field that stores ``places`` digits after the decimal + point and appends ``suffix`` (if specified) when encoding as a + string. + """ + super(FloatValueField, self).__init__(unicode, **kwargs) + + fmt = ['%.', str(places), 'f'] + if suffix: + fmt += [' ', suffix] + self.fmt = ''.join(fmt) + + def __get__(self, obj, owner): + valstr = super(FloatValueField, self).__get__(obj, owner) + return _safe_cast(float, valstr) + + def __set__(self, obj, val): + if not val: + val = 0.0 + valstr = self.fmt % val + super(FloatValueField, self).__set__(obj, valstr) + # The file (a collection of fields). @@ -865,6 +903,32 @@ class MediaFile(object): etc = StorageStyle('musicbrainz_albumartistid') ) + # ReplayGain fields. + rg_track_gain = FloatValueField(2, 'dB', + mp3 = StorageStyle('TXXX', + id3_desc=u'REPLAYGAIN_TRACK_GAIN'), + mp4 = None, + etc = StorageStyle(u'REPLAYGAIN_TRACK_GAIN') + ) + rg_album_gain = FloatValueField(2, 'dB', + mp3 = StorageStyle('TXXX', + id3_desc=u'REPLAYGAIN_ALBUM_GAIN'), + mp4 = None, + etc = StorageStyle(u'REPLAYGAIN_ALBUM_GAIN') + ) + rg_track_peak = FloatValueField(6, None, + mp3 = StorageStyle('TXXX', + id3_desc=u'REPLAYGAIN_TRACK_PEAK'), + mp4 = None, + etc = StorageStyle(u'REPLAYGAIN_TRACK_PEAK') + ) + rg_album_peak = FloatValueField(6, None, + mp3 = StorageStyle('TXXX', + id3_desc=u'REPLAYGAIN_ALBUM_PEAK'), + mp4 = None, + etc = StorageStyle(u'REPLAYGAIN_ALBUM_PEAK') + ) + @property def length(self): return self.mgfile.info.length diff --git a/test/test_mediafile.py b/test/test_mediafile.py index 5fa7cdd62..a75d6b2d2 100644 --- a/test/test_mediafile.py +++ b/test/test_mediafile.py @@ -94,6 +94,18 @@ class InvalidValueToleranceTest(unittest.TestCase): def test_safe_cast_intstring_to_bool(self): self.assertEqual(_sc(bool, '5'), True) + def test_safe_cast_string_to_float(self): + self.assertAlmostEqual(_sc(float, '1.234'), 1.234) + + def test_safe_cast_int_to_float(self): + self.assertAlmostEqual(_sc(float, 2), 2.0) + + def test_safe_cast_string_with_cruft_to_float(self): + self.assertAlmostEqual(_sc(float, '1.234stuff'), 1.234) + + def test_safe_cast_negative_string_to_float(self): + self.assertAlmostEqual(_sc(float, '-1.234'), -1.234) + class SafetyTest(unittest.TestCase): def _exccheck(self, fn, exc, data=''): fn = os.path.join(_common.RSRC, fn) diff --git a/test/test_mediafile_basic.py b/test/test_mediafile_basic.py index 3e622fd8f..839964423 100644 --- a/test/test_mediafile_basic.py +++ b/test/test_mediafile_basic.py @@ -32,9 +32,12 @@ def MakeReadingTest(path, correct_dict, field): def runTest(self): got = getattr(self.f, field) correct = correct_dict[field] - self.assertEqual(got, correct, - field + ' incorrect (expected ' + repr(correct) + ', got ' + \ - repr(got) + ') when testing ' + os.path.basename(path)) + if isinstance(correct, float): + self.assertAlmostEqual(got, correct) + else: + self.assertEqual(got, correct, + field + ' incorrect (expected ' + repr(correct) + ', got ' + + repr(got) + ') when testing ' + os.path.basename(path)) return ReadingTest def MakeReadOnlyTest(path, field, value): @@ -74,6 +77,8 @@ def MakeWritingTest(path, correct_dict, field, testsuffix='_test'): self.value = correct_dict[field] + datetime.timedelta(42) elif type(correct_dict[field]) is str: self.value = 'TestValue-' + str(field) + elif type(correct_dict[field]) is float: + self.value = 9.87 else: raise ValueError('unknown field type ' + \ str(type(correct_dict[field]))) @@ -91,10 +96,13 @@ def MakeWritingTest(path, correct_dict, field, testsuffix='_test'): # Make sure the modified field was changed correctly... if readfield == field: - self.assertEqual(got, self.value, - field + ' modified incorrectly (changed to ' + \ - repr(self.value) + ' but read ' + repr(got) + \ - ') when testing ' + os.path.basename(path)) + if isinstance(self.value, float): + self.assertAlmostEqual(got, self.value) + else: + self.assertEqual(got, self.value, + field + ' modified incorrectly (changed to ' + \ + repr(self.value) + ' but read ' + repr(got) + \ + ') when testing ' + os.path.basename(path)) # ... and that no other field was changed. else: @@ -117,11 +125,14 @@ def MakeWritingTest(path, correct_dict, field, testsuffix='_test'): if field=='date' and readfield in ('year', 'month', 'day'): correct = getattr(self.value, readfield) - self.assertEqual(got, correct, - readfield + ' changed when it should not have' - ' (expected ' + repr(correct) + ', got ' + \ - repr(got) + ') when modifying ' + field + ' in ' + \ - os.path.basename(path)) + if isinstance(correct, float): + self.assertAlmostEqual(got, correct) + else: + self.assertEqual(got, correct, + readfield + ' changed when it should not have' + ' (expected ' + repr(correct) + ', got ' + \ + repr(got) + ') when modifying ' + field + ' in ' + \ + os.path.basename(path)) def tearDown(self): if os.path.exists(self.tpath): @@ -156,6 +167,11 @@ correct_dicts = { 'mb_artistid':'7cf0ea9d-86b9-4dad-ba9e-2355a64899ea', 'art': None, 'label': u'the label', + + 'rg_track_peak': 1.23, + 'rg_track_gain': 1.34, + 'rg_album_peak': 1.45, + 'rg_album_gain': 1.56, }, # Additional coverage for common cases when "total" fields are unset.