typings: corrections for dbcore/queryparse

- Add NamedQuery abstract class to be able to express the expectation
  that a query should be such a query (and have a specific constructor
  signature) in construct_query_part
- slightly (and probably completely irrelevantly) improve Query.__hash__
- also, sprinkle some ABC/abstractmethod around to clarify things
This commit is contained in:
wisp3rwind 2023-02-23 20:43:24 +01:00
parent 7d05e01b85
commit 09d2c87f29
5 changed files with 55 additions and 18 deletions

View file

@ -17,7 +17,7 @@ Library.
"""
from .db import Model, Database
from .query import Query, FieldQuery, MatchQuery, AndQuery, OrQuery
from .query import Query, FieldQuery, MatchQuery, NamedQuery, AndQuery, OrQuery
from .types import Type
from .queryparse import query_from_strings
from .queryparse import sort_from_strings

View file

@ -16,10 +16,23 @@
"""
from __future__ import annotations
from abc import ABC, abstractmethod
import re
from operator import mul
from typing import Union, Tuple, List, Optional, Pattern, Any, Type, Iterator,\
Collection, MutableMapping, Sequence
from typing import (
Any,
Collection,
Iterator,
List,
MutableMapping,
Optional,
Pattern,
Sequence,
Tuple,
Type,
Union,
)
from beets import util
from datetime import datetime, timedelta
@ -67,24 +80,28 @@ class InvalidQueryArgumentValueError(ParsingError):
super().__init__(message)
class Query:
class Query(ABC):
"""An abstract class representing a query into the item database.
"""
def clause(self) -> Tuple[None, Tuple]:
def clause(self) -> Tuple[Optional[str], Sequence[Any]]:
"""Generate an SQLite expression implementing the query.
Return (clause, subvals) where clause is a valid sqlite
WHERE clause implementing the query and subvals is a list of
items to be substituted for ?s in the clause.
The default implementation returns None, falling back to a slow query
using `match()`.
"""
return None, ()
@abstractmethod
def match(self, item: Item):
"""Check whether this query matches a given Item. Can be used to
perform queries on arbitrary sets of Items.
"""
raise NotImplementedError
...
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
@ -93,7 +110,21 @@ class Query:
return type(self) == type(other)
def __hash__(self) -> int:
return 0
"""Minimalistic default implementation of a hash.
Given the implementation if __eq__ above, this is
certainly correct.
"""
return hash(type(self))
class NamedQuery(Query):
"""Non-field query, i.e. the query prefix is not a field but identifies the
query class.
"""
@abstractmethod
def __init__(self, pattern):
...
class FieldQuery(Query):
@ -104,12 +135,12 @@ class FieldQuery(Query):
same matching functionality in SQLite.
"""
def __init__(self, field: str, pattern: Optional[str], fast: bool = True):
def __init__(self, field: str, pattern: str, fast: bool = True):
self.field = field
self.pattern = pattern
self.fast = fast
def col_clause(self) -> Union[None, Tuple]:
def col_clause(self) -> Union[Optional[str], Sequence[Any]]:
return None, ()
def clause(self):
@ -156,7 +187,7 @@ class NoneQuery(FieldQuery):
"""A query that checks whether a field is null."""
def __init__(self, field, fast: bool = True):
super().__init__(field, None, fast)
super().__init__(field, "", fast)
def col_clause(self) -> Tuple[str, Tuple]:
return self.field + " IS NULL", ()

View file

@ -128,6 +128,8 @@ def construct_query_part(
if not query_part:
return query.TrueQuery()
out_query: query.Query
# Use `model_cls` to build up a map from field (or query) names to
# `Query` classes.
query_classes = {}
@ -149,9 +151,11 @@ def construct_query_part(
# any field.
out_query = query.AnyFieldQuery(pattern, model_cls._search_fields,
query_class)
else:
elif issubclass(query_class, query.NamedQuery):
# Non-field query type.
out_query = query_class(pattern)
else:
assert False, "Unexpected query type"
# Field queries get constructed according to the name of the field
# they are querying.
@ -160,8 +164,10 @@ def construct_query_part(
out_query = query_class(key.lower(), pattern, key in model_cls._fields)
# Non-field (named) query.
else:
elif issubclass(query_class, query.NamedQuery):
out_query = query_class(pattern)
else:
assert False, "Unexpected query type"
# Apply negation.
if negate:
@ -172,7 +178,7 @@ def construct_query_part(
# TYPING ERROR
def query_from_strings(
query_cls: Type[query.Query],
query_cls: Type[query.CollectionQuery],
model_cls: Type[Model],
prefixes: Dict,
query_parts: Collection[str],
@ -227,15 +233,15 @@ def sort_from_strings(
"""Create a `Sort` from a list of sort criteria (strings).
"""
if not sort_parts:
sort = query.NullSort()
return query.NullSort()
elif len(sort_parts) == 1:
sort = construct_sort_part(model_cls, sort_parts[0], case_insensitive)
return construct_sort_part(model_cls, sort_parts[0], case_insensitive)
else:
sort = query.MultipleSort()
for part in sort_parts:
sort.add_sort(construct_sort_part(model_cls, part,
case_insensitive))
return sort
return sort
def parse_sorted_query(

View file

@ -19,7 +19,7 @@ import beets
from beets.util import path_as_posix
class PlaylistQuery(beets.dbcore.Query):
class PlaylistQuery(beets.dbcore.NamedQuery):
"""Matches files listed by a playlist file.
"""
def __init__(self, pattern):

View file

@ -32,7 +32,7 @@ class SortFixture(dbcore.query.FieldSort):
pass
class QueryFixture(dbcore.query.Query):
class QueryFixture(dbcore.query.NamedQuery):
def __init__(self, pattern):
self.pattern = pattern