diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index e12178784..c50d2d572 100644 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -483,43 +483,64 @@ class Results(object): """An item query result set. Iterating over the collection lazily constructs LibModel objects that reflect database rows. """ - def __init__(self, model_class, rows, db, query=None): + def __init__(self, model_class, rows, db, query=None, sort=None): """Create a result set that will construct objects of type `model_class`, which should be a subclass of `LibModel`, out of the query result mapping in `rows`. The new objects are - associated with the database `db`. If `query` is provided, it is - used as a predicate to filter the results for a "slow query" that - cannot be evaluated by the database directly. + associated with the database `db`. + If `query` is provided, it is used as a predicate to filter the results + for a "slow query" that cannot be evaluated by the database directly. + If `sort` is provided, it is used to sort the full list of results + before returning. This means it is a "slow sort" and all objects must + be built before returning the first one. """ self.model_class = model_class self.rows = rows self.db = db self.query = query + self.sort = sort def __iter__(self): """Construct Python objects for all rows that pass the query predicate. """ - for row in self.rows: - # Get the flexible attributes for the object. - with self.db.transaction() as tx: - flex_rows = tx.query( - 'SELECT * FROM {0} WHERE entity_id=?'.format( - self.model_class._flex_table - ), - (row['id'],) - ) + if self.sort: + # slow sort, must build the full list first + objects = [] + for row in self.rows: + obj = self._generate_results(row) + # check the predicate if any + if not self.query or self.query.match(obj): + objects.append(obj) + # Now that we have the full list, we can sort it + objects = self.sort.sort(objects) + for o in objects: + yield o + else: + for row in self.rows: + obj = self._generate_results(row) + # check the predicate if any + if not self.query or self.query.match(obj): + yield obj - cols = dict(row) - values = dict((k, v) for (k, v) in cols.items() - if not k[:4] == 'flex') - flex_values = dict((row['key'], row['value']) for row in flex_rows) + def _generate_results(self, row): + # Get the flexible attributes for the object. + with self.db.transaction() as tx: + flex_rows = tx.query( + 'SELECT * FROM {0} WHERE entity_id=?'.format( + self.model_class._flex_table + ), + (row['id'],) + ) - # Construct the Python object and yield it if it passes the - # predicate. - obj = self.model_class._awaken(self.db, values, flex_values) - if not self.query or self.query.match(obj): - yield obj + cols = dict(row) + values = dict((k, v) for (k, v) in cols.items() + if not k[:4] == 'flex') + flex_values = dict((row['key'], row['value']) for row in flex_rows) + + # Construct the Python object + obj = self.model_class._awaken(self.db, values, flex_values) + return obj def __len__(self): """Get the number of matching objects. @@ -750,12 +771,15 @@ class Database(object): Sort object. """ - sql, subvals, is_slow = build_sql(model_cls, query, sort_order) + sql, subvals, slow_query, slow_sort = build_sql(model_cls, query, + sort_order) with self.transaction() as tx: rows = tx.query(sql, subvals) - return Results(model_cls, rows, self, None if not is_slow else query) + return Results(model_cls, rows, self, + None if not slow_query else query, + None if not slow_sort else sort_order) def _get(self, model_cls, id): """Get a Model object by its id or None if the id does not diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 9952fcf4b..14cb2af82 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -15,6 +15,7 @@ """The Query type hierarchy for DBCore. """ import re +from operator import attrgetter from beets import util from datetime import datetime, timedelta @@ -500,33 +501,43 @@ class DateQuery(FieldQuery): class Sort(object): - """An abstract class representing a sort opertation for a query into the + """An abstract class representing a sort operation for a query into the item database. """ def select_clause(self): - """ Generates a select sql fragment. + """ Generates a select sql fragment if the sort operation requires one, + an empty string otherwise. """ - return None + return "" def union_clause(self): - """ Generates a union sql fragment or None if the Sort is a slow sort. + """ Generates a union sql fragment if the sort operation requires one, + an empty string otherwise. """ - return None + return "" def order_clause(self): - """Generates a sql fragment to be use in a ORDER BY clause - or None if it's a slow query + """Generates a sql fragment to be use in a ORDER BY clause or None if + it's a slow query. """ return None def sort(self, items): - """Sort the given items list. Meant to be used with slow queries. + """Return a key function that can be used with the list.sort() method. + Meant to be used with slow sort, it must be implemented even for sort + that can be done with sql, as they might be used in conjunction with + slow sort. """ - return items + return sorted(items, key=lambda x: x) + + def is_slow(self): + return False class MultipleSort(Sort): """ Sort class that combines several sort criteria. + This implementation tries to implement as many sort operation in sql, + falling back to python sort only when necessary. """ def __init__(self): @@ -535,59 +546,75 @@ class MultipleSort(Sort): def add_criteria(self, sort): self.sorts.append(sort) - def select_clause(self): - """ Generate a select sql fragment. + def _sql_sorts(self): + """ Returns the list of sort for which sql can be used """ - select_strings = [] - index = 0 - for sort in self.sorts: - select = sort.select_clause() - if select is None: - # FIXME : sort for slow sort - break + # with several Sort, we can use SQL sorting only if there is only + # SQL-capable Sort or if the list ends with SQl-capable Sort. + sql_sorts = [] + for sort in reversed(self.sorts): + if not sort.order_clause() is None: + sql_sorts.append(sort) else: + break + sql_sorts.reverse() + return sql_sorts + + def select_clause(self): + sql_sorts = self._sql_sorts() + select_strings = [] + for sort in sql_sorts: + select = sort.select_clause() + if select: select_strings.append(select) - index = index + 1 select_string = ",".join(select_strings) return "" if not select_string else ", " + select_string def union_clause(self): - """ Returns a union sql fragment. - """ + sql_sorts = self._sql_sorts() union_strings = [] - for sort in self.sorts: + for sort in sql_sorts: union = sort.union_clause() - if union is None: - pass - else: - union_strings.append(union) + union_strings.append(union) return "".join(union_strings) def order_clause(self): - """Returns a sql fragment to be use in a ORDER BY clause - or None if it's a slow query - """ + sql_sorts = self._sql_sorts() order_strings = [] - index = 0 - for sort in self.sorts: + for sort in sql_sorts: order = sort.order_clause() - if order is None: - break - else: - order_strings.append(order) - index = index + 1 + order_strings.append(order) return ",".join(order_strings) + def is_slow(self): + for sort in self.sorts: + if sort.is_slow(): + return True + return False + def sort(self, items): - # FIXME : sort according to criteria following the first slow sort + slow_sorts = [] + switch_slow = False + for sort in reversed(self.sorts): + if switch_slow: + slow_sorts.append(sort) + elif sort.order_clause() is None: + switch_slow = True + slow_sorts.append(sort) + else: + pass + + for sort in slow_sorts: + items = sort.sort(items) return items class FlexFieldSort(Sort): - + """Sort object to sort on a flexible attribute field + """ def __init__(self, model_cls, field, is_ascending): self.model_cls = model_cls self.field = field @@ -615,19 +642,26 @@ class FlexFieldSort(Sort): order = "ASC" if self.is_ascending else "DESC" return "flex_{0} {1} ".format(self.field, order) + def sort(self, items): + return sorted(items, key=attrgetter(self.field), + reverse=(not self.is_ascending)) + class FixedFieldSort(Sort): - + """Sort object to sort on a fixed field + """ def __init__(self, field, is_ascending=True): self.field = field self.is_ascending = is_ascending - """ Sort on a fixed field - """ def order_clause(self): order = "ASC" if self.is_ascending else "DESC" return "{0} {1}".format(self.field, order) + def sort(self, items): + return sorted(items, key=attrgetter(self.field), + reverse=(not self.is_ascending)) + class SmartArtistSort(Sort): """ Sort Album or Item on artist sort fields, defaulting back on @@ -660,27 +694,48 @@ class SmartArtistSort(Sort): return order_str +class ComputedFieldSort(Sort): + + def __init__(self, model_cls, field, is_ascending=True): + self.is_ascending = is_ascending + self.field = field + self._getters = model_cls._getters() + + def is_slow(self): + return True + + def sort(self, items): + return sorted(items, key=lambda x: self._getters[self.field](x), + reverse=(not self.is_ascending)) + special_sorts = {'smartartist': SmartArtistSort} -def build_sql(model_cls, query, sort_order): +def build_sql(model_cls, query, sort): """ Generate a sql statement (and the values that must be injected into it) from a query, sort and a model class. """ where, subvals = query.clause() + slow_query = where is None - if not sort_order: + if not sort: sort_select = "" sort_union = "" sort_order = "" - elif isinstance(sort_order, basestring): + slow_sort = False + elif isinstance(sort, basestring): sort_select = "" sort_union = "" - sort_order = " ORDER BY {0}".format(sort_order) - elif isinstance(sort_order, Sort): - sort_select = sort_order.select_clause() - sort_union = sort_order.union_clause() - sort_order = " ORDER BY {0}".format(sort_order.order_clause()) + sort_order = " ORDER BY {0}".format(sort) \ + if sort else "" + slow_sort = False + elif isinstance(sort, Sort): + sort_select = sort.select_clause() + sort_union = sort.union_clause() + slow_sort = sort.is_slow() + order_clause = sort.order_clause() + sort_order = " ORDER BY {0}".format(order_clause) \ + if order_clause else "" sql = ("SELECT {table}.* {sort_select} FROM {table} {sort_union} WHERE " "{query_clause} {sort_order}").format( @@ -691,4 +746,4 @@ def build_sql(model_cls, query, sort_order): sort_order=sort_order ) - return (sql, subvals, where is None) \ No newline at end of file + return (sql, subvals, slow_query, slow_sort) \ No newline at end of file diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index 1a94c1831..03536f8ca 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -133,7 +133,7 @@ def construct_sort_part(model_cls, part): sort = query.FixedFieldSort(field, is_ascending) elif field in model_cls._getters(): # Computed field, all following fields must use the slow path. - pass + sort = query.ComputedFieldSort(model_cls, field, is_ascending) elif field in query.special_sorts: sort = query.special_sorts[field](model_cls, is_ascending) else: