diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 24020e94c..b1314d1f8 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -73,6 +73,9 @@ class Query(object): """ raise NotImplementedError + def __eq__(self, other): + return type(self) == type(other) + class FieldQuery(Query): """An abstract query that searches in a specific field for a @@ -106,6 +109,10 @@ class FieldQuery(Query): def match(self, item): return self.value_match(self.pattern, item.get(self.field)) + def __eq__(self, other): + return super(FieldQuery, self).__eq__(other) and \ + self.field == other.field and self.pattern == other.pattern + class MatchQuery(FieldQuery): """A query that looks for exact matches in an item field.""" @@ -120,8 +127,7 @@ class MatchQuery(FieldQuery): class NoneQuery(FieldQuery): def __init__(self, field, fast=True): - self.field = field - self.fast = fast + super(NoneQuery, self).__init__(field, None, fast) def col_clause(self): return self.field + " IS NULL", () @@ -337,6 +343,10 @@ class CollectionQuery(Query): clause = (' ' + joiner + ' ').join(clause_parts) return clause, subvals + def __eq__(self, other): + return super(CollectionQuery, self).__eq__(other) and \ + self.subqueries == other.subqueries + class AnyFieldQuery(CollectionQuery): """A query that matches if a given FieldQuery subclass matches in @@ -362,6 +372,10 @@ class AnyFieldQuery(CollectionQuery): return True return False + def __eq__(self, other): + return super(AnyFieldQuery, self).__eq__(other) and \ + self.query_class == other.query_class + class MutableCollectionQuery(CollectionQuery): """A collection query whose subqueries may be modified after the diff --git a/test/test_query.py b/test/test_query.py index d512e02b8..6d8d744fe 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -60,6 +60,16 @@ class AnyFieldQueryTest(_common.LibTestCase): dbcore.query.SubstringQuery) self.assertEqual(self.lib.items(q).get(), None) + def test_eq(self): + q1 = dbcore.query.AnyFieldQuery('foo', ['bar'], + dbcore.query.SubstringQuery) + q2 = dbcore.query.AnyFieldQuery('foo', ['bar'], + dbcore.query.SubstringQuery) + self.assertEqual(q1, q2) + + q2.query_class = None + self.assertNotEqual(q1, q2) + class AssertsMixin(object): def assert_items_matched(self, results, titles): @@ -344,6 +354,16 @@ class MatchTest(_common.TestCase): def test_open_range(self): dbcore.query.NumericQuery('bitrate', '100000..') + def test_eq(self): + q1 = dbcore.query.MatchQuery('foo', 'bar') + q2 = dbcore.query.MatchQuery('foo', 'bar') + q3 = dbcore.query.MatchQuery('foo', 'baz') + q4 = dbcore.query.StringFieldQuery('foo', 'bar') + self.assertEqual(q1, q2) + self.assertNotEqual(q1, q3) + self.assertNotEqual(q1, q4) + self.assertNotEqual(q3, q4) + class PathQueryTest(_common.LibTestCase, TestHelper, AssertsMixin): def setUp(self):