Add more typings

This commit is contained in:
Serene-Arc 2022-09-15 20:34:21 +10:00
parent e29337d4e6
commit 13c1561390
2 changed files with 113 additions and 59 deletions

View file

@ -22,6 +22,10 @@ from collections import defaultdict
import threading
import sqlite3
import contextlib
from sqlite3 import Connection
from types import TracebackType
from typing import Iterable, Type, List, Tuple, NoReturn, Optional, Union, \
Dict, Any, Generator
from unidecode import unidecode
@ -29,9 +33,13 @@ import beets
from beets.util import functemplate
from beets.util import py3_path
from beets.dbcore import types
from .query import MatchQuery, NullSort, TrueQuery, AndQuery
from .query import MatchQuery, NullSort, TrueQuery, AndQuery, Query, \
FieldQuery, Sort
from collections.abc import Mapping
from ..library import LibModel
from ..util.functemplate import Template
class DBAccessError(Exception):
"""The SQLite database became inaccessible.
@ -58,7 +66,12 @@ class FormattedMapping(Mapping):
ALL_KEYS = '*'
def __init__(self, model, included_keys=ALL_KEYS, for_path=False):
def __init__(
self,
model: 'Model',
included_keys: str = ALL_KEYS,
for_path: bool = False,
):
self.for_path = for_path
self.model = model
if included_keys == self.ALL_KEYS:
@ -73,10 +86,10 @@ class FormattedMapping(Mapping):
else:
raise KeyError(key)
def __iter__(self):
def __iter__(self) -> Iterable:
return iter(self.model_keys)
def __len__(self):
def __len__(self) -> int:
return len(self.model_keys)
def get(self, key, default=None):
@ -107,7 +120,7 @@ class LazyConvertDict:
"""Lazily convert types for attributes fetched from the database
"""
def __init__(self, model_cls):
def __init__(self, model_cls: 'Model'):
"""Initialize the object empty
"""
self.data = {}
@ -148,12 +161,12 @@ class LazyConvertDict:
if key in self.data:
del self.data[key]
def keys(self):
def keys(self) -> List:
"""Get a list of available field names for this object.
"""
return list(self._converted.keys()) + list(self.data.keys())
def copy(self):
def copy(self) -> 'LazyConvertDict':
"""Create a copy of the object.
"""
new = self.__class__(self.model_cls)
@ -169,7 +182,7 @@ class LazyConvertDict:
for key, value in values.items():
self[key] = value
def items(self):
def items(self) -> Generator[Tuple]:
"""Iterate over (key, value) pairs that this object contains.
Computed fields are not included.
"""
@ -185,12 +198,12 @@ class LazyConvertDict:
else:
return default
def __contains__(self, key):
def __contains__(self, key) -> bool:
"""Determine whether `key` is an attribute on this object.
"""
return key in self.keys()
def __iter__(self):
def __iter__(self) -> Iterable:
"""Iterate over the available field names (excluding computed
fields).
"""
@ -269,14 +282,14 @@ class Model:
"""
@classmethod
def _getters(cls):
def _getters(cls: Type['Model']):
"""Return a mapping from field names to getter functions.
"""
# We could cache this if it becomes a performance problem to
# gather the getter mapping every time.
raise NotImplementedError()
def _template_funcs(self):
def _template_funcs(self) -> NoReturn:
"""Return a mapping from function names to text-transformer
functions.
"""
@ -285,7 +298,7 @@ class Model:
# Basic operation.
def __init__(self, db=None, **values):
def __init__(self, db: Optional['Database'] = None, **values):
"""Create a new object with an optional Database association and
initial field values.
"""
@ -299,7 +312,12 @@ class Model:
self.clear_dirty()
@classmethod
def _awaken(cls, db=None, fixed_values={}, flex_values={}):
def _awaken(
cls: Type['Model'],
db: 'Database' = None,
fixed_values: Mapping = {},
flex_values: Mapping = {},
) -> 'Model':
"""Create an object with values drawn from the database.
This is a performance optimization: the checks involved with
@ -312,7 +330,7 @@ class Model:
return obj
def __repr__(self):
def __repr__(self) -> str:
return '{}({})'.format(
type(self).__name__,
', '.join(f'{k}={v!r}' for k, v in dict(self).items()),
@ -326,7 +344,7 @@ class Model:
if self._db:
self._revision = self._db.revision
def _check_db(self, need_id=True):
def _check_db(self, need_id: bool = True):
"""Ensure that this object is associated with a database row: it
has a reference to a database (`_db`) and an id. A ValueError
exception is raised otherwise.
@ -338,7 +356,7 @@ class Model:
if need_id and not self.id:
raise ValueError('{} has no id'.format(type(self).__name__))
def copy(self):
def copy(self) -> 'Model':
"""Create a copy of the model object.
The field values and other state is duplicated, but the new copy
@ -356,7 +374,7 @@ class Model:
# Essential field accessors.
@classmethod
def _type(cls, key):
def _type(cls, key) -> types.Type:
"""Get the type of a field, a `Type` instance.
If the field has no explicit type, it is given the base `Type`,
@ -364,7 +382,7 @@ class Model:
"""
return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT
def _get(self, key, default=None, raise_=False):
def _get(self, key, default: bool = None, raise_: bool = False):
"""Get the value for a field, or `default`. Alternatively,
raise a KeyError if the field is not available.
"""
@ -431,7 +449,7 @@ class Model:
else:
raise KeyError(f'no such field {key}')
def keys(self, computed=False):
def keys(self, computed: bool = False):
"""Get a list of available field names for this object. The
`computed` parameter controls whether computed (plugin-provided)
fields are included in the key list.
@ -457,19 +475,19 @@ class Model:
for key, value in values.items():
self[key] = value
def items(self):
def items(self) -> Generator:
"""Iterate over (key, value) pairs that this object contains.
Computed fields are not included.
"""
for key in self:
yield key, self[key]
def __contains__(self, key):
def __contains__(self, key) -> bool:
"""Determine whether `key` is an attribute on this object.
"""
return key in self.keys(computed=True)
def __iter__(self):
def __iter__(self) -> Iterable:
"""Iterate over the available field names (excluding computed
fields).
"""
@ -500,7 +518,7 @@ class Model:
# Database interaction (CRUD methods).
def store(self, fields=None):
def store(self, fields: bool = None):
"""Save the object's metadata into the library database.
:param fields: the fields to be stored. If not specified, all fields
will be.
@ -581,7 +599,7 @@ class Model:
(self.id,)
)
def add(self, db=None):
def add(self, db: Optional['Database'] = None):
"""Add the object to the library database. This object must be
associated with a database; you can provide one via the `db`
parameter or use the currently associated database.
@ -610,13 +628,21 @@ class Model:
_formatter = FormattedMapping
def formatted(self, included_keys=_formatter.ALL_KEYS, for_path=False):
def formatted(
self,
included_keys: str = _formatter.ALL_KEYS,
for_path: bool = False,
):
"""Get a mapping containing all values on this object formatted
as human-readable unicode strings.
"""
return self._formatter(self, included_keys, for_path)
def evaluate_template(self, template, for_path=False):
def evaluate_template(
self,
template: Union[str, Template],
for_path: bool = False,
) -> str:
"""Evaluate a template (a string or a `Template` object) using
the object's fields. If `for_path` is true, then no new path
separators will be added to the template.
@ -630,7 +656,7 @@ class Model:
# Parsing.
@classmethod
def _parse(cls, key, string):
def _parse(cls, key, string: str) -> types.Type:
"""Parse a string as a value for the given key.
"""
if not isinstance(string, str):
@ -638,7 +664,7 @@ class Model:
return cls._type(key).parse(string)
def set_parse(self, key, string):
def set_parse(self, key, string: str):
"""Set the object's key to a value represented by a string.
"""
self[key] = self._parse(key, string)
@ -646,12 +672,21 @@ class Model:
# Convenient queries.
@classmethod
def field_query(cls, field, pattern, query_cls=MatchQuery):
def field_query(
cls,
field,
pattern,
query_cls: Type[FieldQuery] = MatchQuery,
) -> FieldQuery:
"""Get a `FieldQuery` for this model."""
return query_cls(field, pattern, field in cls._fields)
@classmethod
def all_fields_query(cls, pats, query_cls=MatchQuery):
def all_fields_query(
cls: Type['Model'],
pats: Mapping,
query_cls: Type[FieldQuery] = MatchQuery,
):
"""Get a query that matches many fields with different patterns.
`pats` should be a mapping from field names to patterns. The
@ -670,8 +705,15 @@ class Results:
constructs LibModel objects that reflect database rows.
"""
def __init__(self, model_class, rows, db, flex_rows,
query=None, sort=None):
def __init__(
self,
model_class: Type[LibModel],
rows: List[Mapping],
db: 'Database',
flex_rows,
query: Optional[FieldQuery] = None,
sort=None,
):
"""Create a result set that will construct objects of type
`model_class`.
@ -703,7 +745,7 @@ class Results:
# consumed.
self._objects = []
def _get_objects(self):
def _get_objects(self) -> Generator[Model]:
"""Construct and generate Model objects for they query. The
objects are returned in the order emitted from the database; no
slow sort is applied.
@ -738,7 +780,7 @@ class Results:
yield obj
break
def __iter__(self):
def __iter__(self) -> Iterable[Model]:
"""Construct and generate Model objects for all matching
objects, in sorted order.
"""
@ -751,7 +793,7 @@ class Results:
# Objects are pre-sorted (i.e., by the database).
return self._get_objects()
def _get_indexed_flex_attrs(self):
def _get_indexed_flex_attrs(self) -> Mapping:
""" Index flexible attributes by the entity id they belong to
"""
flex_values = {}
@ -763,7 +805,7 @@ class Results:
return flex_values
def _make_model(self, row, flex_values={}):
def _make_model(self, row, flex_values: Dict = {}) -> Model:
""" Create a Model object for the given row
"""
cols = dict(row)
@ -774,7 +816,7 @@ class Results:
obj = self.model_class._awaken(self.db, values, flex_values)
return obj
def __len__(self):
def __len__(self) -> int:
"""Get the number of matching objects.
"""
if not self._rows:
@ -792,12 +834,12 @@ class Results:
# A fast query. Just count the rows.
return self._row_count
def __nonzero__(self):
def __nonzero__(self) -> bool:
"""Does this result contain any objects?
"""
return self.__bool__()
def __bool__(self):
def __bool__(self) -> bool:
"""Does this result contain any objects?
"""
return bool(len(self))
@ -819,7 +861,7 @@ class Results:
except StopIteration:
raise IndexError(f'result index {n} out of range')
def get(self):
def get(self) -> Optional[Model]:
"""Return the first matching object, or None if no objects
match.
"""
@ -840,10 +882,10 @@ class Transaction:
current transaction.
"""
def __init__(self, db):
def __init__(self, db: 'Database'):
self.db = db
def __enter__(self):
def __enter__(self) -> 'Transaction':
"""Begin a transaction. This transaction may be created while
another is active in a different thread.
"""
@ -856,7 +898,12 @@ class Transaction:
self.db._db_lock.acquire()
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Type[Exception],
exc_value: Exception,
traceback: TracebackType,
):
"""Complete a transaction. This must be the most recently
entered but not yet exited transaction. If it is the last active
transaction, the database updates are committed.
@ -872,14 +919,14 @@ class Transaction:
self._mutated = False
self.db._db_lock.release()
def query(self, statement, subvals=()):
def query(self, statement: str, subvals: Iterable = ()) -> List:
"""Execute an SQL statement with substitution values and return
a list of rows from the database.
"""
cursor = self.db._connection().execute(statement, subvals)
return cursor.fetchall()
def mutate(self, statement, subvals=()):
def mutate(self, statement: str, subvals: Iterable = ()) -> Any:
"""Execute an SQL statement with substitution values and return
the row ID of the last affected row.
"""
@ -898,7 +945,7 @@ class Transaction:
self._mutated = True
return cursor.lastrowid
def script(self, statements):
def script(self, statements: str):
"""Execute a string containing multiple SQL statements."""
# We don't know whether this mutates, but quite likely it does.
self._mutated = True
@ -922,7 +969,7 @@ class Database:
data is written in a transaction.
"""
def __init__(self, path, timeout=5.0):
def __init__(self, path, timeout: float = 5.0):
if sqlite3.threadsafety == 0:
raise RuntimeError(
"sqlite3 must be compiled with multi-threading support"
@ -956,7 +1003,7 @@ class Database:
# Primitive access control: connections and transactions.
def _connection(self):
def _connection(self) -> Connection:
"""Get a SQLite connection object to the underlying database.
One connection object is created per thread.
"""
@ -969,7 +1016,7 @@ class Database:
self._connections[thread_id] = conn
return conn
def _create_connection(self):
def _create_connection(self) -> Connection:
"""Create a SQLite connection to the underlying database.
Makes a new connection every time. If you need to configure the
@ -1019,7 +1066,7 @@ class Database:
conn.close()
@contextlib.contextmanager
def _tx_stack(self):
def _tx_stack(self) -> Generator[List]:
"""A context manager providing access to the current thread's
transaction stack. The context manager synchronizes access to
the stack map. Transactions should never migrate across threads.
@ -1028,7 +1075,7 @@ class Database:
with self._shared_map_lock:
yield self._tx_stacks[thread_id]
def transaction(self):
def transaction(self) -> Transaction:
"""Get a :class:`Transaction` object for interacting directly
with the underlying SQLite database.
"""
@ -1048,7 +1095,7 @@ class Database:
# Schema setup and migration.
def _make_table(self, table, fields):
def _make_table(self, table: str, fields: Mapping[str, types.Type]):
"""Set up the schema of the database. `fields` is a mapping
from field names to `Type`s. Columns are added if necessary.
"""
@ -1083,7 +1130,7 @@ class Database:
with self.transaction() as tx:
tx.script(setup_sql)
def _make_attribute_table(self, flex_table):
def _make_attribute_table(self, flex_table: str):
"""Create a table and associated index for flexible attributes
for the given entity (if they don't exist).
"""
@ -1101,7 +1148,12 @@ class Database:
# Querying.
def _fetch(self, model_cls, query=None, sort=None):
def _fetch(
self,
model_cls: Type[LibModel],
query: Optional[Query] = None,
sort: Optional[Sort] = None,
) -> Results:
"""Fetch the objects of type `model_cls` matching the given
query. The query may be given as a string, string sequence, a
Query object, or None (to fetch everything). `sort` is an
@ -1141,7 +1193,7 @@ class Database:
sort if sort.is_slow() else None, # Slow sort component.
)
def _get(self, model_cls, id):
def _get(self, model_cls: Union[Type[Model], Type[LibModel]], id) -> Model:
"""Get a Model object by its id or None if the id does not
exist.
"""

View file

@ -18,13 +18,14 @@
import re
from operator import mul
from typing import Union, Tuple, List, Optional, Pattern, Any, Type, Iterator, \
Collection, Mapping
Collection, Mapping, MutableMapping
from beets import util
from datetime import datetime, timedelta
import unicodedata
from functools import reduce
from beets.dbcore import Model
from beets.library import Item
@ -120,7 +121,7 @@ class FieldQuery(Query):
"""
raise NotImplementedError()
def match(self, item: Item):
def match(self, item: Model):
return self.value_match(self.pattern, item.get(self.field))
def __repr__(self) -> str:
@ -465,6 +466,7 @@ class MutableCollectionQuery(CollectionQuery):
"""A collection query whose subqueries may be modified after the
query is initialized.
"""
subqueries: MutableMapping
def __setitem__(self, key, value):
self.subqueries[key] = value
@ -786,7 +788,7 @@ class Sort:
the item database.
"""
def order_clause(self):
def order_clause(self) -> None:
"""Generates a SQL fragment to be used in a ORDER BY clause, or
None if no fragment is used (i.e., this is a slow sort).
"""