mirror of
https://github.com/beetbox/beets.git
synced 2025-12-27 19:12:40 +01:00
Add more typings
This commit is contained in:
parent
e29337d4e6
commit
13c1561390
2 changed files with 113 additions and 59 deletions
|
|
@ -22,6 +22,10 @@ from collections import defaultdict
|
|||
import threading
|
||||
import sqlite3
|
||||
import contextlib
|
||||
from sqlite3 import Connection
|
||||
from types import TracebackType
|
||||
from typing import Iterable, Type, List, Tuple, NoReturn, Optional, Union, \
|
||||
Dict, Any, Generator
|
||||
|
||||
from unidecode import unidecode
|
||||
|
||||
|
|
@ -29,9 +33,13 @@ import beets
|
|||
from beets.util import functemplate
|
||||
from beets.util import py3_path
|
||||
from beets.dbcore import types
|
||||
from .query import MatchQuery, NullSort, TrueQuery, AndQuery
|
||||
from .query import MatchQuery, NullSort, TrueQuery, AndQuery, Query, \
|
||||
FieldQuery, Sort
|
||||
from collections.abc import Mapping
|
||||
|
||||
from ..library import LibModel
|
||||
from ..util.functemplate import Template
|
||||
|
||||
|
||||
class DBAccessError(Exception):
|
||||
"""The SQLite database became inaccessible.
|
||||
|
|
@ -58,7 +66,12 @@ class FormattedMapping(Mapping):
|
|||
|
||||
ALL_KEYS = '*'
|
||||
|
||||
def __init__(self, model, included_keys=ALL_KEYS, for_path=False):
|
||||
def __init__(
|
||||
self,
|
||||
model: 'Model',
|
||||
included_keys: str = ALL_KEYS,
|
||||
for_path: bool = False,
|
||||
):
|
||||
self.for_path = for_path
|
||||
self.model = model
|
||||
if included_keys == self.ALL_KEYS:
|
||||
|
|
@ -73,10 +86,10 @@ class FormattedMapping(Mapping):
|
|||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterable:
|
||||
return iter(self.model_keys)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.model_keys)
|
||||
|
||||
def get(self, key, default=None):
|
||||
|
|
@ -107,7 +120,7 @@ class LazyConvertDict:
|
|||
"""Lazily convert types for attributes fetched from the database
|
||||
"""
|
||||
|
||||
def __init__(self, model_cls):
|
||||
def __init__(self, model_cls: 'Model'):
|
||||
"""Initialize the object empty
|
||||
"""
|
||||
self.data = {}
|
||||
|
|
@ -148,12 +161,12 @@ class LazyConvertDict:
|
|||
if key in self.data:
|
||||
del self.data[key]
|
||||
|
||||
def keys(self):
|
||||
def keys(self) -> List:
|
||||
"""Get a list of available field names for this object.
|
||||
"""
|
||||
return list(self._converted.keys()) + list(self.data.keys())
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> 'LazyConvertDict':
|
||||
"""Create a copy of the object.
|
||||
"""
|
||||
new = self.__class__(self.model_cls)
|
||||
|
|
@ -169,7 +182,7 @@ class LazyConvertDict:
|
|||
for key, value in values.items():
|
||||
self[key] = value
|
||||
|
||||
def items(self):
|
||||
def items(self) -> Generator[Tuple]:
|
||||
"""Iterate over (key, value) pairs that this object contains.
|
||||
Computed fields are not included.
|
||||
"""
|
||||
|
|
@ -185,12 +198,12 @@ class LazyConvertDict:
|
|||
else:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key) -> bool:
|
||||
"""Determine whether `key` is an attribute on this object.
|
||||
"""
|
||||
return key in self.keys()
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterable:
|
||||
"""Iterate over the available field names (excluding computed
|
||||
fields).
|
||||
"""
|
||||
|
|
@ -269,14 +282,14 @@ class Model:
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def _getters(cls):
|
||||
def _getters(cls: Type['Model']):
|
||||
"""Return a mapping from field names to getter functions.
|
||||
"""
|
||||
# We could cache this if it becomes a performance problem to
|
||||
# gather the getter mapping every time.
|
||||
raise NotImplementedError()
|
||||
|
||||
def _template_funcs(self):
|
||||
def _template_funcs(self) -> NoReturn:
|
||||
"""Return a mapping from function names to text-transformer
|
||||
functions.
|
||||
"""
|
||||
|
|
@ -285,7 +298,7 @@ class Model:
|
|||
|
||||
# Basic operation.
|
||||
|
||||
def __init__(self, db=None, **values):
|
||||
def __init__(self, db: Optional['Database'] = None, **values):
|
||||
"""Create a new object with an optional Database association and
|
||||
initial field values.
|
||||
"""
|
||||
|
|
@ -299,7 +312,12 @@ class Model:
|
|||
self.clear_dirty()
|
||||
|
||||
@classmethod
|
||||
def _awaken(cls, db=None, fixed_values={}, flex_values={}):
|
||||
def _awaken(
|
||||
cls: Type['Model'],
|
||||
db: 'Database' = None,
|
||||
fixed_values: Mapping = {},
|
||||
flex_values: Mapping = {},
|
||||
) -> 'Model':
|
||||
"""Create an object with values drawn from the database.
|
||||
|
||||
This is a performance optimization: the checks involved with
|
||||
|
|
@ -312,7 +330,7 @@ class Model:
|
|||
|
||||
return obj
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return '{}({})'.format(
|
||||
type(self).__name__,
|
||||
', '.join(f'{k}={v!r}' for k, v in dict(self).items()),
|
||||
|
|
@ -326,7 +344,7 @@ class Model:
|
|||
if self._db:
|
||||
self._revision = self._db.revision
|
||||
|
||||
def _check_db(self, need_id=True):
|
||||
def _check_db(self, need_id: bool = True):
|
||||
"""Ensure that this object is associated with a database row: it
|
||||
has a reference to a database (`_db`) and an id. A ValueError
|
||||
exception is raised otherwise.
|
||||
|
|
@ -338,7 +356,7 @@ class Model:
|
|||
if need_id and not self.id:
|
||||
raise ValueError('{} has no id'.format(type(self).__name__))
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> 'Model':
|
||||
"""Create a copy of the model object.
|
||||
|
||||
The field values and other state is duplicated, but the new copy
|
||||
|
|
@ -356,7 +374,7 @@ class Model:
|
|||
# Essential field accessors.
|
||||
|
||||
@classmethod
|
||||
def _type(cls, key):
|
||||
def _type(cls, key) -> types.Type:
|
||||
"""Get the type of a field, a `Type` instance.
|
||||
|
||||
If the field has no explicit type, it is given the base `Type`,
|
||||
|
|
@ -364,7 +382,7 @@ class Model:
|
|||
"""
|
||||
return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT
|
||||
|
||||
def _get(self, key, default=None, raise_=False):
|
||||
def _get(self, key, default: bool = None, raise_: bool = False):
|
||||
"""Get the value for a field, or `default`. Alternatively,
|
||||
raise a KeyError if the field is not available.
|
||||
"""
|
||||
|
|
@ -431,7 +449,7 @@ class Model:
|
|||
else:
|
||||
raise KeyError(f'no such field {key}')
|
||||
|
||||
def keys(self, computed=False):
|
||||
def keys(self, computed: bool = False):
|
||||
"""Get a list of available field names for this object. The
|
||||
`computed` parameter controls whether computed (plugin-provided)
|
||||
fields are included in the key list.
|
||||
|
|
@ -457,19 +475,19 @@ class Model:
|
|||
for key, value in values.items():
|
||||
self[key] = value
|
||||
|
||||
def items(self):
|
||||
def items(self) -> Generator:
|
||||
"""Iterate over (key, value) pairs that this object contains.
|
||||
Computed fields are not included.
|
||||
"""
|
||||
for key in self:
|
||||
yield key, self[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key) -> bool:
|
||||
"""Determine whether `key` is an attribute on this object.
|
||||
"""
|
||||
return key in self.keys(computed=True)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterable:
|
||||
"""Iterate over the available field names (excluding computed
|
||||
fields).
|
||||
"""
|
||||
|
|
@ -500,7 +518,7 @@ class Model:
|
|||
|
||||
# Database interaction (CRUD methods).
|
||||
|
||||
def store(self, fields=None):
|
||||
def store(self, fields: bool = None):
|
||||
"""Save the object's metadata into the library database.
|
||||
:param fields: the fields to be stored. If not specified, all fields
|
||||
will be.
|
||||
|
|
@ -581,7 +599,7 @@ class Model:
|
|||
(self.id,)
|
||||
)
|
||||
|
||||
def add(self, db=None):
|
||||
def add(self, db: Optional['Database'] = None):
|
||||
"""Add the object to the library database. This object must be
|
||||
associated with a database; you can provide one via the `db`
|
||||
parameter or use the currently associated database.
|
||||
|
|
@ -610,13 +628,21 @@ class Model:
|
|||
|
||||
_formatter = FormattedMapping
|
||||
|
||||
def formatted(self, included_keys=_formatter.ALL_KEYS, for_path=False):
|
||||
def formatted(
|
||||
self,
|
||||
included_keys: str = _formatter.ALL_KEYS,
|
||||
for_path: bool = False,
|
||||
):
|
||||
"""Get a mapping containing all values on this object formatted
|
||||
as human-readable unicode strings.
|
||||
"""
|
||||
return self._formatter(self, included_keys, for_path)
|
||||
|
||||
def evaluate_template(self, template, for_path=False):
|
||||
def evaluate_template(
|
||||
self,
|
||||
template: Union[str, Template],
|
||||
for_path: bool = False,
|
||||
) -> str:
|
||||
"""Evaluate a template (a string or a `Template` object) using
|
||||
the object's fields. If `for_path` is true, then no new path
|
||||
separators will be added to the template.
|
||||
|
|
@ -630,7 +656,7 @@ class Model:
|
|||
# Parsing.
|
||||
|
||||
@classmethod
|
||||
def _parse(cls, key, string):
|
||||
def _parse(cls, key, string: str) -> types.Type:
|
||||
"""Parse a string as a value for the given key.
|
||||
"""
|
||||
if not isinstance(string, str):
|
||||
|
|
@ -638,7 +664,7 @@ class Model:
|
|||
|
||||
return cls._type(key).parse(string)
|
||||
|
||||
def set_parse(self, key, string):
|
||||
def set_parse(self, key, string: str):
|
||||
"""Set the object's key to a value represented by a string.
|
||||
"""
|
||||
self[key] = self._parse(key, string)
|
||||
|
|
@ -646,12 +672,21 @@ class Model:
|
|||
# Convenient queries.
|
||||
|
||||
@classmethod
|
||||
def field_query(cls, field, pattern, query_cls=MatchQuery):
|
||||
def field_query(
|
||||
cls,
|
||||
field,
|
||||
pattern,
|
||||
query_cls: Type[FieldQuery] = MatchQuery,
|
||||
) -> FieldQuery:
|
||||
"""Get a `FieldQuery` for this model."""
|
||||
return query_cls(field, pattern, field in cls._fields)
|
||||
|
||||
@classmethod
|
||||
def all_fields_query(cls, pats, query_cls=MatchQuery):
|
||||
def all_fields_query(
|
||||
cls: Type['Model'],
|
||||
pats: Mapping,
|
||||
query_cls: Type[FieldQuery] = MatchQuery,
|
||||
):
|
||||
"""Get a query that matches many fields with different patterns.
|
||||
|
||||
`pats` should be a mapping from field names to patterns. The
|
||||
|
|
@ -670,8 +705,15 @@ class Results:
|
|||
constructs LibModel objects that reflect database rows.
|
||||
"""
|
||||
|
||||
def __init__(self, model_class, rows, db, flex_rows,
|
||||
query=None, sort=None):
|
||||
def __init__(
|
||||
self,
|
||||
model_class: Type[LibModel],
|
||||
rows: List[Mapping],
|
||||
db: 'Database',
|
||||
flex_rows,
|
||||
query: Optional[FieldQuery] = None,
|
||||
sort=None,
|
||||
):
|
||||
"""Create a result set that will construct objects of type
|
||||
`model_class`.
|
||||
|
||||
|
|
@ -703,7 +745,7 @@ class Results:
|
|||
# consumed.
|
||||
self._objects = []
|
||||
|
||||
def _get_objects(self):
|
||||
def _get_objects(self) -> Generator[Model]:
|
||||
"""Construct and generate Model objects for they query. The
|
||||
objects are returned in the order emitted from the database; no
|
||||
slow sort is applied.
|
||||
|
|
@ -738,7 +780,7 @@ class Results:
|
|||
yield obj
|
||||
break
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterable[Model]:
|
||||
"""Construct and generate Model objects for all matching
|
||||
objects, in sorted order.
|
||||
"""
|
||||
|
|
@ -751,7 +793,7 @@ class Results:
|
|||
# Objects are pre-sorted (i.e., by the database).
|
||||
return self._get_objects()
|
||||
|
||||
def _get_indexed_flex_attrs(self):
|
||||
def _get_indexed_flex_attrs(self) -> Mapping:
|
||||
""" Index flexible attributes by the entity id they belong to
|
||||
"""
|
||||
flex_values = {}
|
||||
|
|
@ -763,7 +805,7 @@ class Results:
|
|||
|
||||
return flex_values
|
||||
|
||||
def _make_model(self, row, flex_values={}):
|
||||
def _make_model(self, row, flex_values: Dict = {}) -> Model:
|
||||
""" Create a Model object for the given row
|
||||
"""
|
||||
cols = dict(row)
|
||||
|
|
@ -774,7 +816,7 @@ class Results:
|
|||
obj = self.model_class._awaken(self.db, values, flex_values)
|
||||
return obj
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
"""Get the number of matching objects.
|
||||
"""
|
||||
if not self._rows:
|
||||
|
|
@ -792,12 +834,12 @@ class Results:
|
|||
# A fast query. Just count the rows.
|
||||
return self._row_count
|
||||
|
||||
def __nonzero__(self):
|
||||
def __nonzero__(self) -> bool:
|
||||
"""Does this result contain any objects?
|
||||
"""
|
||||
return self.__bool__()
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
"""Does this result contain any objects?
|
||||
"""
|
||||
return bool(len(self))
|
||||
|
|
@ -819,7 +861,7 @@ class Results:
|
|||
except StopIteration:
|
||||
raise IndexError(f'result index {n} out of range')
|
||||
|
||||
def get(self):
|
||||
def get(self) -> Optional[Model]:
|
||||
"""Return the first matching object, or None if no objects
|
||||
match.
|
||||
"""
|
||||
|
|
@ -840,10 +882,10 @@ class Transaction:
|
|||
current transaction.
|
||||
"""
|
||||
|
||||
def __init__(self, db):
|
||||
def __init__(self, db: 'Database'):
|
||||
self.db = db
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> 'Transaction':
|
||||
"""Begin a transaction. This transaction may be created while
|
||||
another is active in a different thread.
|
||||
"""
|
||||
|
|
@ -856,7 +898,12 @@ class Transaction:
|
|||
self.db._db_lock.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Type[Exception],
|
||||
exc_value: Exception,
|
||||
traceback: TracebackType,
|
||||
):
|
||||
"""Complete a transaction. This must be the most recently
|
||||
entered but not yet exited transaction. If it is the last active
|
||||
transaction, the database updates are committed.
|
||||
|
|
@ -872,14 +919,14 @@ class Transaction:
|
|||
self._mutated = False
|
||||
self.db._db_lock.release()
|
||||
|
||||
def query(self, statement, subvals=()):
|
||||
def query(self, statement: str, subvals: Iterable = ()) -> List:
|
||||
"""Execute an SQL statement with substitution values and return
|
||||
a list of rows from the database.
|
||||
"""
|
||||
cursor = self.db._connection().execute(statement, subvals)
|
||||
return cursor.fetchall()
|
||||
|
||||
def mutate(self, statement, subvals=()):
|
||||
def mutate(self, statement: str, subvals: Iterable = ()) -> Any:
|
||||
"""Execute an SQL statement with substitution values and return
|
||||
the row ID of the last affected row.
|
||||
"""
|
||||
|
|
@ -898,7 +945,7 @@ class Transaction:
|
|||
self._mutated = True
|
||||
return cursor.lastrowid
|
||||
|
||||
def script(self, statements):
|
||||
def script(self, statements: str):
|
||||
"""Execute a string containing multiple SQL statements."""
|
||||
# We don't know whether this mutates, but quite likely it does.
|
||||
self._mutated = True
|
||||
|
|
@ -922,7 +969,7 @@ class Database:
|
|||
data is written in a transaction.
|
||||
"""
|
||||
|
||||
def __init__(self, path, timeout=5.0):
|
||||
def __init__(self, path, timeout: float = 5.0):
|
||||
if sqlite3.threadsafety == 0:
|
||||
raise RuntimeError(
|
||||
"sqlite3 must be compiled with multi-threading support"
|
||||
|
|
@ -956,7 +1003,7 @@ class Database:
|
|||
|
||||
# Primitive access control: connections and transactions.
|
||||
|
||||
def _connection(self):
|
||||
def _connection(self) -> Connection:
|
||||
"""Get a SQLite connection object to the underlying database.
|
||||
One connection object is created per thread.
|
||||
"""
|
||||
|
|
@ -969,7 +1016,7 @@ class Database:
|
|||
self._connections[thread_id] = conn
|
||||
return conn
|
||||
|
||||
def _create_connection(self):
|
||||
def _create_connection(self) -> Connection:
|
||||
"""Create a SQLite connection to the underlying database.
|
||||
|
||||
Makes a new connection every time. If you need to configure the
|
||||
|
|
@ -1019,7 +1066,7 @@ class Database:
|
|||
conn.close()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _tx_stack(self):
|
||||
def _tx_stack(self) -> Generator[List]:
|
||||
"""A context manager providing access to the current thread's
|
||||
transaction stack. The context manager synchronizes access to
|
||||
the stack map. Transactions should never migrate across threads.
|
||||
|
|
@ -1028,7 +1075,7 @@ class Database:
|
|||
with self._shared_map_lock:
|
||||
yield self._tx_stacks[thread_id]
|
||||
|
||||
def transaction(self):
|
||||
def transaction(self) -> Transaction:
|
||||
"""Get a :class:`Transaction` object for interacting directly
|
||||
with the underlying SQLite database.
|
||||
"""
|
||||
|
|
@ -1048,7 +1095,7 @@ class Database:
|
|||
|
||||
# Schema setup and migration.
|
||||
|
||||
def _make_table(self, table, fields):
|
||||
def _make_table(self, table: str, fields: Mapping[str, types.Type]):
|
||||
"""Set up the schema of the database. `fields` is a mapping
|
||||
from field names to `Type`s. Columns are added if necessary.
|
||||
"""
|
||||
|
|
@ -1083,7 +1130,7 @@ class Database:
|
|||
with self.transaction() as tx:
|
||||
tx.script(setup_sql)
|
||||
|
||||
def _make_attribute_table(self, flex_table):
|
||||
def _make_attribute_table(self, flex_table: str):
|
||||
"""Create a table and associated index for flexible attributes
|
||||
for the given entity (if they don't exist).
|
||||
"""
|
||||
|
|
@ -1101,7 +1148,12 @@ class Database:
|
|||
|
||||
# Querying.
|
||||
|
||||
def _fetch(self, model_cls, query=None, sort=None):
|
||||
def _fetch(
|
||||
self,
|
||||
model_cls: Type[LibModel],
|
||||
query: Optional[Query] = None,
|
||||
sort: Optional[Sort] = None,
|
||||
) -> Results:
|
||||
"""Fetch the objects of type `model_cls` matching the given
|
||||
query. The query may be given as a string, string sequence, a
|
||||
Query object, or None (to fetch everything). `sort` is an
|
||||
|
|
@ -1141,7 +1193,7 @@ class Database:
|
|||
sort if sort.is_slow() else None, # Slow sort component.
|
||||
)
|
||||
|
||||
def _get(self, model_cls, id):
|
||||
def _get(self, model_cls: Union[Type[Model], Type[LibModel]], id) -> Model:
|
||||
"""Get a Model object by its id or None if the id does not
|
||||
exist.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -18,13 +18,14 @@
|
|||
import re
|
||||
from operator import mul
|
||||
from typing import Union, Tuple, List, Optional, Pattern, Any, Type, Iterator, \
|
||||
Collection, Mapping
|
||||
Collection, Mapping, MutableMapping
|
||||
|
||||
from beets import util
|
||||
from datetime import datetime, timedelta
|
||||
import unicodedata
|
||||
from functools import reduce
|
||||
|
||||
from beets.dbcore import Model
|
||||
from beets.library import Item
|
||||
|
||||
|
||||
|
|
@ -120,7 +121,7 @@ class FieldQuery(Query):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def match(self, item: Item):
|
||||
def match(self, item: Model):
|
||||
return self.value_match(self.pattern, item.get(self.field))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
|
@ -465,6 +466,7 @@ class MutableCollectionQuery(CollectionQuery):
|
|||
"""A collection query whose subqueries may be modified after the
|
||||
query is initialized.
|
||||
"""
|
||||
subqueries: MutableMapping
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.subqueries[key] = value
|
||||
|
|
@ -786,7 +788,7 @@ class Sort:
|
|||
the item database.
|
||||
"""
|
||||
|
||||
def order_clause(self):
|
||||
def order_clause(self) -> None:
|
||||
"""Generates a SQL fragment to be used in a ORDER BY clause, or
|
||||
None if no fragment is used (i.e., this is a slow sort).
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue