diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index c7c84054c..3865034c5 100644 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -66,7 +66,7 @@ def format_for_path(value, key=None): # Abstract base for model classes and their field types. -Type = namedtuple('Type', 'py_type sql_type') +Type = namedtuple('Type', 'py_type sql_type query') class Model(object): diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index f9b8801cb..2d20f9e0f 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -165,32 +165,18 @@ class NumericQuery(FieldQuery): (``..``) lets users specify one- or two-sided ranges. For example, ``year:2001..`` finds music released since the turn of the century. """ - # FIXME - # types = dict((r[0], r[1]) for r in ITEM_FIELDS) - - @classmethod - def applies_to(cls, field): - """Determine whether a field has numeric type. NumericQuery - should only be used with such fields. - """ - if field not in cls.types: - # This can happen when using album fields. - # FIXME should no longer be necessary with the new type system. - return False - return cls.types.get(field).py_type in (int, float) - def _convert(self, s): """Convert a string to the appropriate numeric type. If the string cannot be converted, return None. """ try: - return self.numtype(s) + # FIXME should work w/ either integer or float + return float(s) except ValueError: return None def __init__(self, field, pattern, fast=True): super(NumericQuery, self).__init__(field, pattern, fast) - self.numtype = self.types[field].py_type parts = pattern.split('..', 1) if len(parts) == 1: diff --git a/beets/library.py b/beets/library.py index 3ea83d06a..1015b9e19 100644 --- a/beets/library.py +++ b/beets/library.py @@ -36,12 +36,12 @@ from datetime import datetime # Common types used in field definitions. TYPES = { - int: Type(int, 'INTEGER'), - float: Type(float, 'REAL'), - datetime: Type(datetime, 'REAL'), - bytes: Type(bytes, 'BLOB'), - unicode: Type(unicode, 'TEXT'), - bool: Type(bool, 'INTEGER'), + int: Type(int, 'INTEGER', dbcore.query.NumericQuery), + float: Type(float, 'REAL', dbcore.query.NumericQuery), + datetime: Type(datetime, 'REAL', dbcore.query.NumericQuery), + bytes: Type(bytes, 'BLOB', dbcore.query.MatchQuery), + unicode: Type(unicode, 'TEXT', dbcore.query.SubstringQuery), + bool: Type(bool, 'INTEGER', dbcore.query.BooleanQuery), } @@ -54,9 +54,10 @@ TYPES = { # - Is the field writable? # - Does the field reflect an attribute of a MediaFile? ITEM_FIELDS = [ - ('id', Type(int, 'INTEGER PRIMARY KEY'), False, False), - ('path', TYPES[bytes], False, False), - ('album_id', TYPES[int], False, False), + ('id', Type(int, 'INTEGER PRIMARY KEY', dbcore.query.NumericQuery), + False, False), + ('path', TYPES[bytes], False, False), + ('album_id', TYPES[int], False, False), ('title', TYPES[unicode], True, True), ('artist', TYPES[unicode], True, True), @@ -125,7 +126,8 @@ ITEM_KEYS = [f[0] for f in ITEM_FIELDS] # The third entry in each tuple indicates whether the field reflects an # identically-named field in the items table. ALBUM_FIELDS = [ - ('id', Type(int, 'INTEGER PRIMARY KEY'), False), + ('id', Type(int, 'INTEGER PRIMARY KEY', dbcore.query.NumericQuery), + False), ('artpath', TYPES[bytes], False), ('added', TYPES[datetime], True), @@ -759,19 +761,23 @@ PARSE_QUERY_PART_REGEX = re.compile( re.I # Case-insensitive. ) -def parse_query_part(part): - """Takes a query in the form of a key/value pair separated by a - colon. The value part is matched against a list of prefixes that - can be extended by plugins to add custom query types. For - example, the colon prefix denotes a regular expression query. +def parse_query_part(part, query_classes={}, + default_class=dbcore.query.SubstringQuery): + """Take a query in the form of a key/value pair separated by a + colon and return a tuple of `(key, value, cls)`. `key` may be None, + indicating that any field may be matched. `cls` is a subclass of + `FieldQuery`. The optional `query_classes` parameter maps field names + to default query types; `default_class` is the fallback. - The function returns a tuple of `(key, value, cls)`. `key` may - be None, indicating that any field may be matched. `cls` is a - subclass of `FieldQuery`. + To determine the query class, two factors are used: prefixes and + field types. For example, the colon prefix denotes a regular + expression query and a type map might provide a special kind of + query for numeric values. If neither a prefix nor a specific query + class is available, `default_class` is used. For instance, - parse_query('stapler') == (None, 'stapler', None) - parse_query('color:red') == ('color', 'red', None) + parse_query('stapler') == (None, 'stapler', SubstringQuery) + parse_query('color:red') == ('color', 'red', SubstringQuery) parse_query(':^Quiet') == (None, '^Quiet', RegexpQuery) parse_query('color::b..e') == ('color', 'b..e', RegexpQuery) @@ -781,6 +787,7 @@ def parse_query_part(part): part = part.strip() match = PARSE_QUERY_PART_REGEX.match(part) + # FIXME parameterize prefixes = {':': dbcore.query.RegexpQuery} prefixes.update(plugins.queries()) @@ -791,16 +798,17 @@ def parse_query_part(part): for pre, query_class in prefixes.items(): if term.startswith(pre): return key, term[len(pre):], query_class - if key and dbcore.query.NumericQuery.applies_to(key): - return key, term, dbcore.query.NumericQuery - return key, term, dbcore.query.SubstringQuery # Default query type. + query_class = query_classes.get(key, default_class) + return key, term, query_class -def construct_query_part(query_part, default_fields, all_keys): - """Create a query from a single query component. Return a Query - instance or None if the value cannot be parsed. +def construct_query_part(query_part, model_cls): + """Create a query from a single query component, `query_part`, for + querying instances of `model_cls`. Return a `Query` instance or + `None` if the value cannot be parsed. """ - parsed = parse_query_part(query_part) + query_classes = dict((k, t.query) for (k, t) in model_cls._fields.items()) + parsed = parse_query_part(query_part, query_classes) if not parsed: return @@ -808,14 +816,15 @@ def construct_query_part(query_part, default_fields, all_keys): # No key specified. if key is None: - if os.sep in pattern and 'path' in all_keys: + if os.sep in pattern and 'path' in model_cls._fields: # This looks like a path. return PathQuery(pattern) elif issubclass(query_class, dbcore.FieldQuery): # The query type matches a specific field, but none was # specified. So we use a version of the query that matches # any field. - return dbcore.query.AnyFieldQuery(pattern, default_fields, + return dbcore.query.AnyFieldQuery(pattern, + model_cls._search_fields, query_class) else: # Other query type. @@ -828,7 +837,7 @@ def construct_query_part(query_part, default_fields, all_keys): return dbcore.query.BooleanQuery(key, pattern) # Path field. - elif key == 'path' and 'path' in all_keys: + elif key == 'path' and 'path' in model_cls._fields: if query_class is dbcore.query.SubstringQuery: # By default, use special path matching logic. return PathQuery(pattern) @@ -842,18 +851,17 @@ def construct_query_part(query_part, default_fields, all_keys): # Other field. else: - return query_class(key.lower(), pattern, key in all_keys) + return query_class(key.lower(), pattern, key in model_cls._fields) -def query_from_strings(query_cls, query_parts, default_fields, all_keys): - """Creates a collection query of type `query-cls` from a list of - strings in the format used by parse_query_part. If default_fields - are specified, they are the fields to be searched by unqualified - search terms. Otherwise, all fields are searched for those terms. +def query_from_strings(query_cls, model_cls, query_parts): + """Creates a collection query of type `query_cls` from a list of + strings in the format used by parse_query_part. `model_cls` + determines how queries are constructed from strings. """ subqueries = [] for part in query_parts: - subq = construct_query_part(part, default_fields, all_keys) + subq = construct_query_part(part, model_cls) if subq: subqueries.append(subq) if not subqueries: # No terms in query. @@ -881,9 +889,7 @@ def get_query(val, model_cls): if val is None: return dbcore.query.TrueQuery() elif isinstance(val, list) or isinstance(val, tuple): - return query_from_strings(dbcore.AndQuery, - val, model_cls._search_fields, - model_cls._fields.keys()) + return query_from_strings(dbcore.AndQuery, model_cls, val) elif isinstance(val, dbcore.Query): return val else: diff --git a/test/test_dbcore.py b/test/test_dbcore.py index 4f95c3070..e8bbd3b69 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -25,12 +25,15 @@ from beets import dbcore # Fixture: concrete database and model classes. For migration tests, we # have multiple models with different numbers of fields. +ID_TYPE = dbcore.Type(int, 'INTEGER PRIMARY KEY', dbcore.query.NumericQuery) +INT_TYPE = dbcore.Type(int, 'INTEGER', dbcore.query.NumericQuery) + class TestModel1(dbcore.Model): _table = 'test' _flex_table = 'testflex' _fields = { - 'id': dbcore.Type(int, 'INTEGER PRIMARY KEY'), - 'field_one': dbcore.Type(int, 'INTEGER'), + 'id': ID_TYPE, + 'field_one': INT_TYPE, } @classmethod @@ -46,9 +49,9 @@ class TestDatabase1(dbcore.Database): class TestModel2(TestModel1): _fields = { - 'id': dbcore.Type(int, 'INTEGER PRIMARY KEY'), - 'field_one': dbcore.Type(int, 'INTEGER'), - 'field_two': dbcore.Type(int, 'INTEGER'), + 'id': ID_TYPE, + 'field_one': INT_TYPE, + 'field_two': INT_TYPE, } class TestDatabase2(dbcore.Database): @@ -57,10 +60,10 @@ class TestDatabase2(dbcore.Database): class TestModel3(TestModel1): _fields = { - 'id': dbcore.Type(int, 'INTEGER PRIMARY KEY'), - 'field_one': dbcore.Type(int, 'INTEGER'), - 'field_two': dbcore.Type(int, 'INTEGER'), - 'field_three': dbcore.Type(int, 'INTEGER'), + 'id': ID_TYPE, + 'field_one': INT_TYPE, + 'field_two': INT_TYPE, + 'field_three': INT_TYPE, } class TestDatabase3(dbcore.Database): @@ -69,11 +72,11 @@ class TestDatabase3(dbcore.Database): class TestModel4(TestModel1): _fields = { - 'id': dbcore.Type(int, 'INTEGER PRIMARY KEY'), - 'field_one': dbcore.Type(int, 'INTEGER'), - 'field_two': dbcore.Type(int, 'INTEGER'), - 'field_three': dbcore.Type(int, 'INTEGER'), - 'field_four': dbcore.Type(int, 'INTEGER'), + 'id': ID_TYPE, + 'field_one': INT_TYPE, + 'field_two': INT_TYPE, + 'field_three': INT_TYPE, + 'field_four': INT_TYPE, } class TestDatabase4(dbcore.Database): @@ -84,8 +87,8 @@ class AnotherTestModel(TestModel1): _table = 'another' _flex_table = 'anotherflex' _fields = { - 'id': dbcore.Type(int, 'INTEGER PRIMARY KEY'), - 'foo': dbcore.Type(int, 'INTEGER'), + 'id': ID_TYPE, + 'foo': INT_TYPE, } class TestDatabaseTwoModels(dbcore.Database): diff --git a/test/test_query.py b/test/test_query.py index f629ae27b..d05470081 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -22,51 +22,55 @@ from beets import dbcore pqp = beets.library.parse_query_part +TEST_TYPES = { + 'year': dbcore.query.NumericQuery +} + class QueryParseTest(_common.TestCase): def test_one_basic_term(self): q = 'test' r = (None, 'test', dbcore.query.SubstringQuery) - self.assertEqual(pqp(q), r) + self.assertEqual(pqp(q, TEST_TYPES), r) def test_one_keyed_term(self): q = 'test:val' r = ('test', 'val', dbcore.query.SubstringQuery) - self.assertEqual(pqp(q), r) + self.assertEqual(pqp(q, TEST_TYPES), r) def test_colon_at_end(self): q = 'test:' r = (None, 'test:', dbcore.query.SubstringQuery) - self.assertEqual(pqp(q), r) + self.assertEqual(pqp(q, TEST_TYPES), r) def test_one_basic_regexp(self): q = r':regexp' r = (None, 'regexp', dbcore.query.RegexpQuery) - self.assertEqual(pqp(q), r) + self.assertEqual(pqp(q, TEST_TYPES), r) def test_keyed_regexp(self): q = r'test::regexp' r = ('test', 'regexp', dbcore.query.RegexpQuery) - self.assertEqual(pqp(q), r) + self.assertEqual(pqp(q, TEST_TYPES), r) def test_escaped_colon(self): q = r'test\:val' r = (None, 'test:val', dbcore.query.SubstringQuery) - self.assertEqual(pqp(q), r) + self.assertEqual(pqp(q, TEST_TYPES), r) def test_escaped_colon_in_regexp(self): q = r':test\:regexp' r = (None, 'test:regexp', dbcore.query.RegexpQuery) - self.assertEqual(pqp(q), r) + self.assertEqual(pqp(q, TEST_TYPES), r) def test_single_year(self): q = 'year:1999' r = ('year', '1999', dbcore.query.NumericQuery) - self.assertEqual(pqp(q), r) + self.assertEqual(pqp(q, TEST_TYPES), r) def test_multiple_years(self): q = 'year:1999..2010' r = ('year', '1999..2010', dbcore.query.NumericQuery) - self.assertEqual(pqp(q), r) + self.assertEqual(pqp(q, TEST_TYPES), r) class AnyFieldQueryTest(_common.LibTestCase):