Ensure that any field query uses the table name

In order to include the table name for fields in this query, use the
`field_query` method.

Since `AnyFieldQuery` is just an `OrQuery` under the hood, remove it and
construct `OrQuery` explicitly instead.
This commit is contained in:
Šarūnas Nejus 2024-05-10 09:58:35 +01:00
parent 132acf8077
commit be4fdf492a
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
6 changed files with 29 additions and 103 deletions

View file

@ -507,50 +507,6 @@ class CollectionQuery(Query):
return reduce(mul, map(hash, self.subqueries), 1)
class AnyFieldQuery(CollectionQuery):
"""A query that matches if a given FieldQuery subclass matches in
any field. The individual field query class is provided to the
constructor.
"""
@property
def field_names(self) -> set[str]:
"""Return a set with field names that this query operates on."""
return set(self.fields)
def __init__(self, pattern, fields, cls: FieldQueryType):
self.pattern = pattern
self.fields = fields
self.query_class = cls
subqueries = []
for field in self.fields:
subqueries.append(cls(field, pattern, True))
# TYPING ERROR
super().__init__(subqueries)
def clause(self) -> tuple[str, Sequence[SQLiteType]]:
return self.clause_with_joiner("or")
def match(self, obj: Model) -> bool:
for subq in self.subqueries:
if subq.match(obj):
return True
return False
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.pattern!r}, {self.fields!r}, "
f"{self.query_class.__name__})"
)
def __eq__(self, other) -> bool:
return super().__eq__(other) and self.query_class == other.query_class
def __hash__(self) -> int:
return hash((self.pattern, tuple(self.fields), self.query_class))
class MutableCollectionQuery(CollectionQuery):
"""A collection query whose subqueries may be modified after the
query is initialized.

View file

@ -149,19 +149,13 @@ def construct_query_part(
query_part, query_classes, prefixes
)
# If there's no key (field name) specified, this is a "match
# anything" query.
if key is None:
# The query type matches a specific field, but none was
# specified. So we use a version of the query that matches
# any field.
out_query = query.AnyFieldQuery(
pattern, model_cls._search_fields, query_class
)
# Field queries get constructed according to the name of the field
# they are querying.
# If there's no key (field name) specified, this is a "match anything"
# query.
out_query = model_cls.any_field_query(query_class, pattern)
else:
# Field queries get constructed according to the name of the field
# they are querying.
out_query = model_cls.field_query(key.lower(), pattern, query_class)
# Apply negation.

View file

@ -707,7 +707,7 @@ class ImportTask(BaseImportTask):
# use a temporary Album object to generate any computed fields.
tmp_album = library.Album(lib, **info)
keys = config["import"]["duplicate_keys"]["album"].as_str_seq()
dup_query = library.Album.all_fields_query(
dup_query = library.Album.match_all_query(
{key: tmp_album.get(key) for key in keys}
)
@ -1025,7 +1025,7 @@ class SingletonImportTask(ImportTask):
# temporary `Item` object to generate any computed fields.
tmp_item = library.Item(lib, **info)
keys = config["import"]["duplicate_keys"]["item"].as_str_seq()
dup_query = library.Item.all_fields_query(
dup_query = library.Item.match_all_query(
{key: tmp_item.get(key) for key in keys}
)

View file

@ -396,7 +396,18 @@ class LibModel(dbcore.Model["Library"]):
return query_cls(field, pattern, fast)
@classmethod
def all_fields_query(
def any_field_query(
cls, query_class: FieldQueryType, pattern: str
) -> dbcore.OrQuery:
return dbcore.OrQuery(
[
cls.field_query(f, pattern, query_class)
for f in cls._search_fields
]
)
@classmethod
def match_all_query(
cls, pattern_by_field: Mapping[str, str]
) -> dbcore.AndQuery:
"""Get a query that matches many fields with different patterns.

View file

@ -587,7 +587,7 @@ class QueryFromStringsTest(unittest.TestCase):
q = self.qfs(["foo", "bar:baz"])
assert isinstance(q, dbcore.query.AndQuery)
assert len(q.subqueries) == 2
assert isinstance(q.subqueries[0], dbcore.query.AnyFieldQuery)
assert isinstance(q.subqueries[0], dbcore.query.OrQuery)
assert isinstance(q.subqueries[1], dbcore.query.SubstringQuery)
def test_parse_fixed_type_query(self):

View file

@ -56,40 +56,6 @@ class AssertsMixin:
assert item.id not in result_ids
class AnyFieldQueryTest(ItemInDBTestCase):
def test_no_restriction(self):
q = dbcore.query.AnyFieldQuery(
"title",
beets.library.Item._fields.keys(),
dbcore.query.SubstringQuery,
)
assert self.lib.items(q).get().title == "the title"
def test_restriction_completeness(self):
q = dbcore.query.AnyFieldQuery(
"title", ["title"], dbcore.query.SubstringQuery
)
assert self.lib.items(q).get().title == "the title"
def test_restriction_soundness(self):
q = dbcore.query.AnyFieldQuery(
"title", ["artist"], dbcore.query.SubstringQuery
)
assert self.lib.items(q).get() is None
def test_eq(self):
q1 = dbcore.query.AnyFieldQuery(
"foo", ["bar"], dbcore.query.SubstringQuery
)
q2 = dbcore.query.AnyFieldQuery(
"foo", ["bar"], dbcore.query.SubstringQuery
)
assert q1 == q2
q2.query_class = None
assert q1 != q2
# A test case class providing a library with some dummy data and some
# assertions involving that data.
class DummyDataTestCase(BeetsTestCase, AssertsMixin):
@ -965,14 +931,6 @@ class NotQueryTest(DummyDataTestCase):
self.assert_items_matched(not_results, ["foo bar", "beets 4 eva"])
self.assertNegationProperties(q)
def test_type_anyfield(self):
q = dbcore.query.AnyFieldQuery(
"foo", ["title", "artist", "album"], dbcore.query.SubstringQuery
)
not_results = self.lib.items(dbcore.query.NotQuery(q))
self.assert_items_matched(not_results, ["baz qux"])
self.assertNegationProperties(q)
def test_type_boolean(self):
q = dbcore.query.BooleanQuery("comp", True)
not_results = self.lib.items(dbcore.query.NotQuery(q))
@ -1120,11 +1078,18 @@ class RelatedQueriesTest(BeetsTestCase, AssertsMixin):
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
def test_filter_by_common_field(self):
q = "catalognum:ABC Album1"
def test_filter_albums_by_common_field(self):
# title:Album1 ensures that the items table is joined for the query
q = "title:Album1 catalognum:ABC"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def test_filter_items_by_common_field(self):
# artpath::A ensures that the albums table is joined for the query
q = "artpath::A Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
def test_get_items_filter_by_track_flex(self):
q = "item_flex1:Item1"
results = self.lib.items(q)