diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 890acbe72..e2e5399cf 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -12,8 +12,7 @@ # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. -"""The Query type hierarchy for DBCore. -""" +"""The Query type hierarchy for DBCore.""" from __future__ import annotations @@ -155,9 +154,7 @@ class FieldQuery(Query, Generic[P]): @classmethod def value_match(cls, pattern: P, value: Any): - """Determine whether the value matches the pattern. Both - arguments are strings. - """ + """Determine whether the value matches the pattern.""" raise NotImplementedError() def match(self, obj: Model) -> bool: @@ -428,6 +425,28 @@ class NumericQuery(FieldQuery): return "1", () +class InQuery(FieldQuery[Sequence[AnySQLiteType]]): + """Query which matches values in the given set.""" + + field: str + pattern: Sequence[AnySQLiteType] + fast: bool = True + + @property + def subvals(self) -> Sequence[AnySQLiteType]: + return self.pattern + + def col_clause(self) -> Tuple[str, Sequence[AnySQLiteType]]: + placeholders = ", ".join(["?"] * len(self.subvals)) + return f"{self.field} IN ({placeholders})", self.subvals + + @classmethod + def value_match( + cls, pattern: Sequence[AnySQLiteType], value: AnySQLiteType + ) -> bool: + return value in pattern + + class CollectionQuery(Query): """An abstract query class that aggregates other queries. Can be indexed like a list to access the sub-queries. diff --git a/beetsplug/playlist.py b/beetsplug/playlist.py index d40f4125f..401178553 100644 --- a/beetsplug/playlist.py +++ b/beetsplug/playlist.py @@ -15,17 +15,18 @@ import fnmatch import os import tempfile -from typing import Any, Optional, Sequence, Tuple +from typing import Sequence import beets +from beets.dbcore.query import InQuery +from beets.library import BLOB_TYPE from beets.util import path_as_posix -class PlaylistQuery(beets.dbcore.NamedQuery): +class PlaylistQuery(InQuery): """Matches files listed by a playlist file.""" - def __init__(self, pattern): - self.pattern = pattern + def __init__(self, _, pattern: str, __): config = beets.config["playlist"] # Get the full path to the playlist @@ -39,7 +40,7 @@ class PlaylistQuery(beets.dbcore.NamedQuery): ), ) - self.paths = [] + paths = [] for playlist_path in playlist_paths: if not fnmatch.fnmatch(playlist_path, "*.[mM]3[uU]"): # This is not am M3U playlist, skip this candidate @@ -63,23 +64,18 @@ class PlaylistQuery(beets.dbcore.NamedQuery): # ignore comments, and extm3u extension continue - self.paths.append( + paths.append( beets.util.normpath( os.path.join(relative_to, line.rstrip()) ) ) f.close() break + super().__init__("path", paths) - def clause(self) -> Tuple[Optional[str], Sequence[Any]]: - if not self.paths: - # Playlist is empty - return "0", () - clause = "path IN ({})".format(", ".join("?" for path in self.paths)) - return clause, (beets.library.BLOB_TYPE(p) for p in self.paths) - - def match(self, item): - return item.path in self.paths + @property + def subvals(self) -> Sequence[BLOB_TYPE]: + return [BLOB_TYPE(p) for p in self.pattern] class PlaylistPlugin(beets.plugins.BeetsPlugin):