mirror of
https://github.com/beetbox/beets.git
synced 2025-12-26 10:34:09 +01:00
Fix concurrent iterators in incremental results
Some weird behavior was possible when having two iterators on the same Results object. May seem far-fetched, but it is possible. As an added bonus, this saves a little memory by disposing of rows when they have been materialized into model objects.
This commit is contained in:
parent
b777bde0af
commit
9d5fdbb37f
2 changed files with 62 additions and 23 deletions
|
|
@ -478,9 +478,15 @@ class Results(object):
|
|||
self.query = query
|
||||
self.sort = sort
|
||||
|
||||
self._objects = [] # Model objects materialized *so far*.
|
||||
self._row_iter = iter(self.rows) # Indicate next row to materialize.
|
||||
self._materialized = False # All objects have been materialized.
|
||||
# We keep a queue of rows we haven't yet consumed for
|
||||
# materialization. We preserve the original total number of
|
||||
# rows.
|
||||
self._rows = rows
|
||||
self._row_count = len(rows)
|
||||
|
||||
# The materialized objects corresponding to rows that have been
|
||||
# consumed.
|
||||
self._objects = []
|
||||
|
||||
def _get_objects(self):
|
||||
"""Construct and generate Model objects for they query. The
|
||||
|
|
@ -492,23 +498,26 @@ class Results(object):
|
|||
a `Results` object a second time should be much faster than the
|
||||
first.
|
||||
"""
|
||||
# Get the previously-materialized objects.
|
||||
for object in self._objects:
|
||||
yield object
|
||||
index = 0 # Position in the materialized objects.
|
||||
while index < len(self._objects) or self._rows:
|
||||
# Are there previously-materialized objects to produce?
|
||||
if index < len(self._objects):
|
||||
yield self._objects[index]
|
||||
index += 1
|
||||
|
||||
# Next, for the rows that have not yet been processed, materialize
|
||||
# objects and add them to the list.
|
||||
for row in self._row_iter:
|
||||
obj = self._make_model(row)
|
||||
# If there is a slow-query predicate, ensurer that the
|
||||
# object passes it.
|
||||
if not self.query or self.query.match(obj):
|
||||
self._objects.append(obj)
|
||||
yield obj
|
||||
|
||||
# Now that all the rows have been materialized, set a flag so we
|
||||
# can take a shortcut in certain other situations.
|
||||
self._materialized = True
|
||||
# Otherwise, we consume another row, materialize its object
|
||||
# and produce it.
|
||||
else:
|
||||
while self._rows:
|
||||
row = self._rows.pop(0)
|
||||
obj = self._make_model(row)
|
||||
# If there is a slow-query predicate, ensurer that the
|
||||
# object passes it.
|
||||
if not self.query or self.query.match(obj):
|
||||
self._objects.append(obj)
|
||||
index += 1
|
||||
yield obj
|
||||
break
|
||||
|
||||
def __iter__(self):
|
||||
"""Construct and generate Model objects for all matching
|
||||
|
|
@ -545,10 +554,11 @@ class Results(object):
|
|||
def __len__(self):
|
||||
"""Get the number of matching objects.
|
||||
"""
|
||||
if self._materialized:
|
||||
if not self._rows:
|
||||
# Fully materialized. Just count the objects.
|
||||
return len(self._objects)
|
||||
|
||||
if self.query:
|
||||
elif self.query:
|
||||
# A slow query. Fall back to testing every object.
|
||||
count = 0
|
||||
for obj in self:
|
||||
|
|
@ -557,7 +567,7 @@ class Results(object):
|
|||
|
||||
else:
|
||||
# A fast query. Just count the rows.
|
||||
return len(self.rows)
|
||||
return self._row_count
|
||||
|
||||
def __nonzero__(self):
|
||||
"""Does this result contain any objects?
|
||||
|
|
@ -568,7 +578,9 @@ class Results(object):
|
|||
"""Get the nth item in this result set. This is inefficient: all
|
||||
items up to n are materialized and thrown away.
|
||||
"""
|
||||
if self._materialized and not self.sort:
|
||||
if not self._rows and not self.sort:
|
||||
# Fully materialized and already in order. Just look up the
|
||||
# object.
|
||||
return self._objects[n]
|
||||
|
||||
it = iter(self)
|
||||
|
|
|
|||
|
|
@ -467,6 +467,33 @@ class SortFromStringsTest(unittest.TestCase):
|
|||
self.assertIsInstance(s.sorts[0], TestSort)
|
||||
|
||||
|
||||
class ResultsIteratorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.db = TestDatabase1(':memory:')
|
||||
TestModel1().add(self.db)
|
||||
TestModel1().add(self.db)
|
||||
|
||||
def tearDown(self):
|
||||
self.db._connection().close()
|
||||
|
||||
def test_iterate_once(self):
|
||||
objs = self.db._fetch(TestModel1)
|
||||
self.assertEqual(len(list(objs)), 2)
|
||||
|
||||
def test_iterate_twice(self):
|
||||
objs = self.db._fetch(TestModel1)
|
||||
list(objs)
|
||||
self.assertEqual(len(list(objs)), 2)
|
||||
|
||||
def test_concurrent_iterators(self):
|
||||
results = self.db._fetch(TestModel1)
|
||||
it1 = iter(results)
|
||||
it2 = iter(results)
|
||||
it1.next()
|
||||
list(it2)
|
||||
self.assertEqual(len(list(it1)), 1)
|
||||
|
||||
|
||||
def suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue