Roll back fast flexible field sorts (#953)

Sad to see them go, but happy be rid of the SQL injection.
This commit is contained in:
Adrian Sampson 2014-09-13 17:16:12 -07:00
parent 395be1a0f6
commit 4870d7e0fa
5 changed files with 36 additions and 74 deletions

View file

@ -751,7 +751,7 @@ class Database(object):
Query object, or None (to fetch everything). If provided, Query object, or None (to fetch everything). If provided,
`sort_order` is either a SQLite ORDER BY clause for sorting or a `sort_order` is either a SQLite ORDER BY clause for sorting or a
Sort object. Sort object.
""" """
sql, subvals, query, sort = build_sql(model_cls, query, sort_order) sql, subvals, query, sort = build_sql(model_cls, query, sort_order)

View file

@ -613,56 +613,29 @@ class MultipleSort(Sort):
return items return items
class FlexFieldSort(Sort): class FieldSort(Sort):
"""Sort object to sort on a flexible attribute field """An abstract sort criterion that orders by a specific field (of
any kind).
""" """
def __init__(self, model_cls, field, is_ascending): def __init__(self, field, ascending=True):
self.model_cls = model_cls
self.field = field self.field = field
self.is_ascending = is_ascending self.ascending = ascending
def select_clause(self): def sort(self, objs):
"""Return a SELECT fragment. # TODO: Conversion and null-detection here. In Python 3,
""" # comparisons with None fail. We should also support flexible
return "sort_flexattr{0!s}.value as flex_{0!s} ".format(self.field) # attributes with different types without falling over.
return sorted(objs, key=attrgetter(self.field),
def union_clause(self): reverse=not self.ascending)
"""Return a JOIN fragment.
"""
union = ("LEFT JOIN {flextable} as sort_flexattr{index!s} "
"ON {table}.id = sort_flexattr{index!s}.entity_id "
"AND sort_flexattr{index!s}.key='{flexattr}' ").format(
flextable=self.model_cls._flex_table,
table=self.model_cls._table,
index=self.field, flexattr=self.field)
return union
def order_clause(self):
"""Return an ORDER BY fragment.
"""
order = "ASC" if self.is_ascending else "DESC"
return "flex_{0} {1} ".format(self.field, order)
def sort(self, items):
return sorted(items, key=attrgetter(self.field),
reverse=(not self.is_ascending))
class FixedFieldSort(Sort): class FixedFieldSort(FieldSort):
"""Sort object to sort on a fixed field """Sort object to sort on a fixed field.
""" """
def __init__(self, field, is_ascending=True):
self.field = field
self.is_ascending = is_ascending
def order_clause(self): def order_clause(self):
order = "ASC" if self.is_ascending else "DESC" order = "ASC" if self.ascending else "DESC"
return "{0} {1}".format(self.field, order) return "{0} {1}".format(self.field, order)
def sort(self, items):
return sorted(items, key=attrgetter(self.field),
reverse=(not self.is_ascending))
class SmartArtistSort(Sort): class SmartArtistSort(Sort):
""" Sort Album or Item on artist sort fields, defaulting back on """ Sort Album or Item on artist sort fields, defaulting back on
@ -695,19 +668,13 @@ class SmartArtistSort(Sort):
return order_str return order_str
class ComputedFieldSort(Sort): class SlowFieldSort(FieldSort):
"""A sort criterion by some model field other than a fixed field:
def __init__(self, model_cls, field, is_ascending=True): i.e., a computed or flexible field.
self.is_ascending = is_ascending """
self.field = field
self._getters = model_cls._getters()
def is_slow(self): def is_slow(self):
return True return True
def sort(self, items):
return sorted(items, key=lambda x: self._getters[self.field](x),
reverse=(not self.is_ascending))
special_sorts = {'smartartist': SmartArtistSort} special_sorts = {'smartartist': SmartArtistSort}
@ -740,7 +707,7 @@ def build_sql(model_cls, query, sort):
order_clause = sort.order_clause() order_clause = sort.order_clause()
sort_order = " ORDER BY {0}".format(order_clause) \ sort_order = " ORDER BY {0}".format(order_clause) \
if order_clause else "" if order_clause else ""
if sort.is_slow(): if not sort.is_slow():
sort = None sort = None
sql = ("SELECT {table}.* {sort_select} FROM {table} {sort_union} WHERE " sql = ("SELECT {table}.* {sort_select} FROM {table} {sort_union} WHERE "

View file

@ -131,14 +131,11 @@ def construct_sort_part(model_cls, part):
is_ascending = (part[-1] == '+') is_ascending = (part[-1] == '+')
if field in model_cls._fields: if field in model_cls._fields:
sort = query.FixedFieldSort(field, is_ascending) sort = query.FixedFieldSort(field, is_ascending)
elif field in model_cls._getters():
# Computed field, all following fields must use the slow path.
sort = query.ComputedFieldSort(model_cls, field, is_ascending)
elif field in query.special_sorts: elif field in query.special_sorts:
sort = query.special_sorts[field](model_cls, is_ascending) sort = query.special_sorts[field](model_cls, is_ascending)
else: else:
# Neither fixed nor computed : must be a flex attr. # Flexible or comptued.
sort = query.FlexFieldSort(model_cls, field, is_ascending) sort = query.SlowFieldSort(field, is_ascending)
return sort return sort

View file

@ -453,7 +453,7 @@ class SortFromStringsTest(unittest.TestCase):
def test_flex_field_sort(self): def test_flex_field_sort(self):
s = self.sfs(['flex_field+']) s = self.sfs(['flex_field+'])
self.assertIsInstance(s, dbcore.query.MultipleSort) self.assertIsInstance(s, dbcore.query.MultipleSort)
self.assertIsInstance(s.sorts[0], dbcore.query.FlexFieldSort) self.assertIsInstance(s.sorts[0], dbcore.query.SlowFieldSort)
def suite(): def suite():

View file

@ -131,7 +131,7 @@ class SortFixedFieldTest(DummyDataTestCase):
class SortFlexFieldTest(DummyDataTestCase): class SortFlexFieldTest(DummyDataTestCase):
def test_sort_asc(self): def test_sort_asc(self):
q = '' q = ''
sort = dbcore.query.FlexFieldSort(beets.library.Item, "flex1", True) sort = dbcore.query.SlowFieldSort("flex1", True)
results = self.lib.items(q, sort) results = self.lib.items(q, sort)
self.assertLessEqual(results[0]['flex1'], results[1]['flex1']) self.assertLessEqual(results[0]['flex1'], results[1]['flex1'])
self.assertEqual(results[0]['flex1'], 'flex1-0') self.assertEqual(results[0]['flex1'], 'flex1-0')
@ -143,7 +143,7 @@ class SortFlexFieldTest(DummyDataTestCase):
def test_sort_desc(self): def test_sort_desc(self):
q = '' q = ''
sort = dbcore.query.FlexFieldSort(beets.library.Item, "flex1", False) sort = dbcore.query.SlowFieldSort("flex1", False)
results = self.lib.items(q, sort) results = self.lib.items(q, sort)
self.assertGreaterEqual(results[0]['flex1'], results[1]['flex1']) self.assertGreaterEqual(results[0]['flex1'], results[1]['flex1'])
self.assertGreaterEqual(results[1]['flex1'], results[2]['flex1']) self.assertGreaterEqual(results[1]['flex1'], results[2]['flex1'])
@ -157,8 +157,8 @@ class SortFlexFieldTest(DummyDataTestCase):
def test_sort_two_field(self): def test_sort_two_field(self):
q = '' q = ''
s1 = dbcore.query.FlexFieldSort(beets.library.Item, "flex2", False) s1 = dbcore.query.SlowFieldSort("flex2", False)
s2 = dbcore.query.FlexFieldSort(beets.library.Item, "flex1", True) s2 = dbcore.query.SlowFieldSort("flex1", True)
sort = dbcore.query.MultipleSort() sort = dbcore.query.MultipleSort()
sort.add_sort(s1) sort.add_sort(s1)
sort.add_sort(s2) sort.add_sort(s2)
@ -220,10 +220,10 @@ class SortAlbumFixedFieldTest(DummyDataTestCase):
self.assertEqual(r1.id, r2.id) self.assertEqual(r1.id, r2.id)
class SortAlbumFlexdFieldTest(DummyDataTestCase): class SortAlbumFlexFieldTest(DummyDataTestCase):
def test_sort_asc(self): def test_sort_asc(self):
q = '' q = ''
sort = dbcore.query.FlexFieldSort(beets.library.Album, "flex1", True) sort = dbcore.query.SlowFieldSort("flex1", True)
results = self.lib.albums(q, sort) results = self.lib.albums(q, sort)
self.assertLessEqual(results[0]['flex1'], results[1]['flex1']) self.assertLessEqual(results[0]['flex1'], results[1]['flex1'])
self.assertLessEqual(results[1]['flex1'], results[2]['flex1']) self.assertLessEqual(results[1]['flex1'], results[2]['flex1'])
@ -235,7 +235,7 @@ class SortAlbumFlexdFieldTest(DummyDataTestCase):
def test_sort_desc(self): def test_sort_desc(self):
q = '' q = ''
sort = dbcore.query.FlexFieldSort(beets.library.Album, "flex1", False) sort = dbcore.query.SlowFieldSort("flex1", False)
results = self.lib.albums(q, sort) results = self.lib.albums(q, sort)
self.assertGreaterEqual(results[0]['flex1'], results[1]['flex1']) self.assertGreaterEqual(results[0]['flex1'], results[1]['flex1'])
self.assertGreaterEqual(results[1]['flex1'], results[2]['flex1']) self.assertGreaterEqual(results[1]['flex1'], results[2]['flex1'])
@ -247,8 +247,8 @@ class SortAlbumFlexdFieldTest(DummyDataTestCase):
def test_sort_two_field_asc(self): def test_sort_two_field_asc(self):
q = '' q = ''
s1 = dbcore.query.FlexFieldSort(beets.library.Album, "flex2", True) s1 = dbcore.query.SlowFieldSort("flex2", True)
s2 = dbcore.query.FlexFieldSort(beets.library.Album, "flex1", True) s2 = dbcore.query.SlowFieldSort("flex1", True)
sort = dbcore.query.MultipleSort() sort = dbcore.query.MultipleSort()
sort.add_sort(s1) sort.add_sort(s1)
sort.add_sort(s2) sort.add_sort(s2)
@ -268,8 +268,7 @@ class SortAlbumFlexdFieldTest(DummyDataTestCase):
class SortAlbumComputedFieldTest(DummyDataTestCase): class SortAlbumComputedFieldTest(DummyDataTestCase):
def test_sort_asc(self): def test_sort_asc(self):
q = '' q = ''
sort = dbcore.query.ComputedFieldSort(beets.library.Album, "path", sort = dbcore.query.SlowFieldSort("path", True)
True)
results = self.lib.albums(q, sort) results = self.lib.albums(q, sort)
self.assertLessEqual(results[0]['path'], results[1]['path']) self.assertLessEqual(results[0]['path'], results[1]['path'])
self.assertLessEqual(results[1]['path'], results[2]['path']) self.assertLessEqual(results[1]['path'], results[2]['path'])
@ -281,8 +280,7 @@ class SortAlbumComputedFieldTest(DummyDataTestCase):
def test_sort_desc(self): def test_sort_desc(self):
q = '' q = ''
sort = dbcore.query.ComputedFieldSort(beets.library.Album, "path", sort = dbcore.query.SlowFieldSort("path", False)
False)
results = self.lib.albums(q, sort) results = self.lib.albums(q, sort)
self.assertGreaterEqual(results[0]['path'], results[1]['path']) self.assertGreaterEqual(results[0]['path'], results[1]['path'])
self.assertGreaterEqual(results[1]['path'], results[2]['path']) self.assertGreaterEqual(results[1]['path'], results[2]['path'])
@ -296,7 +294,7 @@ class SortAlbumComputedFieldTest(DummyDataTestCase):
class SortCombinedFieldTest(DummyDataTestCase): class SortCombinedFieldTest(DummyDataTestCase):
def test_computed_first(self): def test_computed_first(self):
q = '' q = ''
s1 = dbcore.query.ComputedFieldSort(beets.library.Album, "path", True) s1 = dbcore.query.SlowFieldSort("path", True)
s2 = dbcore.query.FixedFieldSort("year", True) s2 = dbcore.query.FixedFieldSort("year", True)
sort = dbcore.query.MultipleSort() sort = dbcore.query.MultipleSort()
sort.add_sort(s1) sort.add_sort(s1)
@ -312,7 +310,7 @@ class SortCombinedFieldTest(DummyDataTestCase):
def test_computed_second(self): def test_computed_second(self):
q = '' q = ''
s1 = dbcore.query.FixedFieldSort("year", True) s1 = dbcore.query.FixedFieldSort("year", True)
s2 = dbcore.query.ComputedFieldSort(beets.library.Album, "path", True) s2 = dbcore.query.SlowFieldSort("path", True)
sort = dbcore.query.MultipleSort() sort = dbcore.query.MultipleSort()
sort.add_sort(s1) sort.add_sort(s1)
sort.add_sort(s2) sort.add_sort(s2)