Ensure that any field query uses the table name

In order to include the table name for fields in this query, use the
`field_query` method.

Since `AnyFieldQuery` is just an `OrQuery` under the hood, remove it and
construct `OrQuery` explicitly instead.
This commit is contained in:
Šarūnas Nejus 2024-05-10 09:58:35 +01:00
parent e0c50c5501
commit 6792a75c7e
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
6 changed files with 29 additions and 102 deletions

View file

@ -515,49 +515,6 @@ class CollectionQuery(Query):
return reduce(mul, map(hash, self.subqueries), 1)
class AnyFieldQuery(CollectionQuery):
"""A query that matches if a given FieldQuery subclass matches in
any field. The individual field query class is provided to the
constructor.
"""
def __init__(self, pattern, fields, cls: Type[FieldQuery]):
self.pattern = pattern
self.fields = fields
self.query_class = cls
subqueries = []
for field in self.fields:
subqueries.append(cls(field, pattern, True))
# TYPING ERROR
super().__init__(subqueries)
@property
def field_names(self) -> Set[str]:
return set(self.fields)
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.clause_with_joiner("or")
def match(self, obj: Model) -> bool:
for subq in self.subqueries:
if subq.match(obj):
return True
return False
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.pattern!r}, {self.fields!r}, "
f"{self.query_class.__name__})"
)
def __eq__(self, other) -> bool:
return super().__eq__(other) and self.query_class == other.query_class
def __hash__(self) -> int:
return hash((self.pattern, tuple(self.fields), self.query_class))
class MutableCollectionQuery(CollectionQuery):
"""A collection query whose subqueries may be modified after the
query is initialized.

View file

@ -151,19 +151,13 @@ def construct_query_part(
query_part, query_classes, prefixes
)
# If there's no key (field name) specified, this is a "match
# anything" query.
if key is None:
# The query type matches a specific field, but none was
# specified. So we use a version of the query that matches
# any field.
out_query = query.AnyFieldQuery(
pattern, model_cls._search_fields, query_class
)
# If there's no key (field name) specified, this is a "match anything"
# query.
out_query = model_cls.any_field_query(query_class, pattern)
else:
# Field queries get constructed according to the name of the field
# they are querying.
else:
out_query = model_cls.field_query(key.lower(), pattern, query_class)
# Apply negation.

View file

@ -708,7 +708,7 @@ class ImportTask(BaseImportTask):
# use a temporary Album object to generate any computed fields.
tmp_album = library.Album(lib, **info)
keys = config["import"]["duplicate_keys"]["album"].as_str_seq()
dup_query = library.Album.all_fields_query(
dup_query = library.Album.match_all_query(
{key: tmp_album.get(key) for key in keys}
)
@ -1019,7 +1019,7 @@ class SingletonImportTask(ImportTask):
# temporary `Item` object to generate any computed fields.
tmp_item = library.Item(lib, **info)
keys = config["import"]["duplicate_keys"]["item"].as_str_seq()
dup_query = library.Item.all_fields_query(
dup_query = library.Item.match_all_query(
{key: tmp_item.get(key) for key in keys}
)

View file

@ -443,7 +443,18 @@ class LibModel(dbcore.Model):
return query_cls(field, pattern, fast)
@classmethod
def all_fields_query(
def any_field_query(
cls, query_class: Type[dbcore.FieldQuery], pattern: str
) -> dbcore.OrQuery:
return dbcore.OrQuery(
[
cls.field_query(f, pattern, query_class)
for f in cls._search_fields
]
)
@classmethod
def match_all_query(
cls, pattern_by_field: Mapping[str, str]
) -> dbcore.AndQuery:
"""Get a query that matches many fields with different patterns.

View file

@ -590,7 +590,7 @@ class QueryFromStringsTest(unittest.TestCase):
q = self.qfs(["foo", "bar:baz"])
self.assertIsInstance(q, dbcore.query.AndQuery)
self.assertEqual(len(q.subqueries), 2)
self.assertIsInstance(q.subqueries[0], dbcore.query.AnyFieldQuery)
self.assertIsInstance(q.subqueries[0], dbcore.query.OrQuery)
self.assertIsInstance(q.subqueries[1], dbcore.query.SubstringQuery)
def test_parse_fixed_type_query(self):

View file

@ -48,40 +48,6 @@ class TestHelper(helper.TestHelper):
self.assertNotIn(item.id, result_ids)
class AnyFieldQueryTest(_common.LibTestCase):
def test_no_restriction(self):
q = dbcore.query.AnyFieldQuery(
"title",
beets.library.Item._fields.keys(),
dbcore.query.SubstringQuery,
)
self.assertEqual(self.lib.items(q).get().title, "the title")
def test_restriction_completeness(self):
q = dbcore.query.AnyFieldQuery(
"title", ["title"], dbcore.query.SubstringQuery
)
self.assertEqual(self.lib.items(q).get().title, "the title")
def test_restriction_soundness(self):
q = dbcore.query.AnyFieldQuery(
"title", ["artist"], dbcore.query.SubstringQuery
)
self.assertIsNone(self.lib.items(q).get())
def test_eq(self):
q1 = dbcore.query.AnyFieldQuery(
"foo", ["bar"], dbcore.query.SubstringQuery
)
q2 = dbcore.query.AnyFieldQuery(
"foo", ["bar"], dbcore.query.SubstringQuery
)
self.assertEqual(q1, q2)
q2.query_class = None
self.assertNotEqual(q1, q2)
class AssertsMixin:
def assert_items_matched(self, results, titles):
self.assertEqual({i.title for i in results}, set(titles))
@ -981,14 +947,6 @@ class NotQueryTest(DummyDataTestCase):
self.assert_items_matched(not_results, ["foo bar", "beets 4 eva"])
self.assertNegationProperties(q)
def test_type_anyfield(self):
q = dbcore.query.AnyFieldQuery(
"foo", ["title", "artist", "album"], dbcore.query.SubstringQuery
)
not_results = self.lib.items(dbcore.query.NotQuery(q))
self.assert_items_matched(not_results, ["baz qux"])
self.assertNegationProperties(q)
def test_type_boolean(self):
q = dbcore.query.BooleanQuery("comp", True)
not_results = self.lib.items(dbcore.query.NotQuery(q))
@ -1137,11 +1095,18 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
def test_filter_by_common_field(self):
q = "catalognum:ABC Album1"
def test_filter_albums_by_common_field(self):
# title:Album1 ensures that the items table is joined for the query
q = "title:Album1 catalognum:ABC"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def test_filter_items_by_common_field(self):
# artpath::A ensures that the albums table is joined for the query
q = "artpath::A Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
def test_get_items_filter_by_track_flex(self):
q = "item_flex1:Item1"
results = self.lib.items(q)