Make sure we can filter common fields

This commit is contained in:
Šarūnas Nejus 2024-06-19 22:41:06 +01:00
parent 981a61bd56
commit bcc2826000
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
3 changed files with 45 additions and 27 deletions

View file

@ -134,18 +134,24 @@ class FieldQuery(Query, Generic[P]):
same matching functionality in SQLite.
"""
@property
def field(self) -> str:
return (
f"{self.table}.{self.field_name}" if self.table else self.field_name
)
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return {self.field}
return {self.field_name}
def __init__(self, field: str, pattern: P, fast: bool = True):
self.field = field
def __init__(self, field_name: str, pattern: P, fast: bool = True):
self.table, _, self.field_name = field_name.rpartition(".")
self.pattern = pattern
self.fast = fast
def col_clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
return None, ()
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.field, ()
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
if self.fast:
@ -160,23 +166,23 @@ class FieldQuery(Query, Generic[P]):
raise NotImplementedError()
def match(self, obj: Model) -> bool:
return self.value_match(self.pattern, obj.get(self.field))
return self.value_match(self.pattern, obj.get(self.field_name))
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.field!r}, {self.pattern!r}, "
f"{self.__class__.__name__}({self.field_name!r}, {self.pattern!r}, "
f"fast={self.fast})"
)
def __eq__(self, other) -> bool:
return (
super().__eq__(other)
and self.field == other.field
and self.field_name == other.field_name
and self.pattern == other.pattern
)
def __hash__(self) -> int:
return hash((self.field, hash(self.pattern)))
return hash((self.field_name, hash(self.pattern)))
class MatchQuery(FieldQuery[AnySQLiteType]):
@ -200,10 +206,10 @@ class NoneQuery(FieldQuery[None]):
return self.field + " IS NULL", ()
def match(self, obj: Model) -> bool:
return obj.get(self.field) is None
return obj.get(self.field_name) is None
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.field!r}, {self.fast})"
return f"{self.__class__.__name__}({self.field_name!r}, {self.fast})"
class StringFieldQuery(FieldQuery[P]):
@ -274,7 +280,7 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
expression.
"""
def __init__(self, field: str, pattern: str, fast: bool = True):
def __init__(self, field_name: str, pattern: str, fast: bool = True):
pattern = self._normalize(pattern)
try:
pattern_re = re.compile(pattern)
@ -284,7 +290,7 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
pattern, "a regular expression", format(exc)
)
super().__init__(field, pattern_re, fast)
super().__init__(field_name, pattern_re, fast)
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return f" regexp({self.field}, ?)", [self.pattern.pattern]
@ -308,7 +314,7 @@ class BooleanQuery(MatchQuery[int]):
def __init__(
self,
field: str,
field_name: str,
pattern: bool,
fast: bool = True,
):
@ -317,7 +323,7 @@ class BooleanQuery(MatchQuery[int]):
pattern_int = int(pattern)
super().__init__(field, pattern_int, fast)
super().__init__(field_name, pattern_int, fast)
class BytesQuery(FieldQuery[bytes]):
@ -327,7 +333,7 @@ class BytesQuery(FieldQuery[bytes]):
`MatchQuery` when matching on BLOB values.
"""
def __init__(self, field: str, pattern: Union[bytes, str, memoryview]):
def __init__(self, field_name: str, pattern: Union[bytes, str, memoryview]):
# Use a buffer/memoryview representation of the pattern for SQLite
# matching. This instructs SQLite to treat the blob as binary
# rather than encoded Unicode.
@ -343,7 +349,7 @@ class BytesQuery(FieldQuery[bytes]):
else:
raise ValueError("pattern must be bytes, str, or memoryview")
super().__init__(field, bytes_pattern)
super().__init__(field_name, bytes_pattern)
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.field + " = ?", [self.buf_pattern]
@ -379,8 +385,8 @@ class NumericQuery(FieldQuery[str]):
except ValueError:
raise InvalidQueryArgumentValueError(s, "an int or a float")
def __init__(self, field: str, pattern: str, fast: bool = True):
super().__init__(field, pattern, fast)
def __init__(self, field_name: str, pattern: str, fast: bool = True):
super().__init__(field_name, pattern, fast)
parts = pattern.split("..", 1)
if len(parts) == 1:
@ -395,9 +401,9 @@ class NumericQuery(FieldQuery[str]):
self.rangemax = self._convert(parts[1])
def match(self, obj: Model) -> bool:
if self.field not in obj:
if self.field_name not in obj:
return False
value = obj[self.field]
value = obj[self.field_name]
if isinstance(value, str):
value = self._convert(value)
@ -430,7 +436,7 @@ class NumericQuery(FieldQuery[str]):
class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]):
"""Query which matches values in the given set."""
field: str
field_name: str
pattern: Sequence[AnySQLiteType]
fast: bool = True
@ -440,7 +446,7 @@ class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]):
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
placeholders = ", ".join(["?"] * len(self.subvals))
return f"{self.field} IN ({placeholders})", self.subvals
return f"{self.field_name} IN ({placeholders})", self.subvals
@classmethod
def value_match(
@ -823,15 +829,15 @@ class DateQuery(FieldQuery[str]):
using an ellipsis interval syntax similar to that of NumericQuery.
"""
def __init__(self, field: str, pattern: str, fast: bool = True):
super().__init__(field, pattern, fast)
def __init__(self, field_name: str, pattern: str, fast: bool = True):
super().__init__(field_name, pattern, fast)
start, end = _parse_periods(pattern)
self.interval = DateInterval.from_periods(start, end)
def match(self, obj: Model) -> bool:
if self.field not in obj:
if self.field_name not in obj:
return False
timestamp = float(obj[self.field])
timestamp = float(obj[self.field_name])
date = datetime.fromtimestamp(timestamp)
return self.interval.contains(date)

View file

@ -153,6 +153,12 @@ def construct_query_part(
# they are querying.
else:
key = key.lower()
if key in model_cls.shared_db_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError if we try to query it in a join.
# Using an explicit table name resolves this.
key = f"{model_cls._table}.{key}"
out_query = query_class(key, pattern, key in model_cls.all_db_fields)
# Apply negation.

View file

@ -1148,6 +1148,7 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
album_items.append(item)
album = self.lib.add_album(album_items)
album.artpath = f"{album_name} Artpath"
album.catalognum = "ABC"
album.store()
albums.append(album)
@ -1163,6 +1164,11 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
def test_filter_by_common_field(self):
q = "catalognum:ABC Album1"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])
def suite():
return unittest.TestLoader().loadTestsFromName(__name__)