diff --git a/beets/dbcore/__init__.py b/beets/dbcore/__init__.py index 923c34cac..7cca828bd 100644 --- a/beets/dbcore/__init__.py +++ b/beets/dbcore/__init__.py @@ -17,7 +17,7 @@ Library. """ from .db import Model, Database -from .query import Query, FieldQuery, MatchQuery, AndQuery, OrQuery +from .query import Query, FieldQuery, MatchQuery, NamedQuery, AndQuery, OrQuery from .types import Type from .queryparse import query_from_strings from .queryparse import sort_from_strings diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 5a9ea7059..abecd658d 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -16,10 +16,23 @@ """ from __future__ import annotations + +from abc import ABC, abstractmethod import re from operator import mul -from typing import Union, Tuple, List, Optional, Pattern, Any, Type, Iterator,\ - Collection, MutableMapping, Sequence +from typing import ( + Any, + Collection, + Iterator, + List, + MutableMapping, + Optional, + Pattern, + Sequence, + Tuple, + Type, + Union, +) from beets import util from datetime import datetime, timedelta @@ -67,24 +80,28 @@ class InvalidQueryArgumentValueError(ParsingError): super().__init__(message) -class Query: +class Query(ABC): """An abstract class representing a query into the item database. """ - def clause(self) -> Tuple[None, Tuple]: + def clause(self) -> Tuple[Optional[str], Sequence[Any]]: """Generate an SQLite expression implementing the query. Return (clause, subvals) where clause is a valid sqlite WHERE clause implementing the query and subvals is a list of items to be substituted for ?s in the clause. + + The default implementation returns None, falling back to a slow query + using `match()`. """ return None, () + @abstractmethod def match(self, item: Item): """Check whether this query matches a given Item. Can be used to perform queries on arbitrary sets of Items. """ - raise NotImplementedError + ... def __repr__(self) -> str: return f"{self.__class__.__name__}()" @@ -93,7 +110,21 @@ class Query: return type(self) == type(other) def __hash__(self) -> int: - return 0 + """Minimalistic default implementation of a hash. + + Given the implementation if __eq__ above, this is + certainly correct. + """ + return hash(type(self)) + + +class NamedQuery(Query): + """Non-field query, i.e. the query prefix is not a field but identifies the + query class. + """ + @abstractmethod + def __init__(self, pattern): + ... class FieldQuery(Query): @@ -104,12 +135,12 @@ class FieldQuery(Query): same matching functionality in SQLite. """ - def __init__(self, field: str, pattern: Optional[str], fast: bool = True): + def __init__(self, field: str, pattern: str, fast: bool = True): self.field = field self.pattern = pattern self.fast = fast - def col_clause(self) -> Union[None, Tuple]: + def col_clause(self) -> Union[Optional[str], Sequence[Any]]: return None, () def clause(self): @@ -156,7 +187,7 @@ class NoneQuery(FieldQuery): """A query that checks whether a field is null.""" def __init__(self, field, fast: bool = True): - super().__init__(field, None, fast) + super().__init__(field, "", fast) def col_clause(self) -> Tuple[str, Tuple]: return self.field + " IS NULL", () diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index 00d393cf8..2fa7bcfbb 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -128,6 +128,8 @@ def construct_query_part( if not query_part: return query.TrueQuery() + out_query: query.Query + # Use `model_cls` to build up a map from field (or query) names to # `Query` classes. query_classes = {} @@ -149,9 +151,11 @@ def construct_query_part( # any field. out_query = query.AnyFieldQuery(pattern, model_cls._search_fields, query_class) - else: + elif issubclass(query_class, query.NamedQuery): # Non-field query type. out_query = query_class(pattern) + else: + assert False, "Unexpected query type" # Field queries get constructed according to the name of the field # they are querying. @@ -160,8 +164,10 @@ def construct_query_part( out_query = query_class(key.lower(), pattern, key in model_cls._fields) # Non-field (named) query. - else: + elif issubclass(query_class, query.NamedQuery): out_query = query_class(pattern) + else: + assert False, "Unexpected query type" # Apply negation. if negate: @@ -172,7 +178,7 @@ def construct_query_part( # TYPING ERROR def query_from_strings( - query_cls: Type[query.Query], + query_cls: Type[query.CollectionQuery], model_cls: Type[Model], prefixes: Dict, query_parts: Collection[str], @@ -227,15 +233,15 @@ def sort_from_strings( """Create a `Sort` from a list of sort criteria (strings). """ if not sort_parts: - sort = query.NullSort() + return query.NullSort() elif len(sort_parts) == 1: - sort = construct_sort_part(model_cls, sort_parts[0], case_insensitive) + return construct_sort_part(model_cls, sort_parts[0], case_insensitive) else: sort = query.MultipleSort() for part in sort_parts: sort.add_sort(construct_sort_part(model_cls, part, case_insensitive)) - return sort + return sort def parse_sorted_query( diff --git a/beetsplug/playlist.py b/beetsplug/playlist.py index 265b8bad2..9686a046c 100644 --- a/beetsplug/playlist.py +++ b/beetsplug/playlist.py @@ -19,7 +19,7 @@ import beets from beets.util import path_as_posix -class PlaylistQuery(beets.dbcore.Query): +class PlaylistQuery(beets.dbcore.NamedQuery): """Matches files listed by a playlist file. """ def __init__(self, pattern): diff --git a/test/test_dbcore.py b/test/test_dbcore.py index c25157b85..980ebd137 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -32,7 +32,7 @@ class SortFixture(dbcore.query.FieldSort): pass -class QueryFixture(dbcore.query.Query): +class QueryFixture(dbcore.query.NamedQuery): def __init__(self, pattern): self.pattern = pattern