mirror of
https://github.com/beetbox/beets.git
synced 2025-12-06 08:39:17 +01:00
Ensure that any field query uses the table name
In order to include the table name for fields in this query, use the `field_query` method. Since `AnyFieldQuery` is just an `OrQuery` under the hood, remove it and construct `OrQuery` explicitly instead.
This commit is contained in:
parent
e0c50c5501
commit
6792a75c7e
6 changed files with 29 additions and 102 deletions
|
|
@ -515,49 +515,6 @@ class CollectionQuery(Query):
|
|||
return reduce(mul, map(hash, self.subqueries), 1)
|
||||
|
||||
|
||||
class AnyFieldQuery(CollectionQuery):
|
||||
"""A query that matches if a given FieldQuery subclass matches in
|
||||
any field. The individual field query class is provided to the
|
||||
constructor.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern, fields, cls: Type[FieldQuery]):
|
||||
self.pattern = pattern
|
||||
self.fields = fields
|
||||
self.query_class = cls
|
||||
|
||||
subqueries = []
|
||||
for field in self.fields:
|
||||
subqueries.append(cls(field, pattern, True))
|
||||
# TYPING ERROR
|
||||
super().__init__(subqueries)
|
||||
|
||||
@property
|
||||
def field_names(self) -> Set[str]:
|
||||
return set(self.fields)
|
||||
|
||||
def clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
return self.clause_with_joiner("or")
|
||||
|
||||
def match(self, obj: Model) -> bool:
|
||||
for subq in self.subqueries:
|
||||
if subq.match(obj):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"{self.__class__.__name__}({self.pattern!r}, {self.fields!r}, "
|
||||
f"{self.query_class.__name__})"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return super().__eq__(other) and self.query_class == other.query_class
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.pattern, tuple(self.fields), self.query_class))
|
||||
|
||||
|
||||
class MutableCollectionQuery(CollectionQuery):
|
||||
"""A collection query whose subqueries may be modified after the
|
||||
query is initialized.
|
||||
|
|
|
|||
|
|
@ -151,19 +151,13 @@ def construct_query_part(
|
|||
query_part, query_classes, prefixes
|
||||
)
|
||||
|
||||
# If there's no key (field name) specified, this is a "match
|
||||
# anything" query.
|
||||
if key is None:
|
||||
# The query type matches a specific field, but none was
|
||||
# specified. So we use a version of the query that matches
|
||||
# any field.
|
||||
out_query = query.AnyFieldQuery(
|
||||
pattern, model_cls._search_fields, query_class
|
||||
)
|
||||
|
||||
# Field queries get constructed according to the name of the field
|
||||
# they are querying.
|
||||
# If there's no key (field name) specified, this is a "match anything"
|
||||
# query.
|
||||
out_query = model_cls.any_field_query(query_class, pattern)
|
||||
else:
|
||||
# Field queries get constructed according to the name of the field
|
||||
# they are querying.
|
||||
out_query = model_cls.field_query(key.lower(), pattern, query_class)
|
||||
|
||||
# Apply negation.
|
||||
|
|
|
|||
|
|
@ -708,7 +708,7 @@ class ImportTask(BaseImportTask):
|
|||
# use a temporary Album object to generate any computed fields.
|
||||
tmp_album = library.Album(lib, **info)
|
||||
keys = config["import"]["duplicate_keys"]["album"].as_str_seq()
|
||||
dup_query = library.Album.all_fields_query(
|
||||
dup_query = library.Album.match_all_query(
|
||||
{key: tmp_album.get(key) for key in keys}
|
||||
)
|
||||
|
||||
|
|
@ -1019,7 +1019,7 @@ class SingletonImportTask(ImportTask):
|
|||
# temporary `Item` object to generate any computed fields.
|
||||
tmp_item = library.Item(lib, **info)
|
||||
keys = config["import"]["duplicate_keys"]["item"].as_str_seq()
|
||||
dup_query = library.Item.all_fields_query(
|
||||
dup_query = library.Item.match_all_query(
|
||||
{key: tmp_item.get(key) for key in keys}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -443,7 +443,18 @@ class LibModel(dbcore.Model):
|
|||
return query_cls(field, pattern, fast)
|
||||
|
||||
@classmethod
|
||||
def all_fields_query(
|
||||
def any_field_query(
|
||||
cls, query_class: Type[dbcore.FieldQuery], pattern: str
|
||||
) -> dbcore.OrQuery:
|
||||
return dbcore.OrQuery(
|
||||
[
|
||||
cls.field_query(f, pattern, query_class)
|
||||
for f in cls._search_fields
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def match_all_query(
|
||||
cls, pattern_by_field: Mapping[str, str]
|
||||
) -> dbcore.AndQuery:
|
||||
"""Get a query that matches many fields with different patterns.
|
||||
|
|
|
|||
|
|
@ -590,7 +590,7 @@ class QueryFromStringsTest(unittest.TestCase):
|
|||
q = self.qfs(["foo", "bar:baz"])
|
||||
self.assertIsInstance(q, dbcore.query.AndQuery)
|
||||
self.assertEqual(len(q.subqueries), 2)
|
||||
self.assertIsInstance(q.subqueries[0], dbcore.query.AnyFieldQuery)
|
||||
self.assertIsInstance(q.subqueries[0], dbcore.query.OrQuery)
|
||||
self.assertIsInstance(q.subqueries[1], dbcore.query.SubstringQuery)
|
||||
|
||||
def test_parse_fixed_type_query(self):
|
||||
|
|
|
|||
|
|
@ -48,40 +48,6 @@ class TestHelper(helper.TestHelper):
|
|||
self.assertNotIn(item.id, result_ids)
|
||||
|
||||
|
||||
class AnyFieldQueryTest(_common.LibTestCase):
|
||||
def test_no_restriction(self):
|
||||
q = dbcore.query.AnyFieldQuery(
|
||||
"title",
|
||||
beets.library.Item._fields.keys(),
|
||||
dbcore.query.SubstringQuery,
|
||||
)
|
||||
self.assertEqual(self.lib.items(q).get().title, "the title")
|
||||
|
||||
def test_restriction_completeness(self):
|
||||
q = dbcore.query.AnyFieldQuery(
|
||||
"title", ["title"], dbcore.query.SubstringQuery
|
||||
)
|
||||
self.assertEqual(self.lib.items(q).get().title, "the title")
|
||||
|
||||
def test_restriction_soundness(self):
|
||||
q = dbcore.query.AnyFieldQuery(
|
||||
"title", ["artist"], dbcore.query.SubstringQuery
|
||||
)
|
||||
self.assertIsNone(self.lib.items(q).get())
|
||||
|
||||
def test_eq(self):
|
||||
q1 = dbcore.query.AnyFieldQuery(
|
||||
"foo", ["bar"], dbcore.query.SubstringQuery
|
||||
)
|
||||
q2 = dbcore.query.AnyFieldQuery(
|
||||
"foo", ["bar"], dbcore.query.SubstringQuery
|
||||
)
|
||||
self.assertEqual(q1, q2)
|
||||
|
||||
q2.query_class = None
|
||||
self.assertNotEqual(q1, q2)
|
||||
|
||||
|
||||
class AssertsMixin:
|
||||
def assert_items_matched(self, results, titles):
|
||||
self.assertEqual({i.title for i in results}, set(titles))
|
||||
|
|
@ -981,14 +947,6 @@ class NotQueryTest(DummyDataTestCase):
|
|||
self.assert_items_matched(not_results, ["foo bar", "beets 4 eva"])
|
||||
self.assertNegationProperties(q)
|
||||
|
||||
def test_type_anyfield(self):
|
||||
q = dbcore.query.AnyFieldQuery(
|
||||
"foo", ["title", "artist", "album"], dbcore.query.SubstringQuery
|
||||
)
|
||||
not_results = self.lib.items(dbcore.query.NotQuery(q))
|
||||
self.assert_items_matched(not_results, ["baz qux"])
|
||||
self.assertNegationProperties(q)
|
||||
|
||||
def test_type_boolean(self):
|
||||
q = dbcore.query.BooleanQuery("comp", True)
|
||||
not_results = self.lib.items(dbcore.query.NotQuery(q))
|
||||
|
|
@ -1137,11 +1095,18 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
|
|||
results = self.lib.items(q)
|
||||
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
|
||||
|
||||
def test_filter_by_common_field(self):
|
||||
q = "catalognum:ABC Album1"
|
||||
def test_filter_albums_by_common_field(self):
|
||||
# title:Album1 ensures that the items table is joined for the query
|
||||
q = "title:Album1 catalognum:ABC"
|
||||
results = self.lib.albums(q)
|
||||
self.assert_albums_matched(results, ["Album1"])
|
||||
|
||||
def test_filter_items_by_common_field(self):
|
||||
# artpath::A ensures that the albums table is joined for the query
|
||||
q = "artpath::A Album1"
|
||||
results = self.lib.items(q)
|
||||
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
|
||||
|
||||
def test_get_items_filter_by_track_flex(self):
|
||||
q = "item_flex1:Item1"
|
||||
results = self.lib.items(q)
|
||||
|
|
|
|||
Loading…
Reference in a new issue