diff --git a/.mypy.ini b/.mypy.ini index b47e5dff3..6bad7a0b6 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,5 +1,5 @@ [mypy] +allow_any_generics = false # FIXME: Would be better to actually type the libraries (if under our control), # or write our own stubs. For now, silence errors -ignore_missing_imports = True - +ignore_missing_imports = true diff --git a/beets/dbcore/__init__.py b/beets/dbcore/__init__.py index baeb10d26..06d0b3dc9 100644 --- a/beets/dbcore/__init__.py +++ b/beets/dbcore/__init__.py @@ -22,7 +22,6 @@ from .query import ( FieldQuery, InvalidQueryError, MatchQuery, - NamedQuery, OrQuery, Query, ) diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 7fbf646dc..369b1ffe0 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -12,8 +12,7 @@ # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. -"""The central Model and Database constructs for DBCore. -""" +"""The central Model and Database constructs for DBCore.""" from __future__ import annotations @@ -309,7 +308,7 @@ class Model(ABC): are subclasses of `Sort`. """ - _queries: Dict[str, Type[Query]] = {} + _queries: Dict[str, Type[FieldQuery]] = {} """Named queries that use a field-like `name:value` syntax but which do not relate to any specific field. """ @@ -599,8 +598,7 @@ class Model(ABC): # Deleted flexible attributes. for key in self._dirty: tx.mutate( - "DELETE FROM {} " - "WHERE entity_id=? AND key=?".format(self._flex_table), + f"DELETE FROM {self._flex_table} WHERE entity_id=? AND key=?", (self.id, key), ) diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 890acbe72..ffa89168e 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -12,8 +12,7 @@ # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. -"""The Query type hierarchy for DBCore. -""" +"""The Query type hierarchy for DBCore.""" from __future__ import annotations @@ -116,17 +115,8 @@ class Query(ABC): return hash(type(self)) -class NamedQuery(Query): - """Non-field query, i.e. the query prefix is not a field but identifies the - query class. - """ - - @abstractmethod - def __init__(self, pattern): ... - - P = TypeVar("P") -SQLiteType = Union[str, float, int, memoryview] +SQLiteType = Union[str, bytes, float, int, memoryview] AnySQLiteType = TypeVar("AnySQLiteType", bound=SQLiteType) @@ -155,9 +145,7 @@ class FieldQuery(Query, Generic[P]): @classmethod def value_match(cls, pattern: P, value: Any): - """Determine whether the value matches the pattern. Both - arguments are strings. - """ + """Determine whether the value matches the pattern.""" raise NotImplementedError() def match(self, obj: Model) -> bool: @@ -428,6 +416,28 @@ class NumericQuery(FieldQuery): return "1", () +class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]): + """Query which matches values in the given set.""" + + field: str + pattern: Sequence[AnySQLiteType] + fast: bool = True + + @property + def subvals(self) -> Sequence[SQLiteType]: + return self.pattern + + def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]: + placeholders = ", ".join(["?"] * len(self.subvals)) + return f"{self.field} IN ({placeholders})", self.subvals + + @classmethod + def value_match( + cls, pattern: Sequence[AnySQLiteType], value: AnySQLiteType + ) -> bool: + return value in pattern + + class CollectionQuery(Query): """An abstract query class that aggregates other queries. Can be indexed like a list to access the sub-queries. diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index e2b082ecc..ea6f16922 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -12,15 +12,14 @@ # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. -"""Parsing of strings into DBCore queries. -""" +"""Parsing of strings into DBCore queries.""" import itertools import re from typing import Collection, Dict, List, Optional, Sequence, Tuple, Type from . import Model, query -from .query import Query, Sort +from .query import Sort PARSE_QUERY_PART_REGEX = re.compile( # Non-capturing optional segment for the keyword. @@ -36,10 +35,10 @@ PARSE_QUERY_PART_REGEX = re.compile( def parse_query_part( part: str, - query_classes: Dict = {}, + query_classes: Dict[str, Type[query.FieldQuery]] = {}, prefixes: Dict = {}, default_class: Type[query.SubstringQuery] = query.SubstringQuery, -) -> Tuple[Optional[str], str, Type[query.Query], bool]: +) -> Tuple[Optional[str], str, Type[query.FieldQuery], bool]: """Parse a single *query part*, which is a chunk of a complete query string representing a single criterion. @@ -128,7 +127,7 @@ def construct_query_part( # Use `model_cls` to build up a map from field (or query) names to # `Query` classes. - query_classes: Dict[str, Type[Query]] = {} + query_classes: Dict[str, Type[query.FieldQuery]] = {} for k, t in itertools.chain( model_cls._fields.items(), model_cls._types.items() ): @@ -143,30 +142,17 @@ def construct_query_part( # If there's no key (field name) specified, this is a "match # anything" query. if key is None: - if issubclass(query_class, query.FieldQuery): - # The query type matches a specific field, but none was - # specified. So we use a version of the query that matches - # any field. - out_query = query.AnyFieldQuery( - pattern, model_cls._search_fields, query_class - ) - elif issubclass(query_class, query.NamedQuery): - # Non-field query type. - out_query = query_class(pattern) - else: - assert False, "Unexpected query type" + # The query type matches a specific field, but none was + # specified. So we use a version of the query that matches + # any field. + out_query = query.AnyFieldQuery( + pattern, model_cls._search_fields, query_class + ) # Field queries get constructed according to the name of the field # they are querying. - elif issubclass(query_class, query.FieldQuery): - key = key.lower() - out_query = query_class(key.lower(), pattern, key in model_cls._fields) - - # Non-field (named) query. - elif issubclass(query_class, query.NamedQuery): - out_query = query_class(pattern) else: - assert False, "Unexpected query type" + out_query = query_class(key.lower(), pattern, key in model_cls._fields) # Apply negation. if negate: diff --git a/beets/util/__init__.py b/beets/util/__init__.py index ccb95b459..4335e0f3e 100644 --- a/beets/util/__init__.py +++ b/beets/util/__init__.py @@ -427,7 +427,7 @@ def displayable_path( return path.decode("utf-8", "ignore") -def syspath(path: bytes, prefix: bool = True) -> Bytes_or_String: +def syspath(path: Bytes_or_String, prefix: bool = True) -> Bytes_or_String: """Convert a path for use by the operating system. In particular, paths on Windows must receive a magic prefix and must be converted to Unicode before they are sent to the OS. To disable the magic diff --git a/beetsplug/playlist.py b/beetsplug/playlist.py index d40f4125f..83f95796e 100644 --- a/beetsplug/playlist.py +++ b/beetsplug/playlist.py @@ -15,17 +15,22 @@ import fnmatch import os import tempfile -from typing import Any, Optional, Sequence, Tuple +from typing import Sequence import beets +from beets.dbcore.query import InQuery +from beets.library import BLOB_TYPE from beets.util import path_as_posix -class PlaylistQuery(beets.dbcore.NamedQuery): +class PlaylistQuery(InQuery[bytes]): """Matches files listed by a playlist file.""" - def __init__(self, pattern): - self.pattern = pattern + @property + def subvals(self) -> Sequence[BLOB_TYPE]: + return [BLOB_TYPE(p) for p in self.pattern] + + def __init__(self, _, pattern: str, __): config = beets.config["playlist"] # Get the full path to the playlist @@ -39,7 +44,7 @@ class PlaylistQuery(beets.dbcore.NamedQuery): ), ) - self.paths = [] + paths = [] for playlist_path in playlist_paths: if not fnmatch.fnmatch(playlist_path, "*.[mM]3[uU]"): # This is not am M3U playlist, skip this candidate @@ -63,23 +68,14 @@ class PlaylistQuery(beets.dbcore.NamedQuery): # ignore comments, and extm3u extension continue - self.paths.append( + paths.append( beets.util.normpath( os.path.join(relative_to, line.rstrip()) ) ) f.close() break - - def clause(self) -> Tuple[Optional[str], Sequence[Any]]: - if not self.paths: - # Playlist is empty - return "0", () - clause = "path IN ({})".format(", ".join("?" for path in self.paths)) - return clause, (beets.library.BLOB_TYPE(p) for p in self.paths) - - def match(self, item): - return item.path in self.paths + super().__init__("path", paths) class PlaylistPlugin(beets.plugins.BeetsPlugin): diff --git a/test/test_dbcore.py b/test/test_dbcore.py index e5ab1910b..763601b7f 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -12,8 +12,7 @@ # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. -"""Tests for the DBCore database abstraction. -""" +"""Tests for the DBCore database abstraction.""" import os import shutil @@ -32,7 +31,7 @@ class SortFixture(dbcore.query.FieldSort): pass -class QueryFixture(dbcore.query.NamedQuery): +class QueryFixture(dbcore.query.FieldQuery): def __init__(self, pattern): self.pattern = pattern @@ -605,10 +604,6 @@ class QueryFromStringsTest(unittest.TestCase): q = self.qfs([""]) self.assertIsInstance(q.subqueries[0], dbcore.query.TrueQuery) - def test_parse_named_query(self): - q = self.qfs(["some_query:foo"]) - self.assertIsInstance(q.subqueries[0], QueryFixture) - class SortFromStringsTest(unittest.TestCase): def sfs(self, strings):