Add Distance.__iter__() and Distance.__len__(), for convenience.

This commit is contained in:
Tai Lee 2013-06-06 09:51:17 +10:00
parent e92b8bb8fb
commit ea1becfea1
3 changed files with 30 additions and 4 deletions

View file

@ -210,6 +210,12 @@ class Distance(object):
def __init__(self):
self._penalties = {}
def __iter__(self):
return iter(self.sorted)
def __len__(self):
return len(self.sorted)
def __sub__(self, other):
return self.distance - other
@ -344,6 +350,9 @@ class Distance(object):
dist = self[key]
if dist:
list_.append((dist, key))
# Convert distance into a negative float we can sort items in ascending
# order (for keys, when the penalty is equal) and still get the items
# with the biggest distance first.
return sorted(list_, key=lambda (dist, key): (0-dist, key))
def update(self, dist):
@ -545,10 +554,10 @@ def _recommendation(results):
# Downgrade to the max rec if it is lower than the current rec for an
# applied penalty.
keys = set(key for _, key in min_dist.sorted)
keys = set(key for _, key in min_dist)
if isinstance(results[0], hooks.AlbumMatch):
for track_dist in min_dist.tracks.values():
keys.update(key for _, key in track_dist.sorted)
keys.update(key for _, key in track_dist)
for key in keys:
max_rec = config['match']['max_rec'][key].as_choice({
'strong': recommendation.strong,
@ -580,7 +589,7 @@ def _add_candidate(items, results, info):
dist = distance(items, info, mapping)
# Skip matches with ignored penalties.
penalties = [key for _, key in dist.sorted]
penalties = [key for _, key in dist]
for penalty in config['match']['ignored'].as_str_seq():
if penalty in penalties:
log.debug('Ignored. Penalty: %s' % penalty)

View file

@ -168,7 +168,7 @@ def penalty_string(distance, limit=None):
a distance object.
"""
penalties = []
for _, key in distance.sorted:
for _, key in distance:
key = key.replace('album_', '')
key = key.replace('track_', '')
key = key.replace('_', ' ')

View file

@ -200,6 +200,23 @@ class DistanceTest(unittest.TestCase):
self.dist.add('medium', 0.0)
self.assertEqual(self.dist.max_distance, 5.0)
def test_operators(self):
config['match']['distance_weights']['source'] = 1.0
config['match']['distance_weights']['album'] = 2.0
config['match']['distance_weights']['medium'] = 1.0
self.dist.add('source', 0.0)
self.dist.add('album', 0.5)
self.dist.add('medium', 0.25)
self.dist.add('medium', 0.75)
self.assertEqual(len(self.dist), 2)
self.assertEqual(list(self.dist), [(0.2, 'album'), (0.2, 'medium')])
self.assertTrue(self.dist == 0.4)
self.assertTrue(self.dist < 1.0)
self.assertTrue(self.dist > 0.0)
self.assertEqual(self.dist - 0.4, 0.0)
self.assertEqual(0.4 - self.dist, 0.0)
self.assertEqual(float(self.dist), 0.4)
def test_raw_distance(self):
config['match']['distance_weights']['album'] = 3.0
config['match']['distance_weights']['medium'] = 1.0