Sort implementation

* sort can be sepcified using the 'field_name'(+|-) syntax
 * supports fixed fields and flexible attributes
 * includes plugins fix for API changes (might have missed some)
This commit is contained in:
Pierre Rust 2014-06-11 16:21:43 +02:00
parent eb5c37ecc0
commit 1303a915c1
8 changed files with 263 additions and 38 deletions

View file

@ -19,5 +19,6 @@ from .db import Model, Database
from .query import Query, FieldQuery, MatchQuery, AndQuery, OrQuery
from .types import Type
from .queryparse import query_from_strings
from .queryparse import sort_from_strings
# flake8: noqa

View file

@ -24,7 +24,7 @@ import collections
import beets
from beets.util.functemplate import Template
from .query import MatchQuery
from .query import MatchQuery, build_sql
from .types import BASE_TYPE
@ -509,7 +509,10 @@ class Results(object):
),
(row['id'],)
)
values = dict(row)
cols = dict(row)
values = dict((k, v) for (k, v) in cols.items()
if not k[:4] == 'flex')
flex_values = dict((row['key'], row['value']) for row in flex_rows)
# Construct the Python object and yield it if it passes the
@ -739,24 +742,20 @@ class Database(object):
# Querying.
def _fetch(self, model_cls, query, order_by=None):
def _fetch(self, model_cls, query, sort_order=None):
"""Fetch the objects of type `model_cls` matching the given
query. The query may be given as a string, string sequence, a
Query object, or None (to fetch everything). If provided,
`order_by` is a SQLite ORDER BY clause for sorting.
"""
where, subvals = query.clause()
`sort_order` is either a SQLite ORDER BY clause for sorting or a
Sort object.
"""
sql, subvals, is_slow = build_sql(model_cls, query, sort_order)
sql = "SELECT * FROM {0} WHERE {1}".format(
model_cls._table,
where or '1',
)
if order_by:
sql += " ORDER BY {0}".format(order_by)
with self.transaction() as tx:
rows = tx.query(sql, subvals)
return Results(model_cls, rows, self, None if where else query)
return Results(model_cls, rows, self, None if not is_slow else query)
def _get(self, model_cls, id):
"""Get a Model object by its id or None if the id does not

View file

