From bcc28260009f5f13835aa19ff65676ad10e323f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Wed, 19 Jun 2024 22:41:06 +0100 Subject: [PATCH] Make sure we can filter common fields --- beets/dbcore/query.py | 60 +++++++++++++++++++++----------------- beets/dbcore/queryparse.py | 6 ++++ test/test_query.py | 6 ++++ 3 files changed, 45 insertions(+), 27 deletions(-) diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 5309ebaf3..f8cf7fe4c 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -134,18 +134,24 @@ class FieldQuery(Query, Generic[P]): same matching functionality in SQLite. """ + @property + def field(self) -> str: + return ( + f"{self.table}.{self.field_name}" if self.table else self.field_name + ) + @property def field_names(self) -> Set[str]: """Return a set with field names that this query operates on.""" - return {self.field} + return {self.field_name} - def __init__(self, field: str, pattern: P, fast: bool = True): - self.field = field + def __init__(self, field_name: str, pattern: P, fast: bool = True): + self.table, _, self.field_name = field_name.rpartition(".") self.pattern = pattern self.fast = fast - def col_clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: - return None, () + def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: + return self.field, () def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]: if self.fast: @@ -160,23 +166,23 @@ class FieldQuery(Query, Generic[P]): raise NotImplementedError() def match(self, obj: Model) -> bool: - return self.value_match(self.pattern, obj.get(self.field)) + return self.value_match(self.pattern, obj.get(self.field_name)) def __repr__(self) -> str: return ( - f"{self.__class__.__name__}({self.field!r}, {self.pattern!r}, " + f"{self.__class__.__name__}({self.field_name!r}, {self.pattern!r}, " f"fast={self.fast})" ) def __eq__(self, other) -> bool: return ( super().__eq__(other) - and self.field == other.field + and self.field_name == other.field_name and self.pattern == other.pattern ) def __hash__(self) -> int: - return hash((self.field, hash(self.pattern))) + return hash((self.field_name, hash(self.pattern))) class MatchQuery(FieldQuery[AnySQLiteType]): @@ -200,10 +206,10 @@ class NoneQuery(FieldQuery[None]): return self.field + " IS NULL", () def match(self, obj: Model) -> bool: - return obj.get(self.field) is None + return obj.get(self.field_name) is None def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.field!r}, {self.fast})" + return f"{self.__class__.__name__}({self.field_name!r}, {self.fast})" class StringFieldQuery(FieldQuery[P]): @@ -274,7 +280,7 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]): expression. """ - def __init__(self, field: str, pattern: str, fast: bool = True): + def __init__(self, field_name: str, pattern: str, fast: bool = True): pattern = self._normalize(pattern) try: pattern_re = re.compile(pattern) @@ -284,7 +290,7 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]): pattern, "a regular expression", format(exc) ) - super().__init__(field, pattern_re, fast) + super().__init__(field_name, pattern_re, fast) def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: return f" regexp({self.field}, ?)", [self.pattern.pattern] @@ -308,7 +314,7 @@ class BooleanQuery(MatchQuery[int]): def __init__( self, - field: str, + field_name: str, pattern: bool, fast: bool = True, ): @@ -317,7 +323,7 @@ class BooleanQuery(MatchQuery[int]): pattern_int = int(pattern) - super().__init__(field, pattern_int, fast) + super().__init__(field_name, pattern_int, fast) class BytesQuery(FieldQuery[bytes]): @@ -327,7 +333,7 @@ class BytesQuery(FieldQuery[bytes]): `MatchQuery` when matching on BLOB values. """ - def __init__(self, field: str, pattern: Union[bytes, str, memoryview]): + def __init__(self, field_name: str, pattern: Union[bytes, str, memoryview]): # Use a buffer/memoryview representation of the pattern for SQLite # matching. This instructs SQLite to treat the blob as binary # rather than encoded Unicode. @@ -343,7 +349,7 @@ class BytesQuery(FieldQuery[bytes]): else: raise ValueError("pattern must be bytes, str, or memoryview") - super().__init__(field, bytes_pattern) + super().__init__(field_name, bytes_pattern) def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: return self.field + " = ?", [self.buf_pattern] @@ -379,8 +385,8 @@ class NumericQuery(FieldQuery[str]): except ValueError: raise InvalidQueryArgumentValueError(s, "an int or a float") - def __init__(self, field: str, pattern: str, fast: bool = True): - super().__init__(field, pattern, fast) + def __init__(self, field_name: str, pattern: str, fast: bool = True): + super().__init__(field_name, pattern, fast) parts = pattern.split("..", 1) if len(parts) == 1: @@ -395,9 +401,9 @@ class NumericQuery(FieldQuery[str]): self.rangemax = self._convert(parts[1]) def match(self, obj: Model) -> bool: - if self.field not in obj: + if self.field_name not in obj: return False - value = obj[self.field] + value = obj[self.field_name] if isinstance(value, str): value = self._convert(value) @@ -430,7 +436,7 @@ class NumericQuery(FieldQuery[str]): class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]): """Query which matches values in the given set.""" - field: str + field_name: str pattern: Sequence[AnySQLiteType] fast: bool = True @@ -440,7 +446,7 @@ class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]): def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: placeholders = ", ".join(["?"] * len(self.subvals)) - return f"{self.field} IN ({placeholders})", self.subvals + return f"{self.field_name} IN ({placeholders})", self.subvals @classmethod def value_match( @@ -823,15 +829,15 @@ class DateQuery(FieldQuery[str]): using an ellipsis interval syntax similar to that of NumericQuery. """ - def __init__(self, field: str, pattern: str, fast: bool = True): - super().__init__(field, pattern, fast) + def __init__(self, field_name: str, pattern: str, fast: bool = True): + super().__init__(field_name, pattern, fast) start, end = _parse_periods(pattern) self.interval = DateInterval.from_periods(start, end) def match(self, obj: Model) -> bool: - if self.field not in obj: + if self.field_name not in obj: return False - timestamp = float(obj[self.field]) + timestamp = float(obj[self.field_name]) date = datetime.fromtimestamp(timestamp) return self.interval.contains(date) diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index fd29aedff..b7558038f 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -153,6 +153,12 @@ def construct_query_part( # they are querying. else: key = key.lower() + if key in model_cls.shared_db_fields: + # This field exists in both tables, so SQLite will encounter + # an OperationalError if we try to query it in a join. + # Using an explicit table name resolves this. + key = f"{model_cls._table}.{key}" + out_query = query_class(key, pattern, key in model_cls.all_db_fields) # Apply negation. diff --git a/test/test_query.py b/test/test_query.py index 41b3a2de3..69277cfcd 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -1148,6 +1148,7 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin): album_items.append(item) album = self.lib.add_album(album_items) album.artpath = f"{album_name} Artpath" + album.catalognum = "ABC" album.store() albums.append(album) @@ -1163,6 +1164,11 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin): results = self.lib.items(q) self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"]) + def test_filter_by_common_field(self): + q = "catalognum:ABC Album1" + results = self.lib.albums(q) + self.assert_albums_matched(results, ["Album1"]) + def suite(): return unittest.TestLoader().loadTestsFromName(__name__)