diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index accb62327..8635dd1c1 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -16,6 +16,7 @@ """ from __future__ import annotations +from abc import ABC import time import os import re @@ -25,22 +26,36 @@ import sqlite3 import contextlib from sqlite3 import Connection from types import TracebackType -from typing import Iterable, Type, List, Tuple, Optional, Union, \ - Dict, Any, Generator, Iterator, Callable +from typing import ( + Any, + Callable, + cast, + DefaultDict, + Dict, + Generator, + Generic, + Iterable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, +) from unidecode import unidecode import beets from beets.util import functemplate from beets.util import py3_path -from beets.dbcore import types +from . import types from .query import MatchQuery, NullSort, TrueQuery, AndQuery, Query, \ - FieldQuery, Sort -from collections.abc import Mapping + FieldQuery, Sort, FieldSort -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from beets.library import LibModel from ..util.functemplate import Template @@ -53,7 +68,7 @@ class DBAccessError(Exception): """ -class FormattedMapping(Mapping): +class FormattedMapping(Mapping[str, str]): """A `dict`-like formatted view of a model. The accessor `mapping[key]` returns the formatted version of @@ -71,7 +86,7 @@ class FormattedMapping(Mapping): def __init__( self, - model: 'Model', + model: Model, included_keys: str = ALL_KEYS, for_path: bool = False, ): @@ -83,31 +98,39 @@ class FormattedMapping(Mapping): else: self.model_keys = included_keys - def __getitem__(self, key): + def __getitem__(self, key: str) -> str: if key in self.model_keys: return self._get_formatted(self.model, key) else: raise KeyError(key) - def __iter__(self) -> Iterable[str]: + def __iter__(self) -> Iterator[str]: return iter(self.model_keys) def __len__(self) -> int: return len(self.model_keys) - def get(self, key, default=None): + # The following signature is incompatible with `Mapping[str, str]`, since + # the return type doesn't include `None` (but `default` can be `None`). + def get( # type: ignore + self, + key: str, + default: Optional[str] = None, + ) -> str: + """Similar to Mapping.get(key, default), but always formats to str. + """ if default is None: default = self.model._type(key).format(None) return super().get(key, default) - def _get_formatted(self, model, key): + def _get_formatted(self, model: Model, key: str) -> str: value = model._type(key).format(model.get(key)) if isinstance(value, bytes): value = value.decode('utf-8', 'ignore') if self.for_path: - sep_repl = beets.config['path_sep_replace'].as_str() - sep_drive = beets.config['drive_sep_replace'].as_str() + sep_repl = cast(str, beets.config['path_sep_replace'].as_str()) + sep_drive = cast(str, beets.config['drive_sep_replace'].as_str()) if re.match(r'^\w:', value): value = re.sub(r'(?<=^\w):', sep_drive, value) @@ -119,6 +142,15 @@ class FormattedMapping(Mapping): return value +# NOTE: This seems like it should be a `Mapping`, i.e. +# ``` +# class LazyConvertDict(Mapping[str, Any]) +# ``` +# but there are some conflicts with the `Mapping` protocol such that we +# can't do this without changing behaviour: In particular, iterators returned +# by some methods build intermediate lists, such that modification of the +# `LazyConvertDict` becomes safe during iteration. Some code does in fact rely +# on this. class LazyConvertDict: """Lazily convert types for attributes fetched from the database """ @@ -126,60 +158,61 @@ class LazyConvertDict: def __init__(self, model_cls: 'Model'): """Initialize the object empty """ - self.data = {} + # FIXME: Dict[str, SQLiteType] + self._data: Dict[str, Any] = {} self.model_cls = model_cls - self._converted = {} + self._converted: Dict[str, Any] = {} - def init(self, data): + def init(self, data: Dict[str, Any]): """Set the base data that should be lazily converted """ - self.data = data + self._data = data - def _convert(self, key, value): - """Convert the attribute type according the the SQL type + def _convert(self, key: str, value: Any): + """Convert the attribute type according to the SQL type """ return self.model_cls._type(key).from_sql(value) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any): """Set an attribute value, assume it's already converted """ self._converted[key] = value - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: """Get an attribute value, converting the type on demand if needed """ if key in self._converted: return self._converted[key] - elif key in self.data: - value = self._convert(key, self.data[key]) + elif key in self._data: + value = self._convert(key, self._data[key]) self._converted[key] = value return value - def __delitem__(self, key): + def __delitem__(self, key: str): """Delete both converted and base data """ if key in self._converted: del self._converted[key] - if key in self.data: - del self.data[key] + if key in self._data: + del self._data[key] def keys(self) -> List[str]: """Get a list of available field names for this object. """ - return list(self._converted.keys()) + list(self.data.keys()) + return list(self._converted.keys()) + list(self._data.keys()) - def copy(self) -> 'LazyConvertDict': + def copy(self) -> LazyConvertDict: """Create a copy of the object. """ new = self.__class__(self.model_cls) - new.data = self.data.copy() + new._data = self._data.copy() new._converted = self._converted.copy() return new # Act like a dictionary. - def update(self, values): + def update(self, values: Mapping[str, Any]): """Assign all values in the given dict. """ for key, value in values.items(): @@ -192,7 +225,7 @@ class LazyConvertDict: for key in self: yield key, self[key] - def get(self, key, default=None): + def get(self, key: str, default: Optional[Any] = None): """Get the value for a given key or `default` if it does not exist. """ @@ -201,21 +234,30 @@ class LazyConvertDict: else: return default - def __contains__(self, key) -> bool: + def __contains__(self, key: Any) -> bool: """Determine whether `key` is an attribute on this object. """ - return key in self.keys() + return key in self._converted or key in self._data - def __iter__(self) -> Iterable[str]: + def __iter__(self) -> Iterator[str]: """Iterate over the available field names (excluding computed fields). """ + # NOTE: It would be nice to use the following: + # yield from self._converted + # yield from self._data + # but that won't work since some code relies on modifying `self` + # during iteration. return iter(self.keys()) + def __len__(self) -> int: + # FIXME: This is incorrect due to duplication of keys + return len(self._converted) + len(self._data) + # Abstract base for model classes. -class Model: +class Model(ABC): """An abstract object representing an object in the database. Model objects act like dictionaries (i.e., they allow subscript access like ``obj['field']``). The same field set is available via attribute @@ -241,34 +283,34 @@ class Model: # Abstract components (to be provided by subclasses). - _table = None + _table: str """The main SQLite table name. """ - _flex_table = None + _flex_table: str """The flex field SQLite table name. """ - _fields = {} + _fields: Dict[str, types.Type] = {} """A mapping indicating available "fixed" fields on this type. The keys are field names and the values are `Type` objects. """ - _search_fields = () + _search_fields: Sequence[str] = () """The fields that should be queried by default by unqualified query terms. """ - _types = {} + _types: Dict[str, types.Type] = {} """Optional Types for non-fixed (i.e., flexible and computed) fields. """ - _sorts = {} + _sorts: Dict[str, Type[FieldSort]] = {} """Optional named sort criteria. The keys are strings and the values are subclasses of `Sort`. """ - _queries = {} + _queries: Dict[str, Type[Query]] = {} """Named queries that use a field-like `name:value` syntax but which do not relate to any specific field. """ @@ -301,12 +343,12 @@ class Model: # Basic operation. - def __init__(self, db: Optional['Database'] = None, **values): + def __init__(self, db: Optional[Database] = None, **values): """Create a new object with an optional Database association and initial field values. """ self._db = db - self._dirty = set() + self._dirty: Set[str] = set() self._values_fixed = LazyConvertDict(self) self._values_flex = LazyConvertDict(self) @@ -316,11 +358,11 @@ class Model: @classmethod def _awaken( - cls: Type['Model'], - db: 'Database' = None, - fixed_values: Mapping = {}, - flex_values: Mapping = {}, - ) -> 'Model': + cls: Type[AnyModel], + db: Optional[Database] = None, + fixed_values: Dict[str, Any] = {}, + flex_values: Dict[str, Any] = {}, + ) -> AnyModel: """Create an object with values drawn from the database. This is a performance optimization: the checks involved with @@ -347,7 +389,7 @@ class Model: if self._db: self._revision = self._db.revision - def _check_db(self, need_id: bool = True): + def _check_db(self, need_id: bool = True) -> Database: """Ensure that this object is associated with a database row: it has a reference to a database (`_db`) and an id. A ValueError exception is raised otherwise. @@ -359,6 +401,8 @@ class Model: if need_id and not self.id: raise ValueError('{} has no id'.format(type(self).__name__)) + return self._db + def copy(self) -> 'Model': """Create a copy of the model object. @@ -490,7 +534,7 @@ class Model: """ return key in self.keys(computed=True) - def __iter__(self) -> Iterable[str]: + def __iter__(self) -> Iterator[str]: """Iterate over the available field names (excluding computed fields). """ @@ -521,14 +565,14 @@ class Model: # Database interaction (CRUD methods). - def store(self, fields: bool = None): + def store(self, fields: Optional[Iterable[str]] = None): """Save the object's metadata into the library database. :param fields: the fields to be stored. If not specified, all fields will be. """ if fields is None: fields = self._fields - self._check_db() + db = self._check_db() # Build assignments for query. assignments = [] @@ -539,13 +583,13 @@ class Model: assignments.append(key + '=?') value = self._type(key).to_sql(self[key]) subvars.append(value) - assignments = ','.join(assignments) - with self._db.transaction() as tx: + with db.transaction() as tx: # Main table update. if assignments: query = 'UPDATE {} SET {} WHERE id=?'.format( - self._table, assignments + self._table, + ','.join(assignments) ) subvars.append(self.id) tx.mutate(query, subvars) @@ -577,11 +621,11 @@ class Model: If check_revision is true, the database is only queried loaded when a transaction has been committed since the item was last loaded. """ - self._check_db() - if not self._dirty and self._db.revision == self._revision: + db = self._check_db() + if not self._dirty and db.revision == self._revision: # Exit early return - stored_obj = self._db._get(type(self), self.id) + stored_obj = db._get(type(self), self.id) assert stored_obj is not None, f"object {self.id} not in DB" self._values_fixed = LazyConvertDict(self) self._values_flex = LazyConvertDict(self) @@ -591,8 +635,8 @@ class Model: def remove(self): """Remove the object's associated rows from the database. """ - self._check_db() - with self._db.transaction() as tx: + db = self._check_db() + with db.transaction() as tx: tx.mutate( f'DELETE FROM {self._table} WHERE id=?', (self.id,) @@ -612,9 +656,9 @@ class Model: """ if db: self._db = db - self._check_db(False) + db = self._check_db(False) - with self._db.transaction() as tx: + with db.transaction() as tx: new_id = tx.mutate( f'INSERT INTO {self._table} DEFAULT VALUES' ) @@ -652,9 +696,12 @@ class Model: """ # Perform substitution. if isinstance(template, str): - template = functemplate.template(template) - return template.substitute(self.formatted(for_path=for_path), - self._template_funcs()) + t = functemplate.template(template) + else: + # Help out mypy + t = template + return t.substitute(self.formatted(for_path=for_path), + self._template_funcs()) # Parsing. @@ -703,24 +750,28 @@ class Model: # Database controller and supporting interfaces. -class Results: + +AnyModel = TypeVar("AnyModel", bound=Model) + + +class Results(Generic[AnyModel]): """An item query result set. Iterating over the collection lazily - constructs LibModel objects that reflect database rows. + constructs Model objects that reflect database rows. """ def __init__( self, - model_class: Type['LibModel'], + model_class: Type[AnyModel], rows: List[Mapping], db: 'Database', flex_rows, - query: Optional[FieldQuery] = None, + query: Optional[Query] = None, sort=None, ): """Create a result set that will construct objects of type `model_class`. - `model_class` is a subclass of `LibModel` that will be + `model_class` is a subclass of `Model` that will be constructed. `rows` is a query result: a list of mappings. The new objects will be associated with the database `db`. @@ -746,9 +797,9 @@ class Results: # The materialized objects corresponding to rows that have been # consumed. - self._objects = [] + self._objects: List[AnyModel] = [] - def _get_objects(self) -> Iterable[Model]: + def _get_objects(self) -> Iterator[AnyModel]: """Construct and generate Model objects for they query. The objects are returned in the order emitted from the database; no slow sort is applied. @@ -783,7 +834,7 @@ class Results: yield obj break - def __iter__(self) -> Iterable[Model]: + def __iter__(self) -> Iterator[AnyModel]: """Construct and generate Model objects for all matching objects, in sorted order. """ @@ -799,7 +850,7 @@ class Results: def _get_indexed_flex_attrs(self) -> Mapping: """ Index flexible attributes by the entity id they belong to """ - flex_values = {} + flex_values: Dict[int, Dict[str, Any]] = {} for row in self.flex_rows: if row['entity_id'] not in flex_values: flex_values[row['entity_id']] = {} @@ -808,7 +859,7 @@ class Results: return flex_values - def _make_model(self, row, flex_values: Dict = {}) -> Model: + def _make_model(self, row, flex_values: Dict = {}) -> AnyModel: """ Create a Model object for the given row """ cols = dict(row) @@ -864,7 +915,7 @@ class Results: except StopIteration: raise IndexError(f'result index {n} out of range') - def get(self) -> Optional[Model]: + def get(self) -> Optional[AnyModel]: """Return the first matching object, or None if no objects match. """ @@ -922,14 +973,14 @@ class Transaction: self._mutated = False self.db._db_lock.release() - def query(self, statement: str, subvals: Iterable = ()) -> List: + def query(self, statement: str, subvals: Sequence = ()) -> List: """Execute an SQL statement with substitution values and return a list of rows from the database. """ cursor = self.db._connection().execute(statement, subvals) return cursor.fetchall() - def mutate(self, statement: str, subvals: Iterable = ()) -> Any: + def mutate(self, statement: str, subvals: Sequence = ()) -> Any: """Execute an SQL statement with substitution values and return the row ID of the last affected row. """ @@ -960,7 +1011,7 @@ class Database: the backend. """ - _models = () + _models: Sequence[Type[Model]] = () """The Model subclasses representing tables in this database. """ @@ -981,9 +1032,10 @@ class Database: self.path = path self.timeout = timeout - self._connections = {} - self._tx_stacks = defaultdict(list) - self._extensions = [] + self._connections: Dict[int, sqlite3.Connection] = {} + self._tx_stacks: DefaultDict[int, List[Transaction]] = \ + defaultdict(list) + self._extensions: List[str] = [] # A lock to protect the _connections and _tx_stacks maps, which # both map thread IDs to private resources. @@ -1011,6 +1063,11 @@ class Database: One connection object is created per thread. """ thread_id = threading.current_thread().ident + # Help the type checker: ident can only be None if the thread has not + # been started yet; but since this results from current_thread(), that + # can't happen + assert thread_id is not None + with self._shared_map_lock: if thread_id in self._connections: return self._connections[thread_id] @@ -1075,6 +1132,11 @@ class Database: the stack map. Transactions should never migrate across threads. """ thread_id = threading.current_thread().ident + # Help the type checker: ident can only be None if the thread has not + # been started yet; but since this results from current_thread(), that + # can't happen + assert thread_id is not None + with self._shared_map_lock: yield self._tx_stacks[thread_id] @@ -1084,7 +1146,7 @@ class Database: """ return Transaction(self) - def load_extension(self, path): + def load_extension(self, path: str): """Load an SQLite extension into all open connections.""" if not self.supports_extensions: raise ValueError( @@ -1152,11 +1214,11 @@ class Database: # Querying. def _fetch( - self, - model_cls: Type['LibModel'], - query: Optional[Query] = None, - sort: Optional[Sort] = None, - ) -> Results: + self, + model_cls: Type[AnyModel], + query: Optional[Query] = None, + sort: Optional[Sort] = None, + ) -> Results[AnyModel]: """Fetch the objects of type `model_cls` matching the given query. The query may be given as a string, string sequence, a Query object, or None (to fetch everything). `sort` is an @@ -1180,10 +1242,10 @@ class Database: SELECT * FROM {} WHERE entity_id IN (SELECT id FROM {} WHERE {}); """.format( - model_cls._flex_table, - model_cls._table, - where or '1', - ) + model_cls._flex_table, + model_cls._table, + where or '1', + ) ) with self.transaction() as tx: @@ -1196,7 +1258,11 @@ class Database: sort if sort.is_slow() else None, # Slow sort component. ) - def _get(self, model_cls: Union[Type[Model], Type[LibModel]], id) -> Model: + def _get( + self, + model_cls: Type[AnyModel], + id, + ) -> Optional[AnyModel]: """Get a Model object by its id or None if the id does not exist. """ diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index fbc080426..9c04dc0ee 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -334,9 +334,9 @@ class BytesQuery(FieldQuery[bytes]): else: bytes_pattern = pattern self.buf_pattern = memoryview(bytes_pattern) - elif isinstance(self.pattern, memoryview): - self.buf_pattern = self.pattern - bytes_pattern = bytes(self.pattern) + elif isinstance(pattern, memoryview): + self.buf_pattern = pattern + bytes_pattern = bytes(pattern) else: raise ValueError("pattern must be bytes, str, or memoryview") diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index 2fa7bcfbb..dc51a5065 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -21,7 +21,7 @@ from typing import Dict, Type, Tuple, Optional, Collection, List, \ Sequence from . import query, Model -from .query import Sort +from .query import Query, Sort PARSE_QUERY_PART_REGEX = re.compile( # Non-capturing optional segment for the keyword. @@ -132,7 +132,7 @@ def construct_query_part( # Use `model_cls` to build up a map from field (or query) names to # `Query` classes. - query_classes = {} + query_classes: Dict[str, Type[Query]] = {} for k, t in itertools.chain(model_cls._fields.items(), model_cls._types.items()): query_classes[k] = t.query