mirror of
https://github.com/beetbox/beets.git
synced 2025-12-27 11:02:43 +01:00
Add Distance.__iter__() and Distance.__len__(), for convenience.
This commit is contained in:
parent
e92b8bb8fb
commit
ea1becfea1
3 changed files with 30 additions and 4 deletions
|
|
@ -210,6 +210,12 @@ class Distance(object):
|
|||
def __init__(self):
|
||||
self._penalties = {}
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.sorted)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sorted)
|
||||
|
||||
def __sub__(self, other):
|
||||
return self.distance - other
|
||||
|
||||
|
|
@ -344,6 +350,9 @@ class Distance(object):
|
|||
dist = self[key]
|
||||
if dist:
|
||||
list_.append((dist, key))
|
||||
# Convert distance into a negative float we can sort items in ascending
|
||||
# order (for keys, when the penalty is equal) and still get the items
|
||||
# with the biggest distance first.
|
||||
return sorted(list_, key=lambda (dist, key): (0-dist, key))
|
||||
|
||||
def update(self, dist):
|
||||
|
|
@ -545,10 +554,10 @@ def _recommendation(results):
|
|||
|
||||
# Downgrade to the max rec if it is lower than the current rec for an
|
||||
# applied penalty.
|
||||
keys = set(key for _, key in min_dist.sorted)
|
||||
keys = set(key for _, key in min_dist)
|
||||
if isinstance(results[0], hooks.AlbumMatch):
|
||||
for track_dist in min_dist.tracks.values():
|
||||
keys.update(key for _, key in track_dist.sorted)
|
||||
keys.update(key for _, key in track_dist)
|
||||
for key in keys:
|
||||
max_rec = config['match']['max_rec'][key].as_choice({
|
||||
'strong': recommendation.strong,
|
||||
|
|
@ -580,7 +589,7 @@ def _add_candidate(items, results, info):
|
|||
dist = distance(items, info, mapping)
|
||||
|
||||
# Skip matches with ignored penalties.
|
||||
penalties = [key for _, key in dist.sorted]
|
||||
penalties = [key for _, key in dist]
|
||||
for penalty in config['match']['ignored'].as_str_seq():
|
||||
if penalty in penalties:
|
||||
log.debug('Ignored. Penalty: %s' % penalty)
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ def penalty_string(distance, limit=None):
|
|||
a distance object.
|
||||
"""
|
||||
penalties = []
|
||||
for _, key in distance.sorted:
|
||||
for _, key in distance:
|
||||
key = key.replace('album_', '')
|
||||
key = key.replace('track_', '')
|
||||
key = key.replace('_', ' ')
|
||||
|
|
|
|||
|
|
@ -200,6 +200,23 @@ class DistanceTest(unittest.TestCase):
|
|||
self.dist.add('medium', 0.0)
|
||||
self.assertEqual(self.dist.max_distance, 5.0)
|
||||
|
||||
def test_operators(self):
|
||||
config['match']['distance_weights']['source'] = 1.0
|
||||
config['match']['distance_weights']['album'] = 2.0
|
||||
config['match']['distance_weights']['medium'] = 1.0
|
||||
self.dist.add('source', 0.0)
|
||||
self.dist.add('album', 0.5)
|
||||
self.dist.add('medium', 0.25)
|
||||
self.dist.add('medium', 0.75)
|
||||
self.assertEqual(len(self.dist), 2)
|
||||
self.assertEqual(list(self.dist), [(0.2, 'album'), (0.2, 'medium')])
|
||||
self.assertTrue(self.dist == 0.4)
|
||||
self.assertTrue(self.dist < 1.0)
|
||||
self.assertTrue(self.dist > 0.0)
|
||||
self.assertEqual(self.dist - 0.4, 0.0)
|
||||
self.assertEqual(0.4 - self.dist, 0.0)
|
||||
self.assertEqual(float(self.dist), 0.4)
|
||||
|
||||
def test_raw_distance(self):
|
||||
config['match']['distance_weights']['album'] = 3.0
|
||||
config['match']['distance_weights']['medium'] = 1.0
|
||||
|
|
|
|||
Loading…
Reference in a new issue