diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 988752a1e..8b7f3f35c 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -365,7 +365,7 @@ class Model(ABC): Availability of the 'flex_attrs' means we can query flexible attributes in the same manner we query other entity fields, see - `FieldQuery.col_name`. This way, we also remove the need for an + `FieldQuery.field`. This way, we also remove the need for an additional query to fetch them. Note: we use LEFT join to include entities without flexible attributes. diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 5e9e83149..2282c7815 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -131,22 +131,24 @@ class FieldQuery(Query, Generic[P]): same matching functionality in SQLite. """ - def __init__(self, field: str, pattern: P, fast: bool = True): - self.table, _, self.field = field.rpartition(".") + def __init__(self, field_name: str, pattern: P, fast: bool = True): + self.table, _, self.field_name = field_name.rpartition(".") self.pattern = pattern self.fast = fast @property def field_names(self) -> Set[str]: """Return a set with field names that this query operates on.""" - return {self.field} + return {self.field_name} @property - def col_name(self) -> str: + def field(self) -> str: if not self.fast: - return f'json_extract(all_flex_attrs, "$.{self.field}")' + return f'json_extract(all_flex_attrs, "$.{self.field_name}")' - return f"{self.table}.{self.field}" if self.table else self.field + return ( + f"{self.table}.{self.field_name}" if self.table else self.field_name + ) def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: raise NotImplementedError @@ -160,30 +162,30 @@ class FieldQuery(Query, Generic[P]): raise NotImplementedError() def match(self, obj: Model) -> bool: - return self.value_match(self.pattern, obj.get(self.field)) + return self.value_match(self.pattern, obj.get(self.field_name)) def __repr__(self) -> str: return ( - f"{self.__class__.__name__}({self.field!r}, {self.pattern!r}, " + f"{self.__class__.__name__}({self.field_name!r}, {self.pattern!r}, " f"fast={self.fast})" ) def __eq__(self, other) -> bool: return ( super().__eq__(other) - and self.field == other.field + and self.field_name == other.field_name and self.pattern == other.pattern ) def __hash__(self) -> int: - return hash((self.field, hash(self.pattern))) + return hash((self.field_name, hash(self.pattern))) class MatchQuery(FieldQuery[AnySQLiteType]): """A query that looks for exact matches in an Model field.""" def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: - return self.col_name + " = ?", [self.pattern] + return self.field + " = ?", [self.pattern] @classmethod def value_match(cls, pattern: AnySQLiteType, value: Any) -> bool: @@ -197,13 +199,13 @@ class NoneQuery(FieldQuery[None]): super().__init__(field, None, fast) def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: - return self.col_name + " IS NULL", () + return self.field + " IS NULL", () def match(self, obj: Model) -> bool: - return obj.get(self.field) is None + return obj.get(self.field_name) is None def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.field!r}, {self.fast})" + return f"{self.__class__.__name__}({self.field_name!r}, {self.fast})" class StringFieldQuery(FieldQuery[P]): @@ -239,7 +241,7 @@ class StringQuery(StringFieldQuery[str]): .replace("%", "\\%") .replace("_", "\\_") ) - clause = self.col_name + " like ? escape '\\'" + clause = self.field + " like ? escape '\\'" subvals = [search] return clause, subvals @@ -258,7 +260,7 @@ class SubstringQuery(StringFieldQuery[str]): .replace("_", "\\_") ) search = "%" + pattern + "%" - clause = self.col_name + " like ? escape '\\'" + clause = self.field + " like ? escape '\\'" subvals = [search] return clause, subvals @@ -274,7 +276,7 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]): expression. """ - def __init__(self, field: str, pattern: str, fast: bool = True): + def __init__(self, field_name: str, pattern: str, fast: bool = True): pattern = self._normalize(pattern) try: pattern_re = re.compile(pattern) @@ -284,10 +286,10 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]): pattern, "a regular expression", format(exc) ) - super().__init__(field, pattern_re, fast) + super().__init__(field_name, pattern_re, fast) def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: - return f" regexp({self.col_name}, ?)", [self.pattern.pattern] + return f" regexp({self.field}, ?)", [self.pattern.pattern] @staticmethod def _normalize(s: str) -> str: @@ -305,10 +307,10 @@ class NumericColumnQuery(MatchQuery[AnySQLiteType]): """A base class for queries that work with NUMERIC SQLite affinity.""" @property - def col_name(self) -> str: + def field(self) -> str: """Cast a flexible attribute column (string) to NUMERIC affinity.""" - col_name = super().col_name - return col_name if self.fast else f"CAST({col_name} AS NUMERIC)" + field = super().field + return field if self.fast else f"CAST({field} AS NUMERIC)" class BooleanQuery(NumericColumnQuery[bool]): @@ -318,7 +320,7 @@ class BooleanQuery(NumericColumnQuery[bool]): def __init__( self, - field: str, + field_name: str, pattern: bool, fast: bool = True, ): @@ -327,7 +329,7 @@ class BooleanQuery(NumericColumnQuery[bool]): pattern_int = int(pattern) - super().__init__(field, pattern_int, fast) + super().__init__(field_name, pattern_int, fast) class BytesQuery(FieldQuery[bytes]): @@ -337,7 +339,7 @@ class BytesQuery(FieldQuery[bytes]): `MatchQuery` when matching on BLOB values. """ - def __init__(self, field: str, pattern: Union[bytes, str, memoryview]): + def __init__(self, field_name: str, pattern: Union[bytes, str, memoryview]): # Use a buffer/memoryview representation of the pattern for SQLite # matching. This instructs SQLite to treat the blob as binary # rather than encoded Unicode. @@ -353,10 +355,10 @@ class BytesQuery(FieldQuery[bytes]): else: raise ValueError("pattern must be bytes, str, or memoryview") - super().__init__(field, bytes_pattern) + super().__init__(field_name, bytes_pattern) def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: - return self.col_name + " = ?", [self.buf_pattern] + return self.field + " = ?", [self.buf_pattern] @classmethod def value_match(cls, pattern: bytes, value: Any) -> bool: @@ -389,8 +391,8 @@ class NumericQuery(NumericColumnQuery[Union[int, float]]): except ValueError: raise InvalidQueryArgumentValueError(s, "an int or a float") - def __init__(self, field: str, pattern: str, fast: bool = True): - super().__init__(field, pattern, fast) + def __init__(self, field_name: str, pattern: str, fast: bool = True): + super().__init__(field_name, pattern, fast) parts = pattern.split("..", 1) if len(parts) == 1: @@ -405,9 +407,9 @@ class NumericQuery(NumericColumnQuery[Union[int, float]]): self.rangemax = self._convert(parts[1]) def match(self, obj: Model) -> bool: - if self.field not in obj: + if self.field_name not in obj: return False - value = obj[self.field] + value = obj[self.field_name] if isinstance(value, str): value = self._convert(value) @@ -422,17 +424,17 @@ class NumericQuery(NumericColumnQuery[Union[int, float]]): def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: if self.point is not None: - return self.col_name + "=?", (self.point,) + return self.field + "=?", (self.point,) else: if self.rangemin is not None and self.rangemax is not None: return ( - "{0} >= ? AND {0} <= ?".format(self.col_name), + "{0} >= ? AND {0} <= ?".format(self.field), (self.rangemin, self.rangemax), ) elif self.rangemin is not None: - return f"{self.col_name} >= ?", (self.rangemin,) + return f"{self.field} >= ?", (self.rangemin,) elif self.rangemax is not None: - return f"{self.col_name} <= ?", (self.rangemax,) + return f"{self.field} <= ?", (self.rangemax,) else: return "1", () @@ -440,7 +442,7 @@ class NumericQuery(NumericColumnQuery[Union[int, float]]): class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]): """Query which matches values in the given set.""" - field: str + field_name: str pattern: Sequence[AnySQLiteType] fast: bool = True @@ -450,7 +452,7 @@ class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]): def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: placeholders = ", ".join(["?"] * len(self.subvals)) - return f"{self.col_name} IN ({placeholders})", self.subvals + return f"{self.field_name} IN ({placeholders})", self.subvals @classmethod def value_match( @@ -781,15 +783,15 @@ class DateQuery(NumericColumnQuery[int]): using an ellipsis interval syntax similar to that of NumericQuery. """ - def __init__(self, field: str, pattern: str, fast: bool = True): - super().__init__(field, pattern, fast) + def __init__(self, field_name: str, pattern: str, fast: bool = True): + super().__init__(field_name, pattern, fast) start, end = _parse_periods(pattern) self.interval = DateInterval.from_periods(start, end) def match(self, obj: Model) -> bool: - if self.field not in obj: + if self.field_name not in obj: return False - timestamp = float(obj[self.field]) + timestamp = float(obj[self.field_name]) date = datetime.fromtimestamp(timestamp) return self.interval.contains(date) @@ -802,11 +804,11 @@ class DateQuery(NumericColumnQuery[int]): # Convert the `datetime` objects to an integer number of seconds since # the (local) Unix epoch using `datetime.timestamp()`. if self.interval.start: - clause_parts.append(self._clause_tmpl.format(self.col_name, ">=")) + clause_parts.append(self._clause_tmpl.format(self.field, ">=")) subvals.append(int(self.interval.start.timestamp())) if self.interval.end: - clause_parts.append(self._clause_tmpl.format(self.col_name, "<")) + clause_parts.append(self._clause_tmpl.format(self.field, "<")) subvals.append(int(self.interval.end.timestamp())) if clause_parts: diff --git a/beets/library.py b/beets/library.py index 80ffb4576..2e0003dd4 100644 --- a/beets/library.py +++ b/beets/library.py @@ -148,7 +148,7 @@ class PathQuery(dbcore.FieldQuery[bytes]): query_part = "(BYTELOWER({0}) = BYTELOWER(?)) || \ (substr(BYTELOWER({0}), 1, ?) = BYTELOWER(?))" - return query_part.format(self.col_name), ( + return query_part.format(self.field), ( file_blob, len(dir_blob), dir_blob, diff --git a/beetsplug/bareasc.py b/beetsplug/bareasc.py index 7ee33460d..8cdcbb113 100644 --- a/beetsplug/bareasc.py +++ b/beetsplug/bareasc.py @@ -46,7 +46,7 @@ class BareascQuery(StringFieldQuery[str]): def col_clause(self): """Compare ascii version of the pattern.""" - clause = f"unidecode({self.col_name})" + clause = f"unidecode({self.field})" if self.pattern.islower(): clause = f"lower({clause})"