Merge pull request #4858 from wisp3rwind/dbcore_typing_4

typing: Wrap up dbcore
This commit is contained in:
Benedikt 2023-07-26 13:40:57 +02:00 committed by GitHub
commit ab3e2a98d1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 168 additions and 102 deletions

View file

@ -16,6 +16,7 @@
"""
from __future__ import annotations
from abc import ABC
import time
import os
import re
@ -25,22 +26,36 @@ import sqlite3
import contextlib
from sqlite3 import Connection
from types import TracebackType
from typing import Iterable, Type, List, Tuple, Optional, Union, \
Dict, Any, Generator, Iterator, Callable
from typing import (
Any,
Callable,
cast,
DefaultDict,
Dict,
Generator,
Generic,
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
)
from unidecode import unidecode
import beets
from beets.util import functemplate
from beets.util import py3_path
from beets.dbcore import types
from . import types
from .query import MatchQuery, NullSort, TrueQuery, AndQuery, Query, \
FieldQuery, Sort
from collections.abc import Mapping
FieldQuery, Sort, FieldSort
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from beets.library import LibModel
from ..util.functemplate import Template
@ -53,7 +68,7 @@ class DBAccessError(Exception):
"""
class FormattedMapping(Mapping):
class FormattedMapping(Mapping[str, str]):
"""A `dict`-like formatted view of a model.
The accessor `mapping[key]` returns the formatted version of
@ -71,7 +86,7 @@ class FormattedMapping(Mapping):
def __init__(
self,
model: 'Model',
model: Model,
included_keys: str = ALL_KEYS,
for_path: bool = False,
):
@ -83,31 +98,39 @@ class FormattedMapping(Mapping):
else:
self.model_keys = included_keys
def __getitem__(self, key):
def __getitem__(self, key: str) -> str:
if key in self.model_keys:
return self._get_formatted(self.model, key)
else:
raise KeyError(key)
def __iter__(self) -> Iterable[str]:
def __iter__(self) -> Iterator[str]:
return iter(self.model_keys)
def __len__(self) -> int:
return len(self.model_keys)
def get(self, key, default=None):
# The following signature is incompatible with `Mapping[str, str]`, since
# the return type doesn't include `None` (but `default` can be `None`).
def get( # type: ignore
self,
key: str,
default: Optional[str] = None,
) -> str:
"""Similar to Mapping.get(key, default), but always formats to str.
"""
if default is None:
default = self.model._type(key).format(None)
return super().get(key, default)
def _get_formatted(self, model, key):
def _get_formatted(self, model: Model, key: str) -> str:
value = model._type(key).format(model.get(key))
if isinstance(value, bytes):
value = value.decode('utf-8', 'ignore')
if self.for_path:
sep_repl = beets.config['path_sep_replace'].as_str()
sep_drive = beets.config['drive_sep_replace'].as_str()
sep_repl = cast(str, beets.config['path_sep_replace'].as_str())
sep_drive = cast(str, beets.config['drive_sep_replace'].as_str())
if re.match(r'^\w:', value):
value = re.sub(r'(?<=^\w):', sep_drive, value)
@ -119,6 +142,15 @@ class FormattedMapping(Mapping):
return value
# NOTE: This seems like it should be a `Mapping`, i.e.
# ```
# class LazyConvertDict(Mapping[str, Any])
# ```
# but there are some conflicts with the `Mapping` protocol such that we
# can't do this without changing behaviour: In particular, iterators returned
# by some methods build intermediate lists, such that modification of the
# `LazyConvertDict` becomes safe during iteration. Some code does in fact rely
# on this.
class LazyConvertDict:
"""Lazily convert types for attributes fetched from the database
"""
@ -126,60 +158,61 @@ class LazyConvertDict:
def __init__(self, model_cls: 'Model'):
"""Initialize the object empty
"""
self.data = {}
# FIXME: Dict[str, SQLiteType]
self._data: Dict[str, Any] = {}
self.model_cls = model_cls
self._converted = {}
self._converted: Dict[str, Any] = {}
def init(self, data):
def init(self, data: Dict[str, Any]):
"""Set the base data that should be lazily converted
"""
self.data = data
self._data = data
def _convert(self, key, value):
"""Convert the attribute type according the the SQL type
def _convert(self, key: str, value: Any):
"""Convert the attribute type according to the SQL type
"""
return self.model_cls._type(key).from_sql(value)
def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any):
"""Set an attribute value, assume it's already converted
"""
self._converted[key] = value
def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
"""Get an attribute value, converting the type on demand
if needed
"""
if key in self._converted:
return self._converted[key]
elif key in self.data:
value = self._convert(key, self.data[key])
elif key in self._data:
value = self._convert(key, self._data[key])
self._converted[key] = value
return value
def __delitem__(self, key):
def __delitem__(self, key: str):
"""Delete both converted and base data
"""
if key in self._converted:
del self._converted[key]
if key in self.data:
del self.data[key]
if key in self._data:
del self._data[key]
def keys(self) -> List[str]:
"""Get a list of available field names for this object.
"""
return list(self._converted.keys()) + list(self.data.keys())
return list(self._converted.keys()) + list(self._data.keys())
def copy(self) -> 'LazyConvertDict':
def copy(self) -> LazyConvertDict:
"""Create a copy of the object.
"""
new = self.__class__(self.model_cls)
new.data = self.data.copy()
new._data = self._data.copy()
new._converted = self._converted.copy()
return new
# Act like a dictionary.
def update(self, values):
def update(self, values: Mapping[str, Any]):
"""Assign all values in the given dict.
"""
for key, value in values.items():
@ -192,7 +225,7 @@ class LazyConvertDict:
for key in self:
yield key, self[key]
def get(self, key, default=None):
def get(self, key: str, default: Optional[Any] = None):
"""Get the value for a given key or `default` if it does not
exist.
"""
@ -201,21 +234,30 @@ class LazyConvertDict:
else:
return default
def __contains__(self, key) -> bool:
def __contains__(self, key: Any) -> bool:
"""Determine whether `key` is an attribute on this object.
"""
return key in self.keys()
return key in self._converted or key in self._data
def __iter__(self) -> Iterable[str]:
def __iter__(self) -> Iterator[str]:
"""Iterate over the available field names (excluding computed
fields).
"""
# NOTE: It would be nice to use the following:
# yield from self._converted
# yield from self._data
# but that won't work since some code relies on modifying `self`
# during iteration.
return iter(self.keys())
def __len__(self) -> int:
# FIXME: This is incorrect due to duplication of keys
return len(self._converted) + len(self._data)
# Abstract base for model classes.
class Model:
class Model(ABC):
"""An abstract object representing an object in the database. Model
objects act like dictionaries (i.e., they allow subscript access like
``obj['field']``). The same field set is available via attribute
@ -241,34 +283,34 @@ class Model:
# Abstract components (to be provided by subclasses).
_table = None
_table: str
"""The main SQLite table name.
"""
_flex_table = None
_flex_table: str
"""The flex field SQLite table name.
"""
_fields = {}
_fields: Dict[str, types.Type] = {}
"""A mapping indicating available "fixed" fields on this type. The
keys are field names and the values are `Type` objects.
"""
_search_fields = ()
_search_fields: Sequence[str] = ()
"""The fields that should be queried by default by unqualified query
terms.
"""
_types = {}
_types: Dict[str, types.Type] = {}
"""Optional Types for non-fixed (i.e., flexible and computed) fields.
"""
_sorts = {}
_sorts: Dict[str, Type[FieldSort]] = {}
"""Optional named sort criteria. The keys are strings and the values
are subclasses of `Sort`.
"""
_queries = {}
_queries: Dict[str, Type[Query]] = {}
"""Named queries that use a field-like `name:value` syntax but which
do not relate to any specific field.
"""
@ -301,12 +343,12 @@ class Model:
# Basic operation.
def __init__(self, db: Optional['Database'] = None, **values):
def __init__(self, db: Optional[Database] = None, **values):
"""Create a new object with an optional Database association and
initial field values.
"""
self._db = db
self._dirty = set()
self._dirty: Set[str] = set()
self._values_fixed = LazyConvertDict(self)
self._values_flex = LazyConvertDict(self)
@ -316,11 +358,11 @@ class Model:
@classmethod
def _awaken(
cls: Type['Model'],
db: 'Database' = None,
fixed_values: Mapping = {},
flex_values: Mapping = {},
) -> 'Model':
cls: Type[AnyModel],
db: Optional[Database] = None,
fixed_values: Dict[str, Any] = {},
flex_values: Dict[str, Any] = {},
) -> AnyModel:
"""Create an object with values drawn from the database.
This is a performance optimization: the checks involved with
@ -347,7 +389,7 @@ class Model:
if self._db:
self._revision = self._db.revision
def _check_db(self, need_id: bool = True):
def _check_db(self, need_id: bool = True) -> Database:
"""Ensure that this object is associated with a database row: it
has a reference to a database (`_db`) and an id. A ValueError
exception is raised otherwise.
@ -359,6 +401,8 @@ class Model:
if need_id and not self.id:
raise ValueError('{} has no id'.format(type(self).__name__))
return self._db
def copy(self) -> 'Model':
"""Create a copy of the model object.
@ -490,7 +534,7 @@ class Model:
"""
return key in self.keys(computed=True)
def __iter__(self) -> Iterable[str]:
def __iter__(self) -> Iterator[str]:
"""Iterate over the available field names (excluding computed
fields).
"""
@ -521,14 +565,14 @@ class Model:
# Database interaction (CRUD methods).
def store(self, fields: bool = None):
def store(self, fields: Optional[Iterable[str]] = None):
"""Save the object's metadata into the library database.
:param fields: the fields to be stored. If not specified, all fields
will be.
"""
if fields is None:
fields = self._fields
self._check_db()
db = self._check_db()
# Build assignments for query.
assignments = []
@ -539,13 +583,13 @@ class Model:
assignments.append(key + '=?')
value = self._type(key).to_sql(self[key])
subvars.append(value)
assignments = ','.join(assignments)
with self._db.transaction() as tx:
with db.transaction() as tx:
# Main table update.
if assignments:
query = 'UPDATE {} SET {} WHERE id=?'.format(
self._table, assignments
self._table,
','.join(assignments)
)
subvars.append(self.id)
tx.mutate(query, subvars)
@ -577,11 +621,11 @@ class Model:
If check_revision is true, the database is only queried loaded when a
transaction has been committed since the item was last loaded.
"""
self._check_db()
if not self._dirty and self._db.revision == self._revision:
db = self._check_db()
if not self._dirty and db.revision == self._revision:
# Exit early
return
stored_obj = self._db._get(type(self), self.id)
stored_obj = db._get(type(self), self.id)
assert stored_obj is not None, f"object {self.id} not in DB"
self._values_fixed = LazyConvertDict(self)
self._values_flex = LazyConvertDict(self)
@ -591,8 +635,8 @@ class Model:
def remove(self):
"""Remove the object's associated rows from the database.
"""
self._check_db()
with self._db.transaction() as tx:
db = self._check_db()
with db.transaction() as tx:
tx.mutate(
f'DELETE FROM {self._table} WHERE id=?',
(self.id,)
@ -612,9 +656,9 @@ class Model:
"""
if db:
self._db = db
self._check_db(False)
db = self._check_db(False)
with self._db.transaction() as tx:
with db.transaction() as tx:
new_id = tx.mutate(
f'INSERT INTO {self._table} DEFAULT VALUES'
)
@ -652,9 +696,12 @@ class Model:
"""
# Perform substitution.
if isinstance(template, str):
template = functemplate.template(template)
return template.substitute(self.formatted(for_path=for_path),
self._template_funcs())
t = functemplate.template(template)
else:
# Help out mypy
t = template
return t.substitute(self.formatted(for_path=for_path),
self._template_funcs())
# Parsing.
@ -703,24 +750,28 @@ class Model:
# Database controller and supporting interfaces.
class Results:
AnyModel = TypeVar("AnyModel", bound=Model)
class Results(Generic[AnyModel]):
"""An item query result set. Iterating over the collection lazily
constructs LibModel objects that reflect database rows.
constructs Model objects that reflect database rows.
"""
def __init__(
self,
model_class: Type['LibModel'],
model_class: Type[AnyModel],
rows: List[Mapping],
db: 'Database',
flex_rows,
query: Optional[FieldQuery] = None,
query: Optional[Query] = None,
sort=None,
):
"""Create a result set that will construct objects of type
`model_class`.
`model_class` is a subclass of `LibModel` that will be
`model_class` is a subclass of `Model` that will be
constructed. `rows` is a query result: a list of mappings. The
new objects will be associated with the database `db`.
@ -746,9 +797,9 @@ class Results:
# The materialized objects corresponding to rows that have been
# consumed.
self._objects = []
self._objects: List[AnyModel] = []
def _get_objects(self) -> Iterable[Model]:
def _get_objects(self) -> Iterator[AnyModel]:
"""Construct and generate Model objects for they query. The
objects are returned in the order emitted from the database; no
slow sort is applied.
@ -783,7 +834,7 @@ class Results:
yield obj
break
def __iter__(self) -> Iterable[Model]:
def __iter__(self) -> Iterator[AnyModel]:
"""Construct and generate Model objects for all matching
objects, in sorted order.
"""
@ -799,7 +850,7 @@ class Results:
def _get_indexed_flex_attrs(self) -> Mapping:
""" Index flexible attributes by the entity id they belong to
"""
flex_values = {}
flex_values: Dict[int, Dict[str, Any]] = {}
for row in self.flex_rows:
if row['entity_id'] not in flex_values:
flex_values[row['entity_id']] = {}
@ -808,7 +859,7 @@ class Results:
return flex_values
def _make_model(self, row, flex_values: Dict = {}) -> Model:
def _make_model(self, row, flex_values: Dict = {}) -> AnyModel:
""" Create a Model object for the given row
"""
cols = dict(row)
@ -864,7 +915,7 @@ class Results:
except StopIteration:
raise IndexError(f'result index {n} out of range')
def get(self) -> Optional[Model]:
def get(self) -> Optional[AnyModel]:
"""Return the first matching object, or None if no objects
match.
"""
@ -922,14 +973,14 @@ class Transaction:
self._mutated = False
self.db._db_lock.release()
def query(self, statement: str, subvals: Iterable = ()) -> List:
def query(self, statement: str, subvals: Sequence = ()) -> List:
"""Execute an SQL statement with substitution values and return
a list of rows from the database.
"""
cursor = self.db._connection().execute(statement, subvals)
return cursor.fetchall()
def mutate(self, statement: str, subvals: Iterable = ()) -> Any:
def mutate(self, statement: str, subvals: Sequence = ()) -> Any:
"""Execute an SQL statement with substitution values and return
the row ID of the last affected row.
"""
@ -960,7 +1011,7 @@ class Database:
the backend.
"""
_models = ()
_models: Sequence[Type[Model]] = ()
"""The Model subclasses representing tables in this database.
"""
@ -981,9 +1032,10 @@ class Database:
self.path = path
self.timeout = timeout
self._connections = {}
self._tx_stacks = defaultdict(list)
self._extensions = []
self._connections: Dict[int, sqlite3.Connection] = {}
self._tx_stacks: DefaultDict[int, List[Transaction]] = \
defaultdict(list)
self._extensions: List[str] = []
# A lock to protect the _connections and _tx_stacks maps, which
# both map thread IDs to private resources.
@ -1011,6 +1063,11 @@ class Database:
One connection object is created per thread.
"""
thread_id = threading.current_thread().ident
# Help the type checker: ident can only be None if the thread has not
# been started yet; but since this results from current_thread(), that
# can't happen
assert thread_id is not None
with self._shared_map_lock:
if thread_id in self._connections:
return self._connections[thread_id]
@ -1075,6 +1132,11 @@ class Database:
the stack map. Transactions should never migrate across threads.
"""
thread_id = threading.current_thread().ident
# Help the type checker: ident can only be None if the thread has not
# been started yet; but since this results from current_thread(), that
# can't happen
assert thread_id is not None
with self._shared_map_lock:
yield self._tx_stacks[thread_id]
@ -1084,7 +1146,7 @@ class Database:
"""
return Transaction(self)
def load_extension(self, path):
def load_extension(self, path: str):
"""Load an SQLite extension into all open connections."""
if not self.supports_extensions:
raise ValueError(
@ -1152,11 +1214,11 @@ class Database:
# Querying.
def _fetch(
self,
model_cls: Type['LibModel'],
query: Optional[Query] = None,
sort: Optional[Sort] = None,
) -> Results:
self,
model_cls: Type[AnyModel],
query: Optional[Query] = None,
sort: Optional[Sort] = None,
) -> Results[AnyModel]:
"""Fetch the objects of type `model_cls` matching the given
query. The query may be given as a string, string sequence, a
Query object, or None (to fetch everything). `sort` is an
@ -1180,10 +1242,10 @@ class Database:
SELECT * FROM {} WHERE entity_id IN
(SELECT id FROM {} WHERE {});
""".format(
model_cls._flex_table,
model_cls._table,
where or '1',
)
model_cls._flex_table,
model_cls._table,
where or '1',
)
)
with self.transaction() as tx:
@ -1196,7 +1258,11 @@ class Database:
sort if sort.is_slow() else None, # Slow sort component.
)
def _get(self, model_cls: Union[Type[Model], Type[LibModel]], id) -> Model:
def _get(
self,
model_cls: Type[AnyModel],
id,
) -> Optional[AnyModel]:
"""Get a Model object by its id or None if the id does not
exist.
"""

View file

@ -334,9 +334,9 @@ class BytesQuery(FieldQuery[bytes]):
else:
bytes_pattern = pattern
self.buf_pattern = memoryview(bytes_pattern)
elif isinstance(self.pattern, memoryview):
self.buf_pattern = self.pattern
bytes_pattern = bytes(self.pattern)
elif isinstance(pattern, memoryview):
self.buf_pattern = pattern
bytes_pattern = bytes(pattern)
else:
raise ValueError("pattern must be bytes, str, or memoryview")

View file

@ -21,7 +21,7 @@ from typing import Dict, Type, Tuple, Optional, Collection, List, \
Sequence
from . import query, Model
from .query import Sort
from .query import Query, Sort
PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
@ -132,7 +132,7 @@ def construct_query_part(
# Use `model_cls` to build up a map from field (or query) names to
# `Query` classes.
query_classes = {}
query_classes: Dict[str, Type[Query]] = {}
for k, t in itertools.chain(model_cls._fields.items(),
model_cls._types.items()):
query_classes[k] = t.query