generalize YearQuery to NumericQuery

This commit is contained in:
Adrian Sampson 2013-05-09 15:47:25 -07:00
parent 0f06c79991
commit ea0928c845
2 changed files with 71 additions and 42 deletions

View file

@ -548,33 +548,75 @@ class BooleanQuery(MatchQuery):
self.pattern = util.str2bool(pattern)
self.pattern = int(self.pattern)
class YearQuery(FieldQuery):
"""Matches a year or years against a year field.
class NumericQuery(FieldQuery):
"""Matches numeric fields. A syntax using Ruby-style range ellipses
(``..``) lets users specify one- or two-sided ranges. For example,
``year:2001..`` finds music released since the turn of the century.
"""
kinds = dict((r[0], r[1]) for r in ITEM_FIELDS)
@classmethod
def applies_to(cls, field):
return field in ['year', 'original_year']
@classmethod
def value_match(cls, pattern, value):
"""Determine whether the value matches the pattern. Both
arguments are strings.
"""Determine whether a field has numeric type. NumericQuery
should only be used with such fields.
"""
return value in cls._expanded_years(pattern)
return cls.kinds.get(field) in ('int', 'real')
def _convert(self, s):
"""Convert a string to the appropriate numeric type. If the
string cannot be converted, return None.
"""
try:
return self.numtype(s)
except ValueError:
return None
def __init__(self, field, pattern):
super(NumericQuery, self).__init__(field, pattern)
if self.kinds[field] == 'int':
self.numtype = int
else:
self.numtype = float
parts = pattern.split('..', 1)
if len(parts) == 1:
# No range.
self.point = self._convert(parts[0])
self.rangemin = None
self.rangemax = None
else:
# One- or two-sided range.
self.point = None
self.rangemin = self._convert(parts[0])
self.rangemax = self._convert(parts[1])
def match(self, item):
value = getattr(item, self.field)
if isinstance(value, basestring):
value = self._convert(value)
if self.point is not None:
return value == self.point
else:
if self.rangemin is not None and value < self.rangemin:
return False
if self.rangemax is not None and value > self.rangemax:
return False
return True
def clause(self):
years = YearQuery._expanded_years(self.pattern)
return self.field + " IN (" + ",".join(years) + ")", ()
@classmethod
def _expanded_years(self, pattern):
try:
ranges = [[int(y) for y in se.split('-')] for se in pattern.split(',')]
except ValueError:
raise ValueError('invalid year')
years = [range(r[0], r[1] + 1) if len(r) > 1 else [r[0]] for r in ranges]
return [str(y) for yrs in years for y in yrs]
if self.point is not None:
return self.field + '=?', (self.point,)
else:
if self.rangemin is not None and self.rangemax is not None:
return (u'{0} >= ? AND {0} <= ?'.format(self.field),
(self.rangemin, self.rangemax))
elif self.rangemin is not None:
return u'{0} >= ?'.format(self.field), (self.rangemin,)
elif self.rangemax is not None:
return u'{0} <= ?'.format(self.field), (self.rangemax,)
else:
return '1'
class SingletonQuery(Query):
"""Matches either singleton or non-singleton items."""
@ -777,8 +819,8 @@ def parse_query_part(part):
for pre, query_class in prefixes.items():
if term.startswith(pre):
return key, term[len(pre):], query_class
if YearQuery.applies_to(key):
return key, term, YearQuery
if key and NumericQuery.applies_to(key):
return key, term, NumericQuery
return key, term, SubstringQuery # The default query type.
def construct_query_part(query_part, default_fields, all_keys):

View file

@ -62,17 +62,12 @@ class QueryParseTest(unittest.TestCase):
def test_single_year(self):
q = 'year:1999'
r = ('year', '1999', beets.library.YearQuery)
r = ('year', '1999', beets.library.NumericQuery)
self.assertEqual(pqp(q), r)
def test_multiple_years(self):
q = 'year:1999,2002,2010'
r = ('year', '1999,2002,2010', beets.library.YearQuery)
self.assertEqual(pqp(q), r)
def test_year_range(self):
q = 'year:1999-2001'
r = ('year', '1999-2001', beets.library.YearQuery)
q = 'year:1999..2010'
r = ('year', '1999..2010', beets.library.NumericQuery)
self.assertEqual(pqp(q), r)
class AnyFieldQueryTest(unittest.TestCase):
@ -242,21 +237,13 @@ class GetTest(unittest.TestCase, AssertsMixin):
self.assert_done(results)
def test_year_range(self):
q = 'year:2000-2010'
q = 'year:2000..2010'
results = self.lib.items(q)
self.assert_matched(results, 'Littlest Things')
self.assert_matched(results, 'Take Pills')
self.assert_matched(results, 'Lovers Who Uncover')
self.assert_done(results)
def test_multiple_years(self):
q = 'year:1987,2004-2006'
results = self.lib.items(q)
self.assert_matched(results, 'Littlest Things')
self.assert_matched(results, 'Lovers Who Uncover')
self.assert_matched(results, 'Boracay')
self.assert_done(results)
def test_bad_year(self):
q = 'year:delete from items'
self.assertRaises(ValueError, self.lib.items, q)
@ -355,11 +342,11 @@ class MatchTest(unittest.TestCase):
self.assertTrue(q.match(self.item))
def test_year_match_positive(self):
q = beets.library.YearQuery('year', '1')
q = beets.library.NumericQuery('year', '1')
self.assertTrue(q.match(self.item))
def test_year_match_negative(self):
q = beets.library.YearQuery('year', '10')
q = beets.library.NumericQuery('year', '10')
self.assertFalse(q.match(self.item))
class PathQueryTest(unittest.TestCase, AssertsMixin):