Implement __eq__ for all Query subclasses

Tests are a bit light.
This commit is contained in:
Bruno Cauet 2015-03-16 16:56:25 +01:00
parent a1b048f5de
commit 4151e81969
2 changed files with 36 additions and 2 deletions

View file

@ -73,6 +73,9 @@ class Query(object):
"""
raise NotImplementedError
def __eq__(self, other):
return type(self) == type(other)
class FieldQuery(Query):
"""An abstract query that searches in a specific field for a
@ -106,6 +109,10 @@ class FieldQuery(Query):
def match(self, item):
return self.value_match(self.pattern, item.get(self.field))
def __eq__(self, other):
return super(FieldQuery, self).__eq__(other) and \
self.field == other.field and self.pattern == other.pattern
class MatchQuery(FieldQuery):
"""A query that looks for exact matches in an item field."""
@ -120,8 +127,7 @@ class MatchQuery(FieldQuery):
class NoneQuery(FieldQuery):
def __init__(self, field, fast=True):
self.field = field
self.fast = fast
super(NoneQuery, self).__init__(field, None, fast)
def col_clause(self):
return self.field + " IS NULL", ()
@ -337,6 +343,10 @@ class CollectionQuery(Query):
clause = (' ' + joiner + ' ').join(clause_parts)
return clause, subvals
def __eq__(self, other):
return super(CollectionQuery, self).__eq__(other) and \
self.subqueries == other.subqueries
class AnyFieldQuery(CollectionQuery):
"""A query that matches if a given FieldQuery subclass matches in
@ -362,6 +372,10 @@ class AnyFieldQuery(CollectionQuery):
return True
return False
def __eq__(self, other):
return super(AnyFieldQuery, self).__eq__(other) and \
self.query_class == other.query_class
class MutableCollectionQuery(CollectionQuery):
"""A collection query whose subqueries may be modified after the

View file

@ -60,6 +60,16 @@ class AnyFieldQueryTest(_common.LibTestCase):
dbcore.query.SubstringQuery)
self.assertEqual(self.lib.items(q).get(), None)
def test_eq(self):
q1 = dbcore.query.AnyFieldQuery('foo', ['bar'],
dbcore.query.SubstringQuery)
q2 = dbcore.query.AnyFieldQuery('foo', ['bar'],
dbcore.query.SubstringQuery)
self.assertEqual(q1, q2)
q2.query_class = None
self.assertNotEqual(q1, q2)
class AssertsMixin(object):
def assert_items_matched(self, results, titles):
@ -344,6 +354,16 @@ class MatchTest(_common.TestCase):
def test_open_range(self):
dbcore.query.NumericQuery('bitrate', '100000..')
def test_eq(self):
q1 = dbcore.query.MatchQuery('foo', 'bar')
q2 = dbcore.query.MatchQuery('foo', 'bar')
q3 = dbcore.query.MatchQuery('foo', 'baz')
q4 = dbcore.query.StringFieldQuery('foo', 'bar')
self.assertEqual(q1, q2)
self.assertNotEqual(q1, q3)
self.assertNotEqual(q1, q4)
self.assertNotEqual(q3, q4)
class PathQueryTest(_common.LibTestCase, TestHelper, AssertsMixin):
def setUp(self):