move Distance class to hooks module

This commit is contained in:
Adrian Sampson 2013-06-10 15:21:32 -07:00
parent d1ebe423c9
commit f6faf72328
7 changed files with 379 additions and 345 deletions

View file

@ -15,9 +15,13 @@
"""Glue between metadata sources and the matching logic."""
import logging
from collections import namedtuple
import re
from beets import plugins
from beets import config
from beets.autotag import mb
from beets.util import levenshtein
from unidecode import unidecode
log = logging.getLogger('beets')
@ -158,6 +162,294 @@ class TrackInfo(object):
if isinstance(value, str):
setattr(self, fld, value.decode(codec, 'ignore'))
# Candidate distance scoring.
# Parameters for string distance function.
# Words that can be moved to the end of a string using a comma.
SD_END_WORDS = ['the', 'a', 'an']
# Reduced weights for certain portions of the string.
SD_PATTERNS = [
(r'^the ', 0.1),
(r'[\[\(]?(ep|single)[\]\)]?', 0.0),
(r'[\[\(]?(featuring|feat|ft)[\. :].+', 0.1),
(r'\(.*?\)', 0.3),
(r'\[.*?\]', 0.3),
(r'(, )?(pt\.|part) .+', 0.2),
]
# Replacements to use before testing distance.
SD_REPLACE = [
(r'&', 'and'),
]
def _string_dist_basic(str1, str2):
"""Basic edit distance between two strings, ignoring
non-alphanumeric characters and case. Comparisons are based on a
transliteration/lowering to ASCII characters. Normalized by string
length.
"""
str1 = unidecode(str1)
str2 = unidecode(str2)
str1 = re.sub(r'[^a-z0-9]', '', str1.lower())
str2 = re.sub(r'[^a-z0-9]', '', str2.lower())
if not str1 and not str2:
return 0.0
return levenshtein(str1, str2) / float(max(len(str1), len(str2)))
def string_dist(str1, str2):
"""Gives an "intuitive" edit distance between two strings. This is
an edit distance, normalized by the string length, with a number of
tweaks that reflect intuition about text.
"""
str1 = str1.lower()
str2 = str2.lower()
# Don't penalize strings that move certain words to the end. For
# example, "the something" should be considered equal to
# "something, the".
for word in SD_END_WORDS:
if str1.endswith(', %s' % word):
str1 = '%s %s' % (word, str1[:-len(word)-2])
if str2.endswith(', %s' % word):
str2 = '%s %s' % (word, str2[:-len(word)-2])
# Perform a couple of basic normalizing substitutions.
for pat, repl in SD_REPLACE:
str1 = re.sub(pat, repl, str1)
str2 = re.sub(pat, repl, str2)
# Change the weight for certain string portions matched by a set
# of regular expressions. We gradually change the strings and build
# up penalties associated with parts of the string that were
# deleted.
base_dist = _string_dist_basic(str1, str2)
penalty = 0.0
for pat, weight in SD_PATTERNS:
# Get strings that drop the pattern.
case_str1 = re.sub(pat, '', str1)
case_str2 = re.sub(pat, '', str2)
if case_str1 != str1 or case_str2 != str2:
# If the pattern was present (i.e., it is deleted in the
# the current case), recalculate the distances for the
# modified strings.
case_dist = _string_dist_basic(case_str1, case_str2)
case_delta = max(0.0, base_dist - case_dist)
if case_delta == 0.0:
continue
# Shift our baseline strings down (to avoid rematching the
# same part of the string) and add a scaled distance
# amount to the penalties.
str1 = case_str1
str2 = case_str2
base_dist = case_dist
penalty += weight * case_delta
dist = base_dist + penalty
return dist
class Distance(object):
"""Keeps track of multiple distance penalties. Provides a single
weighted distance for all penalties as well as a weighted distance
for each individual penalty.
"""
def __init__(self):
self._penalties = {}
weights_view = config['match']['distance_weights']
self._weights = {}
for key in weights_view.keys():
self._weights[key] = weights_view[key].as_number()
# Access the components and their aggregates.
@property
def distance(self):
"""Returns a weighted and normalised distance across all
penalties.
"""
dist_max = self.max_distance
if dist_max:
return self.raw_distance / self.max_distance
return 0.0
@property
def max_distance(self):
"""Returns the maximum distance penalty.
"""
dist_max = 0.0
for key, penalty in self._penalties.iteritems():
dist_max += len(penalty) * self._weights[key]
return dist_max
@property
def raw_distance(self):
"""Returns the raw (denormalized) distance.
"""
dist_raw = 0.0
for key, penalty in self._penalties.iteritems():
dist_raw += sum(penalty) * self._weights[key]
return dist_raw
@property
def sorted(self):
"""Returns a list of (dist, key) pairs, with `dist` being the
weighted distance, sorted from highest to lowest. Does not
include penalties with a zero value.
"""
list_ = []
for key in self._penalties:
dist = self[key]
if dist:
list_.append((dist, key))
# Convert distance into a negative float we can sort items in ascending
# order (for keys, when the penalty is equal) and still get the items
# with the biggest distance first.
return sorted(list_, key=lambda (dist, key): (0-dist, key))
# Behave like a float.
def __cmp__(self, other):
return cmp(self.distance, other)
def __float__(self):
return self.distance
def __sub__(self, other):
return self.distance - other
def __rsub__(self, other):
return other - self.distance
# Behave like a dict.
def __getitem__(self, key):
"""Returns the weighted distance for a named penalty.
"""
dist = sum(self._penalties[key]) * self._weights[key]
dist_max = self.max_distance
if dist_max:
return dist / dist_max
return 0.0
def __iter__(self):
return iter(self.sorted)
def __len__(self):
return len(self.sorted)
def update(self, dist):
"""Adds all the distance penalties from `dist`.
"""
if not isinstance(dist, Distance):
raise ValueError(
'`dist` must be a Distance object. It is: %r' % dist)
for key, penalties in dist._penalties.iteritems():
self._penalties.setdefault(key, []).extend(penalties)
# Adding components.
def _eq(self, value1, value2):
"""Returns True if `value1` is equal to `value2`. `value1` may
be a compiled regular expression, in which case it will be
matched against `value2`.
"""
if isinstance(value1, re._pattern_type):
return bool(value1.match(value2))
return value1 == value2
def add(self, key, dist):
"""Adds a distance penalty. `key` must correspond with a
configured weight setting. `dist` must be a float between 0.0
and 1.0, and will be added to any existing distance penalties
for the same key.
"""
if not 0.0 <= dist <= 1.0:
raise ValueError(
'`dist` must be between 0.0 and 1.0. It is: %r' % dist)
self._penalties.setdefault(key, []).append(dist)
def add_equality(self, key, value, options):
"""Adds a distance penalty of 1.0 if `value` doesn't match any
of the values in `options`. If an option is a compiled regular
expression, it will be considered equal if it matches against
`value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
for opt in options:
if self._eq(opt, value):
dist = 0.0
break
else:
dist = 1.0
self.add(key, dist)
def add_expr(self, key, expr):
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
or 0.0.
"""
if expr:
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_number(self, key, number1, number2):
"""Adds a distance penalty of 1.0 for each number of difference
between `number1` and `number2`, or 0.0 when there is no
difference. Use this when there is no upper limit on the
difference between the two numbers.
"""
diff = abs(number1 - number2)
if diff:
for i in range(diff):
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_priority(self, key, value, options):
"""Adds a distance penalty that corresponds to the position at
which `value` appears in `options`. A distance penalty of 0.0
for the first option, or 1.0 if there is no matching option. If
an option is a compiled regular expression, it will be
considered equal if it matches against `value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
unit = 1.0 / (len(options) or 1)
for i, opt in enumerate(options):
if self._eq(opt, value):
dist = i * unit
break
else:
dist = 1.0
self.add(key, dist)
def add_ratio(self, key, number1, number2):
"""Adds a distance penalty for `number1` as a ratio of `number2`.
`number1` is bound at 0 and `number2`.
"""
number = float(max(min(number1, number2), 0))
if number2:
dist = number / number2
else:
dist = 0.0
self.add(key, dist)
def add_string(self, key, str1, str2):
"""Adds a distance penalty based on the edit distance between
`str1` and `str2`.
"""
dist = string_dist(str1, str2)
self.add(key, dist)
# Structures that compose all the information for a candidate match.
AlbumMatch = namedtuple('AlbumMatch', ['distance', 'info', 'mapping',
'extra_items', 'extra_tracks'])

View file

@ -21,34 +21,16 @@ import datetime
import logging
import re
from munkres import Munkres
from unidecode import unidecode
from beets import plugins
from beets import config
from beets.util import levenshtein, plurality
from beets.util import plurality
from beets.util.enumeration import enum
from beets.autotag import hooks
# A configuration view for the distance weights.
weights = config['match']['distance_weights']
# Parameters for string distance function.
# Words that can be moved to the end of a string using a comma.
SD_END_WORDS = ['the', 'a', 'an']
# Reduced weights for certain portions of the string.
SD_PATTERNS = [
(r'^the ', 0.1),
(r'[\[\(]?(ep|single)[\]\)]?', 0.0),
(r'[\[\(]?(featuring|feat|ft)[\. :].+', 0.1),
(r'\(.*?\)', 0.3),
(r'\[.*?\]', 0.3),
(r'(, )?(pt\.|part) .+', 0.2),
]
# Replacements to use before testing distance.
SD_REPLACE = [
(r'&', 'and'),
]
# Recommendation enumeration.
recommendation = enum('none', 'low', 'medium', 'strong', name='recommendation')
@ -64,73 +46,6 @@ log = logging.getLogger('beets')
# Primary matching functionality.
def _string_dist_basic(str1, str2):
"""Basic edit distance between two strings, ignoring
non-alphanumeric characters and case. Comparisons are based on a
transliteration/lowering to ASCII characters. Normalized by string
length.
"""
str1 = unidecode(str1)
str2 = unidecode(str2)
str1 = re.sub(r'[^a-z0-9]', '', str1.lower())
str2 = re.sub(r'[^a-z0-9]', '', str2.lower())
if not str1 and not str2:
return 0.0
return levenshtein(str1, str2) / float(max(len(str1), len(str2)))
def string_dist(str1, str2):
"""Gives an "intuitive" edit distance between two strings. This is
an edit distance, normalized by the string length, with a number of
tweaks that reflect intuition about text.
"""
str1 = str1.lower()
str2 = str2.lower()
# Don't penalize strings that move certain words to the end. For
# example, "the something" should be considered equal to
# "something, the".
for word in SD_END_WORDS:
if str1.endswith(', %s' % word):
str1 = '%s %s' % (word, str1[:-len(word)-2])
if str2.endswith(', %s' % word):
str2 = '%s %s' % (word, str2[:-len(word)-2])
# Perform a couple of basic normalizing substitutions.
for pat, repl in SD_REPLACE:
str1 = re.sub(pat, repl, str1)
str2 = re.sub(pat, repl, str2)
# Change the weight for certain string portions matched by a set
# of regular expressions. We gradually change the strings and build
# up penalties associated with parts of the string that were
# deleted.
base_dist = _string_dist_basic(str1, str2)
penalty = 0.0
for pat, weight in SD_PATTERNS:
# Get strings that drop the pattern.
case_str1 = re.sub(pat, '', str1)
case_str2 = re.sub(pat, '', str2)
if case_str1 != str1 or case_str2 != str2:
# If the pattern was present (i.e., it is deleted in the
# the current case), recalculate the distances for the
# modified strings.
case_dist = _string_dist_basic(case_str1, case_str2)
case_delta = max(0.0, base_dist - case_dist)
if case_delta == 0.0:
continue
# Shift our baseline strings down (to avoid rematching the
# same part of the string) and add a scaled distance
# amount to the penalties.
str1 = case_str1
str2 = case_str2
base_dist = case_dist
penalty += weight * case_delta
dist = base_dist + penalty
return dist
def current_metadata(items):
"""Extract the likely current metadata for an album given a list of its
items. Return two dictionaries:
@ -187,189 +102,12 @@ def track_index_changed(item, track_info):
"""
return item.track not in (track_info.medium_index, track_info.index)
class Distance(object):
"""Keeps track of multiple distance penalties. Provides a single weighted
distance for all penalties as well as a weighted distance for each
individual penalty.
"""
def __cmp__(self, other):
return cmp(self.distance, other)
def __float__(self):
return self.distance
def __getitem__(self, key):
"""Returns the weighted distance for a named penalty.
"""
dist = sum(self._penalties[key]) * weights[key].as_number()
dist_max = self.max_distance
if dist_max:
return dist / dist_max
return 0.0
def __init__(self):
self._penalties = {}
def __iter__(self):
return iter(self.sorted)
def __len__(self):
return len(self.sorted)
def __sub__(self, other):
return self.distance - other
def __rsub__(self, other):
return other - self.distance
def _eq(self, value1, value2):
"""Returns True if `value1` is equal to `value2`. `value1` may be a
compiled regular expression, in which case it will be matched against
`value2`.
"""
if isinstance(value1, re._pattern_type):
return bool(value1.match(value2))
return value1 == value2
def add(self, key, dist):
"""Adds a distance penalty. `key` must correspond with a configured
weight setting. `dist` must be a float between 0.0 and 1.0, and will be
added to any existing distance penalties for the same key.
"""
if not 0.0 <= dist <= 1.0:
raise ValueError(
'`dist` must be between 0.0 and 1.0. It is: %r' % dist)
self._penalties.setdefault(key, []).append(dist)
def add_equality(self, key, value, options):
"""Adds a distance penalty of 1.0 if `value` doesn't match any of the
values in `options`. If an option is a compiled regular expression, it
will be considered equal if it matches against `value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
for opt in options:
if self._eq(opt, value):
dist = 0.0
break
else:
dist = 1.0
self.add(key, dist)
def add_expr(self, key, expr):
"""Adds a distance penalty of 1.0 if `expr` evaluates to True, or 0.0.
"""
if expr:
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_number(self, key, number1, number2):
"""Adds a distance penalty of 1.0 for each number of difference between
`number1` and `number2`, or 0.0 when there is no difference. Use this
when there is no upper limit on the difference between the two numbers.
"""
diff = abs(number1 - number2)
if diff:
for i in range(diff):
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_priority(self, key, value, options):
"""Adds a distance penalty that corresponds to the position at which
`value` appears in `options`. A distance penalty of 0.0 for the first
option, or 1.0 if there is no matching option. If an option is a
compiled regular expression, it will be considered equal if it matches
against `value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
unit = 1.0 / (len(options) or 1)
for i, opt in enumerate(options):
if self._eq(opt, value):
dist = i * unit
break
else:
dist = 1.0
self.add(key, dist)
def add_ratio(self, key, number1, number2):
"""Adds a distance penalty for `number1` as a ratio of `number2`.
`number1` is bound at 0 and `number2`.
"""
number = float(max(min(number1, number2), 0))
if number2:
dist = number / number2
else:
dist = 0.0
self.add(key, dist)
def add_string(self, key, str1, str2):
"""Adds a distance penalty based on the edit distance between `str1`
and `str2`.
"""
dist = string_dist(str1, str2)
self.add(key, dist)
@property
def distance(self):
"""Returns a weighted and normalised distance across all penalties.
"""
dist_max = self.max_distance
if dist_max:
return self.raw_distance / self.max_distance
return 0.0
@property
def max_distance(self):
"""Returns the maximum distance penalty.
"""
dist_max = 0.0
for key, penalty in self._penalties.iteritems():
dist_max += len(penalty) * weights[key].as_number()
return dist_max
@property
def raw_distance(self):
"""Returns the raw (denormalised) distance.
"""
dist_raw = 0.0
for key, penalty in self._penalties.iteritems():
dist_raw += sum(penalty) * weights[key].as_number()
return dist_raw
@property
def sorted(self):
"""Returns a list of (dist, key) pairs, with `dist` being the weighted
distance, sorted from highest to lowest. Does not include penalties
with a zero value.
"""
list_ = []
for key in self._penalties:
dist = self[key]
if dist:
list_.append((dist, key))
# Convert distance into a negative float we can sort items in ascending
# order (for keys, when the penalty is equal) and still get the items
# with the biggest distance first.
return sorted(list_, key=lambda (dist, key): (0-dist, key))
def update(self, dist):
"""Adds all the distance penalties from `dist`.
"""
if not isinstance(dist, Distance):
raise ValueError(
'`dist` must be a Distance object. It is: %r' % dist)
for key, penalties in dist._penalties.iteritems():
self._penalties.setdefault(key, []).extend(penalties)
def track_distance(item, track_info, incl_artist=False):
"""Determines the significance of a track metadata change. Returns a
Distance object. `incl_artist` indicates that a distance component should
be included for the track artist (i.e., for various-artist releases).
"""
dist = Distance()
dist = hooks.Distance()
# Length.
if track_info.length:
@ -410,7 +148,7 @@ def distance(items, album_info, mapping):
"""
likelies, _ = current_metadata(items)
dist = Distance()
dist = hooks.Distance()
# Artist, if not various.
if not album_info.va:

View file

@ -67,13 +67,13 @@ class BeetsPlugin(object):
"""Should return a Distance object to be added to the
distance for every track comparison.
"""
return beets.autotag.match.Distance()
return beets.autotag.hooks.Distance()
def album_distance(self, items, album_info, mapping):
"""Should return a Distance object to be added to the
distance for every album-level comparison.
"""
return beets.autotag.match.Distance()
return beets.autotag.hooks.Distance()
def candidates(self, items, artist, album, va_likely):
"""Should return a sequence of AlbumInfo objects that match the
@ -244,14 +244,16 @@ def track_distance(item, info):
"""Gets the track distance calculated by all loaded plugins.
Returns a Distance object.
"""
dist = beets.autotag.match.Distance()
from beets.autotag.hooks import Distance
dist = Distance()
for plugin in find_plugins():
dist.update(plugin.track_distance(item, info))
return dist
def album_distance(items, album_info, mapping):
"""Returns the album distance calculated by plugins."""
dist = beets.autotag.match.Distance()
from beets.autotag.hooks import Distance
dist = Distance()
for plugin in find_plugins():
dist.update(plugin.album_distance(items, album_info, mapping))
return dist

