From 9d5fdbb37f406953bfb3ec95ec85a9c9bc7b116d Mon Sep 17 00:00:00 2001 From: Adrian Sampson Date: Sat, 11 Oct 2014 20:30:33 -0700 Subject: [PATCH] Fix concurrent iterators in incremental results Some weird behavior was possible when having two iterators on the same Results object. May seem far-fetched, but it is possible. As an added bonus, this saves a little memory by disposing of rows when they have been materialized into model objects. --- beets/dbcore/db.py | 58 +++++++++++++++++++++++++++------------------ test/test_dbcore.py | 27 +++++++++++++++++++++ 2 files changed, 62 insertions(+), 23 deletions(-) diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 723e07ce9..0548c6c2f 100644 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -478,9 +478,15 @@ class Results(object): self.query = query self.sort = sort - self._objects = [] # Model objects materialized *so far*. - self._row_iter = iter(self.rows) # Indicate next row to materialize. - self._materialized = False # All objects have been materialized. + # We keep a queue of rows we haven't yet consumed for + # materialization. We preserve the original total number of + # rows. + self._rows = rows + self._row_count = len(rows) + + # The materialized objects corresponding to rows that have been + # consumed. + self._objects = [] def _get_objects(self): """Construct and generate Model objects for they query. The @@ -492,23 +498,26 @@ class Results(object): a `Results` object a second time should be much faster than the first. """ - # Get the previously-materialized objects. - for object in self._objects: - yield object + index = 0 # Position in the materialized objects. + while index < len(self._objects) or self._rows: + # Are there previously-materialized objects to produce? + if index < len(self._objects): + yield self._objects[index] + index += 1 - # Next, for the rows that have not yet been processed, materialize - # objects and add them to the list. - for row in self._row_iter: - obj = self._make_model(row) - # If there is a slow-query predicate, ensurer that the - # object passes it. - if not self.query or self.query.match(obj): - self._objects.append(obj) - yield obj - - # Now that all the rows have been materialized, set a flag so we - # can take a shortcut in certain other situations. - self._materialized = True + # Otherwise, we consume another row, materialize its object + # and produce it. + else: + while self._rows: + row = self._rows.pop(0) + obj = self._make_model(row) + # If there is a slow-query predicate, ensurer that the + # object passes it. + if not self.query or self.query.match(obj): + self._objects.append(obj) + index += 1 + yield obj + break def __iter__(self): """Construct and generate Model objects for all matching @@ -545,10 +554,11 @@ class Results(object): def __len__(self): """Get the number of matching objects. """ - if self._materialized: + if not self._rows: + # Fully materialized. Just count the objects. return len(self._objects) - if self.query: + elif self.query: # A slow query. Fall back to testing every object. count = 0 for obj in self: @@ -557,7 +567,7 @@ class Results(object): else: # A fast query. Just count the rows. - return len(self.rows) + return self._row_count def __nonzero__(self): """Does this result contain any objects? @@ -568,7 +578,9 @@ class Results(object): """Get the nth item in this result set. This is inefficient: all items up to n are materialized and thrown away. """ - if self._materialized and not self.sort: + if not self._rows and not self.sort: + # Fully materialized and already in order. Just look up the + # object. return self._objects[n] it = iter(self) diff --git a/test/test_dbcore.py b/test/test_dbcore.py index 395c3cf9c..d2bdb907b 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -467,6 +467,33 @@ class SortFromStringsTest(unittest.TestCase): self.assertIsInstance(s.sorts[0], TestSort) +class ResultsIteratorTest(unittest.TestCase): + def setUp(self): + self.db = TestDatabase1(':memory:') + TestModel1().add(self.db) + TestModel1().add(self.db) + + def tearDown(self): + self.db._connection().close() + + def test_iterate_once(self): + objs = self.db._fetch(TestModel1) + self.assertEqual(len(list(objs)), 2) + + def test_iterate_twice(self): + objs = self.db._fetch(TestModel1) + list(objs) + self.assertEqual(len(list(objs)), 2) + + def test_concurrent_iterators(self): + results = self.db._fetch(TestModel1) + it1 = iter(results) + it2 = iter(results) + it1.next() + list(it2) + self.assertEqual(len(list(it1)), 1) + + def suite(): return unittest.TestLoader().loadTestsFromName(__name__)