From 13c15613900cf4cfef0b3de41939d8b158539f1d Mon Sep 17 00:00:00 2001 From: Serene-Arc Date: Thu, 15 Sep 2022 20:34:21 +1000 Subject: [PATCH] Add more typings --- beets/dbcore/db.py | 164 +++++++++++++++++++++++++++--------------- beets/dbcore/query.py | 8 ++- 2 files changed, 113 insertions(+), 59 deletions(-) diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 6621b55d3..63b94a066 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -22,6 +22,10 @@ from collections import defaultdict import threading import sqlite3 import contextlib +from sqlite3 import Connection +from types import TracebackType +from typing import Iterable, Type, List, Tuple, NoReturn, Optional, Union, \ + Dict, Any, Generator from unidecode import unidecode @@ -29,9 +33,13 @@ import beets from beets.util import functemplate from beets.util import py3_path from beets.dbcore import types -from .query import MatchQuery, NullSort, TrueQuery, AndQuery +from .query import MatchQuery, NullSort, TrueQuery, AndQuery, Query, \ + FieldQuery, Sort from collections.abc import Mapping +from ..library import LibModel +from ..util.functemplate import Template + class DBAccessError(Exception): """The SQLite database became inaccessible. @@ -58,7 +66,12 @@ class FormattedMapping(Mapping): ALL_KEYS = '*' - def __init__(self, model, included_keys=ALL_KEYS, for_path=False): + def __init__( + self, + model: 'Model', + included_keys: str = ALL_KEYS, + for_path: bool = False, + ): self.for_path = for_path self.model = model if included_keys == self.ALL_KEYS: @@ -73,10 +86,10 @@ class FormattedMapping(Mapping): else: raise KeyError(key) - def __iter__(self): + def __iter__(self) -> Iterable: return iter(self.model_keys) - def __len__(self): + def __len__(self) -> int: return len(self.model_keys) def get(self, key, default=None): @@ -107,7 +120,7 @@ class LazyConvertDict: """Lazily convert types for attributes fetched from the database """ - def __init__(self, model_cls): + def __init__(self, model_cls: 'Model'): """Initialize the object empty """ self.data = {} @@ -148,12 +161,12 @@ class LazyConvertDict: if key in self.data: del self.data[key] - def keys(self): + def keys(self) -> List: """Get a list of available field names for this object. """ return list(self._converted.keys()) + list(self.data.keys()) - def copy(self): + def copy(self) -> 'LazyConvertDict': """Create a copy of the object. """ new = self.__class__(self.model_cls) @@ -169,7 +182,7 @@ class LazyConvertDict: for key, value in values.items(): self[key] = value - def items(self): + def items(self) -> Generator[Tuple]: """Iterate over (key, value) pairs that this object contains. Computed fields are not included. """ @@ -185,12 +198,12 @@ class LazyConvertDict: else: return default - def __contains__(self, key): + def __contains__(self, key) -> bool: """Determine whether `key` is an attribute on this object. """ return key in self.keys() - def __iter__(self): + def __iter__(self) -> Iterable: """Iterate over the available field names (excluding computed fields). """ @@ -269,14 +282,14 @@ class Model: """ @classmethod - def _getters(cls): + def _getters(cls: Type['Model']): """Return a mapping from field names to getter functions. """ # We could cache this if it becomes a performance problem to # gather the getter mapping every time. raise NotImplementedError() - def _template_funcs(self): + def _template_funcs(self) -> NoReturn: """Return a mapping from function names to text-transformer functions. """ @@ -285,7 +298,7 @@ class Model: # Basic operation. - def __init__(self, db=None, **values): + def __init__(self, db: Optional['Database'] = None, **values): """Create a new object with an optional Database association and initial field values. """ @@ -299,7 +312,12 @@ class Model: self.clear_dirty() @classmethod - def _awaken(cls, db=None, fixed_values={}, flex_values={}): + def _awaken( + cls: Type['Model'], + db: 'Database' = None, + fixed_values: Mapping = {}, + flex_values: Mapping = {}, + ) -> 'Model': """Create an object with values drawn from the database. This is a performance optimization: the checks involved with @@ -312,7 +330,7 @@ class Model: return obj - def __repr__(self): + def __repr__(self) -> str: return '{}({})'.format( type(self).__name__, ', '.join(f'{k}={v!r}' for k, v in dict(self).items()), @@ -326,7 +344,7 @@ class Model: if self._db: self._revision = self._db.revision - def _check_db(self, need_id=True): + def _check_db(self, need_id: bool = True): """Ensure that this object is associated with a database row: it has a reference to a database (`_db`) and an id. A ValueError exception is raised otherwise. @@ -338,7 +356,7 @@ class Model: if need_id and not self.id: raise ValueError('{} has no id'.format(type(self).__name__)) - def copy(self): + def copy(self) -> 'Model': """Create a copy of the model object. The field values and other state is duplicated, but the new copy @@ -356,7 +374,7 @@ class Model: # Essential field accessors. @classmethod - def _type(cls, key): + def _type(cls, key) -> types.Type: """Get the type of a field, a `Type` instance. If the field has no explicit type, it is given the base `Type`, @@ -364,7 +382,7 @@ class Model: """ return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT - def _get(self, key, default=None, raise_=False): + def _get(self, key, default: bool = None, raise_: bool = False): """Get the value for a field, or `default`. Alternatively, raise a KeyError if the field is not available. """ @@ -431,7 +449,7 @@ class Model: else: raise KeyError(f'no such field {key}') - def keys(self, computed=False): + def keys(self, computed: bool = False): """Get a list of available field names for this object. The `computed` parameter controls whether computed (plugin-provided) fields are included in the key list. @@ -457,19 +475,19 @@ class Model: for key, value in values.items(): self[key] = value - def items(self): + def items(self) -> Generator: """Iterate over (key, value) pairs that this object contains. Computed fields are not included. """ for key in self: yield key, self[key] - def __contains__(self, key): + def __contains__(self, key) -> bool: """Determine whether `key` is an attribute on this object. """ return key in self.keys(computed=True) - def __iter__(self): + def __iter__(self) -> Iterable: """Iterate over the available field names (excluding computed fields). """ @@ -500,7 +518,7 @@ class Model: # Database interaction (CRUD methods). - def store(self, fields=None): + def store(self, fields: bool = None): """Save the object's metadata into the library database. :param fields: the fields to be stored. If not specified, all fields will be. @@ -581,7 +599,7 @@ class Model: (self.id,) ) - def add(self, db=None): + def add(self, db: Optional['Database'] = None): """Add the object to the library database. This object must be associated with a database; you can provide one via the `db` parameter or use the currently associated database. @@ -610,13 +628,21 @@ class Model: _formatter = FormattedMapping - def formatted(self, included_keys=_formatter.ALL_KEYS, for_path=False): + def formatted( + self, + included_keys: str = _formatter.ALL_KEYS, + for_path: bool = False, + ): """Get a mapping containing all values on this object formatted as human-readable unicode strings. """ return self._formatter(self, included_keys, for_path) - def evaluate_template(self, template, for_path=False): + def evaluate_template( + self, + template: Union[str, Template], + for_path: bool = False, + ) -> str: """Evaluate a template (a string or a `Template` object) using the object's fields. If `for_path` is true, then no new path separators will be added to the template. @@ -630,7 +656,7 @@ class Model: # Parsing. @classmethod - def _parse(cls, key, string): + def _parse(cls, key, string: str) -> types.Type: """Parse a string as a value for the given key. """ if not isinstance(string, str): @@ -638,7 +664,7 @@ class Model: return cls._type(key).parse(string) - def set_parse(self, key, string): + def set_parse(self, key, string: str): """Set the object's key to a value represented by a string. """ self[key] = self._parse(key, string) @@ -646,12 +672,21 @@ class Model: # Convenient queries. @classmethod - def field_query(cls, field, pattern, query_cls=MatchQuery): + def field_query( + cls, + field, + pattern, + query_cls: Type[FieldQuery] = MatchQuery, + ) -> FieldQuery: """Get a `FieldQuery` for this model.""" return query_cls(field, pattern, field in cls._fields) @classmethod - def all_fields_query(cls, pats, query_cls=MatchQuery): + def all_fields_query( + cls: Type['Model'], + pats: Mapping, + query_cls: Type[FieldQuery] = MatchQuery, + ): """Get a query that matches many fields with different patterns. `pats` should be a mapping from field names to patterns. The @@ -670,8 +705,15 @@ class Results: constructs LibModel objects that reflect database rows. """ - def __init__(self, model_class, rows, db, flex_rows, - query=None, sort=None): + def __init__( + self, + model_class: Type[LibModel], + rows: List[Mapping], + db: 'Database', + flex_rows, + query: Optional[FieldQuery] = None, + sort=None, + ): """Create a result set that will construct objects of type `model_class`. @@ -703,7 +745,7 @@ class Results: # consumed. self._objects = [] - def _get_objects(self): + def _get_objects(self) -> Generator[Model]: """Construct and generate Model objects for they query. The objects are returned in the order emitted from the database; no slow sort is applied. @@ -738,7 +780,7 @@ class Results: yield obj break - def __iter__(self): + def __iter__(self) -> Iterable[Model]: """Construct and generate Model objects for all matching objects, in sorted order. """ @@ -751,7 +793,7 @@ class Results: # Objects are pre-sorted (i.e., by the database). return self._get_objects() - def _get_indexed_flex_attrs(self): + def _get_indexed_flex_attrs(self) -> Mapping: """ Index flexible attributes by the entity id they belong to """ flex_values = {} @@ -763,7 +805,7 @@ class Results: return flex_values - def _make_model(self, row, flex_values={}): + def _make_model(self, row, flex_values: Dict = {}) -> Model: """ Create a Model object for the given row """ cols = dict(row) @@ -774,7 +816,7 @@ class Results: obj = self.model_class._awaken(self.db, values, flex_values) return obj - def __len__(self): + def __len__(self) -> int: """Get the number of matching objects. """ if not self._rows: @@ -792,12 +834,12 @@ class Results: # A fast query. Just count the rows. return self._row_count - def __nonzero__(self): + def __nonzero__(self) -> bool: """Does this result contain any objects? """ return self.__bool__() - def __bool__(self): + def __bool__(self) -> bool: """Does this result contain any objects? """ return bool(len(self)) @@ -819,7 +861,7 @@ class Results: except StopIteration: raise IndexError(f'result index {n} out of range') - def get(self): + def get(self) -> Optional[Model]: """Return the first matching object, or None if no objects match. """ @@ -840,10 +882,10 @@ class Transaction: current transaction. """ - def __init__(self, db): + def __init__(self, db: 'Database'): self.db = db - def __enter__(self): + def __enter__(self) -> 'Transaction': """Begin a transaction. This transaction may be created while another is active in a different thread. """ @@ -856,7 +898,12 @@ class Transaction: self.db._db_lock.acquire() return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Type[Exception], + exc_value: Exception, + traceback: TracebackType, + ): """Complete a transaction. This must be the most recently entered but not yet exited transaction. If it is the last active transaction, the database updates are committed. @@ -872,14 +919,14 @@ class Transaction: self._mutated = False self.db._db_lock.release() - def query(self, statement, subvals=()): + def query(self, statement: str, subvals: Iterable = ()) -> List: """Execute an SQL statement with substitution values and return a list of rows from the database. """ cursor = self.db._connection().execute(statement, subvals) return cursor.fetchall() - def mutate(self, statement, subvals=()): + def mutate(self, statement: str, subvals: Iterable = ()) -> Any: """Execute an SQL statement with substitution values and return the row ID of the last affected row. """ @@ -898,7 +945,7 @@ class Transaction: self._mutated = True return cursor.lastrowid - def script(self, statements): + def script(self, statements: str): """Execute a string containing multiple SQL statements.""" # We don't know whether this mutates, but quite likely it does. self._mutated = True @@ -922,7 +969,7 @@ class Database: data is written in a transaction. """ - def __init__(self, path, timeout=5.0): + def __init__(self, path, timeout: float = 5.0): if sqlite3.threadsafety == 0: raise RuntimeError( "sqlite3 must be compiled with multi-threading support" @@ -956,7 +1003,7 @@ class Database: # Primitive access control: connections and transactions. - def _connection(self): + def _connection(self) -> Connection: """Get a SQLite connection object to the underlying database. One connection object is created per thread. """ @@ -969,7 +1016,7 @@ class Database: self._connections[thread_id] = conn return conn - def _create_connection(self): + def _create_connection(self) -> Connection: """Create a SQLite connection to the underlying database. Makes a new connection every time. If you need to configure the @@ -1019,7 +1066,7 @@ class Database: conn.close() @contextlib.contextmanager - def _tx_stack(self): + def _tx_stack(self) -> Generator[List]: """A context manager providing access to the current thread's transaction stack. The context manager synchronizes access to the stack map. Transactions should never migrate across threads. @@ -1028,7 +1075,7 @@ class Database: with self._shared_map_lock: yield self._tx_stacks[thread_id] - def transaction(self): + def transaction(self) -> Transaction: """Get a :class:`Transaction` object for interacting directly with the underlying SQLite database. """ @@ -1048,7 +1095,7 @@ class Database: # Schema setup and migration. - def _make_table(self, table, fields): + def _make_table(self, table: str, fields: Mapping[str, types.Type]): """Set up the schema of the database. `fields` is a mapping from field names to `Type`s. Columns are added if necessary. """ @@ -1083,7 +1130,7 @@ class Database: with self.transaction() as tx: tx.script(setup_sql) - def _make_attribute_table(self, flex_table): + def _make_attribute_table(self, flex_table: str): """Create a table and associated index for flexible attributes for the given entity (if they don't exist). """ @@ -1101,7 +1148,12 @@ class Database: # Querying. - def _fetch(self, model_cls, query=None, sort=None): + def _fetch( + self, + model_cls: Type[LibModel], + query: Optional[Query] = None, + sort: Optional[Sort] = None, + ) -> Results: """Fetch the objects of type `model_cls` matching the given query. The query may be given as a string, string sequence, a Query object, or None (to fetch everything). `sort` is an @@ -1141,7 +1193,7 @@ class Database: sort if sort.is_slow() else None, # Slow sort component. ) - def _get(self, model_cls, id): + def _get(self, model_cls: Union[Type[Model], Type[LibModel]], id) -> Model: """Get a Model object by its id or None if the id does not exist. """ diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index e190083c5..209c1319f 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -18,13 +18,14 @@ import re from operator import mul from typing import Union, Tuple, List, Optional, Pattern, Any, Type, Iterator, \ - Collection, Mapping + Collection, Mapping, MutableMapping from beets import util from datetime import datetime, timedelta import unicodedata from functools import reduce +from beets.dbcore import Model from beets.library import Item @@ -120,7 +121,7 @@ class FieldQuery(Query): """ raise NotImplementedError() - def match(self, item: Item): + def match(self, item: Model): return self.value_match(self.pattern, item.get(self.field)) def __repr__(self) -> str: @@ -465,6 +466,7 @@ class MutableCollectionQuery(CollectionQuery): """A collection query whose subqueries may be modified after the query is initialized. """ + subqueries: MutableMapping def __setitem__(self, key, value): self.subqueries[key] = value @@ -786,7 +788,7 @@ class Sort: the item database. """ - def order_clause(self): + def order_clause(self) -> None: """Generates a SQL fragment to be used in a ORDER BY clause, or None if no fragment is used (i.e., this is a slow sort). """