further paramaterize parse_query_part

This commit is contained in:
Adrian Sampson 2014-01-21 19:09:34 -08:00
parent aa07eb9551
commit 74d0dc8352
2 changed files with 37 additions and 31 deletions

View file

@ -752,13 +752,17 @@ PARSE_QUERY_PART_REGEX = re.compile(
re.I # Case-insensitive.
)
def parse_query_part(part, query_classes={},
def parse_query_part(part, query_classes={}, prefixes={},
default_class=dbcore.query.SubstringQuery):
"""Take a query in the form of a key/value pair separated by a
colon and return a tuple of `(key, value, cls)`. `key` may be None,
indicating that any field may be matched. `cls` is a subclass of
`FieldQuery`. The optional `query_classes` parameter maps field names
to default query types; `default_class` is the fallback.
`FieldQuery`.
The optional `query_classes` parameter maps field names to default
query types; `default_class` is the fallback. `prefixes` is a map
from query prefix markers and query types. Prefix-indicated queries
take precedence over type-based queries.
To determine the query class, two factors are used: prefixes and
field types. For example, the colon prefix denotes a regular
@ -778,19 +782,18 @@ def parse_query_part(part, query_classes={},
part = part.strip()
match = PARSE_QUERY_PART_REGEX.match(part)
# FIXME parameterize
prefixes = {':': dbcore.query.RegexpQuery}
prefixes.update(plugins.queries())
assert match # Regex should always match.
key = match.group(1)
term = match.group(2).replace('\:', ':')
if match:
key = match.group(1)
term = match.group(2).replace('\:', ':')
# Match the search term against the list of prefixes.
for pre, query_class in prefixes.items():
if term.startswith(pre):
return key, term[len(pre):], query_class
query_class = query_classes.get(key, default_class)
return key, term, query_class
# Match the search term against the list of prefixes.
for pre, query_class in prefixes.items():
if term.startswith(pre):
return key, term[len(pre):], query_class
# No matching prefix: use type-based or fallback/default query.
query_class = query_classes.get(key, default_class)
return key, term, query_class
def construct_query_part(query_part, model_cls):
@ -799,7 +802,9 @@ def construct_query_part(query_part, model_cls):
`None` if the value cannot be parsed.
"""
query_classes = dict((k, t.query) for (k, t) in model_cls._fields.items())
parsed = parse_query_part(query_part, query_classes)
prefixes = {':': dbcore.query.RegexpQuery}
prefixes.update(plugins.queries())
parsed = parse_query_part(query_part, query_classes, prefixes)
if not parsed:
return

View file

@ -19,58 +19,59 @@ from _common import unittest
import beets.library
from beets import dbcore
pqp = beets.library.parse_query_part
TEST_TYPES = {
'year': dbcore.query.NumericQuery
}
class QueryParseTest(_common.TestCase):
def pqp(self, part):
return beets.library.parse_query_part(
part,
{'year': dbcore.query.NumericQuery},
{':': dbcore.query.RegexpQuery},
)
def test_one_basic_term(self):
q = 'test'
r = (None, 'test', dbcore.query.SubstringQuery)
self.assertEqual(pqp(q, TEST_TYPES), r)
self.assertEqual(self.pqp(q), r)
def test_one_keyed_term(self):
q = 'test:val'
r = ('test', 'val', dbcore.query.SubstringQuery)
self.assertEqual(pqp(q, TEST_TYPES), r)
self.assertEqual(self.pqp(q), r)
def test_colon_at_end(self):
q = 'test:'
r = (None, 'test:', dbcore.query.SubstringQuery)
self.assertEqual(pqp(q, TEST_TYPES), r)
self.assertEqual(self.pqp(q), r)
def test_one_basic_regexp(self):
q = r':regexp'
r = (None, 'regexp', dbcore.query.RegexpQuery)
self.assertEqual(pqp(q, TEST_TYPES), r)
self.assertEqual(self.pqp(q), r)
def test_keyed_regexp(self):
q = r'test::regexp'
r = ('test', 'regexp', dbcore.query.RegexpQuery)
self.assertEqual(pqp(q, TEST_TYPES), r)
self.assertEqual(self.pqp(q), r)
def test_escaped_colon(self):
q = r'test\:val'
r = (None, 'test:val', dbcore.query.SubstringQuery)
self.assertEqual(pqp(q, TEST_TYPES), r)
self.assertEqual(self.pqp(q), r)
def test_escaped_colon_in_regexp(self):
q = r':test\:regexp'
r = (None, 'test:regexp', dbcore.query.RegexpQuery)
self.assertEqual(pqp(q, TEST_TYPES), r)
self.assertEqual(self.pqp(q), r)
def test_single_year(self):
q = 'year:1999'
r = ('year', '1999', dbcore.query.NumericQuery)
self.assertEqual(pqp(q, TEST_TYPES), r)
self.assertEqual(self.pqp(q), r)
def test_multiple_years(self):
q = 'year:1999..2010'
r = ('year', '1999..2010', dbcore.query.NumericQuery)
self.assertEqual(pqp(q, TEST_TYPES), r)
self.assertEqual(self.pqp(q), r)
class AnyFieldQueryTest(_common.LibTestCase):