Add support for a specific type in InQuery

This commit is contained in:
Šarūnas Nejus 2024-04-30 12:59:01 +01:00
parent a57c164348
commit 7d636d8f22
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
3 changed files with 10 additions and 10 deletions

View file

@ -116,7 +116,7 @@ class Query(ABC):
P = TypeVar("P") P = TypeVar("P")
SQLiteType = Union[str, float, int, memoryview] SQLiteType = Union[str, bytes, float, int, memoryview]
AnySQLiteType = TypeVar("AnySQLiteType", bound=SQLiteType) AnySQLiteType = TypeVar("AnySQLiteType", bound=SQLiteType)
@ -416,7 +416,7 @@ class NumericQuery(FieldQuery):
return "1", () return "1", ()
class InQuery(FieldQuery[Sequence[AnySQLiteType]]): class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]):
"""Query which matches values in the given set.""" """Query which matches values in the given set."""
field: str field: str
@ -424,10 +424,10 @@ class InQuery(FieldQuery[Sequence[AnySQLiteType]]):
fast: bool = True fast: bool = True
@property @property
def subvals(self) -> Sequence[AnySQLiteType]: def subvals(self) -> Sequence[SQLiteType]:
return self.pattern return self.pattern
def col_clause(self) -> Tuple[str, Sequence[AnySQLiteType]]: def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
placeholders = ", ".join(["?"] * len(self.subvals)) placeholders = ", ".join(["?"] * len(self.subvals))
return f"{self.field} IN ({placeholders})", self.subvals return f"{self.field} IN ({placeholders})", self.subvals

View file

@ -427,7 +427,7 @@ def displayable_path(
return path.decode("utf-8", "ignore") return path.decode("utf-8", "ignore")
def syspath(path: bytes, prefix: bool = True) -> Bytes_or_String: def syspath(path: Bytes_or_String, prefix: bool = True) -> Bytes_or_String:
"""Convert a path for use by the operating system. In particular, """Convert a path for use by the operating system. In particular,
paths on Windows must receive a magic prefix and must be converted paths on Windows must receive a magic prefix and must be converted
to Unicode before they are sent to the OS. To disable the magic to Unicode before they are sent to the OS. To disable the magic

View file

@ -23,9 +23,13 @@ from beets.library import BLOB_TYPE
from beets.util import path_as_posix from beets.util import path_as_posix
class PlaylistQuery(InQuery): class PlaylistQuery(InQuery[bytes]):
"""Matches files listed by a playlist file.""" """Matches files listed by a playlist file."""
@property
def subvals(self) -> Sequence[BLOB_TYPE]:
return [BLOB_TYPE(p) for p in self.pattern]
def __init__(self, _, pattern: str, __): def __init__(self, _, pattern: str, __):
config = beets.config["playlist"] config = beets.config["playlist"]
@ -73,10 +77,6 @@ class PlaylistQuery(InQuery):
break break
super().__init__("path", paths) super().__init__("path", paths)
@property
def subvals(self) -> Sequence[BLOB_TYPE]:
return [BLOB_TYPE(p) for p in self.pattern]
class PlaylistPlugin(beets.plugins.BeetsPlugin): class PlaylistPlugin(beets.plugins.BeetsPlugin):
item_queries = {"playlist": PlaylistQuery} item_queries = {"playlist": PlaylistQuery}