mirror of
https://github.com/beetbox/beets.git
synced 2025-12-06 08:39:17 +01:00
Make sure we can filter common fields
This commit is contained in:
parent
981a61bd56
commit
bcc2826000
3 changed files with 45 additions and 27 deletions
|
|
@ -134,18 +134,24 @@ class FieldQuery(Query, Generic[P]):
|
|||
same matching functionality in SQLite.
|
||||
"""
|
||||
|
||||
@property
|
||||
def field(self) -> str:
|
||||
return (
|
||||
f"{self.table}.{self.field_name}" if self.table else self.field_name
|
||||
)
|
||||
|
||||
@property
|
||||
def field_names(self) -> Set[str]:
|
||||
"""Return a set with field names that this query operates on."""
|
||||
return {self.field}
|
||||
return {self.field_name}
|
||||
|
||||
def __init__(self, field: str, pattern: P, fast: bool = True):
|
||||
self.field = field
|
||||
def __init__(self, field_name: str, pattern: P, fast: bool = True):
|
||||
self.table, _, self.field_name = field_name.rpartition(".")
|
||||
self.pattern = pattern
|
||||
self.fast = fast
|
||||
|
||||
def col_clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
||||
return None, ()
|
||||
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
return self.field, ()
|
||||
|
||||
def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
|
||||
if self.fast:
|
||||
|
|
@ -160,23 +166,23 @@ class FieldQuery(Query, Generic[P]):
|
|||
raise NotImplementedError()
|
||||
|
||||
def match(self, obj: Model) -> bool:
|
||||
return self.value_match(self.pattern, obj.get(self.field))
|
||||
return self.value_match(self.pattern, obj.get(self.field_name))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"{self.__class__.__name__}({self.field!r}, {self.pattern!r}, "
|
||||
f"{self.__class__.__name__}({self.field_name!r}, {self.pattern!r}, "
|
||||
f"fast={self.fast})"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return (
|
||||
super().__eq__(other)
|
||||
and self.field == other.field
|
||||
and self.field_name == other.field_name
|
||||
and self.pattern == other.pattern
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.field, hash(self.pattern)))
|
||||
return hash((self.field_name, hash(self.pattern)))
|
||||
|
||||
|
||||
class MatchQuery(FieldQuery[AnySQLiteType]):
|
||||
|
|
@ -200,10 +206,10 @@ class NoneQuery(FieldQuery[None]):
|
|||
return self.field + " IS NULL", ()
|
||||
|
||||
def match(self, obj: Model) -> bool:
|
||||
return obj.get(self.field) is None
|
||||
return obj.get(self.field_name) is None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.field!r}, {self.fast})"
|
||||
return f"{self.__class__.__name__}({self.field_name!r}, {self.fast})"
|
||||
|
||||
|
||||
class StringFieldQuery(FieldQuery[P]):
|
||||
|
|
@ -274,7 +280,7 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
|
|||
expression.
|
||||
"""
|
||||
|
||||
def __init__(self, field: str, pattern: str, fast: bool = True):
|
||||
def __init__(self, field_name: str, pattern: str, fast: bool = True):
|
||||
pattern = self._normalize(pattern)
|
||||
try:
|
||||
pattern_re = re.compile(pattern)
|
||||
|
|
@ -284,7 +290,7 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
|
|||
pattern, "a regular expression", format(exc)
|
||||
)
|
||||
|
||||
super().__init__(field, pattern_re, fast)
|
||||
super().__init__(field_name, pattern_re, fast)
|
||||
|
||||
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
return f" regexp({self.field}, ?)", [self.pattern.pattern]
|
||||
|
|
@ -308,7 +314,7 @@ class BooleanQuery(MatchQuery[int]):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
field: str,
|
||||
field_name: str,
|
||||
pattern: bool,
|
||||
fast: bool = True,
|
||||
):
|
||||
|
|
@ -317,7 +323,7 @@ class BooleanQuery(MatchQuery[int]):
|
|||
|
||||
pattern_int = int(pattern)
|
||||
|
||||
super().__init__(field, pattern_int, fast)
|
||||
super().__init__(field_name, pattern_int, fast)
|
||||
|
||||
|
||||
class BytesQuery(FieldQuery[bytes]):
|
||||
|
|
@ -327,7 +333,7 @@ class BytesQuery(FieldQuery[bytes]):
|
|||
`MatchQuery` when matching on BLOB values.
|
||||
"""
|
||||
|
||||
def __init__(self, field: str, pattern: Union[bytes, str, memoryview]):
|
||||
def __init__(self, field_name: str, pattern: Union[bytes, str, memoryview]):
|
||||
# Use a buffer/memoryview representation of the pattern for SQLite
|
||||
# matching. This instructs SQLite to treat the blob as binary
|
||||
# rather than encoded Unicode.
|
||||
|
|
@ -343,7 +349,7 @@ class BytesQuery(FieldQuery[bytes]):
|
|||
else:
|
||||
raise ValueError("pattern must be bytes, str, or memoryview")
|
||||
|
||||
super().__init__(field, bytes_pattern)
|
||||
super().__init__(field_name, bytes_pattern)
|
||||
|
||||
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
return self.field + " = ?", [self.buf_pattern]
|
||||
|
|
@ -379,8 +385,8 @@ class NumericQuery(FieldQuery[str]):
|
|||
except ValueError:
|
||||
raise InvalidQueryArgumentValueError(s, "an int or a float")
|
||||
|
||||
def __init__(self, field: str, pattern: str, fast: bool = True):
|
||||
super().__init__(field, pattern, fast)
|
||||
def __init__(self, field_name: str, pattern: str, fast: bool = True):
|
||||
super().__init__(field_name, pattern, fast)
|
||||
|
||||
parts = pattern.split("..", 1)
|
||||
if len(parts) == 1:
|
||||
|
|
@ -395,9 +401,9 @@ class NumericQuery(FieldQuery[str]):
|
|||
self.rangemax = self._convert(parts[1])
|
||||
|
||||
def match(self, obj: Model) -> bool:
|
||||
if self.field not in obj:
|
||||
if self.field_name not in obj:
|
||||
return False
|
||||
value = obj[self.field]
|
||||
value = obj[self.field_name]
|
||||
if isinstance(value, str):
|
||||
value = self._convert(value)
|
||||
|
||||
|
|
@ -430,7 +436,7 @@ class NumericQuery(FieldQuery[str]):
|
|||
class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]):
|
||||
"""Query which matches values in the given set."""
|
||||
|
||||
field: str
|
||||
field_name: str
|
||||
pattern: Sequence[AnySQLiteType]
|
||||
fast: bool = True
|
||||
|
||||
|
|
@ -440,7 +446,7 @@ class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]):
|
|||
|
||||
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
|
||||
placeholders = ", ".join(["?"] * len(self.subvals))
|
||||
return f"{self.field} IN ({placeholders})", self.subvals
|
||||
return f"{self.field_name} IN ({placeholders})", self.subvals
|
||||
|
||||
@classmethod
|
||||
def value_match(
|
||||
|
|
@ -823,15 +829,15 @@ class DateQuery(FieldQuery[str]):
|
|||
using an ellipsis interval syntax similar to that of NumericQuery.
|
||||
"""
|
||||
|
||||
def __init__(self, field: str, pattern: str, fast: bool = True):
|
||||
super().__init__(field, pattern, fast)
|
||||
def __init__(self, field_name: str, pattern: str, fast: bool = True):
|
||||
super().__init__(field_name, pattern, fast)
|
||||
start, end = _parse_periods(pattern)
|
||||
self.interval = DateInterval.from_periods(start, end)
|
||||
|
||||
def match(self, obj: Model) -> bool:
|
||||
if self.field not in obj:
|
||||
if self.field_name not in obj:
|
||||
return False
|
||||
timestamp = float(obj[self.field])
|
||||
timestamp = float(obj[self.field_name])
|
||||
date = datetime.fromtimestamp(timestamp)
|
||||
return self.interval.contains(date)
|
||||
|
||||
|
|
|
|||
|
|
@ -153,6 +153,12 @@ def construct_query_part(
|
|||
# they are querying.
|
||||
else:
|
||||
key = key.lower()
|
||||
if key in model_cls.shared_db_fields:
|
||||
# This field exists in both tables, so SQLite will encounter
|
||||
# an OperationalError if we try to query it in a join.
|
||||
# Using an explicit table name resolves this.
|
||||
key = f"{model_cls._table}.{key}"
|
||||
|
||||
out_query = query_class(key, pattern, key in model_cls.all_db_fields)
|
||||
|
||||
# Apply negation.
|
||||
|
|
|
|||
|
|
@ -1148,6 +1148,7 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
|
|||
album_items.append(item)
|
||||
album = self.lib.add_album(album_items)
|
||||
album.artpath = f"{album_name} Artpath"
|
||||
album.catalognum = "ABC"
|
||||
album.store()
|
||||
albums.append(album)
|
||||
|
||||
|
|
@ -1163,6 +1164,11 @@ class RelatedQueriesTest(_common.TestCase, AssertsMixin):
|
|||
results = self.lib.items(q)
|
||||
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
|
||||
|
||||
def test_filter_by_common_field(self):
|
||||
q = "catalognum:ABC Album1"
|
||||
results = self.lib.albums(q)
|
||||
self.assert_albums_matched(results, ["Album1"])
|
||||
|
||||
|
||||
def suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
|
|
|||
Loading…
Reference in a new issue