mirror of
https://github.com/beetbox/beets.git
synced 2026-01-30 12:02:41 +01:00
Merge pull request #4858 from wisp3rwind/dbcore_typing_4
typing: Wrap up dbcore
This commit is contained in:
commit
ab3e2a98d1
3 changed files with 168 additions and 102 deletions
|
|
@ -16,6 +16,7 @@
|
|||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from abc import ABC
|
||||
import time
|
||||
import os
|
||||
import re
|
||||
|
|
@ -25,22 +26,36 @@ import sqlite3
|
|||
import contextlib
|
||||
from sqlite3 import Connection
|
||||
from types import TracebackType
|
||||
from typing import Iterable, Type, List, Tuple, Optional, Union, \
|
||||
Dict, Any, Generator, Iterator, Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
Generator,
|
||||
Generic,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from unidecode import unidecode
|
||||
|
||||
import beets
|
||||
from beets.util import functemplate
|
||||
from beets.util import py3_path
|
||||
from beets.dbcore import types
|
||||
from . import types
|
||||
from .query import MatchQuery, NullSort, TrueQuery, AndQuery, Query, \
|
||||
FieldQuery, Sort
|
||||
from collections.abc import Mapping
|
||||
FieldQuery, Sort, FieldSort
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from beets.library import LibModel
|
||||
from ..util.functemplate import Template
|
||||
|
||||
|
||||
|
|
@ -53,7 +68,7 @@ class DBAccessError(Exception):
|
|||
"""
|
||||
|
||||
|
||||
class FormattedMapping(Mapping):
|
||||
class FormattedMapping(Mapping[str, str]):
|
||||
"""A `dict`-like formatted view of a model.
|
||||
|
||||
The accessor `mapping[key]` returns the formatted version of
|
||||
|
|
@ -71,7 +86,7 @@ class FormattedMapping(Mapping):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model: 'Model',
|
||||
model: Model,
|
||||
included_keys: str = ALL_KEYS,
|
||||
for_path: bool = False,
|
||||
):
|
||||
|
|
@ -83,31 +98,39 @@ class FormattedMapping(Mapping):
|
|||
else:
|
||||
self.model_keys = included_keys
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> str:
|
||||
if key in self.model_keys:
|
||||
return self._get_formatted(self.model, key)
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self) -> Iterable[str]:
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self.model_keys)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.model_keys)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# The following signature is incompatible with `Mapping[str, str]`, since
|
||||
# the return type doesn't include `None` (but `default` can be `None`).
|
||||
def get( # type: ignore
|
||||
self,
|
||||
key: str,
|
||||
default: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Similar to Mapping.get(key, default), but always formats to str.
|
||||
"""
|
||||
if default is None:
|
||||
default = self.model._type(key).format(None)
|
||||
return super().get(key, default)
|
||||
|
||||
def _get_formatted(self, model, key):
|
||||
def _get_formatted(self, model: Model, key: str) -> str:
|
||||
value = model._type(key).format(model.get(key))
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf-8', 'ignore')
|
||||
|
||||
if self.for_path:
|
||||
sep_repl = beets.config['path_sep_replace'].as_str()
|
||||
sep_drive = beets.config['drive_sep_replace'].as_str()
|
||||
sep_repl = cast(str, beets.config['path_sep_replace'].as_str())
|
||||
sep_drive = cast(str, beets.config['drive_sep_replace'].as_str())
|
||||
|
||||
if re.match(r'^\w:', value):
|
||||
value = re.sub(r'(?<=^\w):', sep_drive, value)
|
||||
|
|
@ -119,6 +142,15 @@ class FormattedMapping(Mapping):
|
|||
return value
|
||||
|
||||
|
||||
# NOTE: This seems like it should be a `Mapping`, i.e.
|
||||
# ```
|
||||
# class LazyConvertDict(Mapping[str, Any])
|
||||
# ```
|
||||
# but there are some conflicts with the `Mapping` protocol such that we
|
||||
# can't do this without changing behaviour: In particular, iterators returned
|
||||
# by some methods build intermediate lists, such that modification of the
|
||||
# `LazyConvertDict` becomes safe during iteration. Some code does in fact rely
|
||||
# on this.
|
||||
class LazyConvertDict:
|
||||
"""Lazily convert types for attributes fetched from the database
|
||||
"""
|
||||
|
|
@ -126,60 +158,61 @@ class LazyConvertDict:
|
|||
def __init__(self, model_cls: 'Model'):
|
||||
"""Initialize the object empty
|
||||
"""
|
||||
self.data = {}
|
||||
# FIXME: Dict[str, SQLiteType]
|
||||
self._data: Dict[str, Any] = {}
|
||||
self.model_cls = model_cls
|
||||
self._converted = {}
|
||||
self._converted: Dict[str, Any] = {}
|
||||
|
||||
def init(self, data):
|
||||
def init(self, data: Dict[str, Any]):
|
||||
"""Set the base data that should be lazily converted
|
||||
"""
|
||||
self.data = data
|
||||
self._data = data
|
||||
|
||||
def _convert(self, key, value):
|
||||
"""Convert the attribute type according the the SQL type
|
||||
def _convert(self, key: str, value: Any):
|
||||
"""Convert the attribute type according to the SQL type
|
||||
"""
|
||||
return self.model_cls._type(key).from_sql(value)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
def __setitem__(self, key: str, value: Any):
|
||||
"""Set an attribute value, assume it's already converted
|
||||
"""
|
||||
self._converted[key] = value
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""Get an attribute value, converting the type on demand
|
||||
if needed
|
||||
"""
|
||||
if key in self._converted:
|
||||
return self._converted[key]
|
||||
elif key in self.data:
|
||||
value = self._convert(key, self.data[key])
|
||||
elif key in self._data:
|
||||
value = self._convert(key, self._data[key])
|
||||
self._converted[key] = value
|
||||
return value
|
||||
|
||||
def __delitem__(self, key):
|
||||
def __delitem__(self, key: str):
|
||||
"""Delete both converted and base data
|
||||
"""
|
||||
if key in self._converted:
|
||||
del self._converted[key]
|
||||
if key in self.data:
|
||||
del self.data[key]
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
|
||||
def keys(self) -> List[str]:
|
||||
"""Get a list of available field names for this object.
|
||||
"""
|
||||
return list(self._converted.keys()) + list(self.data.keys())
|
||||
return list(self._converted.keys()) + list(self._data.keys())
|
||||
|
||||
def copy(self) -> 'LazyConvertDict':
|
||||
def copy(self) -> LazyConvertDict:
|
||||
"""Create a copy of the object.
|
||||
"""
|
||||
new = self.__class__(self.model_cls)
|
||||
new.data = self.data.copy()
|
||||
new._data = self._data.copy()
|
||||
new._converted = self._converted.copy()
|
||||
return new
|
||||
|
||||
# Act like a dictionary.
|
||||
|
||||
def update(self, values):
|
||||
def update(self, values: Mapping[str, Any]):
|
||||
"""Assign all values in the given dict.
|
||||
"""
|
||||
for key, value in values.items():
|
||||
|
|
@ -192,7 +225,7 @@ class LazyConvertDict:
|
|||
for key in self:
|
||||
yield key, self[key]
|
||||
|
||||
def get(self, key, default=None):
|
||||
def get(self, key: str, default: Optional[Any] = None):
|
||||
"""Get the value for a given key or `default` if it does not
|
||||
exist.
|
||||
"""
|
||||
|
|
@ -201,21 +234,30 @@ class LazyConvertDict:
|
|||
else:
|
||||
return default
|
||||
|
||||
def __contains__(self, key) -> bool:
|
||||
def __contains__(self, key: Any) -> bool:
|
||||
"""Determine whether `key` is an attribute on this object.
|
||||
"""
|
||||
return key in self.keys()
|
||||
return key in self._converted or key in self._data
|
||||
|
||||
def __iter__(self) -> Iterable[str]:
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
"""Iterate over the available field names (excluding computed
|
||||
fields).
|
||||
"""
|
||||
# NOTE: It would be nice to use the following:
|
||||
# yield from self._converted
|
||||
# yield from self._data
|
||||
# but that won't work since some code relies on modifying `self`
|
||||
# during iteration.
|
||||
return iter(self.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
# FIXME: This is incorrect due to duplication of keys
|
||||
return len(self._converted) + len(self._data)
|
||||
|
||||
|
||||
# Abstract base for model classes.
|
||||
|
||||
class Model:
|
||||
class Model(ABC):
|
||||
"""An abstract object representing an object in the database. Model
|
||||
objects act like dictionaries (i.e., they allow subscript access like
|
||||
``obj['field']``). The same field set is available via attribute
|
||||
|
|
@ -241,34 +283,34 @@ class Model:
|
|||
|
||||
# Abstract components (to be provided by subclasses).
|
||||
|
||||
_table = None
|
||||
_table: str
|
||||
"""The main SQLite table name.
|
||||
"""
|
||||
|
||||
_flex_table = None
|
||||
_flex_table: str
|
||||
"""The flex field SQLite table name.
|
||||
"""
|
||||
|
||||
_fields = {}
|
||||
_fields: Dict[str, types.Type] = {}
|
||||
"""A mapping indicating available "fixed" fields on this type. The
|
||||
keys are field names and the values are `Type` objects.
|
||||
"""
|
||||
|
||||
_search_fields = ()
|
||||
_search_fields: Sequence[str] = ()
|
||||
"""The fields that should be queried by default by unqualified query
|
||||
terms.
|
||||
"""
|
||||
|
||||
_types = {}
|
||||
_types: Dict[str, types.Type] = {}
|
||||
"""Optional Types for non-fixed (i.e., flexible and computed) fields.
|
||||
"""
|
||||
|
||||
_sorts = {}
|
||||
_sorts: Dict[str, Type[FieldSort]] = {}
|
||||
"""Optional named sort criteria. The keys are strings and the values
|
||||
are subclasses of `Sort`.
|
||||
"""
|
||||
|
||||
_queries = {}
|
||||
_queries: Dict[str, Type[Query]] = {}
|
||||
"""Named queries that use a field-like `name:value` syntax but which
|
||||
do not relate to any specific field.
|
||||
"""
|
||||
|
|
@ -301,12 +343,12 @@ class Model:
|
|||
|
||||
# Basic operation.
|
||||
|
||||
def __init__(self, db: Optional['Database'] = None, **values):
|
||||
def __init__(self, db: Optional[Database] = None, **values):
|
||||
"""Create a new object with an optional Database association and
|
||||
initial field values.
|
||||
"""
|
||||
self._db = db
|
||||
self._dirty = set()
|
||||
self._dirty: Set[str] = set()
|
||||
self._values_fixed = LazyConvertDict(self)
|
||||
self._values_flex = LazyConvertDict(self)
|
||||
|
||||
|
|
@ -316,11 +358,11 @@ class Model:
|
|||
|
||||
@classmethod
|
||||
def _awaken(
|
||||
cls: Type['Model'],
|
||||
db: 'Database' = None,
|
||||
fixed_values: Mapping = {},
|
||||
flex_values: Mapping = {},
|
||||
) -> 'Model':
|
||||
cls: Type[AnyModel],
|
||||
db: Optional[Database] = None,
|
||||
fixed_values: Dict[str, Any] = {},
|
||||
flex_values: Dict[str, Any] = {},
|
||||
) -> AnyModel:
|
||||
"""Create an object with values drawn from the database.
|
||||
|
||||
This is a performance optimization: the checks involved with
|
||||
|
|
@ -347,7 +389,7 @@ class Model:
|
|||
if self._db:
|
||||
self._revision = self._db.revision
|
||||
|
||||
def _check_db(self, need_id: bool = True):
|
||||
def _check_db(self, need_id: bool = True) -> Database:
|
||||
"""Ensure that this object is associated with a database row: it
|
||||
has a reference to a database (`_db`) and an id. A ValueError
|
||||
exception is raised otherwise.
|
||||
|
|
@ -359,6 +401,8 @@ class Model:
|
|||
if need_id and not self.id:
|
||||
raise ValueError('{} has no id'.format(type(self).__name__))
|
||||
|
||||
return self._db
|
||||
|
||||
def copy(self) -> 'Model':
|
||||
"""Create a copy of the model object.
|
||||
|
||||
|
|
@ -490,7 +534,7 @@ class Model:
|
|||
"""
|
||||
return key in self.keys(computed=True)
|
||||
|
||||
def __iter__(self) -> Iterable[str]:
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
"""Iterate over the available field names (excluding computed
|
||||
fields).
|
||||
"""
|
||||
|
|
@ -521,14 +565,14 @@ class Model:
|
|||
|
||||
# Database interaction (CRUD methods).
|
||||
|
||||
def store(self, fields: bool = None):
|
||||
def store(self, fields: Optional[Iterable[str]] = None):
|
||||
"""Save the object's metadata into the library database.
|
||||
:param fields: the fields to be stored. If not specified, all fields
|
||||
will be.
|
||||
"""
|
||||
if fields is None:
|
||||
fields = self._fields
|
||||
self._check_db()
|
||||
db = self._check_db()
|
||||
|
||||
# Build assignments for query.
|
||||
assignments = []
|
||||
|
|
@ -539,13 +583,13 @@ class Model:
|
|||
assignments.append(key + '=?')
|
||||
value = self._type(key).to_sql(self[key])
|
||||
subvars.append(value)
|
||||
assignments = ','.join(assignments)
|
||||
|
||||
with self._db.transaction() as tx:
|
||||
with db.transaction() as tx:
|
||||
# Main table update.
|
||||
if assignments:
|
||||
query = 'UPDATE {} SET {} WHERE id=?'.format(
|
||||
self._table, assignments
|
||||
self._table,
|
||||
','.join(assignments)
|
||||
)
|
||||
subvars.append(self.id)
|
||||
tx.mutate(query, subvars)
|
||||
|
|
@ -577,11 +621,11 @@ class Model:
|
|||
If check_revision is true, the database is only queried loaded when a
|
||||
transaction has been committed since the item was last loaded.
|
||||
"""
|
||||
self._check_db()
|
||||
if not self._dirty and self._db.revision == self._revision:
|
||||
db = self._check_db()
|
||||
if not self._dirty and db.revision == self._revision:
|
||||
# Exit early
|
||||
return
|
||||
stored_obj = self._db._get(type(self), self.id)
|
||||
stored_obj = db._get(type(self), self.id)
|
||||
assert stored_obj is not None, f"object {self.id} not in DB"
|
||||
self._values_fixed = LazyConvertDict(self)
|
||||
self._values_flex = LazyConvertDict(self)
|
||||
|
|
@ -591,8 +635,8 @@ class Model:
|
|||
def remove(self):
|
||||
"""Remove the object's associated rows from the database.
|
||||
"""
|
||||
self._check_db()
|
||||
with self._db.transaction() as tx:
|
||||
db = self._check_db()
|
||||
with db.transaction() as tx:
|
||||
tx.mutate(
|
||||
f'DELETE FROM {self._table} WHERE id=?',
|
||||
(self.id,)
|
||||
|
|
@ -612,9 +656,9 @@ class Model:
|
|||
"""
|
||||
if db:
|
||||
self._db = db
|
||||
self._check_db(False)
|
||||
db = self._check_db(False)
|
||||
|
||||
with self._db.transaction() as tx:
|
||||
with db.transaction() as tx:
|
||||
new_id = tx.mutate(
|
||||
f'INSERT INTO {self._table} DEFAULT VALUES'
|
||||
)
|
||||
|
|
@ -652,9 +696,12 @@ class Model:
|
|||
"""
|
||||
# Perform substitution.
|
||||
if isinstance(template, str):
|
||||
template = functemplate.template(template)
|
||||
return template.substitute(self.formatted(for_path=for_path),
|
||||
self._template_funcs())
|
||||
t = functemplate.template(template)
|
||||
else:
|
||||
# Help out mypy
|
||||
t = template
|
||||
return t.substitute(self.formatted(for_path=for_path),
|
||||
self._template_funcs())
|
||||
|
||||
# Parsing.
|
||||
|
||||
|
|
@ -703,24 +750,28 @@ class Model:
|
|||
|
||||
# Database controller and supporting interfaces.
|
||||
|
||||
class Results:
|
||||
|
||||
AnyModel = TypeVar("AnyModel", bound=Model)
|
||||
|
||||
|
||||
class Results(Generic[AnyModel]):
|
||||
"""An item query result set. Iterating over the collection lazily
|
||||
constructs LibModel objects that reflect database rows.
|
||||
constructs Model objects that reflect database rows.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_class: Type['LibModel'],
|
||||
model_class: Type[AnyModel],
|
||||
rows: List[Mapping],
|
||||
db: 'Database',
|
||||
flex_rows,
|
||||
query: Optional[FieldQuery] = None,
|
||||
query: Optional[Query] = None,
|
||||
sort=None,
|
||||
):
|
||||
"""Create a result set that will construct objects of type
|
||||
`model_class`.
|
||||
|
||||
`model_class` is a subclass of `LibModel` that will be
|
||||
`model_class` is a subclass of `Model` that will be
|
||||
constructed. `rows` is a query result: a list of mappings. The
|
||||
new objects will be associated with the database `db`.
|
||||
|
||||
|
|
@ -746,9 +797,9 @@ class Results:
|
|||
|
||||
# The materialized objects corresponding to rows that have been
|
||||
# consumed.
|
||||
self._objects = []
|
||||
self._objects: List[AnyModel] = []
|
||||
|
||||
def _get_objects(self) -> Iterable[Model]:
|
||||
def _get_objects(self) -> Iterator[AnyModel]:
|
||||
"""Construct and generate Model objects for they query. The
|
||||
objects are returned in the order emitted from the database; no
|
||||
slow sort is applied.
|
||||
|
|
@ -783,7 +834,7 @@ class Results:
|
|||
yield obj
|
||||
break
|
||||
|
||||
def __iter__(self) -> Iterable[Model]:
|
||||
def __iter__(self) -> Iterator[AnyModel]:
|
||||
"""Construct and generate Model objects for all matching
|
||||
objects, in sorted order.
|
||||
"""
|
||||
|
|
@ -799,7 +850,7 @@ class Results:
|
|||
def _get_indexed_flex_attrs(self) -> Mapping:
|
||||
""" Index flexible attributes by the entity id they belong to
|
||||
"""
|
||||
flex_values = {}
|
||||
flex_values: Dict[int, Dict[str, Any]] = {}
|
||||
for row in self.flex_rows:
|
||||
if row['entity_id'] not in flex_values:
|
||||
flex_values[row['entity_id']] = {}
|
||||
|
|
@ -808,7 +859,7 @@ class Results:
|
|||
|
||||
return flex_values
|
||||
|
||||
def _make_model(self, row, flex_values: Dict = {}) -> Model:
|
||||
def _make_model(self, row, flex_values: Dict = {}) -> AnyModel:
|
||||
""" Create a Model object for the given row
|
||||
"""
|
||||
cols = dict(row)
|
||||
|
|
@ -864,7 +915,7 @@ class Results:
|
|||
except StopIteration:
|
||||
raise IndexError(f'result index {n} out of range')
|
||||
|
||||
def get(self) -> Optional[Model]:
|
||||
def get(self) -> Optional[AnyModel]:
|
||||
"""Return the first matching object, or None if no objects
|
||||
match.
|
||||
"""
|
||||
|
|
@ -922,14 +973,14 @@ class Transaction:
|
|||
self._mutated = False
|
||||
self.db._db_lock.release()
|
||||
|
||||
def query(self, statement: str, subvals: Iterable = ()) -> List:
|
||||
def query(self, statement: str, subvals: Sequence = ()) -> List:
|
||||
"""Execute an SQL statement with substitution values and return
|
||||
a list of rows from the database.
|
||||
"""
|
||||
cursor = self.db._connection().execute(statement, subvals)
|
||||
return cursor.fetchall()
|
||||
|
||||
def mutate(self, statement: str, subvals: Iterable = ()) -> Any:
|
||||
def mutate(self, statement: str, subvals: Sequence = ()) -> Any:
|
||||
"""Execute an SQL statement with substitution values and return
|
||||
the row ID of the last affected row.
|
||||
"""
|
||||
|
|
@ -960,7 +1011,7 @@ class Database:
|
|||
the backend.
|
||||
"""
|
||||
|
||||
_models = ()
|
||||
_models: Sequence[Type[Model]] = ()
|
||||
"""The Model subclasses representing tables in this database.
|
||||
"""
|
||||
|
||||
|
|
@ -981,9 +1032,10 @@ class Database:
|
|||
self.path = path
|
||||
self.timeout = timeout
|
||||
|
||||
self._connections = {}
|
||||
self._tx_stacks = defaultdict(list)
|
||||
self._extensions = []
|
||||
self._connections: Dict[int, sqlite3.Connection] = {}
|
||||
self._tx_stacks: DefaultDict[int, List[Transaction]] = \
|
||||
defaultdict(list)
|
||||
self._extensions: List[str] = []
|
||||
|
||||
# A lock to protect the _connections and _tx_stacks maps, which
|
||||
# both map thread IDs to private resources.
|
||||
|
|
@ -1011,6 +1063,11 @@ class Database:
|
|||
One connection object is created per thread.
|
||||
"""
|
||||
thread_id = threading.current_thread().ident
|
||||
# Help the type checker: ident can only be None if the thread has not
|
||||
# been started yet; but since this results from current_thread(), that
|
||||
# can't happen
|
||||
assert thread_id is not None
|
||||
|
||||
with self._shared_map_lock:
|
||||
if thread_id in self._connections:
|
||||
return self._connections[thread_id]
|
||||
|
|
@ -1075,6 +1132,11 @@ class Database:
|
|||
the stack map. Transactions should never migrate across threads.
|
||||
"""
|
||||
thread_id = threading.current_thread().ident
|
||||
# Help the type checker: ident can only be None if the thread has not
|
||||
# been started yet; but since this results from current_thread(), that
|
||||
# can't happen
|
||||
assert thread_id is not None
|
||||
|
||||
with self._shared_map_lock:
|
||||
yield self._tx_stacks[thread_id]
|
||||
|
||||
|
|
@ -1084,7 +1146,7 @@ class Database:
|
|||
"""
|
||||
return Transaction(self)
|
||||
|
||||
def load_extension(self, path):
|
||||
def load_extension(self, path: str):
|
||||
"""Load an SQLite extension into all open connections."""
|
||||
if not self.supports_extensions:
|
||||
raise ValueError(
|
||||
|
|
@ -1152,11 +1214,11 @@ class Database:
|
|||
# Querying.
|
||||
|
||||
def _fetch(
|
||||
self,
|
||||
model_cls: Type['LibModel'],
|
||||
query: Optional[Query] = None,
|
||||
sort: Optional[Sort] = None,
|
||||
) -> Results:
|
||||
self,
|
||||
model_cls: Type[AnyModel],
|
||||
query: Optional[Query] = None,
|
||||
sort: Optional[Sort] = None,
|
||||
) -> Results[AnyModel]:
|
||||
"""Fetch the objects of type `model_cls` matching the given
|
||||
query. The query may be given as a string, string sequence, a
|
||||
Query object, or None (to fetch everything). `sort` is an
|
||||
|
|
@ -1180,10 +1242,10 @@ class Database:
|
|||
SELECT * FROM {} WHERE entity_id IN
|
||||
(SELECT id FROM {} WHERE {});
|
||||
""".format(
|
||||
model_cls._flex_table,
|
||||
model_cls._table,
|
||||
where or '1',
|
||||
)
|
||||
model_cls._flex_table,
|
||||
model_cls._table,
|
||||
where or '1',
|
||||
)
|
||||
)
|
||||
|
||||
with self.transaction() as tx:
|
||||
|
|
@ -1196,7 +1258,11 @@ class Database:
|
|||
sort if sort.is_slow() else None, # Slow sort component.
|
||||
)
|
||||
|
||||
def _get(self, model_cls: Union[Type[Model], Type[LibModel]], id) -> Model:
|
||||
def _get(
|
||||
self,
|
||||
model_cls: Type[AnyModel],
|
||||
id,
|
||||
) -> Optional[AnyModel]:
|
||||
"""Get a Model object by its id or None if the id does not
|
||||
exist.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -334,9 +334,9 @@ class BytesQuery(FieldQuery[bytes]):
|
|||
else:
|
||||
bytes_pattern = pattern
|
||||
self.buf_pattern = memoryview(bytes_pattern)
|
||||
elif isinstance(self.pattern, memoryview):
|
||||
self.buf_pattern = self.pattern
|
||||
bytes_pattern = bytes(self.pattern)
|
||||
elif isinstance(pattern, memoryview):
|
||||
self.buf_pattern = pattern
|
||||
bytes_pattern = bytes(pattern)
|
||||
else:
|
||||
raise ValueError("pattern must be bytes, str, or memoryview")
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from typing import Dict, Type, Tuple, Optional, Collection, List, \
|
|||
Sequence
|
||||
|
||||
from . import query, Model
|
||||
from .query import Sort
|
||||
from .query import Query, Sort
|
||||
|
||||
PARSE_QUERY_PART_REGEX = re.compile(
|
||||
# Non-capturing optional segment for the keyword.
|
||||
|
|
@ -132,7 +132,7 @@ def construct_query_part(
|
|||
|
||||
# Use `model_cls` to build up a map from field (or query) names to
|
||||
# `Query` classes.
|
||||
query_classes = {}
|
||||
query_classes: Dict[str, Type[Query]] = {}
|
||||
for k, t in itertools.chain(model_cls._fields.items(),
|
||||
model_cls._types.items()):
|
||||
query_classes[k] = t.query
|
||||
|
|
|
|||
Loading…
Reference in a new issue