diff --git a/beets/library.py b/beets/library.py index 07b0e7129..4962a44ad 100644 --- a/beets/library.py +++ b/beets/library.py @@ -548,33 +548,75 @@ class BooleanQuery(MatchQuery): self.pattern = util.str2bool(pattern) self.pattern = int(self.pattern) -class YearQuery(FieldQuery): - """Matches a year or years against a year field. +class NumericQuery(FieldQuery): + """Matches numeric fields. A syntax using Ruby-style range ellipses + (``..``) lets users specify one- or two-sided ranges. For example, + ``year:2001..`` finds music released since the turn of the century. """ + kinds = dict((r[0], r[1]) for r in ITEM_FIELDS) + @classmethod def applies_to(cls, field): - return field in ['year', 'original_year'] - - @classmethod - def value_match(cls, pattern, value): - """Determine whether the value matches the pattern. Both - arguments are strings. + """Determine whether a field has numeric type. NumericQuery + should only be used with such fields. """ - return value in cls._expanded_years(pattern) + return cls.kinds.get(field) in ('int', 'real') + + def _convert(self, s): + """Convert a string to the appropriate numeric type. If the + string cannot be converted, return None. + """ + try: + return self.numtype(s) + except ValueError: + return None + + def __init__(self, field, pattern): + super(NumericQuery, self).__init__(field, pattern) + if self.kinds[field] == 'int': + self.numtype = int + else: + self.numtype = float + + parts = pattern.split('..', 1) + if len(parts) == 1: + # No range. + self.point = self._convert(parts[0]) + self.rangemin = None + self.rangemax = None + else: + # One- or two-sided range. + self.point = None + self.rangemin = self._convert(parts[0]) + self.rangemax = self._convert(parts[1]) + + def match(self, item): + value = getattr(item, self.field) + if isinstance(value, basestring): + value = self._convert(value) + + if self.point is not None: + return value == self.point + else: + if self.rangemin is not None and value < self.rangemin: + return False + if self.rangemax is not None and value > self.rangemax: + return False + return True def clause(self): - years = YearQuery._expanded_years(self.pattern) - return self.field + " IN (" + ",".join(years) + ")", () - - @classmethod - def _expanded_years(self, pattern): - try: - ranges = [[int(y) for y in se.split('-')] for se in pattern.split(',')] - except ValueError: - raise ValueError('invalid year') - years = [range(r[0], r[1] + 1) if len(r) > 1 else [r[0]] for r in ranges] - return [str(y) for yrs in years for y in yrs] - + if self.point is not None: + return self.field + '=?', (self.point,) + else: + if self.rangemin is not None and self.rangemax is not None: + return (u'{0} >= ? AND {0} <= ?'.format(self.field), + (self.rangemin, self.rangemax)) + elif self.rangemin is not None: + return u'{0} >= ?'.format(self.field), (self.rangemin,) + elif self.rangemax is not None: + return u'{0} <= ?'.format(self.field), (self.rangemax,) + else: + return '1' class SingletonQuery(Query): """Matches either singleton or non-singleton items.""" @@ -777,8 +819,8 @@ def parse_query_part(part): for pre, query_class in prefixes.items(): if term.startswith(pre): return key, term[len(pre):], query_class - if YearQuery.applies_to(key): - return key, term, YearQuery + if key and NumericQuery.applies_to(key): + return key, term, NumericQuery return key, term, SubstringQuery # The default query type. def construct_query_part(query_part, default_fields, all_keys): diff --git a/test/test_query.py b/test/test_query.py index 66d580a1a..ffc3af37d 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -62,17 +62,12 @@ class QueryParseTest(unittest.TestCase): def test_single_year(self): q = 'year:1999' - r = ('year', '1999', beets.library.YearQuery) + r = ('year', '1999', beets.library.NumericQuery) self.assertEqual(pqp(q), r) def test_multiple_years(self): - q = 'year:1999,2002,2010' - r = ('year', '1999,2002,2010', beets.library.YearQuery) - self.assertEqual(pqp(q), r) - - def test_year_range(self): - q = 'year:1999-2001' - r = ('year', '1999-2001', beets.library.YearQuery) + q = 'year:1999..2010' + r = ('year', '1999..2010', beets.library.NumericQuery) self.assertEqual(pqp(q), r) class AnyFieldQueryTest(unittest.TestCase): @@ -242,21 +237,13 @@ class GetTest(unittest.TestCase, AssertsMixin): self.assert_done(results) def test_year_range(self): - q = 'year:2000-2010' + q = 'year:2000..2010' results = self.lib.items(q) self.assert_matched(results, 'Littlest Things') self.assert_matched(results, 'Take Pills') self.assert_matched(results, 'Lovers Who Uncover') self.assert_done(results) - def test_multiple_years(self): - q = 'year:1987,2004-2006' - results = self.lib.items(q) - self.assert_matched(results, 'Littlest Things') - self.assert_matched(results, 'Lovers Who Uncover') - self.assert_matched(results, 'Boracay') - self.assert_done(results) - def test_bad_year(self): q = 'year:delete from items' self.assertRaises(ValueError, self.lib.items, q) @@ -355,11 +342,11 @@ class MatchTest(unittest.TestCase): self.assertTrue(q.match(self.item)) def test_year_match_positive(self): - q = beets.library.YearQuery('year', '1') + q = beets.library.NumericQuery('year', '1') self.assertTrue(q.match(self.item)) def test_year_match_negative(self): - q = beets.library.YearQuery('year', '10') + q = beets.library.NumericQuery('year', '10') self.assertFalse(q.match(self.item)) class PathQueryTest(unittest.TestCase, AssertsMixin):