@ -497,3 +497,163 @@ class DateQuery(FieldQuery):
# Match any date.
clause = '1'
return clause, subvals
class Sort(object):
"""An abstract class representing a sort opertation for a query into the
item database.
"""
def select_clause(self):
""" Generates a select sql fragment.
"""
return None
def union_clause(self):
""" Generates a union sql fragment or None if the Sort is a slow sort.
"""
return None
def order_clause(self):
"""Generates a sql fragment to be use in a ORDER BY clause
or None if it's a slow query
"""
return None
def sort(self, items):
"""Sort the given items list. Meant to be used with slow queries.
"""
return items
class MultipleSort(Sort):
""" Sort class that combines several sort criteria.
"""
def __init__(self):
self.sorts = []
def add_criteria(self, sort):
self.sorts.append(sort)
def select_clause(self):
""" Generate a select sql fragment.
"""
select_strings = []
index = 0
for sort in self.sorts:
select = sort.select_clause()
if select is None:
# FIXME : sort for slow sort
break
else:
select_strings.append(select)
index = index + 1
select_string = ",".join(select_strings)
return "" if not select_string else ", " + select_string
def union_clause(self):
""" Returns a union sql fragment.
"""
union_strings = []
for sort in self.sorts:
union = sort.union_clause()
if union is None:
pass
else:
union_strings.append(union)
return "".join(union_strings)
def order_clause(self):
"""Returns a sql fragment to be use in a ORDER BY clause
or None if it's a slow query
"""
order_strings = []
index = 0
for sort in self.sorts:
order = sort.order_clause()
if order is None:
break
else:
order_strings.append(order)
index = index + 1
return ",".join(order_strings)
def sort(self, items):
# FIXME : sort according to criteria following the first slow sort
return items
class FlexFieldSort(Sort):
def __init__(self, model_cls, field, is_ascending):
self.model_cls = model_cls
self.field = field
self.is_ascending = is_ascending
def select_clause(self):
""" Return a select sql fragment.
"""
return "sort_flexattr{0!s}.value as flex_{0!s} ".format(self.field)
def union_clause(self):
""" Returns an union sql fragment.
"""
return "LEFT JOIN {flextable} as sort_flexattr{index!s} \
ON {table}.id = sort_flexattr{index!s}.entity_id \
AND sort_flexattr{index!s}.key='{flexattr}' ".format(
flextable=self.model_cls._flex_table,
table=self.model_cls._table,
index=self.field, flexattr=self.field)
def order_clause(self):
""" Returns an order sql fragment.
"""
order = "ASC" if self.is_ascending else "DESC"
return "flex_{0} {1} ".format(self.field, order)
class FixedFieldSort(Sort):
def __init__(self, field, is_ascending=True):
self.field = field
self.is_ascending = is_ascending
""" Sort on a fixed field
"""
def order_clause(self):
order = "ASC" if self.is_ascending else "DESC"
return "{0} {1}".format(self.field, order)
def build_sql(model_cls, query, sort_order):
""" Generate a sql statement (and the values that must be injected into it)
from a query, sort and a model class.
"""
where, subvals = query.clause()
if not sort_order:
sort_select = ""
sort_union = ""
sort_order = ""
elif isinstance(sort_order, basestring):
sort_select = ""
sort_union = ""
sort_order = " ORDER BY {0}".format(sort_order)
elif isinstance(sort_order, Sort):
sort_select = sort_order.select_clause()
sort_union = sort_order.union_clause()
sort_order = " ORDER BY {0}".format(sort_order.order_clause())
sql = "SELECT {table}.* {sort_select} FROM {table} {sort_union} WHERE \
{query_clause} {sort_order}".format(
sort_select=sort_select,
sort_union=sort_union,
table=model_cls._table,
query_clause=where or '1',
sort_order=sort_order
)
return (sql, subvals, where is None)

View file

@ -121,3 +121,31 @@ def query_from_strings(query_cls, model_cls, prefixes, query_parts):
if not subqueries: # No terms in query.
subqueries = [query.TrueQuery()]
return query_cls(subqueries)
def construct_sort_part(model_cls, part):
""" Creates a Sort object from a single criteria. Returns a `Sort` instance.
"""
sort = None
field = part[:-1]
is_ascending = (part[-1] == '+')
if field in model_cls._fields:
sort = query.FixedFieldSort(field, is_ascending)
elif field in model_cls._getters():
# Computed field, all following fields must use the slow path.
pass
else:
# Neither fixed nor computed : must be a flex attr.
sort = query.FlexFieldSort(model_cls, field, is_ascending)
return sort
def sort_from_strings(model_cls, sort_parts):
"""Creates a Sort object from a list of sort criteria strings.
"""
if not sort_parts:
return None
sort = query.MultipleSort()
for part in sort_parts:
sort.add_criteria(construct_sort_part(model_cls, part))
return sort

View file