View file

@ -20,8 +20,7 @@ from datetime import datetime, timedelta
import requests
from beets.autotag.hooks import AlbumInfo, TrackInfo
from beets.autotag.match import Distance
from beets.autotag.hooks import AlbumInfo, TrackInfo, Distance
from beets.plugins import BeetsPlugin
log = logging.getLogger('beets')

View file

@ -21,7 +21,6 @@ from beets import util
from beets import config
from beets.util import confit
from beets.autotag import hooks
from beets.autotag.match import Distance
import acoustid
import logging
from collections import defaultdict
@ -114,7 +113,7 @@ def _all_releases(items):
class AcoustidPlugin(plugins.BeetsPlugin):
def track_distance(self, item, info):
dist = Distance()
dist = hooks.Distance()
if item.path not in _matches or not info.track_id:
# Match failed or no track ID.
return dist

View file

@ -15,11 +15,9 @@
"""Adds Discogs album search support to the autotagger. Requires the
discogs-client library.
"""
from beets import config
from beets.autotag.hooks import AlbumInfo, TrackInfo
from beets.autotag.match import current_metadata, Distance, VA_ARTISTS
from beets.autotag.hooks import AlbumInfo, TrackInfo, Distance
from beets.plugins import BeetsPlugin
from discogs_client import Artist, DiscogsAPIError, Release, Search
from discogs_client import DiscogsAPIError, Release, Search
import beets
import discogs_client
import logging

