add keys() method to Distance

This commit is contained in:
Adrian Sampson 2013-06-10 15:40:51 -07:00
parent 33ff001d0a
commit 7983c94ef8
4 changed files with 13 additions and 11 deletions

View file

@ -267,7 +267,7 @@ class Distance(object):
@property
def distance(self):
"""Returns a weighted and normalised distance across all
"""Return a weighted and normalized distance across all
penalties.
"""
dist_max = self.max_distance
@ -277,7 +277,7 @@ class Distance(object):
@property
def max_distance(self):
"""Returns the maximum distance penalty.
"""Return the maximum distance penalty (normalization factor).
"""
dist_max = 0.0
for key, penalty in self._penalties.iteritems():
@ -286,16 +286,15 @@ class Distance(object):
@property
def raw_distance(self):
"""Returns the raw (denormalized) distance.
"""Return the raw (denormalized) distance.
"""
dist_raw = 0.0
for key, penalty in self._penalties.iteritems():
dist_raw += sum(penalty) * self._weights[key]
return dist_raw
@property
def items(self):
"""Returns a list of (key, dist) pairs, with `dist` being the
"""Return a list of (key, dist) pairs, with `dist` being the
weighted distance, sorted from highest to lowest. Does not
include penalties with a zero value.
"""
@ -336,10 +335,13 @@ class Distance(object):
return 0.0
def __iter__(self):
return iter(self.items)
return iter(self.items())
def __len__(self):
return len(self.items)
return len(self.items())
def keys(self):
return [key for key, _ in self.items()]
def update(self, dist):
"""Adds all the distance penalties from `dist`.

View file

@ -291,10 +291,10 @@ def _recommendation(results):
# Downgrade to the max rec if it is lower than the current rec for an
# applied penalty.
keys = set(key for key, _ in min_dist)
keys = set(min_dist.keys())
if isinstance(results[0], hooks.AlbumMatch):
for track_dist in min_dist.tracks.values():
keys.update(key for key, _ in track_dist)
keys.update(track_dist.keys())
for key in keys:
max_rec = config['match']['max_rec'][key].as_choice({
'strong': recommendation.strong,

View file

@ -168,7 +168,7 @@ def penalty_string(distance, limit=None):
applied to a distance object.
"""
penalties = []
for key, _ in distance:
for key in distance.keys():
key = key.replace('album_', '')
key = key.replace('track_', '')
key = key.replace('_', ' ')

View file

@ -239,7 +239,7 @@ class DistanceTest(_common.TestCase):
dist = Distance()
dist.add('album', 0.1875)
dist.add('medium', 0.75)
self.assertEqual(dist.items, [('medium', 0.25), ('album', 0.125)])
self.assertEqual(dist.items(), [('medium', 0.25), ('album', 0.125)])
# Sort by key if distance is equal.
dist = Distance()