Add support for a specific type in InQuery

This commit is contained in:
Šarūnas Nejus 2024-04-30 12:59:01 +01:00
parent a57c164348
commit 7d636d8f22
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435
3 changed files with 10 additions and 10 deletions

View file

@ -116,7 +116,7 @@ class Query(ABC):
P = TypeVar("P")
SQLiteType = Union[str, float, int, memoryview]
SQLiteType = Union[str, bytes, float, int, memoryview]
AnySQLiteType = TypeVar("AnySQLiteType", bound=SQLiteType)
@ -416,7 +416,7 @@ class NumericQuery(FieldQuery):
return "1", ()
class InQuery(FieldQuery[Sequence[AnySQLiteType]]):
class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]):
"""Query which matches values in the given set."""
field: str
@ -424,10 +424,10 @@ class InQuery(FieldQuery[Sequence[AnySQLiteType]]):
fast: bool = True
@property
def subvals(self) -> Sequence[AnySQLiteType]:
def subvals(self) -> Sequence[SQLiteType]:
return self.pattern
def col_clause(self) -> Tuple[str, Sequence[AnySQLiteType]]:
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
placeholders = ", ".join(["?"] * len(self.subvals))
return f"{self.field} IN ({placeholders})", self.subvals

View file

@ -427,7 +427,7 @@ def displayable_path(
return path.decode("utf-8", "ignore")
def syspath(path: bytes, prefix: bool = True) -> Bytes_or_String:
def syspath(path: Bytes_or_String, prefix: bool = True) -> Bytes_or_String:
"""Convert a path for use by the operating system. In particular,
paths on Windows must receive a magic prefix and must be converted
to Unicode before they are sent to the OS. To disable the magic

View file

@ -23,9 +23,13 @@ from beets.library import BLOB_TYPE
from beets.util import path_as_posix
class PlaylistQuery(InQuery):
class PlaylistQuery(InQuery[bytes]):
"""Matches files listed by a playlist file."""
@property
def subvals(self) -> Sequence[BLOB_TYPE]:
return [BLOB_TYPE(p) for p in self.pattern]
def __init__(self, _, pattern: str, __):
config = beets.config["playlist"]
@ -73,10 +77,6 @@ class PlaylistQuery(InQuery):
break
super().__init__("path", paths)
@property
def subvals(self) -> Sequence[BLOB_TYPE]:
return [BLOB_TYPE(p) for p in self.pattern]
class PlaylistPlugin(beets.plugins.BeetsPlugin):
item_queries = {"playlist": PlaylistQuery}