diff --git a/beets/library.py b/beets/library.py index 569d7d2b5..41e25984e 100644 --- a/beets/library.py +++ b/beets/library.py @@ -752,13 +752,17 @@ PARSE_QUERY_PART_REGEX = re.compile( re.I # Case-insensitive. ) -def parse_query_part(part, query_classes={}, +def parse_query_part(part, query_classes={}, prefixes={}, default_class=dbcore.query.SubstringQuery): """Take a query in the form of a key/value pair separated by a colon and return a tuple of `(key, value, cls)`. `key` may be None, indicating that any field may be matched. `cls` is a subclass of - `FieldQuery`. The optional `query_classes` parameter maps field names - to default query types; `default_class` is the fallback. + `FieldQuery`. + + The optional `query_classes` parameter maps field names to default + query types; `default_class` is the fallback. `prefixes` is a map + from query prefix markers and query types. Prefix-indicated queries + take precedence over type-based queries. To determine the query class, two factors are used: prefixes and field types. For example, the colon prefix denotes a regular @@ -778,19 +782,18 @@ def parse_query_part(part, query_classes={}, part = part.strip() match = PARSE_QUERY_PART_REGEX.match(part) - # FIXME parameterize - prefixes = {':': dbcore.query.RegexpQuery} - prefixes.update(plugins.queries()) + assert match # Regex should always match. + key = match.group(1) + term = match.group(2).replace('\:', ':') - if match: - key = match.group(1) - term = match.group(2).replace('\:', ':') - # Match the search term against the list of prefixes. - for pre, query_class in prefixes.items(): - if term.startswith(pre): - return key, term[len(pre):], query_class - query_class = query_classes.get(key, default_class) - return key, term, query_class + # Match the search term against the list of prefixes. + for pre, query_class in prefixes.items(): + if term.startswith(pre): + return key, term[len(pre):], query_class + + # No matching prefix: use type-based or fallback/default query. + query_class = query_classes.get(key, default_class) + return key, term, query_class def construct_query_part(query_part, model_cls): @@ -799,7 +802,9 @@ def construct_query_part(query_part, model_cls): `None` if the value cannot be parsed. """ query_classes = dict((k, t.query) for (k, t) in model_cls._fields.items()) - parsed = parse_query_part(query_part, query_classes) + prefixes = {':': dbcore.query.RegexpQuery} + prefixes.update(plugins.queries()) + parsed = parse_query_part(query_part, query_classes, prefixes) if not parsed: return diff --git a/test/test_query.py b/test/test_query.py index d05470081..a37ad3851 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -19,58 +19,59 @@ from _common import unittest import beets.library from beets import dbcore -pqp = beets.library.parse_query_part - - -TEST_TYPES = { - 'year': dbcore.query.NumericQuery -} class QueryParseTest(_common.TestCase): + def pqp(self, part): + return beets.library.parse_query_part( + part, + {'year': dbcore.query.NumericQuery}, + {':': dbcore.query.RegexpQuery}, + ) + def test_one_basic_term(self): q = 'test' r = (None, 'test', dbcore.query.SubstringQuery) - self.assertEqual(pqp(q, TEST_TYPES), r) + self.assertEqual(self.pqp(q), r) def test_one_keyed_term(self): q = 'test:val' r = ('test', 'val', dbcore.query.SubstringQuery) - self.assertEqual(pqp(q, TEST_TYPES), r) + self.assertEqual(self.pqp(q), r) def test_colon_at_end(self): q = 'test:' r = (None, 'test:', dbcore.query.SubstringQuery) - self.assertEqual(pqp(q, TEST_TYPES), r) + self.assertEqual(self.pqp(q), r) def test_one_basic_regexp(self): q = r':regexp' r = (None, 'regexp', dbcore.query.RegexpQuery) - self.assertEqual(pqp(q, TEST_TYPES), r) + self.assertEqual(self.pqp(q), r) def test_keyed_regexp(self): q = r'test::regexp' r = ('test', 'regexp', dbcore.query.RegexpQuery) - self.assertEqual(pqp(q, TEST_TYPES), r) + self.assertEqual(self.pqp(q), r) def test_escaped_colon(self): q = r'test\:val' r = (None, 'test:val', dbcore.query.SubstringQuery) - self.assertEqual(pqp(q, TEST_TYPES), r) + self.assertEqual(self.pqp(q), r) def test_escaped_colon_in_regexp(self): q = r':test\:regexp' r = (None, 'test:regexp', dbcore.query.RegexpQuery) - self.assertEqual(pqp(q, TEST_TYPES), r) + self.assertEqual(self.pqp(q), r) def test_single_year(self): q = 'year:1999' r = ('year', '1999', dbcore.query.NumericQuery) - self.assertEqual(pqp(q, TEST_TYPES), r) + self.assertEqual(self.pqp(q), r) def test_multiple_years(self): q = 'year:1999..2010' r = ('year', '1999..2010', dbcore.query.NumericQuery) - self.assertEqual(pqp(q, TEST_TYPES), r) + self.assertEqual(self.pqp(q), r) class AnyFieldQueryTest(_common.LibTestCase):