Roll back fast flexible field sorts (#953)

Sad to see them go, but happy be rid of the SQL injection.
This commit is contained in:
Adrian Sampson 2014-09-13 17:16:12 -07:00
parent 395be1a0f6
commit 4870d7e0fa
5 changed files with 36 additions and 74 deletions

View file

@ -613,56 +613,29 @@ class MultipleSort(Sort):
return items
class FlexFieldSort(Sort):
"""Sort object to sort on a flexible attribute field
class FieldSort(Sort):
"""An abstract sort criterion that orders by a specific field (of
any kind).
"""
def __init__(self, model_cls, field, is_ascending):
self.model_cls = model_cls
def __init__(self, field, ascending=True):
self.field = field
self.is_ascending = is_ascending
self.ascending = ascending
def select_clause(self):
"""Return a SELECT fragment.
def sort(self, objs):
# TODO: Conversion and null-detection here. In Python 3,
# comparisons with None fail. We should also support flexible
# attributes with different types without falling over.
return sorted(objs, key=attrgetter(self.field),
reverse=not self.ascending)
class FixedFieldSort(FieldSort):
"""Sort object to sort on a fixed field.
"""
return "sort_flexattr{0!s}.value as flex_{0!s} ".format(self.field)
def union_clause(self):
"""Return a JOIN fragment.
"""
union = ("LEFT JOIN {flextable} as sort_flexattr{index!s} "
"ON {table}.id = sort_flexattr{index!s}.entity_id "
"AND sort_flexattr{index!s}.key='{flexattr}' ").format(
flextable=self.model_cls._flex_table,
table=self.model_cls._table,
index=self.field, flexattr=self.field)
return union
def order_clause(self):
"""Return an ORDER BY fragment.
"""
order = "ASC" if self.is_ascending else "DESC"
return "flex_{0} {1} ".format(self.field, order)
def sort(self, items):
return sorted(items, key=attrgetter(self.field),
reverse=(not self.is_ascending))
class FixedFieldSort(Sort):
"""Sort object to sort on a fixed field
"""
def __init__(self, field, is_ascending=True):
self.field = field
self.is_ascending = is_ascending
def order_clause(self):
order = "ASC" if self.is_ascending else "DESC"
order = "ASC" if self.ascending else "DESC"
return "{0} {1}".format(self.field, order)
def sort(self, items):
return sorted(items, key=attrgetter(self.field),
reverse=(not self.is_ascending))
class SmartArtistSort(Sort):
""" Sort Album or Item on artist sort fields, defaulting back on
@ -695,19 +668,13 @@ class SmartArtistSort(Sort):
return order_str
class ComputedFieldSort(Sort):
def __init__(self, model_cls, field, is_ascending=True):
self.is_ascending = is_ascending
self.field = field
self._getters = model_cls._getters()
class SlowFieldSort(FieldSort):
"""A sort criterion by some model field other than a fixed field:
i.e., a computed or flexible field.
"""
def is_slow(self):
return True
def sort(self, items):
return sorted(items, key=lambda x: self._getters[self.field](x),
reverse=(not self.is_ascending))
special_sorts = {'smartartist': SmartArtistSort}
@ -740,7 +707,7 @@ def build_sql(model_cls, query, sort):
order_clause = sort.order_clause()
sort_order = " ORDER BY {0}".format(order_clause) \
if order_clause else ""
if sort.is_slow():
if not sort.is_slow():
sort = None
sql = ("SELECT {table}.* {sort_select} FROM {table} {sort_union} WHERE "

View file

@ -131,14 +131,11 @@ def construct_sort_part(model_cls, part):
is_ascending = (part[-1] == '+')
if field in model_cls._fields:
sort = query.FixedFieldSort(field, is_ascending)
elif field in model_cls._getters():
# Computed field, all following fields must use the slow path.
sort = query.ComputedFieldSort(model_cls, field, is_ascending)
elif field in query.special_sorts:
sort = query.special_sorts[field](model_cls, is_ascending)
else:
# Neither fixed nor computed : must be a flex attr.
sort = query.FlexFieldSort(model_cls, field, is_ascending)
# Flexible or comptued.
sort = query.SlowFieldSort(field, is_ascending)
return sort

View file

@ -453,7 +453,7 @@ class SortFromStringsTest(unittest.TestCase):
def test_flex_field_sort(self):
s = self.sfs(['flex_field+'])
self.assertIsInstance(s, dbcore.query.MultipleSort)
self.assertIsInstance(s.sorts[0], dbcore.query.FlexFieldSort)
self.assertIsInstance(s.sorts[0], dbcore.query.SlowFieldSort)
def suite():

View file

@ -131,7 +131,7 @@ class SortFixedFieldTest(DummyDataTestCase):
class SortFlexFieldTest(DummyDataTestCase):
def test_sort_asc(self):
q = ''
sort = dbcore.query.FlexFieldSort(beets.library.Item, "flex1", True)
sort = dbcore.query.SlowFieldSort("flex1", True)
results = self.lib.items(q, sort)
self.assertLessEqual(results[0]['flex1'], results[1]['flex1'])
self.assertEqual(results[0]['flex1'], 'flex1-0')
@ -143,7 +143,7 @@ class SortFlexFieldTest(DummyDataTestCase):
def test_sort_desc(self):
q = ''
sort = dbcore.query.FlexFieldSort(beets.library.Item, "flex1", False)
sort = dbcore.query.SlowFieldSort("flex1", False)
results = self.lib.items(q, sort)
self.assertGreaterEqual(results[0]['flex1'], results[1]['flex1'])
self.assertGreaterEqual(results[1]['flex1'], results[2]['flex1'])
@ -157,8 +157,8 @@ class SortFlexFieldTest(DummyDataTestCase):
def test_sort_two_field(self):
q = ''
s1 = dbcore.query.FlexFieldSort(beets.library.Item, "flex2", False)
s2 = dbcore.query.FlexFieldSort(beets.library.Item, "flex1", True)
s1 = dbcore.query.SlowFieldSort("flex2", False)
s2 = dbcore.query.SlowFieldSort("flex1", True)
sort = dbcore.query.MultipleSort()
sort.add_sort(s1)
sort.add_sort(s2)
@ -220,10 +220,10 @@ class SortAlbumFixedFieldTest(DummyDataTestCase):
self.assertEqual(r1.id, r2.id)
class SortAlbumFlexdFieldTest(DummyDataTestCase):
class SortAlbumFlexFieldTest(DummyDataTestCase):
def test_sort_asc(self):
q = ''
sort = dbcore.query.FlexFieldSort(beets.library.Album, "flex1", True)
sort = dbcore.query.SlowFieldSort("flex1", True)
results = self.lib.albums(q, sort)
self.assertLessEqual(results[0]['flex1'], results[1]['flex1'])
self.assertLessEqual(results[1]['flex1'], results[2]['flex1'])
@ -235,7 +235,7 @@ class SortAlbumFlexdFieldTest(DummyDataTestCase):
def test_sort_desc(self):
q = ''
sort = dbcore.query.FlexFieldSort(beets.library.Album, "flex1", False)
sort = dbcore.query.SlowFieldSort("flex1", False)
results = self.lib.albums(q, sort)
self.assertGreaterEqual(results[0]['flex1'], results[1]['flex1'])
self.assertGreaterEqual(results[1]['flex1'], results[2]['flex1'])
@ -247,8 +247,8 @@ class SortAlbumFlexdFieldTest(DummyDataTestCase):
def test_sort_two_field_asc(self):
q = ''
s1 = dbcore.query.FlexFieldSort(beets.library.Album, "flex2", True)
s2 = dbcore.query.FlexFieldSort(beets.library.Album, "flex1", True)
s1 = dbcore.query.SlowFieldSort("flex2", True)
s2 = dbcore.query.SlowFieldSort("flex1", True)
sort = dbcore.query.MultipleSort()
sort.add_sort(s1)
sort.add_sort(s2)
@ -268,8 +268,7 @@ class SortAlbumFlexdFieldTest(DummyDataTestCase):
class SortAlbumComputedFieldTest(DummyDataTestCase):
def test_sort_asc(self):
q = ''
sort = dbcore.query.ComputedFieldSort(beets.library.Album, "path",
True)
sort = dbcore.query.SlowFieldSort("path", True)
results = self.lib.albums(q, sort)
self.assertLessEqual(results[0]['path'], results[1]['path'])
self.assertLessEqual(results[1]['path'], results[2]['path'])
@ -281,8 +280,7 @@ class SortAlbumComputedFieldTest(DummyDataTestCase):
def test_sort_desc(self):
q = ''
sort = dbcore.query.ComputedFieldSort(beets.library.Album, "path",
False)
sort = dbcore.query.SlowFieldSort("path", False)
results = self.lib.albums(q, sort)
self.assertGreaterEqual(results[0]['path'], results[1]['path'])
self.assertGreaterEqual(results[1]['path'], results[2]['path'])
@ -296,7 +294,7 @@ class SortAlbumComputedFieldTest(DummyDataTestCase):
class SortCombinedFieldTest(DummyDataTestCase):
def test_computed_first(self):
q = ''
s1 = dbcore.query.ComputedFieldSort(beets.library.Album, "path", True)
s1 = dbcore.query.SlowFieldSort("path", True)
s2 = dbcore.query.FixedFieldSort("year", True)
sort = dbcore.query.MultipleSort()
sort.add_sort(s1)
@ -312,7 +310,7 @@ class SortCombinedFieldTest(DummyDataTestCase):
def test_computed_second(self):
q = ''
s1 = dbcore.query.FixedFieldSort("year", True)
s2 = dbcore.query.ComputedFieldSort(beets.library.Album, "path", True)
s2 = dbcore.query.SlowFieldSort("path", True)
sort = dbcore.query.MultipleSort()
sort.add_sort(s1)
sort.add_sort(s2)