@ -546,7 +546,7 @@ class Item(LibModel):
for query, path_format in path_formats:
if query == PF_KEY_DEFAULT:
continue
query = get_query(query, type(self))
(query, _) = get_query(query, type(self))
if query.match(self):
# The query matches the item! Use the corresponding path
# format.
@ -889,7 +889,8 @@ class Album(LibModel):
def get_query(val, model_cls):
"""Take a value which may be None, a query string, a query string
list, or a Query object, and return a suitable Query object.
list, or a Query object, and return a suitable Query object and Sort
object.
`model_cls` is the subclass of Model indicating which entity this
is a query for (i.e., Album or Item) and is used to determine which
@ -910,7 +911,7 @@ def get_query(val, model_cls):
val = [s.decode('utf8') for s in shlex.split(val)]
if val is None:
return dbcore.query.TrueQuery()
return (dbcore.query.TrueQuery(), None)
elif isinstance(val, list) or isinstance(val, tuple):
# Special-case path-like queries, which are non-field queries
@ -928,18 +929,23 @@ def get_query(val, model_cls):
path_parts = ()
non_path_parts = val
# separate query token and sort token
query_val = [s for s in non_path_parts if not s.endswith(('+', '-'))]
sort_val = [s for s in non_path_parts if s.endswith(('+', '-'))]
# Parse remaining parts and construct an AndQuery.
query = dbcore.query_from_strings(
dbcore.AndQuery, model_cls, prefixes, non_path_parts
dbcore.AndQuery, model_cls, prefixes, query_val
)
sort = dbcore.sort_from_strings(model_cls, sort_val)
# Add path queries to aggregate query.
if path_parts:
query.subqueries += [PathQuery('path', s) for s in path_parts]
return query
return (query, sort)
elif isinstance(val, dbcore.Query):
return val
return (val, None)
else:
raise ValueError('query must be None or have type Query or str')
@ -1006,30 +1012,30 @@ class Library(dbcore.Database):
# Querying.
def _fetch(self, model_cls, query, order_by=None):
"""Parse a query and fetch.
"""
def _fetch(self, model_cls, query, sort_order=None):
"""Parse a query and fetch. If a sort_order is explicitly given,
any sort order specification present in the query string is ignored.
"""
(query, sort) = get_query(query, model_cls)
sort = sort if sort_order is None else sort_order
return super(Library, self)._fetch(
model_cls, get_query(query, model_cls), order_by
model_cls, query, sort
)
def albums(self, query=None):
def albums(self, query=None, sort_order=None):
"""Get a sorted list of :class:`Album` objects matching the
given query.
given sort order. If a sort_order is explicitly given,
any sort order specification present in the query string is ignored.
"""
order = '{0}, album'.format(
_orelse("albumartist_sort", "albumartist")
)
return self._fetch(Album, query, order)
return self._fetch(Album, query, sort_order)
def items(self, query=None):
def items(self, query=None, sort_order=None):
"""Get a sorted list of :class:`Item` objects matching the given
query.
given sort order. If a sort_order is explicitly given,
any sort order specification present in the query string is ignored.
"""
order = '{0}, album, disc, track'.format(
_orelse("artist_sort", "artist")
)
return self._fetch(Item, query, order)
return self._fetch(Item, query, sort_order)
# Convenience accessors.

View file

@ -57,9 +57,9 @@ class IHatePlugin(BeetsPlugin):
for query_string in action_patterns:
query = None
if task.is_album:
query = get_query(query_string, Album)
(query, _) = get_query(query_string, Album)
else:
query = get_query(query_string, Item)
(query, _) = get_query(query_string, Item)
if any(query.match(item) for item in task.imported_items()):
return True
return False

View file

@ -42,7 +42,7 @@ def _items_for_query(lib, playlist, album=False):
query_strings = [query_strings]
model = library.Album if album else library.Item
query = dbcore.OrQuery(
[library.get_query(q, model) for q in query_strings]
[library.get_query(q, model)[0] for q in query_strings]
)
# Execute query, depending on type.

View file

@ -412,6 +412,37 @@ class QueryFromStringsTest(_common.TestCase):
self.assertIsInstance(q.subqueries[0], dbcore.query.NumericQuery)
class SortFromStringsTest(_common.TestCase):
def sfs(self, strings):
return dbcore.queryparse.sort_from_strings(
TestModel1,
strings,
)
def test_zero_parts(self):
s = self.sfs([])
self.assertIsNone(s)
def test_one_parts(self):
s = self.sfs(['field+'])
self.assertIsInstance(s, dbcore.query.Sort)
def test_two_parts(self):
s = self.sfs(['field+', 'another_field-'])
self.assertIsInstance(s, dbcore.query.MultipleSort)
self.assertEqual(len(s.sorts), 2)
def test_fixed_field_sort(self):
s = self.sfs(['field_one+'])
self.assertIsInstance(s, dbcore.query.MultipleSort)
self.assertIsInstance(s.sorts[0], dbcore.query.FixedFieldSort)
def test_flex_field_sort(self):
s = self.sfs(['flex_field+'])
self.assertIsInstance(s, dbcore.query.MultipleSort)
self.assertIsInstance(s.sorts[0], dbcore.query.FlexFieldSort)
def suite():
return unittest.TestLoader().loadTestsFromName(__name__)