View file

@ -23,13 +23,13 @@ import _common
from _common import unittest
from beets import autotag
from beets.autotag import match
from beets.autotag.match import Distance
from beets.autotag.hooks import Distance, string_dist
from beets.library import Item
from beets.util import plurality
from beets.autotag import AlbumInfo, TrackInfo
from beets import config
class PluralityTest(unittest.TestCase):
class PluralityTest(_common.TestCase):
def test_plurality_consensus(self):
objs = [1, 1, 1, 1]
obj, freq = plurality(objs)
@ -106,8 +106,9 @@ def _make_trackinfo():
TrackInfo(u'three', None, u'some artist', length=1, index=3),
]
class DistanceTest(unittest.TestCase):
class DistanceTest(_common.TestCase):
def setUp(self):
super(DistanceTest, self).setUp()
self.dist = Distance()
def test_add(self):
@ -176,62 +177,66 @@ class DistanceTest(unittest.TestCase):
self.assertEqual(self.dist._penalties['ratio'], [0.25, 1.0, 0.0, 0.0])
def test_add_string(self):
dist = match.string_dist(u'abc', u'bcd')
dist = string_dist(u'abc', u'bcd')
self.dist.add_string('string', u'abc', u'bcd')
self.assertEqual(self.dist._penalties['string'], [dist])
def test_distance(self):
config['match']['distance_weights']['album'] = 2.0
config['match']['distance_weights']['medium'] = 1.0
self.dist.add('album', 0.5)
self.dist.add('media', 0.25)
self.dist.add('media', 0.75)
self.assertEqual(self.dist.distance, 0.5)
dist = Distance()
dist.add('album', 0.5)
dist.add('media', 0.25)
dist.add('media', 0.75)
self.assertEqual(dist.distance, 0.5)
# __getitem__()
self.assertEqual(self.dist['album'], 0.25)
self.assertEqual(self.dist['media'], 0.25)
self.assertEqual(dist['album'], 0.25)
self.assertEqual(dist['media'], 0.25)
def test_max_distance(self):
config['match']['distance_weights']['album'] = 3.0
config['match']['distance_weights']['medium'] = 1.0
self.dist.add('album', 0.5)
self.dist.add('medium', 0.0)
self.dist.add('medium', 0.0)
self.assertEqual(self.dist.max_distance, 5.0)
dist = Distance()
dist.add('album', 0.5)
dist.add('medium', 0.0)
dist.add('medium', 0.0)
self.assertEqual(dist.max_distance, 5.0)
def test_operators(self):
config['match']['distance_weights']['source'] = 1.0
config['match']['distance_weights']['album'] = 2.0
config['match']['distance_weights']['medium'] = 1.0
self.dist.add('source', 0.0)
self.dist.add('album', 0.5)
self.dist.add('medium', 0.25)
self.dist.add('medium', 0.75)
self.assertEqual(len(self.dist), 2)
self.assertEqual(list(self.dist), [(0.2, 'album'), (0.2, 'medium')])
self.assertTrue(self.dist == 0.4)
self.assertTrue(self.dist < 1.0)
self.assertTrue(self.dist > 0.0)
self.assertEqual(self.dist - 0.4, 0.0)
self.assertEqual(0.4 - self.dist, 0.0)
self.assertEqual(float(self.dist), 0.4)
dist = Distance()
dist.add('source', 0.0)
dist.add('album', 0.5)
dist.add('medium', 0.25)
dist.add('medium', 0.75)
self.assertEqual(len(dist), 2)
self.assertEqual(list(dist), [(0.2, 'album'), (0.2, 'medium')])
self.assertTrue(dist == 0.4)
self.assertTrue(dist < 1.0)
self.assertTrue(dist > 0.0)
self.assertEqual(dist - 0.4, 0.0)
self.assertEqual(0.4 - dist, 0.0)
self.assertEqual(float(dist), 0.4)
def test_raw_distance(self):
config['match']['distance_weights']['album'] = 3.0
config['match']['distance_weights']['medium'] = 1.0
self.dist.add('album', 0.5)
self.dist.add('medium', 0.25)
self.dist.add('medium', 0.5)
self.assertEqual(self.dist.raw_distance, 2.25)
dist = Distance()
dist.add('album', 0.5)
dist.add('medium', 0.25)
dist.add('medium', 0.5)
self.assertEqual(dist.raw_distance, 2.25)
def test_sorted(self):
config['match']['distance_weights']['album'] = 4.0
config['match']['distance_weights']['medium'] = 2.0
self.dist.add('album', 0.1875)
self.dist.add('medium', 0.75)
self.assertEqual(self.dist.sorted, [(0.25, 'medium'), (0.125, 'album')])
dist = Distance()
dist.add('album', 0.1875)
dist.add('medium', 0.75)
self.assertEqual(dist.sorted, [(0.25, 'medium'), (0.125, 'album')])
# Sort by key if distance is equal.
dist = Distance()
@ -240,20 +245,21 @@ class DistanceTest(unittest.TestCase):
self.assertEqual(dist.sorted, [(0.25, 'album'), (0.25, 'medium')])
def test_update(self):
self.dist.add('album', 0.5)
self.dist.add('media', 1.0)
dist1 = Distance()
dist1.add('album', 0.5)
dist1.add('media', 1.0)
dist = Distance()
dist.add('album', 0.75)
dist.add('album', 0.25)
self.dist.add('media', 0.05)
dist2 = Distance()
dist2.add('album', 0.75)
dist2.add('album', 0.25)
dist2.add('media', 0.05)
self.dist.update(dist)
dist1.update(dist2)
self.assertEqual(self.dist._penalties, {'album': [0.5, 0.75, 0.25],
'media': [1.0, 0.05]})
self.assertEqual(dist1._penalties, {'album': [0.5, 0.75, 0.25],
'media': [1.0, 0.05]})
class TrackDistanceTest(unittest.TestCase):
class TrackDistanceTest(_common.TestCase):
def test_identical_tracks(self):
item = _make_item(u'one', 1)
info = _make_trackinfo()[0]
@ -280,7 +286,7 @@ class TrackDistanceTest(unittest.TestCase):
dist = match.track_distance(item, info, incl_artist=True)
self.assertEqual(dist, 0.0)
class AlbumDistanceTest(unittest.TestCase):
class AlbumDistanceTest(_common.TestCase):
def _mapping(self, items, info):
out = {}
for i, t in zip(items, info.tracks):
@ -863,77 +869,77 @@ class ApplyCompilationTest(_common.TestCase, ApplyTestUtil):
class StringDistanceTest(unittest.TestCase):
def test_equal_strings(self):
dist = match.string_dist(u'Some String', u'Some String')
dist = string_dist(u'Some String', u'Some String')
self.assertEqual(dist, 0.0)
def test_different_strings(self):
dist = match.string_dist(u'Some String', u'Totally Different')
dist = string_dist(u'Some String', u'Totally Different')
self.assertNotEqual(dist, 0.0)
def test_punctuation_ignored(self):
dist = match.string_dist(u'Some String', u'Some.String!')
dist = string_dist(u'Some String', u'Some.String!')
self.assertEqual(dist, 0.0)
def test_case_ignored(self):
dist = match.string_dist(u'Some String', u'sOME sTring')
dist = string_dist(u'Some String', u'sOME sTring')
self.assertEqual(dist, 0.0)
def test_leading_the_has_lower_weight(self):
dist1 = match.string_dist(u'XXX Band Name', u'Band Name')
dist2 = match.string_dist(u'The Band Name', u'Band Name')
dist1 = string_dist(u'XXX Band Name', u'Band Name')
dist2 = string_dist(u'The Band Name', u'Band Name')
self.assert_(dist2 < dist1)
def test_parens_have_lower_weight(self):
dist1 = match.string_dist(u'One .Two.', u'One')
dist2 = match.string_dist(u'One (Two)', u'One')
dist1 = string_dist(u'One .Two.', u'One')
dist2 = string_dist(u'One (Two)', u'One')
self.assert_(dist2 < dist1)
def test_brackets_have_lower_weight(self):
dist1 = match.string_dist(u'One .Two.', u'One')
dist2 = match.string_dist(u'One [Two]', u'One')
dist1 = string_dist(u'One .Two.', u'One')
dist2 = string_dist(u'One [Two]', u'One')
self.assert_(dist2 < dist1)
def test_ep_label_has_zero_weight(self):
dist = match.string_dist(u'My Song (EP)', u'My Song')
dist = string_dist(u'My Song (EP)', u'My Song')
self.assertEqual(dist, 0.0)
def test_featured_has_lower_weight(self):
dist1 = match.string_dist(u'My Song blah Someone', u'My Song')
dist2 = match.string_dist(u'My Song feat Someone', u'My Song')
dist1 = string_dist(u'My Song blah Someone', u'My Song')
dist2 = string_dist(u'My Song feat Someone', u'My Song')
self.assert_(dist2 < dist1)
def test_postfix_the(self):
dist = match.string_dist(u'The Song Title', u'Song Title, The')
dist = string_dist(u'The Song Title', u'Song Title, The')
self.assertEqual(dist, 0.0)
def test_postfix_a(self):
dist = match.string_dist(u'A Song Title', u'Song Title, A')
dist = string_dist(u'A Song Title', u'Song Title, A')
self.assertEqual(dist, 0.0)
def test_postfix_an(self):
dist = match.string_dist(u'An Album Title', u'Album Title, An')
dist = string_dist(u'An Album Title', u'Album Title, An')
self.assertEqual(dist, 0.0)
def test_empty_strings(self):
dist = match.string_dist(u'', u'')
dist = string_dist(u'', u'')
self.assertEqual(dist, 0.0)
def test_solo_pattern(self):
# Just make sure these don't crash.
match.string_dist(u'The ', u'')
match.string_dist(u'(EP)', u'(EP)')
match.string_dist(u', An', u'')
string_dist(u'The ', u'')
string_dist(u'(EP)', u'(EP)')
string_dist(u', An', u'')
def test_heuristic_does_not_harm_distance(self):
dist = match.string_dist(u'Untitled', u'[Untitled]')
dist = string_dist(u'Untitled', u'[Untitled]')
self.assertEqual(dist, 0.0)
def test_ampersand_expansion(self):
dist = match.string_dist(u'And', u'&')
dist = string_dist(u'And', u'&')
self.assertEqual(dist, 0.0)
def test_accented_characters(self):
dist = match.string_dist(u'\xe9\xe1\xf1', u'ean')
dist = string_dist(u'\xe9\xe1\xf1', u'ean')
self.assertEqual(dist, 0.0)
def suite():