Query flexible attributes already in _fetch()

This resolves the query length and potential security problems,
while keeping the performance benefits.
This commit is contained in:
Heinz Wiesinger 2018-11-27 17:39:09 +01:00
parent 1e1ddd276e
commit 31ec222e0e

View file

@ -32,7 +32,6 @@ from beets.dbcore import types
from .query import MatchQuery, NullSort, TrueQuery
import six
class DBAccessError(Exception):
"""The SQLite database became inaccessible.
@ -524,7 +523,8 @@ class Results(object):
"""An item query result set. Iterating over the collection lazily
constructs LibModel objects that reflect database rows.
"""
def __init__(self, model_class, rows, db, query=None, sort=None):
def __init__(self, model_class, rows, db, flex_rows,
query=None, sort=None):
"""Create a result set that will construct objects of type
`model_class`.
@ -544,6 +544,7 @@ class Results(object):
self.db = db
self.query = query
self.sort = sort
self.flex_rows = flex_rows
# We keep a queue of rows we haven't yet consumed for
# materialization. We preserve the original total number of
@ -569,8 +570,7 @@ class Results(object):
# First fetch all flexible attributes for all items in the result.
# Doing the per-item filtering in python is faster than issuing
# one query per item to sqlite.
item_ids = [row['id'] for row in self._rows]
flex_attrs = self._get_flex_attrs(item_ids)
flex_attrs = self._get_indexed_flex_attrs()
index = 0 # Position in the materialized objects.
while index < len(self._objects) or self._rows:
@ -609,20 +609,11 @@ class Results(object):
# Objects are pre-sorted (i.e., by the database).
return self._get_objects()
def _get_flex_attrs(self, ids):
# Get the flexible attributes for all ids.
with self.db.transaction() as tx:
id_list = ', '.join(str(id) for id in ids)
flex_rows = tx.query(
'SELECT * FROM {0} WHERE entity_id IN ({1})'.format(
self.model_class._flex_table,
id_list
)
)
# Index flexible attributes by the entity id they belong to
def _get_indexed_flex_attrs(self):
""" Index flexible attributes by the entity id they belong to
"""
flex_values = dict()
for row in flex_rows:
for row in self.flex_rows:
if row['entity_id'] not in flex_values:
flex_values[row['entity_id']] = dict()
@ -631,6 +622,8 @@ class Results(object):
return flex_values
def _make_model(self, row, flex_values={}):
""" Create a Model object for the given row
"""
cols = dict(row)
values = dict((k, v) for (k, v) in cols.items()
if not k[:4] == 'flex')
@ -920,11 +913,23 @@ class Database(object):
"ORDER BY {0}".format(order_by) if order_by else '',
)
# Fetch flexible attributes for items matching the main query
flex_sql = ("""
SELECT * FROM {0} WHERE entity_id IN
(SELECT id FROM {1} WHERE {2});
""".format(
model_cls._flex_table,
model_cls._table,
where or '1',
)
)
with self.transaction() as tx:
rows = tx.query(sql, subvals)
flex_rows = tx.query(flex_sql, subvals)
return Results(
model_cls, rows, self,
model_cls, rows, self, flex_rows,
None if where else query, # Slow query component.
sort if sort.is_slow() else None, # Slow sort component.
)