Merge pull request #4826 from wisp3rwind/dbcore_typing_0

typings: corrections for dbcore/queryparse
This commit is contained in:
Adrian Sampson 2023-06-23 17:44:13 -07:00 committed by GitHub
commit 854fec2634
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 58 additions and 20 deletions

View file

@ -17,7 +17,7 @@ Library.
"""
from .db import Model, Database
from .query import Query, FieldQuery, MatchQuery, AndQuery, OrQuery
from .query import Query, FieldQuery, MatchQuery, NamedQuery, AndQuery, OrQuery
from .types import Type
from .queryparse import query_from_strings
from .queryparse import sort_from_strings

View file

@ -16,10 +16,23 @@
"""
from __future__ import annotations
from abc import ABC, abstractmethod
import re
from operator import mul
from typing import Union, Tuple, List, Optional, Pattern, Any, Type, Iterator,\
Collection, MutableMapping, Sequence
from typing import (
Any,
Collection,
Iterator,
List,
MutableMapping,
Optional,
Pattern,
Sequence,
Tuple,
Type,
Union,
)
from beets import util
from datetime import datetime, timedelta
@ -66,24 +79,28 @@ class InvalidQueryArgumentValueError(ParsingError):
super().__init__(message)
class Query:
class Query(ABC):
"""An abstract class representing a query into the database.
"""
def clause(self) -> Tuple[None, Tuple]:
def clause(self) -> Tuple[Optional[str], Sequence[Any]]:
"""Generate an SQLite expression implementing the query.
Return (clause, subvals) where clause is a valid sqlite
WHERE clause implementing the query and subvals is a list of
items to be substituted for ?s in the clause.
The default implementation returns None, falling back to a slow query
using `match()`.
"""
return None, ()
def match(self, obj: Model) -> bool:
@abstractmethod
def match(self, obj: Model):
"""Check whether this query matches a given Model. Can be used to
perform queries on arbitrary sets of Model.
"""
raise NotImplementedError
...
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
@ -92,7 +109,21 @@ class Query:
return type(self) == type(other)
def __hash__(self) -> int:
return 0
"""Minimalistic default implementation of a hash.
Given the implementation if __eq__ above, this is
certainly correct.
"""
return hash(type(self))
class NamedQuery(Query):
"""Non-field query, i.e. the query prefix is not a field but identifies the
query class.
"""
@abstractmethod
def __init__(self, pattern):
...
class FieldQuery(Query):
@ -103,12 +134,12 @@ class FieldQuery(Query):
same matching functionality in SQLite.
"""
def __init__(self, field: str, pattern: Optional[str], fast: bool = True):
def __init__(self, field: str, pattern: str, fast: bool = True):
self.field = field
self.pattern = pattern
self.fast = fast
def col_clause(self) -> Union[None, Tuple]:
def col_clause(self) -> Union[Optional[str], Sequence[Any]]:
return None, ()
def clause(self):
@ -155,7 +186,7 @@ class NoneQuery(FieldQuery):
"""A query that checks whether a field is null."""
def __init__(self, field, fast: bool = True):
super().__init__(field, None, fast)
super().__init__(field, "", fast)
def col_clause(self) -> Tuple[str, Tuple]:
return self.field + " IS NULL", ()

View file

@ -128,6 +128,8 @@ def construct_query_part(
if not query_part:
return query.TrueQuery()
out_query: query.Query
# Use `model_cls` to build up a map from field (or query) names to
# `Query` classes.
query_classes = {}
@ -149,9 +151,11 @@ def construct_query_part(
# any field.
out_query = query.AnyFieldQuery(pattern, model_cls._search_fields,
query_class)
else:
elif issubclass(query_class, query.NamedQuery):
# Non-field query type.
out_query = query_class(pattern)
else:
assert False, "Unexpected query type"
# Field queries get constructed according to the name of the field
# they are querying.
@ -160,8 +164,10 @@ def construct_query_part(
out_query = query_class(key.lower(), pattern, key in model_cls._fields)
# Non-field (named) query.
else:
elif issubclass(query_class, query.NamedQuery):
out_query = query_class(pattern)
else:
assert False, "Unexpected query type"
# Apply negation.
if negate:
@ -172,7 +178,7 @@ def construct_query_part(
# TYPING ERROR
def query_from_strings(
query_cls: Type[query.Query],
query_cls: Type[query.CollectionQuery],
model_cls: Type[Model],
prefixes: Dict,
query_parts: Collection[str],
@ -227,15 +233,15 @@ def sort_from_strings(
"""Create a `Sort` from a list of sort criteria (strings).
"""
if not sort_parts:
sort = query.NullSort()
return query.NullSort()
elif len(sort_parts) == 1:
sort = construct_sort_part(model_cls, sort_parts[0], case_insensitive)
return construct_sort_part(model_cls, sort_parts[0], case_insensitive)
else:
sort = query.MultipleSort()
for part in sort_parts:
sort.add_sort(construct_sort_part(model_cls, part,
case_insensitive))
return sort
return sort
def parse_sorted_query(

View file

@ -15,11 +15,12 @@
import os
import fnmatch
import tempfile
from typing import Any, Optional, Sequence, Tuple
import beets
from beets.util import path_as_posix
class PlaylistQuery(beets.dbcore.Query):
class PlaylistQuery(beets.dbcore.NamedQuery):
"""Matches files listed by a playlist file.
"""
def __init__(self, pattern):
@ -65,7 +66,7 @@ class PlaylistQuery(beets.dbcore.Query):
f.close()
break
def col_clause(self):
def clause(self) -> Tuple[Optional[str], Sequence[Any]]:
if not self.paths:
# Playlist is empty
return '0', ()

View file

@ -32,7 +32,7 @@ class SortFixture(dbcore.query.FieldSort):
pass
class QueryFixture(dbcore.query.Query):
class QueryFixture(dbcore.query.NamedQuery):
def __init__(self, pattern):
self.pattern = pattern