Keep the same FieldQuery.field interface as before

This commit is contained in:
Šarūnas Nejus 2024-06-17 08:59:20 +01:00
parent 2f80ff07e4
commit 2c4b42d167
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
4 changed files with 48 additions and 46 deletions

View file

@ -365,7 +365,7 @@ class Model(ABC):
Availability of the 'flex_attrs' means we can query flexible attributes
in the same manner we query other entity fields, see
`FieldQuery.col_name`. This way, we also remove the need for an
`FieldQuery.field`. This way, we also remove the need for an
additional query to fetch them.
Note: we use LEFT join to include entities without flexible attributes.

View file

@ -131,22 +131,24 @@ class FieldQuery(Query, Generic[P]):
same matching functionality in SQLite.
"""
def __init__(self, field: str, pattern: P, fast: bool = True):
self.table, _, self.field = field.rpartition(".")
def __init__(self, field_name: str, pattern: P, fast: bool = True):
self.table, _, self.field_name = field_name.rpartition(".")
self.pattern = pattern
self.fast = fast
@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return {self.field}
return {self.field_name}
@property
def col_name(self) -> str:
def field(self) -> str:
if not self.fast:
return f'json_extract(all_flex_attrs, "$.{self.field}")'
return f'json_extract(all_flex_attrs, "$.{self.field_name}")'
return f"{self.table}.{self.field}" if self.table else self.field
return (
f"{self.table}.{self.field_name}" if self.table else self.field_name
)
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
raise NotImplementedError
@ -160,30 +162,30 @@ class FieldQuery(Query, Generic[P]):
raise NotImplementedError()
def match(self, obj: Model) -> bool:
return self.value_match(self.pattern, obj.get(self.field))
return self.value_match(self.pattern, obj.get(self.field_name))
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.field!r}, {self.pattern!r}, "
f"{self.__class__.__name__}({self.field_name!r}, {self.pattern!r}, "
f"fast={self.fast})"
)
def __eq__(self, other) -> bool:
return (
super().__eq__(other)
and self.field == other.field
and self.field_name == other.field_name
and self.pattern == other.pattern
)
def __hash__(self) -> int:
return hash((self.field, hash(self.pattern)))
return hash((self.field_name, hash(self.pattern)))
class MatchQuery(FieldQuery[AnySQLiteType]):
"""A query that looks for exact matches in an Model field."""
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.col_name + " = ?", [self.pattern]
return self.field + " = ?", [self.pattern]
@classmethod
def value_match(cls, pattern: AnySQLiteType, value: Any) -> bool:
@ -197,13 +199,13 @@ class NoneQuery(FieldQuery[None]):
super().__init__(field, None, fast)
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.col_name + " IS NULL", ()
return self.field + " IS NULL", ()
def match(self, obj: Model) -> bool:
return obj.get(self.field) is None
return obj.get(self.field_name) is None
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.field!r}, {self.fast})"
return f"{self.__class__.__name__}({self.field_name!r}, {self.fast})"
class StringFieldQuery(FieldQuery[P]):
@ -239,7 +241,7 @@ class StringQuery(StringFieldQuery[str]):
.replace("%", "\\%")
.replace("_", "\\_")
)
clause = self.col_name + " like ? escape '\\'"
clause = self.field + " like ? escape '\\'"
subvals = [search]
return clause, subvals
@ -258,7 +260,7 @@ class SubstringQuery(StringFieldQuery[str]):
.replace("_", "\\_")
)
search = "%" + pattern + "%"
clause = self.col_name + " like ? escape '\\'"
clause = self.field + " like ? escape '\\'"
subvals = [search]
return clause, subvals
@ -274,7 +276,7 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
expression.
"""
def __init__(self, field: str, pattern: str, fast: bool = True):
def __init__(self, field_name: str, pattern: str, fast: bool = True):
pattern = self._normalize(pattern)
try:
pattern_re = re.compile(pattern)
@ -284,10 +286,10 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
pattern, "a regular expression", format(exc)
)
super().__init__(field, pattern_re, fast)
super().__init__(field_name, pattern_re, fast)
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return f" regexp({self.col_name}, ?)", [self.pattern.pattern]
return f" regexp({self.field}, ?)", [self.pattern.pattern]
@staticmethod
def _normalize(s: str) -> str:
@ -305,10 +307,10 @@ class NumericColumnQuery(MatchQuery[AnySQLiteType]):
"""A base class for queries that work with NUMERIC SQLite affinity."""
@property
def col_name(self) -> str:
def field(self) -> str:
"""Cast a flexible attribute column (string) to NUMERIC affinity."""
col_name = super().col_name
return col_name if self.fast else f"CAST({col_name} AS NUMERIC)"
field = super().field
return field if self.fast else f"CAST({field} AS NUMERIC)"
class BooleanQuery(NumericColumnQuery[bool]):
@ -318,7 +320,7 @@ class BooleanQuery(NumericColumnQuery[bool]):
def __init__(
self,
field: str,
field_name: str,
pattern: bool,
fast: bool = True,
):
@ -327,7 +329,7 @@ class BooleanQuery(NumericColumnQuery[bool]):
pattern_int = int(pattern)
super().__init__(field, pattern_int, fast)
super().__init__(field_name, pattern_int, fast)
class BytesQuery(FieldQuery[bytes]):
@ -337,7 +339,7 @@ class BytesQuery(FieldQuery[bytes]):
`MatchQuery` when matching on BLOB values.
"""
def __init__(self, field: str, pattern: Union[bytes, str, memoryview]):
def __init__(self, field_name: str, pattern: Union[bytes, str, memoryview]):
# Use a buffer/memoryview representation of the pattern for SQLite
# matching. This instructs SQLite to treat the blob as binary
# rather than encoded Unicode.
@ -353,10 +355,10 @@ class BytesQuery(FieldQuery[bytes]):
else:
raise ValueError("pattern must be bytes, str, or memoryview")
super().__init__(field, bytes_pattern)
super().__init__(field_name, bytes_pattern)
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.col_name + " = ?", [self.buf_pattern]
return self.field + " = ?", [self.buf_pattern]
@classmethod
def value_match(cls, pattern: bytes, value: Any) -> bool:
@ -389,8 +391,8 @@ class NumericQuery(NumericColumnQuery[Union[int, float]]):
except ValueError:
raise InvalidQueryArgumentValueError(s, "an int or a float")
def __init__(self, field: str, pattern: str, fast: bool = True):
super().__init__(field, pattern, fast)
def __init__(self, field_name: str, pattern: str, fast: bool = True):
super().__init__(field_name, pattern, fast)
parts = pattern.split("..", 1)
if len(parts) == 1:
@ -405,9 +407,9 @@ class NumericQuery(NumericColumnQuery[Union[int, float]]):
self.rangemax = self._convert(parts[1])
def match(self, obj: Model) -> bool:
if self.field not in obj:
if self.field_name not in obj:
return False
value = obj[self.field]
value = obj[self.field_name]
if isinstance(value, str):
value = self._convert(value)
@ -422,17 +424,17 @@ class NumericQuery(NumericColumnQuery[Union[int, float]]):
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
if self.point is not None:
return self.col_name + "=?", (self.point,)
return self.field + "=?", (self.point,)
else:
if self.rangemin is not None and self.rangemax is not None:
return (
"{0} >= ? AND {0} <= ?".format(self.col_name),
"{0} >= ? AND {0} <= ?".format(self.field),
(self.rangemin, self.rangemax),
)
elif self.rangemin is not None:
return f"{self.col_name} >= ?", (self.rangemin,)
return f"{self.field} >= ?", (self.rangemin,)
elif self.rangemax is not None:
return f"{self.col_name} <= ?", (self.rangemax,)
return f"{self.field} <= ?", (self.rangemax,)
else:
return "1", ()
@ -440,7 +442,7 @@ class NumericQuery(NumericColumnQuery[Union[int, float]]):
class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]):
"""Query which matches values in the given set."""
field: str
field_name: str
pattern: Sequence[AnySQLiteType]
fast: bool = True
@ -450,7 +452,7 @@ class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]):
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
placeholders = ", ".join(["?"] * len(self.subvals))
return f"{self.col_name} IN ({placeholders})", self.subvals
return f"{self.field_name} IN ({placeholders})", self.subvals
@classmethod
def value_match(
@ -781,15 +783,15 @@ class DateQuery(NumericColumnQuery[int]):
using an ellipsis interval syntax similar to that of NumericQuery.
"""
def __init__(self, field: str, pattern: str, fast: bool = True):
super().__init__(field, pattern, fast)
def __init__(self, field_name: str, pattern: str, fast: bool = True):
super().__init__(field_name, pattern, fast)
start, end = _parse_periods(pattern)
self.interval = DateInterval.from_periods(start, end)
def match(self, obj: Model) -> bool:
if self.field not in obj:
if self.field_name not in obj:
return False
timestamp = float(obj[self.field])
timestamp = float(obj[self.field_name])
date = datetime.fromtimestamp(timestamp)
return self.interval.contains(date)
@ -802,11 +804,11 @@ class DateQuery(NumericColumnQuery[int]):
# Convert the `datetime` objects to an integer number of seconds since
# the (local) Unix epoch using `datetime.timestamp()`.
if self.interval.start:
clause_parts.append(self._clause_tmpl.format(self.col_name, ">="))
clause_parts.append(self._clause_tmpl.format(self.field, ">="))
subvals.append(int(self.interval.start.timestamp()))
if self.interval.end:
clause_parts.append(self._clause_tmpl.format(self.col_name, "<"))
clause_parts.append(self._clause_tmpl.format(self.field, "<"))
subvals.append(int(self.interval.end.timestamp()))
if clause_parts:

View file

@ -148,7 +148,7 @@ class PathQuery(dbcore.FieldQuery[bytes]):
query_part = "(BYTELOWER({0}) = BYTELOWER(?)) || \
(substr(BYTELOWER({0}), 1, ?) = BYTELOWER(?))"
return query_part.format(self.col_name), (
return query_part.format(self.field), (
file_blob,
len(dir_blob),
dir_blob,

View file

@ -46,7 +46,7 @@ class BareascQuery(StringFieldQuery[str]):
def col_clause(self):
"""Compare ascii version of the pattern."""
clause = f"unidecode({self.col_name})"
clause = f"unidecode({self.field})"
if self.pattern.islower():
clause = f"lower({clause})"