Handle ambiguous column names in queries involving 'any' field and a relation field (#5541)

## Description

Fixes a `sqlite3.OperationalError` that occurs when querying *any* field
and a field that only exists in the relation table. For exaple, trying
to list albums that contain **keyword** and contain a track with **foo**
in its title:
```
beet list -a keyword title:foo
```

## Root Cause
SQLite fails when JOINs contain ambiguous column references. This
happened because:
- *any* Album field search looks at `album`, `albumartist` and `genre`
fields.
- The second part of the query `title:foo` queries a field in the
`items` table, which got
  JOINed with `albums`
- Some fields (like `album`) exist in both `items` and `albums` tables,
thus SQLite couldn't resolve which table's column to use

## Changes
- Centralize query construction in `LibModel` with consistent table
qualification
- Add methods:
  - `field_query()` - Creates table-qualified field queries
  - `any_field_query()` - Creates multi-field OR queries
- `any_writable_media_field_query()` - Similar to the above but for BPD
/ media files
  - `match_all_query()` - Creates multi-field AND queries
- Remove `AnyFieldQuery` in favor of composed `OrQuery`
- Add tests for shared field querying
This commit is contained in:
Šarūnas Nejus 2025-01-19 01:18:52 +00:00 committed by GitHub
commit bd3043935c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 111 additions and 183 deletions

View file

@ -35,8 +35,6 @@ import beets
from ..util import cached_classproperty, functemplate
from . import types
from .query import (
AndQuery,
FieldQuery,
FieldQueryType,
FieldSort,
MatchQuery,
@ -718,33 +716,6 @@ class Model(ABC, Generic[D]):
"""Set the object's key to a value represented by a string."""
self[key] = self._parse(key, string)
# Convenient queries.
@classmethod
def field_query(
cls,
field,
pattern,
query_cls: FieldQueryType = MatchQuery,
) -> FieldQuery:
"""Get a `FieldQuery` for this model."""
return query_cls(field, pattern, field in cls._fields)
@classmethod
def all_fields_query(
cls: type[Model],
pats: Mapping[str, str],
query_cls: FieldQueryType = MatchQuery,
):
"""Get a query that matches many fields with different patterns.
`pats` should be a mapping from field names to patterns. The
resulting query is a conjunction ("and") of per-field queries
for all of these field/pattern pairs.
"""
subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()]
return AndQuery(subqueries)
# Database controller and supporting interfaces.

View file

@ -97,6 +97,9 @@ class Query(ABC):
"""
...
def __and__(self, other: Query) -> AndQuery:
return AndQuery([self, other])
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
@ -505,50 +508,6 @@ class CollectionQuery(Query):
return reduce(mul, map(hash, self.subqueries), 1)
class AnyFieldQuery(CollectionQuery):
"""A query that matches if a given FieldQuery subclass matches in
any field. The individual field query class is provided to the
constructor.
"""
@property
def field_names(self) -> set[str]:
"""Return a set with field names that this query operates on."""
return set(self.fields)
def __init__(self, pattern, fields, cls: FieldQueryType):
self.pattern = pattern
self.fields = fields
self.query_class = cls
subqueries = []
for field in self.fields:
subqueries.append(cls(field, pattern, True))
# TYPING ERROR
super().__init__(subqueries)
def clause(self) -> tuple[str | None, Sequence[SQLiteType]]:
return self.clause_with_joiner("or")
def match(self, obj: Model) -> bool:
for subq in self.subqueries:
if subq.match(obj):
return True
return False
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.pattern!r}, {self.fields!r}, "
f"{self.query_class.__name__})"
)
def __eq__(self, other) -> bool:
return super().__eq__(other) and self.query_class == other.query_class
def __hash__(self) -> int:
return hash((self.pattern, tuple(self.fields), self.query_class))
class MutableCollectionQuery(CollectionQuery):
"""A collection query whose subqueries may be modified after the
query is initialized.

View file

@ -20,15 +20,17 @@ import itertools
import re
from typing import TYPE_CHECKING
from . import Model, query
from . import query
if TYPE_CHECKING:
from collections.abc import Collection, Sequence
from ..library import LibModel
from .query import FieldQueryType, Sort
Prefixes = dict[str, FieldQueryType]
PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
r"(-|\^)?" # Negation prefixes.
@ -112,7 +114,7 @@ def parse_query_part(
def construct_query_part(
model_cls: type[Model],
model_cls: type[LibModel],
prefixes: Prefixes,
query_part: str,
) -> query.Query:
@ -147,28 +149,14 @@ def construct_query_part(
query_part, query_classes, prefixes
)
# If there's no key (field name) specified, this is a "match
# anything" query.
if key is None:
# The query type matches a specific field, but none was
# specified. So we use a version of the query that matches
# any field.
out_query = query.AnyFieldQuery(
pattern, model_cls._search_fields, query_class
)
# Field queries get constructed according to the name of the field
# they are querying.
# If there's no key (field name) specified, this is a "match anything"
# query.
out_query = model_cls.any_field_query(pattern, query_class)
else:
field = table = key.lower()
if field in model_cls.shared_db_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError if we try to query it in a join.
# Using an explicit table name resolves this.
table = f"{model_cls._table}.{field}"
field_in_db = field in model_cls.all_db_fields
out_query = query_class(table, pattern, field_in_db)
# Field queries get constructed according to the name of the field
# they are querying.
out_query = model_cls.field_query(key.lower(), pattern, query_class)
# Apply negation.
if negate:
@ -180,7 +168,7 @@ def construct_query_part(
# TYPING ERROR
def query_from_strings(
query_cls: type[query.CollectionQuery],
model_cls: type[Model],
model_cls: type[LibModel],
prefixes: Prefixes,
query_parts: Collection[str],
) -> query.Query:
@ -197,7 +185,7 @@ def query_from_strings(
def construct_sort_part(
model_cls: type[Model],
model_cls: type[LibModel],
part: str,
case_insensitive: bool = True,
) -> Sort:
@ -228,7 +216,7 @@ def construct_sort_part(
def sort_from_strings(
model_cls: type[Model],
model_cls: type[LibModel],
sort_parts: Sequence[str],
case_insensitive: bool = True,
) -> Sort:
@ -247,7 +235,7 @@ def sort_from_strings(
def parse_sorted_query(
model_cls: type[Model],
model_cls: type[LibModel],
parts: list[str],
prefixes: Prefixes = {},
case_insensitive: bool = True,

View file

@ -707,9 +707,7 @@ class ImportTask(BaseImportTask):
# use a temporary Album object to generate any computed fields.
tmp_album = library.Album(lib, **info)
keys = config["import"]["duplicate_keys"]["album"].as_str_seq()
dup_query = library.Album.all_fields_query(
{key: tmp_album.get(key) for key in keys}
)
dup_query = tmp_album.duplicates_query(keys)
# Don't count albums with the same files as duplicates.
task_paths = {i.path for i in self.items if i}
@ -1025,9 +1023,7 @@ class SingletonImportTask(ImportTask):
# temporary `Item` object to generate any computed fields.
tmp_item = library.Item(lib, **info)
keys = config["import"]["duplicate_keys"]["item"].as_str_seq()
dup_query = library.Album.all_fields_query(
{key: tmp_item.get(key) for key in keys}
)
dup_query = tmp_item.duplicates_query(keys)
found_items = []
for other_item in lib.items(dup_query):

View file

@ -25,6 +25,7 @@ import time
import unicodedata
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING
import platformdirs
from mediafile import MediaFile, UnreadableFileError
@ -42,6 +43,9 @@ from beets.util import (
)
from beets.util.functemplate import Template, template
if TYPE_CHECKING:
from .dbcore.query import FieldQuery, FieldQueryType
# To use the SQLite "blob" type, it doesn't suffice to provide a byte
# string; SQLite treats that as encoded text. Wrapping it in a
# `memoryview` tells it that we actually mean non-text data.
@ -346,6 +350,10 @@ class LibModel(dbcore.Model["Library"]):
# Config key that specifies how an instance should be formatted.
_format_config_key: str
@cached_classproperty
def writable_media_fields(cls) -> set[str]:
return set(MediaFile.fields()) & cls._fields.keys()
def _template_funcs(self):
funcs = DefaultTemplateFunctions(self, self._db).functions()
funcs.update(plugins.template_funcs())
@ -375,6 +383,44 @@ class LibModel(dbcore.Model["Library"]):
def __bytes__(self):
return self.__str__().encode("utf-8")
# Convenient queries.
@classmethod
def field_query(
cls, field: str, pattern: str, query_cls: FieldQueryType
) -> FieldQuery:
"""Get a `FieldQuery` for the given field on this model."""
fast = field in cls.all_db_fields
if field in cls.shared_db_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError if we try to use it in a query.
# Using an explicit table name resolves this.
field = f"{cls._table}.{field}"
return query_cls(field, pattern, fast)
@classmethod
def any_field_query(cls, *args, **kwargs) -> dbcore.OrQuery:
return dbcore.OrQuery(
[cls.field_query(f, *args, **kwargs) for f in cls._search_fields]
)
@classmethod
def any_writable_media_field_query(cls, *args, **kwargs) -> dbcore.OrQuery:
fields = cls.writable_media_fields
return dbcore.OrQuery(
[cls.field_query(f, *args, **kwargs) for f in fields]
)
def duplicates_query(self, fields: list[str]) -> dbcore.AndQuery:
"""Return a query for entities with same values in the given fields."""
return dbcore.AndQuery(
[
self.field_query(f, self.get(f), dbcore.MatchQuery)
for f in fields
]
)
class FormattedItemMapping(dbcore.db.FormattedMapping):
"""Add lookup for album-level fields.
@ -648,6 +694,12 @@ class Item(LibModel):
getters["filesize"] = Item.try_filesize # In bytes.
return getters
def duplicates_query(self, fields: list[str]) -> dbcore.AndQuery:
"""Return a query for entities with same values in the given fields."""
return super().duplicates_query(fields) & dbcore.query.NoneQuery(
"album_id"
)
@classmethod
def from_path(cls, path):
"""Create a new item from the media file at the specified path."""
@ -1866,7 +1918,6 @@ class DefaultTemplateFunctions:
Item.all_keys(),
# Do nothing for non singletons.
lambda i: i.album_id is not None,
initial_subqueries=[dbcore.query.NoneQuery("album_id", True)],
)
def _tmpl_unique_memokey(self, name, keys, disam, item_id):
@ -1885,7 +1936,6 @@ class DefaultTemplateFunctions:
db_item,
item_keys,
skip_item,
initial_subqueries=None,
):
"""Generate a string that is guaranteed to be unique among all items of
the same type as "db_item" who share the same set of keys.
@ -1932,15 +1982,7 @@ class DefaultTemplateFunctions:
bracket_r = ""
# Find matching items to disambiguate with.
subqueries = []
if initial_subqueries is not None:
subqueries.extend(initial_subqueries)
for key in keys:
value = db_item.get(key, "")
# Use slow queries for flexible attributes.
fast = key in item_keys
subqueries.append(dbcore.MatchQuery(key, value, fast))
query = dbcore.AndQuery(subqueries)
query = db_item.duplicates_query(keys)
ambigous_items = (
self.lib.items(query)
if isinstance(db_item, Item)

View file

@ -186,7 +186,9 @@ class AURADocument:
value = converter(value)
# Add exact match query to list
# Use a slow query so it works with all fields
queries.append(MatchQuery(beets_attr, value, fast=False))
queries.append(
self.model_cls.field_query(beets_attr, value, MatchQuery)
)
# NOTE: AURA doesn't officially support multiple queries
return AndQuery(queries)
@ -318,13 +320,12 @@ class AURADocument:
sort = self.translate_sorts(sort_arg)
# For each sort field add a query which ensures all results
# have a non-empty, non-zero value for that field.
for s in sort.sorts:
query.subqueries.append(
NotQuery(
# Match empty fields (^$) or zero fields, (^0$)
RegexpQuery(s.field, "(^$|^0$)", fast=False)
)
query.subqueries.extend(
NotQuery(
self.model_cls.field_query(s.field, "(^$|^0$)", RegexpQuery)
)
for s in sort.sorts
)
else:
sort = None
# Get information from the library

View file

@ -26,8 +26,7 @@ import sys
import time
import traceback
from string import Template
from mediafile import MediaFile
from typing import TYPE_CHECKING
import beets
import beets.ui
@ -36,6 +35,9 @@ from beets.library import Item
from beets.plugins import BeetsPlugin
from beets.util import bluelet
if TYPE_CHECKING:
from beets.dbcore.query import Query
PROTOCOL_VERSION = "0.16.0"
BUFSIZE = 1024
@ -91,8 +93,6 @@ SUBSYSTEMS = [
"partition",
]
ITEM_KEYS_WRITABLE = set(MediaFile.fields()).intersection(Item._fields.keys())
# Gstreamer import error.
class NoGstreamerError(Exception):
@ -1399,29 +1399,29 @@ class Server(BaseServer):
return test_tag, key
raise BPDError(ERROR_UNKNOWN, "no such tagtype")
def _metadata_query(self, query_type, any_query_type, kv):
def _metadata_query(self, query_type, kv, allow_any_query: bool = False):
"""Helper function returns a query object that will find items
according to the library query type provided and the key-value
pairs specified. The any_query_type is used for queries of
type "any"; if None, then an error is thrown.
"""
if kv: # At least one key-value pair.
queries = []
queries: list[Query] = []
# Iterate pairwise over the arguments.
it = iter(kv)
for tag, value in zip(it, it):
if tag.lower() == "any":
if any_query_type:
if allow_any_query:
queries.append(
any_query_type(
value, ITEM_KEYS_WRITABLE, query_type
Item.any_writable_media_field_query(
query_type, value
)
)
else:
raise BPDError(ERROR_UNKNOWN, "no such tagtype")
else:
_, key = self._tagtype_lookup(tag)
queries.append(query_type(key, value))
queries.append(Item.field_query(key, value, query_type))
return dbcore.query.AndQuery(queries)
else: # No key-value pairs.
return dbcore.query.TrueQuery()
@ -1429,14 +1429,14 @@ class Server(BaseServer):
def cmd_search(self, conn, *kv):
"""Perform a substring match for items."""
query = self._metadata_query(
dbcore.query.SubstringQuery, dbcore.query.AnyFieldQuery, kv
dbcore.query.SubstringQuery, kv, allow_any_query=True
)
for item in self.lib.items(query):
yield self._item_info(item)
def cmd_find(self, conn, *kv):
"""Perform an exact match for items."""
query = self._metadata_query(dbcore.query.MatchQuery, None, kv)
query = self._metadata_query(dbcore.query.MatchQuery, kv)
for item in self.lib.items(query):
yield self._item_info(item)
@ -1456,7 +1456,7 @@ class Server(BaseServer):
raise BPDError(ERROR_ARG, 'should be "Album" for 3 arguments')
elif len(kv) % 2 != 0:
raise BPDError(ERROR_ARG, "Incorrect number of filter arguments")
query = self._metadata_query(dbcore.query.MatchQuery, None, kv)
query = self._metadata_query(dbcore.query.MatchQuery, kv)
clause, subvals = query.clause()
statement = (
@ -1484,7 +1484,9 @@ class Server(BaseServer):
_, key = self._tagtype_lookup(tag)
songs = 0
playtime = 0.0
for item in self.lib.items(dbcore.query.MatchQuery(key, value)):
for item in self.lib.items(
Item.field_query(key, value, dbcore.query.MatchQuery)
):
songs += 1
playtime += item.length
yield "songs: " + str(songs)

View file

@ -38,6 +38,9 @@ Bug fixes:
request their own last.fm genre. Also log messages regarding what's been
tagged are now more polished.
:bug:`5582`
* Fix ambiguous column name ``sqlite3.OperationalError`` that occured in album
queries that filtered album track titles, for example ``beet list -a keyword
title:foo``.
For packagers:

View file

@ -23,6 +23,7 @@ from tempfile import mkstemp
import pytest
from beets import dbcore
from beets.library import LibModel
from beets.test import _common
# Fixture: concrete database and model classes. For migration tests, we
@ -44,7 +45,7 @@ class QueryFixture(dbcore.query.FieldQuery):
return True
class ModelFixture1(dbcore.Model):
class ModelFixture1(LibModel):
_table = "test"
_flex_table = "testflex"
_fields = {
@ -587,7 +588,7 @@ class QueryFromStringsTest(unittest.TestCase):
q = self.qfs(["foo", "bar:baz"])
assert isinstance(q, dbcore.query.AndQuery)
assert len(q.subqueries) == 2
assert isinstance(q.subqueries[0], dbcore.query.AnyFieldQuery)
assert isinstance(q.subqueries[0], dbcore.query.OrQuery)
assert isinstance(q.subqueries[1], dbcore.query.SubstringQuery)
def test_parse_fixed_type_query(self):

View file

@ -56,40 +56,6 @@ class AssertsMixin:
assert item.id not in result_ids
class AnyFieldQueryTest(ItemInDBTestCase):
def test_no_restriction(self):
q = dbcore.query.AnyFieldQuery(
"title",
beets.library.Item._fields.keys(),
dbcore.query.SubstringQuery,
)
assert self.lib.items(q).get().title == "the title"
def test_restriction_completeness(self):
q = dbcore.query.AnyFieldQuery(
"title", ["title"], dbcore.query.SubstringQuery
)
assert self.lib.items(q).get().title == "the title"
def test_restriction_soundness(self):
q = dbcore.query.AnyFieldQuery(
"title", ["artist"], dbcore.query.SubstringQuery
)
assert self.lib.items(q).get() is None
def test_eq(self):
q1 = dbcore.query.AnyFieldQuery(
"foo", ["bar"], dbcore.query.SubstringQuery
)
q2 = dbcore.query.AnyFieldQuery(
"foo", ["bar"], dbcore.query.SubstringQuery
)
assert q1 == q2
q2.query_class = None
assert q1 != q2
# A test case class providing a library with some dummy data and some
# assertions involving that data.
class DummyDataTestCase(BeetsTestCase, AssertsMixin):
@ -954,14 +920,6 @@ class NotQueryTest(DummyDataTestCase):
self.assert_items_matched(not_results, ["foo bar", "beets 4 eva"])
self.assertNegationProperties(q)
def test_type_anyfield(self):
q = dbcore.query.AnyFieldQuery(
"foo", ["title", "artist", "album"], dbcore.query.SubstringQuery
)
not_results = self.lib.items(dbcore.query.NotQuery(q))
self.assert_items_matched(not_results, ["baz qux"])
self.assertNegationProperties(q)
def test_type_boolean(self):
q = dbcore.query.BooleanQuery("comp", True)
not_results = self.lib.items(dbcore.query.NotQuery(q))
@ -1135,7 +1093,14 @@ class RelatedQueriesTest(BeetsTestCase, AssertsMixin):
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
def test_filter_by_common_field(self):
q = "catalognum:ABC Album1"
def test_filter_albums_by_common_field(self):
# title:Album1 ensures that the items table is joined for the query
q = "title:Album1 Album1"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def test_filter_items_by_common_field(self):
# artpath::A ensures that the albums table is joined for the query
q = "artpath::A Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])