further paramaterize parse_query_part

This commit is contained in:
Adrian Sampson 2014-01-21 19:09:34 -08:00
parent aa07eb9551
commit 74d0dc8352
2 changed files with 37 additions and 31 deletions

View file

@ -752,13 +752,17 @@ PARSE_QUERY_PART_REGEX = re.compile(
re.I # Case-insensitive. re.I # Case-insensitive.
) )
def parse_query_part(part, query_classes={}, def parse_query_part(part, query_classes={}, prefixes={},
default_class=dbcore.query.SubstringQuery): default_class=dbcore.query.SubstringQuery):
"""Take a query in the form of a key/value pair separated by a """Take a query in the form of a key/value pair separated by a
colon and return a tuple of `(key, value, cls)`. `key` may be None, colon and return a tuple of `(key, value, cls)`. `key` may be None,
indicating that any field may be matched. `cls` is a subclass of indicating that any field may be matched. `cls` is a subclass of
`FieldQuery`. The optional `query_classes` parameter maps field names `FieldQuery`.
to default query types; `default_class` is the fallback.
The optional `query_classes` parameter maps field names to default
query types; `default_class` is the fallback. `prefixes` is a map
from query prefix markers and query types. Prefix-indicated queries
take precedence over type-based queries.
To determine the query class, two factors are used: prefixes and To determine the query class, two factors are used: prefixes and
field types. For example, the colon prefix denotes a regular field types. For example, the colon prefix denotes a regular
@ -778,19 +782,18 @@ def parse_query_part(part, query_classes={},
part = part.strip() part = part.strip()
match = PARSE_QUERY_PART_REGEX.match(part) match = PARSE_QUERY_PART_REGEX.match(part)
# FIXME parameterize assert match # Regex should always match.
prefixes = {':': dbcore.query.RegexpQuery} key = match.group(1)
prefixes.update(plugins.queries()) term = match.group(2).replace('\:', ':')
if match: # Match the search term against the list of prefixes.
key = match.group(1) for pre, query_class in prefixes.items():
term = match.group(2).replace('\:', ':') if term.startswith(pre):
# Match the search term against the list of prefixes. return key, term[len(pre):], query_class
for pre, query_class in prefixes.items():
if term.startswith(pre): # No matching prefix: use type-based or fallback/default query.
return key, term[len(pre):], query_class query_class = query_classes.get(key, default_class)
query_class = query_classes.get(key, default_class) return key, term, query_class
return key, term, query_class
def construct_query_part(query_part, model_cls): def construct_query_part(query_part, model_cls):
@ -799,7 +802,9 @@ def construct_query_part(query_part, model_cls):
`None` if the value cannot be parsed. `None` if the value cannot be parsed.
""" """
query_classes = dict((k, t.query) for (k, t) in model_cls._fields.items()) query_classes = dict((k, t.query) for (k, t) in model_cls._fields.items())
parsed = parse_query_part(query_part, query_classes) prefixes = {':': dbcore.query.RegexpQuery}
prefixes.update(plugins.queries())
parsed = parse_query_part(query_part, query_classes, prefixes)
if not parsed: if not parsed:
return return

View file

@ -19,58 +19,59 @@ from _common import unittest
import beets.library import beets.library
from beets import dbcore from beets import dbcore
pqp = beets.library.parse_query_part
TEST_TYPES = {
'year': dbcore.query.NumericQuery
}
class QueryParseTest(_common.TestCase): class QueryParseTest(_common.TestCase):
def pqp(self, part):
return beets.library.parse_query_part(
part,
{'year': dbcore.query.NumericQuery},
{':': dbcore.query.RegexpQuery},
)
def test_one_basic_term(self): def test_one_basic_term(self):
q = 'test' q = 'test'
r = (None, 'test', dbcore.query.SubstringQuery) r = (None, 'test', dbcore.query.SubstringQuery)
self.assertEqual(pqp(q, TEST_TYPES), r) self.assertEqual(self.pqp(q), r)
def test_one_keyed_term(self): def test_one_keyed_term(self):
q = 'test:val' q = 'test:val'
r = ('test', 'val', dbcore.query.SubstringQuery) r = ('test', 'val', dbcore.query.SubstringQuery)
self.assertEqual(pqp(q, TEST_TYPES), r) self.assertEqual(self.pqp(q), r)
def test_colon_at_end(self): def test_colon_at_end(self):
q = 'test:' q = 'test:'
r = (None, 'test:', dbcore.query.SubstringQuery) r = (None, 'test:', dbcore.query.SubstringQuery)
self.assertEqual(pqp(q, TEST_TYPES), r) self.assertEqual(self.pqp(q), r)
def test_one_basic_regexp(self): def test_one_basic_regexp(self):
q = r':regexp' q = r':regexp'
r = (None, 'regexp', dbcore.query.RegexpQuery) r = (None, 'regexp', dbcore.query.RegexpQuery)
self.assertEqual(pqp(q, TEST_TYPES), r) self.assertEqual(self.pqp(q), r)
def test_keyed_regexp(self): def test_keyed_regexp(self):
q = r'test::regexp' q = r'test::regexp'
r = ('test', 'regexp', dbcore.query.RegexpQuery) r = ('test', 'regexp', dbcore.query.RegexpQuery)
self.assertEqual(pqp(q, TEST_TYPES), r) self.assertEqual(self.pqp(q), r)
def test_escaped_colon(self): def test_escaped_colon(self):
q = r'test\:val' q = r'test\:val'
r = (None, 'test:val', dbcore.query.SubstringQuery) r = (None, 'test:val', dbcore.query.SubstringQuery)
self.assertEqual(pqp(q, TEST_TYPES), r) self.assertEqual(self.pqp(q), r)
def test_escaped_colon_in_regexp(self): def test_escaped_colon_in_regexp(self):
q = r':test\:regexp' q = r':test\:regexp'
r = (None, 'test:regexp', dbcore.query.RegexpQuery) r = (None, 'test:regexp', dbcore.query.RegexpQuery)
self.assertEqual(pqp(q, TEST_TYPES), r) self.assertEqual(self.pqp(q), r)
def test_single_year(self): def test_single_year(self):
q = 'year:1999' q = 'year:1999'
r = ('year', '1999', dbcore.query.NumericQuery) r = ('year', '1999', dbcore.query.NumericQuery)
self.assertEqual(pqp(q, TEST_TYPES), r) self.assertEqual(self.pqp(q), r)
def test_multiple_years(self): def test_multiple_years(self):
q = 'year:1999..2010' q = 'year:1999..2010'
r = ('year', '1999..2010', dbcore.query.NumericQuery) r = ('year', '1999..2010', dbcore.query.NumericQuery)
self.assertEqual(pqp(q, TEST_TYPES), r) self.assertEqual(self.pqp(q), r)
class AnyFieldQueryTest(_common.LibTestCase): class AnyFieldQueryTest(_common.LibTestCase):