diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 6621b55d3..084ceef99 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -15,6 +15,7 @@ """The central Model and Database constructs for DBCore. """ +from __future__ import annotations import time import os import re @@ -22,6 +23,10 @@ from collections import defaultdict import threading import sqlite3 import contextlib +from sqlite3 import Connection +from types import TracebackType +from typing import Iterable, Type, List, Tuple, Optional, Union, \ + Dict, Any, Generator, Iterator, Callable from unidecode import unidecode @@ -29,9 +34,15 @@ import beets from beets.util import functemplate from beets.util import py3_path from beets.dbcore import types -from .query import MatchQuery, NullSort, TrueQuery, AndQuery +from .query import MatchQuery, NullSort, TrueQuery, AndQuery, Query, \ + FieldQuery, Sort from collections.abc import Mapping +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from beets.library import LibModel +from ..util.functemplate import Template + class DBAccessError(Exception): """The SQLite database became inaccessible. @@ -58,7 +69,12 @@ class FormattedMapping(Mapping): ALL_KEYS = '*' - def __init__(self, model, included_keys=ALL_KEYS, for_path=False): + def __init__( + self, + model: 'Model', + included_keys: str = ALL_KEYS, + for_path: bool = False, + ): self.for_path = for_path self.model = model if included_keys == self.ALL_KEYS: @@ -73,10 +89,10 @@ class FormattedMapping(Mapping): else: raise KeyError(key) - def __iter__(self): + def __iter__(self) -> Iterable[str]: return iter(self.model_keys) - def __len__(self): + def __len__(self) -> int: return len(self.model_keys) def get(self, key, default=None): @@ -107,7 +123,7 @@ class LazyConvertDict: """Lazily convert types for attributes fetched from the database """ - def __init__(self, model_cls): + def __init__(self, model_cls: 'Model'): """Initialize the object empty """ self.data = {} @@ -148,12 +164,12 @@ class LazyConvertDict: if key in self.data: del self.data[key] - def keys(self): + def keys(self) -> List[str]: """Get a list of available field names for this object. """ return list(self._converted.keys()) + list(self.data.keys()) - def copy(self): + def copy(self) -> 'LazyConvertDict': """Create a copy of the object. """ new = self.__class__(self.model_cls) @@ -169,7 +185,7 @@ class LazyConvertDict: for key, value in values.items(): self[key] = value - def items(self): + def items(self) -> Iterable[Tuple[str, Any]]: """Iterate over (key, value) pairs that this object contains. Computed fields are not included. """ @@ -185,12 +201,12 @@ class LazyConvertDict: else: return default - def __contains__(self, key): + def __contains__(self, key) -> bool: """Determine whether `key` is an attribute on this object. """ return key in self.keys() - def __iter__(self): + def __iter__(self) -> Iterable[str]: """Iterate over the available field names (excluding computed fields). """ @@ -269,14 +285,14 @@ class Model: """ @classmethod - def _getters(cls): + def _getters(cls: Type['Model']): """Return a mapping from field names to getter functions. """ # We could cache this if it becomes a performance problem to # gather the getter mapping every time. raise NotImplementedError() - def _template_funcs(self): + def _template_funcs(self) -> Mapping[str, Callable[[str], str]]: """Return a mapping from function names to text-transformer functions. """ @@ -285,7 +301,7 @@ class Model: # Basic operation. - def __init__(self, db=None, **values): + def __init__(self, db: Optional['Database'] = None, **values): """Create a new object with an optional Database association and initial field values. """ @@ -299,7 +315,12 @@ class Model: self.clear_dirty() @classmethod - def _awaken(cls, db=None, fixed_values={}, flex_values={}): + def _awaken( + cls: Type['Model'], + db: 'Database' = None, + fixed_values: Mapping = {}, + flex_values: Mapping = {}, + ) -> 'Model': """Create an object with values drawn from the database. This is a performance optimization: the checks involved with @@ -312,7 +333,7 @@ class Model: return obj - def __repr__(self): + def __repr__(self) -> str: return '{}({})'.format( type(self).__name__, ', '.join(f'{k}={v!r}' for k, v in dict(self).items()), @@ -326,7 +347,7 @@ class Model: if self._db: self._revision = self._db.revision - def _check_db(self, need_id=True): + def _check_db(self, need_id: bool = True): """Ensure that this object is associated with a database row: it has a reference to a database (`_db`) and an id. A ValueError exception is raised otherwise. @@ -338,7 +359,7 @@ class Model: if need_id and not self.id: raise ValueError('{} has no id'.format(type(self).__name__)) - def copy(self): + def copy(self) -> 'Model': """Create a copy of the model object. The field values and other state is duplicated, but the new copy @@ -356,7 +377,7 @@ class Model: # Essential field accessors. @classmethod - def _type(cls, key): + def _type(cls, key) -> types.Type: """Get the type of a field, a `Type` instance. If the field has no explicit type, it is given the base `Type`, @@ -364,7 +385,7 @@ class Model: """ return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT - def _get(self, key, default=None, raise_=False): + def _get(self, key, default: bool = None, raise_: bool = False): """Get the value for a field, or `default`. Alternatively, raise a KeyError if the field is not available. """ @@ -431,7 +452,7 @@ class Model: else: raise KeyError(f'no such field {key}') - def keys(self, computed=False): + def keys(self, computed: bool = False): """Get a list of available field names for this object. The `computed` parameter controls whether computed (plugin-provided) fields are included in the key list. @@ -457,19 +478,19 @@ class Model: for key, value in values.items(): self[key] = value - def items(self): + def items(self) -> Iterator[Tuple[str, Any]]: """Iterate over (key, value) pairs that this object contains. Computed fields are not included. """ for key in self: yield key, self[key] - def __contains__(self, key): + def __contains__(self, key) -> bool: """Determine whether `key` is an attribute on this object. """ return key in self.keys(computed=True) - def __iter__(self): + def __iter__(self) -> Iterable[str]: """Iterate over the available field names (excluding computed fields). """ @@ -500,7 +521,7 @@ class Model: # Database interaction (CRUD methods). - def store(self, fields=None): + def store(self, fields: bool = None): """Save the object's metadata into the library database. :param fields: the fields to be stored. If not specified, all fields will be. @@ -581,7 +602,7 @@ class Model: (self.id,) ) - def add(self, db=None): + def add(self, db: Optional['Database'] = None): """Add the object to the library database. This object must be associated with a database; you can provide one via the `db` parameter or use the currently associated database. @@ -610,13 +631,21 @@ class Model: _formatter = FormattedMapping - def formatted(self, included_keys=_formatter.ALL_KEYS, for_path=False): + def formatted( + self, + included_keys: str = _formatter.ALL_KEYS, + for_path: bool = False, + ): """Get a mapping containing all values on this object formatted as human-readable unicode strings. """ return self._formatter(self, included_keys, for_path) - def evaluate_template(self, template, for_path=False): + def evaluate_template( + self, + template: Union[str, Template], + for_path: bool = False, + ) -> str: """Evaluate a template (a string or a `Template` object) using the object's fields. If `for_path` is true, then no new path separators will be added to the template. @@ -630,7 +659,7 @@ class Model: # Parsing. @classmethod - def _parse(cls, key, string): + def _parse(cls, key, string: str) -> Any: """Parse a string as a value for the given key. """ if not isinstance(string, str): @@ -638,7 +667,7 @@ class Model: return cls._type(key).parse(string) - def set_parse(self, key, string): + def set_parse(self, key, string: str): """Set the object's key to a value represented by a string. """ self[key] = self._parse(key, string) @@ -646,12 +675,21 @@ class Model: # Convenient queries. @classmethod - def field_query(cls, field, pattern, query_cls=MatchQuery): + def field_query( + cls, + field, + pattern, + query_cls: Type[FieldQuery] = MatchQuery, + ) -> FieldQuery: """Get a `FieldQuery` for this model.""" return query_cls(field, pattern, field in cls._fields) @classmethod - def all_fields_query(cls, pats, query_cls=MatchQuery): + def all_fields_query( + cls: Type['Model'], + pats: Mapping, + query_cls: Type[FieldQuery] = MatchQuery, + ): """Get a query that matches many fields with different patterns. `pats` should be a mapping from field names to patterns. The @@ -670,8 +708,15 @@ class Results: constructs LibModel objects that reflect database rows. """ - def __init__(self, model_class, rows, db, flex_rows, - query=None, sort=None): + def __init__( + self, + model_class: Type['LibModel'], + rows: List[Mapping], + db: 'Database', + flex_rows, + query: Optional[FieldQuery] = None, + sort=None, + ): """Create a result set that will construct objects of type `model_class`. @@ -703,7 +748,7 @@ class Results: # consumed. self._objects = [] - def _get_objects(self): + def _get_objects(self) -> Iterable[Model]: """Construct and generate Model objects for they query. The objects are returned in the order emitted from the database; no slow sort is applied. @@ -738,7 +783,7 @@ class Results: yield obj break - def __iter__(self): + def __iter__(self) -> Iterable[Model]: """Construct and generate Model objects for all matching objects, in sorted order. """ @@ -751,7 +796,7 @@ class Results: # Objects are pre-sorted (i.e., by the database). return self._get_objects() - def _get_indexed_flex_attrs(self): + def _get_indexed_flex_attrs(self) -> Mapping: """ Index flexible attributes by the entity id they belong to """ flex_values = {} @@ -763,7 +808,7 @@ class Results: return flex_values - def _make_model(self, row, flex_values={}): + def _make_model(self, row, flex_values: Dict = {}) -> Model: """ Create a Model object for the given row """ cols = dict(row) @@ -774,7 +819,7 @@ class Results: obj = self.model_class._awaken(self.db, values, flex_values) return obj - def __len__(self): + def __len__(self) -> int: """Get the number of matching objects. """ if not self._rows: @@ -792,12 +837,12 @@ class Results: # A fast query. Just count the rows. return self._row_count - def __nonzero__(self): + def __nonzero__(self) -> bool: """Does this result contain any objects? """ return self.__bool__() - def __bool__(self): + def __bool__(self) -> bool: """Does this result contain any objects? """ return bool(len(self)) @@ -819,7 +864,7 @@ class Results: except StopIteration: raise IndexError(f'result index {n} out of range') - def get(self): + def get(self) -> Optional[Model]: """Return the first matching object, or None if no objects match. """ @@ -840,10 +885,10 @@ class Transaction: current transaction. """ - def __init__(self, db): + def __init__(self, db: 'Database'): self.db = db - def __enter__(self): + def __enter__(self) -> 'Transaction': """Begin a transaction. This transaction may be created while another is active in a different thread. """ @@ -856,7 +901,12 @@ class Transaction: self.db._db_lock.acquire() return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Type[Exception], + exc_value: Exception, + traceback: TracebackType, + ): """Complete a transaction. This must be the most recently entered but not yet exited transaction. If it is the last active transaction, the database updates are committed. @@ -872,14 +922,14 @@ class Transaction: self._mutated = False self.db._db_lock.release() - def query(self, statement, subvals=()): + def query(self, statement: str, subvals: Iterable = ()) -> List: """Execute an SQL statement with substitution values and return a list of rows from the database. """ cursor = self.db._connection().execute(statement, subvals) return cursor.fetchall() - def mutate(self, statement, subvals=()): + def mutate(self, statement: str, subvals: Iterable = ()) -> Any: """Execute an SQL statement with substitution values and return the row ID of the last affected row. """ @@ -898,7 +948,7 @@ class Transaction: self._mutated = True return cursor.lastrowid - def script(self, statements): + def script(self, statements: str): """Execute a string containing multiple SQL statements.""" # We don't know whether this mutates, but quite likely it does. self._mutated = True @@ -922,7 +972,7 @@ class Database: data is written in a transaction. """ - def __init__(self, path, timeout=5.0): + def __init__(self, path, timeout: float = 5.0): if sqlite3.threadsafety == 0: raise RuntimeError( "sqlite3 must be compiled with multi-threading support" @@ -956,7 +1006,7 @@ class Database: # Primitive access control: connections and transactions. - def _connection(self): + def _connection(self) -> Connection: """Get a SQLite connection object to the underlying database. One connection object is created per thread. """ @@ -969,7 +1019,7 @@ class Database: self._connections[thread_id] = conn return conn - def _create_connection(self): + def _create_connection(self) -> Connection: """Create a SQLite connection to the underlying database. Makes a new connection every time. If you need to configure the @@ -1019,7 +1069,7 @@ class Database: conn.close() @contextlib.contextmanager - def _tx_stack(self): + def _tx_stack(self) -> Generator[List, None, None]: """A context manager providing access to the current thread's transaction stack. The context manager synchronizes access to the stack map. Transactions should never migrate across threads. @@ -1028,7 +1078,7 @@ class Database: with self._shared_map_lock: yield self._tx_stacks[thread_id] - def transaction(self): + def transaction(self) -> Transaction: """Get a :class:`Transaction` object for interacting directly with the underlying SQLite database. """ @@ -1048,7 +1098,7 @@ class Database: # Schema setup and migration. - def _make_table(self, table, fields): + def _make_table(self, table: str, fields: Mapping[str, types.Type]): """Set up the schema of the database. `fields` is a mapping from field names to `Type`s. Columns are added if necessary. """ @@ -1083,7 +1133,7 @@ class Database: with self.transaction() as tx: tx.script(setup_sql) - def _make_attribute_table(self, flex_table): + def _make_attribute_table(self, flex_table: str): """Create a table and associated index for flexible attributes for the given entity (if they don't exist). """ @@ -1101,7 +1151,12 @@ class Database: # Querying. - def _fetch(self, model_cls, query=None, sort=None): + def _fetch( + self, + model_cls: Type['LibModel'], + query: Optional[Query] = None, + sort: Optional[Sort] = None, + ) -> Results: """Fetch the objects of type `model_cls` matching the given query. The query may be given as a string, string sequence, a Query object, or None (to fetch everything). `sort` is an @@ -1141,7 +1196,7 @@ class Database: sort if sort.is_slow() else None, # Slow sort component. ) - def _get(self, model_cls, id): + def _get(self, model_cls: Union[Type[Model], Type[LibModel]], id) -> Model: """Get a Model object by its id or None if the id does not exist. """ diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 016fe2c1a..5a9ea7059 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -15,13 +15,24 @@ """The Query type hierarchy for DBCore. """ +from __future__ import annotations import re from operator import mul +from typing import Union, Tuple, List, Optional, Pattern, Any, Type, Iterator,\ + Collection, MutableMapping, Sequence + from beets import util from datetime import datetime, timedelta import unicodedata from functools import reduce +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from beets.library import Item + from beets.dbcore import Model + class ParsingError(ValueError): """Abstract class for any unparseable user-requested album/query @@ -60,7 +71,7 @@ class Query: """An abstract class representing a query into the item database. """ - def clause(self): + def clause(self) -> Tuple[None, Tuple]: """Generate an SQLite expression implementing the query. Return (clause, subvals) where clause is a valid sqlite @@ -69,19 +80,19 @@ class Query: """ return None, () - def match(self, item): + def match(self, item: Item): """Check whether this query matches a given Item. Can be used to perform queries on arbitrary sets of Items. """ raise NotImplementedError - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}()" - def __eq__(self, other): + def __eq__(self, other) -> bool: return type(self) == type(other) - def __hash__(self): + def __hash__(self) -> int: return 0 @@ -93,12 +104,12 @@ class FieldQuery(Query): same matching functionality in SQLite. """ - def __init__(self, field, pattern, fast=True): + def __init__(self, field: str, pattern: Optional[str], fast: bool = True): self.field = field self.pattern = pattern self.fast = fast - def col_clause(self): + def col_clause(self) -> Union[None, Tuple]: return None, () def clause(self): @@ -109,51 +120,51 @@ class FieldQuery(Query): return None, () @classmethod - def value_match(cls, pattern, value): + def value_match(cls, pattern: str, value: str): """Determine whether the value matches the pattern. Both arguments are strings. """ raise NotImplementedError() - def match(self, item): + def match(self, item: Model): return self.value_match(self.pattern, item.get(self.field)) - def __repr__(self): + def __repr__(self) -> str: return ("{0.__class__.__name__}({0.field!r}, {0.pattern!r}, " "{0.fast})".format(self)) - def __eq__(self, other): + def __eq__(self, other) -> bool: return super().__eq__(other) and \ self.field == other.field and self.pattern == other.pattern - def __hash__(self): + def __hash__(self) -> int: return hash((self.field, hash(self.pattern))) class MatchQuery(FieldQuery): """A query that looks for exact matches in an item field.""" - def col_clause(self): + def col_clause(self) -> Tuple[str, List[str]]: return self.field + " = ?", [self.pattern] @classmethod - def value_match(cls, pattern, value): + def value_match(cls, pattern: str, value: str) -> bool: return pattern == value class NoneQuery(FieldQuery): """A query that checks whether a field is null.""" - def __init__(self, field, fast=True): + def __init__(self, field, fast: bool = True): super().__init__(field, None, fast) - def col_clause(self): + def col_clause(self) -> Tuple[str, Tuple]: return self.field + " IS NULL", () - def match(self, item): + def match(self, item: 'Item') -> bool: return item.get(self.field) is None - def __repr__(self): + def __repr__(self) -> str: return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self) @@ -163,14 +174,18 @@ class StringFieldQuery(FieldQuery): """ @classmethod - def value_match(cls, pattern, value): + def value_match(cls, pattern: str, value: Any): """Determine whether the value matches the pattern. The value may have any type. """ return cls.string_match(pattern, util.as_string(value)) @classmethod - def string_match(cls, pattern, value): + def string_match( + cls, + pattern: str, + value: str, + ) -> bool: """Determine whether the value matches the pattern. Both arguments are strings. Subclasses implement this method. """ @@ -180,7 +195,7 @@ class StringFieldQuery(FieldQuery): class StringQuery(StringFieldQuery): """A query that matches a whole string in a specific item field.""" - def col_clause(self): + def col_clause(self) -> Tuple[str, List[str]]: search = (self.pattern .replace('\\', '\\\\') .replace('%', '\\%') @@ -190,14 +205,14 @@ class StringQuery(StringFieldQuery): return clause, subvals @classmethod - def string_match(cls, pattern, value): + def string_match(cls, pattern: str, value: str) -> bool: return pattern.lower() == value.lower() class SubstringQuery(StringFieldQuery): """A query that matches a substring in a specific item field.""" - def col_clause(self): + def col_clause(self) -> Tuple[str, List[str]]: pattern = (self.pattern .replace('\\', '\\\\') .replace('%', '\\%') @@ -208,7 +223,7 @@ class SubstringQuery(StringFieldQuery): return clause, subvals @classmethod - def string_match(cls, pattern, value): + def string_match(cls, pattern: str, value: str) -> bool: return pattern.lower() in value.lower() @@ -220,7 +235,7 @@ class RegexpQuery(StringFieldQuery): expression. """ - def __init__(self, field, pattern, fast=True): + def __init__(self, field: str, pattern: str, fast: bool = True): super().__init__(field, pattern, fast) pattern = self._normalize(pattern) try: @@ -235,14 +250,14 @@ class RegexpQuery(StringFieldQuery): return f" regexp({self.field}, ?)", [self.pattern.pattern] @staticmethod - def _normalize(s): + def _normalize(s: str) -> str: """Normalize a Unicode string's representation (used on both patterns and matched values). """ return unicodedata.normalize('NFC', s) @classmethod - def string_match(cls, pattern, value): + def string_match(cls, pattern: Pattern, value: str) -> bool: return pattern.search(cls._normalize(value)) is not None @@ -251,7 +266,12 @@ class BooleanQuery(MatchQuery): string reflecting a boolean. """ - def __init__(self, field, pattern, fast=True): + def __init__( + self, + field: str, + pattern: Union[bool, str], + fast: bool = True, + ): super().__init__(field, pattern, fast) if isinstance(pattern, str): self.pattern = util.str2bool(pattern) @@ -265,7 +285,7 @@ class BytesQuery(MatchQuery): `MatchQuery` when matching on BLOB values. """ - def __init__(self, field, pattern): + def __init__(self, field: str, pattern: Union[bytes, str, memoryview]): super().__init__(field, pattern) # Use a buffer/memoryview representation of the pattern for SQLite @@ -279,7 +299,7 @@ class BytesQuery(MatchQuery): self.buf_pattern = self.pattern self.pattern = bytes(self.pattern) - def col_clause(self): + def col_clause(self) -> Tuple[str, List[memoryview]]: return self.field + " = ?", [self.buf_pattern] @@ -292,7 +312,7 @@ class NumericQuery(FieldQuery): a float. """ - def _convert(self, s): + def _convert(self, s: str) -> Union[float, int, None]: """Convert a string to a numeric type (float or int). Return None if `s` is empty. @@ -309,7 +329,7 @@ class NumericQuery(FieldQuery): except ValueError: raise InvalidQueryArgumentValueError(s, "an int or a float") - def __init__(self, field, pattern, fast=True): + def __init__(self, field: str, pattern: str, fast: bool = True): super().__init__(field, pattern, fast) parts = pattern.split('..', 1) @@ -324,7 +344,7 @@ class NumericQuery(FieldQuery): self.rangemin = self._convert(parts[0]) self.rangemax = self._convert(parts[1]) - def match(self, item): + def match(self, item: 'Item') -> bool: if self.field not in item: return False value = item[self.field] @@ -340,7 +360,7 @@ class NumericQuery(FieldQuery): return False return True - def col_clause(self): + def col_clause(self) -> Tuple[str, Tuple]: if self.point is not None: return self.field + '=?', (self.point,) else: @@ -360,24 +380,27 @@ class CollectionQuery(Query): indexed like a list to access the sub-queries. """ - def __init__(self, subqueries=()): + def __init__(self, subqueries: Sequence = ()): self.subqueries = subqueries # Act like a sequence. - def __len__(self): + def __len__(self) -> int: return len(self.subqueries) def __getitem__(self, key): return self.subqueries[key] - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.subqueries) - def __contains__(self, item): + def __contains__(self, item) -> bool: return item in self.subqueries - def clause_with_joiner(self, joiner): + def clause_with_joiner( + self, + joiner: str, + ) -> Tuple[Optional[str], Collection]: """Return a clause created by joining together the clauses of all subqueries with the string joiner (padded by spaces). """ @@ -393,14 +416,14 @@ class CollectionQuery(Query): clause = (' ' + joiner + ' ').join(clause_parts) return clause, subvals - def __repr__(self): + def __repr__(self) -> str: return "{0.__class__.__name__}({0.subqueries!r})".format(self) - def __eq__(self, other): + def __eq__(self, other) -> bool: return super().__eq__(other) and \ self.subqueries == other.subqueries - def __hash__(self): + def __hash__(self) -> int: """Since subqueries are mutable, this object should not be hashable. However and for conveniences purposes, it can be hashed. """ @@ -413,7 +436,7 @@ class AnyFieldQuery(CollectionQuery): constructor. """ - def __init__(self, pattern, fields, cls): + def __init__(self, pattern, fields, cls: Type[FieldQuery]): self.pattern = pattern self.fields = fields self.query_class = cls @@ -421,26 +444,27 @@ class AnyFieldQuery(CollectionQuery): subqueries = [] for field in self.fields: subqueries.append(cls(field, pattern, True)) + # TYPING ERROR super().__init__(subqueries) - def clause(self): + def clause(self) -> Tuple[Union[str, None], Collection]: return self.clause_with_joiner('or') - def match(self, item): + def match(self, item: 'Item') -> bool: for subq in self.subqueries: if subq.match(item): return True return False - def __repr__(self): + def __repr__(self) -> str: return ("{0.__class__.__name__}({0.pattern!r}, {0.fields!r}, " "{0.query_class.__name__})".format(self)) - def __eq__(self, other): + def __eq__(self, other) -> bool: return super().__eq__(other) and \ self.query_class == other.query_class - def __hash__(self): + def __hash__(self) -> int: return hash((self.pattern, tuple(self.fields), self.query_class)) @@ -448,6 +472,7 @@ class MutableCollectionQuery(CollectionQuery): """A collection query whose subqueries may be modified after the query is initialized. """ + subqueries: MutableMapping def __setitem__(self, key, value): self.subqueries[key] = value @@ -459,20 +484,20 @@ class MutableCollectionQuery(CollectionQuery): class AndQuery(MutableCollectionQuery): """A conjunction of a list of other queries.""" - def clause(self): + def clause(self) -> Tuple[Union[str, None], Collection]: return self.clause_with_joiner('and') - def match(self, item): + def match(self, item) -> bool: return all(q.match(item) for q in self.subqueries) class OrQuery(MutableCollectionQuery): """A conjunction of a list of other queries.""" - def clause(self): + def clause(self) -> Tuple[Union[str, None], Collection]: return self.clause_with_joiner('or') - def match(self, item): + def match(self, item) -> bool: return any(q.match(item) for q in self.subqueries) @@ -493,43 +518,43 @@ class NotQuery(Query): # is handled by match() for slow queries. return clause, subvals - def match(self, item): + def match(self, item) -> bool: return not self.subquery.match(item) - def __repr__(self): + def __repr__(self) -> str: return "{0.__class__.__name__}({0.subquery!r})".format(self) - def __eq__(self, other): + def __eq__(self, other) -> bool: return super().__eq__(other) and \ self.subquery == other.subquery - def __hash__(self): + def __hash__(self) -> int: return hash(('not', hash(self.subquery))) class TrueQuery(Query): """A query that always matches.""" - def clause(self): + def clause(self) -> Tuple[Union[str, None], Collection]: return '1', () - def match(self, item): + def match(self, item) -> bool: return True class FalseQuery(Query): """A query that never matches.""" - def clause(self): + def clause(self) -> Tuple[Union[str, None], Collection]: return '0', () - def match(self, item): + def match(self, item) -> bool: return False # Time/date queries. -def _parse_periods(pattern): +def _parse_periods(pattern: str) -> Tuple['Period', 'Period']: """Parse a string containing two dates separated by two dots (..). Return a pair of `Period` objects. """ @@ -563,7 +588,7 @@ class Period: relative_re = '(?P[+|-]?)(?P[0-9]+)' + \ '(?P[y|m|w|d])' - def __init__(self, date, precision): + def __init__(self, date: datetime, precision: str): """Create a period with the given date (a `datetime` object) and precision (a string, one of "year", "month", "day", "hour", "minute", or "second"). @@ -574,7 +599,7 @@ class Period: self.precision = precision @classmethod - def parse(cls, string): + def parse(cls: Type['Period'], string: str) -> Optional['Period']: """Parse a date and return a `Period` object or `None` if the string is empty, or raise an InvalidQueryArgumentValueError if the string cannot be parsed to a date. @@ -591,7 +616,8 @@ class Period: and a "year" is exactly 365 days. """ - def find_date_and_format(string): + def find_date_and_format(string: str) -> \ + Union[Tuple[None, None], Tuple[datetime, int]]: for ord, format in enumerate(cls.date_formats): for format_option in format: try: @@ -628,7 +654,7 @@ class Period: precision = cls.precisions[ordinal] return cls(date, precision) - def open_right_endpoint(self): + def open_right_endpoint(self) -> datetime: """Based on the precision, convert the period to a precise `datetime` for use as a right endpoint in a right-open interval. """ @@ -660,7 +686,7 @@ class DateInterval: A right endpoint of None means towards infinity. """ - def __init__(self, start, end): + def __init__(self, start: Optional[datetime], end: Optional[datetime]): if start is not None and end is not None and not start < end: raise ValueError("start date {} is not before end date {}" .format(start, end)) @@ -668,21 +694,21 @@ class DateInterval: self.end = end @classmethod - def from_periods(cls, start, end): + def from_periods(cls, start: Period, end: Period) -> 'DateInterval': """Create an interval with two Periods as the endpoints. """ end_date = end.open_right_endpoint() if end is not None else None start_date = start.date if start is not None else None return cls(start_date, end_date) - def contains(self, date): + def contains(self, date: datetime) -> bool: if self.start is not None and date < self.start: return False if self.end is not None and date >= self.end: return False return True - def __str__(self): + def __str__(self) -> str: return f'[{self.start}, {self.end})' @@ -696,12 +722,12 @@ class DateQuery(FieldQuery): using an ellipsis interval syntax similar to that of NumericQuery. """ - def __init__(self, field, pattern, fast=True): + def __init__(self, field, pattern, fast: bool = True): super().__init__(field, pattern, fast) start, end = _parse_periods(pattern) self.interval = DateInterval.from_periods(start, end) - def match(self, item): + def match(self, item: 'Item') -> bool: if self.field not in item: return False timestamp = float(item[self.field]) @@ -710,7 +736,7 @@ class DateQuery(FieldQuery): _clause_tmpl = "{0} {1} ?" - def col_clause(self): + def col_clause(self) -> Tuple[Union[str, None], Collection]: clause_parts = [] subvals = [] @@ -742,7 +768,7 @@ class DurationQuery(NumericQuery): or M:SS time interval. """ - def _convert(self, s): + def _convert(self, s: str) -> Optional[float]: """Convert a M:SS or numeric string to a float. Return None if `s` is empty. @@ -768,27 +794,27 @@ class Sort: the item database. """ - def order_clause(self): + def order_clause(self) -> None: """Generates a SQL fragment to be used in a ORDER BY clause, or None if no fragment is used (i.e., this is a slow sort). """ return None - def sort(self, items): + def sort(self, items: List) -> List: """Sort the list of objects and return a list. """ return sorted(items) - def is_slow(self): + def is_slow(self) -> bool: """Indicate whether this query is *slow*, meaning that it cannot be executed in SQL and must be executed in Python. """ return False - def __hash__(self): + def __hash__(self) -> int: return 0 - def __eq__(self, other): + def __eq__(self, other) -> bool: return type(self) == type(other) @@ -796,13 +822,13 @@ class MultipleSort(Sort): """Sort that encapsulates multiple sub-sorts. """ - def __init__(self, sorts=None): + def __init__(self, sorts: Optional[List[Sort]] = None): self.sorts = sorts or [] - def add_sort(self, sort): + def add_sort(self, sort: Sort): self.sorts.append(sort) - def _sql_sorts(self): + def _sql_sorts(self) -> List[Sort]: """Return the list of sub-sorts for which we can be (at least partially) fast. @@ -819,15 +845,16 @@ class MultipleSort(Sort): sql_sorts.reverse() return sql_sorts - def order_clause(self): + def order_clause(self) -> str: order_strings = [] for sort in self._sql_sorts(): order = sort.order_clause() order_strings.append(order) + # TYPING ERROR return ", ".join(order_strings) - def is_slow(self): + def is_slow(self) -> bool: for sort in self.sorts: if sort.is_slow(): return True @@ -865,17 +892,22 @@ class FieldSort(Sort): any kind). """ - def __init__(self, field, ascending=True, case_insensitive=True): + def __init__( + self, + field, + ascending: bool = True, + case_insensitive: bool = True, + ): self.field = field self.ascending = ascending self.case_insensitive = case_insensitive - def sort(self, objs): + def sort(self, objs: Collection): # TODO: Conversion and null-detection here. In Python 3, # comparisons with None fail. We should also support flexible # attributes with different types without falling over. - def key(item): + def key(item: 'Item'): field_val = item.get(self.field, '') if self.case_insensitive and isinstance(field_val, str): field_val = field_val.lower() @@ -883,17 +915,17 @@ class FieldSort(Sort): return sorted(objs, key=key, reverse=not self.ascending) - def __repr__(self): + def __repr__(self) -> str: return '<{}: {}{}>'.format( type(self).__name__, self.field, '+' if self.ascending else '-', ) - def __hash__(self): + def __hash__(self) -> int: return hash((self.field, self.ascending)) - def __eq__(self, other): + def __eq__(self, other) -> bool: return super().__eq__(other) and \ self.field == other.field and \ self.ascending == other.ascending @@ -903,7 +935,7 @@ class FixedFieldSort(FieldSort): """Sort object to sort on a fixed field. """ - def order_clause(self): + def order_clause(self) -> str: order = "ASC" if self.ascending else "DESC" if self.case_insensitive: field = '(CASE ' \ @@ -920,24 +952,24 @@ class SlowFieldSort(FieldSort): i.e., a computed or flexible field. """ - def is_slow(self): + def is_slow(self) -> bool: return True class NullSort(Sort): """No sorting. Leave results unsorted.""" - def sort(self, items): + def sort(self, items: List) -> List: return items - def __nonzero__(self): + def __nonzero__(self) -> bool: return self.__bool__() - def __bool__(self): + def __bool__(self) -> bool: return False - def __eq__(self, other): + def __eq__(self, other) -> bool: return type(self) == type(other) or other is None - def __hash__(self): + def __hash__(self) -> int: return 0 diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index 3bf02e4d2..00d393cf8 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -17,7 +17,11 @@ import re import itertools -from . import query +from typing import Dict, Type, Tuple, Optional, Collection, List, \ + Sequence + +from . import query, Model +from .query import Sort PARSE_QUERY_PART_REGEX = re.compile( # Non-capturing optional segment for the keyword. @@ -34,8 +38,12 @@ PARSE_QUERY_PART_REGEX = re.compile( ) -def parse_query_part(part, query_classes={}, prefixes={}, - default_class=query.SubstringQuery): +def parse_query_part( + part: str, + query_classes: Dict = {}, + prefixes: Dict = {}, + default_class: Type[query.SubstringQuery] = query.SubstringQuery, +) -> Tuple[Optional[str], str, Type[query.Query], bool]: """Parse a single *query part*, which is a chunk of a complete query string representing a single criterion. @@ -100,7 +108,11 @@ def parse_query_part(part, query_classes={}, prefixes={}, return key, term, query_class, negate -def construct_query_part(model_cls, prefixes, query_part): +def construct_query_part( + model_cls: Type[Model], + prefixes: Dict, + query_part: str, +) -> query.Query: """Parse a *query part* string and return a :class:`Query` object. :param model_cls: The :class:`Model` class that this is a query for. @@ -158,7 +170,13 @@ def construct_query_part(model_cls, prefixes, query_part): return out_query -def query_from_strings(query_cls, model_cls, prefixes, query_parts): +# TYPING ERROR +def query_from_strings( + query_cls: Type[query.Query], + model_cls: Type[Model], + prefixes: Dict, + query_parts: Collection[str], +) -> query.Query: """Creates a collection query of type `query_cls` from a list of strings in the format used by parse_query_part. `model_cls` determines how queries are constructed from strings. @@ -171,7 +189,11 @@ def query_from_strings(query_cls, model_cls, prefixes, query_parts): return query_cls(subqueries) -def construct_sort_part(model_cls, part, case_insensitive=True): +def construct_sort_part( + model_cls: Type[Model], + part: str, + case_insensitive: bool = True, +) -> Sort: """Create a `Sort` from a single string criterion. `model_cls` is the `Model` being queried. `part` is a single string @@ -197,7 +219,11 @@ def construct_sort_part(model_cls, part, case_insensitive=True): return sort -def sort_from_strings(model_cls, sort_parts, case_insensitive=True): +def sort_from_strings( + model_cls: Type[Model], + sort_parts: Sequence[str], + case_insensitive: bool = True, +) -> Sort: """Create a `Sort` from a list of sort criteria (strings). """ if not sort_parts: @@ -212,8 +238,12 @@ def sort_from_strings(model_cls, sort_parts, case_insensitive=True): return sort -def parse_sorted_query(model_cls, parts, prefixes={}, - case_insensitive=True): +def parse_sorted_query( + model_cls: Type[Model], + parts: List[str], + prefixes: Dict = {}, + case_insensitive: bool = True, +) -> Tuple[query.Query, Sort]: """Given a list of strings, create the `Query` and `Sort` that they represent. """ diff --git a/beets/dbcore/types.py b/beets/dbcore/types.py index 8c8bfa3f6..ac8dd762b 100644 --- a/beets/dbcore/types.py +++ b/beets/dbcore/types.py @@ -15,6 +15,7 @@ """Representation of type information for DBCore model fields. """ +from typing import Union, Any, Callable from . import query from beets.util import str2bool @@ -35,7 +36,7 @@ class Type: """The `Query` subclass to be used when querying the field. """ - model_type = str + model_type: Callable[[Any], str] = str """The Python type that is used to represent the value in the model. The model is guaranteed to return a value of this type if the field @@ -44,12 +45,12 @@ class Type: """ @property - def null(self): + def null(self) -> model_type: """The value to be exposed when the underlying value is None. """ return self.model_type() - def format(self, value): + def format(self, value: model_type) -> str: """Given a value of this type, produce a Unicode string representing the value. This is used in template evaluation. """ @@ -63,7 +64,7 @@ class Type: return str(value) - def parse(self, string): + def parse(self, string: str) -> model_type: """Parse a (possibly human-written) string and return the indicated value of this type. """ @@ -72,11 +73,12 @@ class Type: except ValueError: return self.null - def normalize(self, value): + def normalize(self, value: Union[None, int, float, bytes]) -> model_type: """Given a value that will be assigned into a field of this type, normalize the value to have the appropriate type. This base implementation only reinterprets `None`. """ + # TYPING ERROR if value is None: return self.null else: @@ -84,7 +86,10 @@ class Type: # `self.model_type(value)` return value - def from_sql(self, sql_value): + def from_sql( + self, + sql_value: Union[None, int, float, str, bytes], + ) -> model_type: """Receives the value stored in the SQL backend and return the value to be stored in the model. @@ -105,7 +110,7 @@ class Type: else: return self.normalize(sql_value) - def to_sql(self, model_value): + def to_sql(self, model_value: Any) -> Union[None, int, float, str, bytes]: """Convert a value as stored in the model object to a value used by the database adapter. """ @@ -125,7 +130,7 @@ class Integer(Type): query = query.NumericQuery model_type = int - def normalize(self, value): + def normalize(self, value: str) -> Union[int, str]: try: return self.model_type(round(float(value))) except ValueError: @@ -138,10 +143,10 @@ class PaddedInt(Integer): """An integer field that is formatted with a given number of digits, padded with zeroes. """ - def __init__(self, digits): + def __init__(self, digits: int): self.digits = digits - def format(self, value): + def format(self, value: int) -> str: return '{0:0{1}d}'.format(value or 0, self.digits) @@ -155,11 +160,11 @@ class ScaledInt(Integer): """An integer whose formatting operation scales the number by a constant and adds a suffix. Good for units with large magnitudes. """ - def __init__(self, unit, suffix=''): + def __init__(self, unit: int, suffix: str = ''): self.unit = unit self.suffix = suffix - def format(self, value): + def format(self, value: int) -> str: return '{}{}'.format((value or 0) // self.unit, self.suffix) @@ -169,7 +174,7 @@ class Id(Integer): """ null = None - def __init__(self, primary=True): + def __init__(self, primary: bool = True): if primary: self.sql = 'INTEGER PRIMARY KEY' @@ -182,10 +187,10 @@ class Float(Type): query = query.NumericQuery model_type = float - def __init__(self, digits=1): + def __init__(self, digits: int = 1): self.digits = digits - def format(self, value): + def format(self, value: float) -> str: return '{0:.{1}f}'.format(value or 0, self.digits) @@ -201,7 +206,7 @@ class String(Type): sql = 'TEXT' query = query.SubstringQuery - def normalize(self, value): + def normalize(self, value: str) -> str: if value is None: return self.null else: @@ -236,10 +241,10 @@ class Boolean(Type): query = query.BooleanQuery model_type = bool - def format(self, value): + def format(self, value: bool) -> str: return str(bool(value)) - def parse(self, string): + def parse(self, string: str) -> bool: return str2bool(string) diff --git a/beets/util/__init__.py b/beets/util/__init__.py index 2319890a3..2dff3ed7c 100644 --- a/beets/util/__init__.py +++ b/beets/util/__init__.py @@ -23,11 +23,17 @@ import shutil import fnmatch import functools from collections import Counter, namedtuple +from logging import Logger from multiprocessing.pool import ThreadPool import traceback import subprocess import platform import shlex +from typing import Callable, List, Optional, Sequence, Pattern, \ + Tuple, MutableSequence, AnyStr, TypeVar, Generator, Any, \ + Iterable, Union +from typing_extensions import TypeAlias + from beets.util import hidden from unidecode import unidecode from enum import Enum @@ -35,6 +41,8 @@ from enum import Enum MAX_FILENAME_LENGTH = 200 WINDOWS_MAGIC_PREFIX = '\\\\?\\' +T = TypeVar('T') +Bytes_or_String: TypeAlias = Union[str, bytes] class HumanReadableException(Exception): @@ -135,7 +143,7 @@ class MoveOperation(Enum): REFLINK_AUTO = 5 -def normpath(path): +def normpath(path: bytes) -> bytes: """Provide the canonical form of the path suitable for storing in the database. """ @@ -144,11 +152,11 @@ def normpath(path): return bytestring_path(path) -def ancestry(path): +def ancestry(path: bytes) -> List[str]: """Return a list consisting of path's parent directory, its grandparent, and so on. For instance: - >>> ancestry('/a/b/c') + >>> ancestry(b'/a/b/c') ['/', '/a', '/a/b'] The argument should *not* be the result of a call to `syspath`. @@ -168,7 +176,12 @@ def ancestry(path): return out -def sorted_walk(path, ignore=(), ignore_hidden=False, logger=None): +def sorted_walk( + path: AnyStr, + ignore: Sequence = (), + ignore_hidden: bool = False, + logger: Optional[Logger] = None, +) -> Generator[Tuple, None, None]: """Like `os.walk`, but yields things in case-insensitive sorted, breadth-first order. Directory and file names matching any glob pattern in `ignore` are skipped. If `logger` is provided, then @@ -225,14 +238,14 @@ def sorted_walk(path, ignore=(), ignore_hidden=False, logger=None): yield from sorted_walk(cur, ignore, ignore_hidden, logger) -def path_as_posix(path): +def path_as_posix(path: bytes) -> bytes: """Return the string representation of the path with forward (/) slashes. """ return path.replace(b'\\', b'/') -def mkdirall(path): +def mkdirall(path: bytes): """Make all the enclosing directories of path (like mkdir -p on the parent). """ @@ -245,7 +258,7 @@ def mkdirall(path): traceback.format_exc()) -def fnmatch_all(names, patterns): +def fnmatch_all(names: Sequence[bytes], patterns: Sequence[bytes]) -> bool: """Determine whether all strings in `names` match at least one of the `patterns`, which should be shell glob expressions. """ @@ -260,7 +273,11 @@ def fnmatch_all(names, patterns): return True -def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')): +def prune_dirs( + path: str, + root: Optional[Bytes_or_String] = None, + clutter: Sequence[str] = ('.DS_Store', 'Thumbs.db'), +): """If path is an empty directory, then remove it. Recursively remove path's ancestry up to root (which is never removed) where there are empty directories. If path is not contained in root, then nothing is @@ -291,7 +308,7 @@ def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')): if not os.path.exists(directory): # Directory gone already. continue - clutter = [bytestring_path(c) for c in clutter] + clutter: List[bytes] = [bytestring_path(c) for c in clutter] match_paths = [bytestring_path(d) for d in os.listdir(directory)] try: if fnmatch_all(match_paths, clutter): @@ -303,10 +320,10 @@ def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')): break -def components(path): +def components(path: AnyStr) -> MutableSequence[AnyStr]: """Return a list of the path components in path. For instance: - >>> components('/a/b/c') + >>> components(b'/a/b/c') ['a', 'b', 'c'] The argument should *not* be the result of a call to `syspath`. @@ -327,14 +344,14 @@ def components(path): return comps -def arg_encoding(): +def arg_encoding() -> str: """Get the encoding for command-line arguments (and other OS locale-sensitive strings). """ return sys.getfilesystemencoding() -def _fsencoding(): +def _fsencoding() -> str: """Get the system's filesystem encoding. On Windows, this is always UTF-8 (not MBCS). """ @@ -349,9 +366,10 @@ def _fsencoding(): return encoding -def bytestring_path(path): +def bytestring_path(path: Bytes_or_String) -> bytes: """Given a path, which is either a bytes or a unicode, returns a str - path (ensuring that we never deal with Unicode pathnames). + path (ensuring that we never deal with Unicode pathnames). Path should be + bytes but has safeguards for strings to be converted. """ # Pass through bytestrings. if isinstance(path, bytes): @@ -370,10 +388,10 @@ def bytestring_path(path): return path.encode('utf-8') -PATH_SEP = bytestring_path(os.sep) +PATH_SEP: bytes = bytestring_path(os.sep) -def displayable_path(path, separator='; '): +def displayable_path(path: bytes, separator: str = '; ') -> str: """Attempts to decode a bytestring path to a unicode object for the purpose of displaying it to the user. If the `path` argument is a list or a tuple, the elements are joined with `separator`. @@ -392,7 +410,7 @@ def displayable_path(path, separator='; '): return path.decode('utf-8', 'ignore') -def syspath(path, prefix=True): +def syspath(path: bytes, prefix: bool = True) -> Bytes_or_String: """Convert a path for use by the operating system. In particular, paths on Windows must receive a magic prefix and must be converted to Unicode before they are sent to the OS. To disable the magic @@ -412,6 +430,7 @@ def syspath(path, prefix=True): except UnicodeError: # The encoding should always be MBCS, Windows' broken # Unicode representation. + assert isinstance(path, bytes) encoding = sys.getfilesystemencoding() or sys.getdefaultencoding() path = path.decode(encoding, 'replace') @@ -426,14 +445,14 @@ def syspath(path, prefix=True): return path -def samefile(p1, p2): +def samefile(p1: bytes, p2: bytes) -> bool: """Safer equality for paths.""" if p1 == p2: return True return shutil._samefile(syspath(p1), syspath(p2)) -def remove(path, soft=True): +def remove(path: bytes, soft: bool = True): """Remove the file. If `soft`, then no error will be raised if the file does not exist. """ @@ -446,7 +465,7 @@ def remove(path, soft=True): raise FilesystemError(exc, 'delete', (path,), traceback.format_exc()) -def copy(path, dest, replace=False): +def copy(path: bytes, dest: bytes, replace: bool = False): """Copy a plain file. Permissions are not copied. If `dest` already exists, raises a FilesystemError unless `replace` is True. Has no effect if `path` is the same as `dest`. Paths are translated to @@ -465,7 +484,7 @@ def copy(path, dest, replace=False): traceback.format_exc()) -def move(path, dest, replace=False): +def move(path: bytes, dest: bytes, replace: bool = False): """Rename a file. `dest` may not be a directory. If `dest` already exists, raises an OSError unless `replace` is True. Has no effect if `path` is the same as `dest`. If the paths are on different @@ -515,7 +534,7 @@ def move(path, dest, replace=False): os.remove(tmp) -def link(path, dest, replace=False): +def link(path: bytes, dest: bytes, replace: bool = False): """Create a symbolic link from path to `dest`. Raises an OSError if `dest` already exists, unless `replace` is True. Does nothing if `path` == `dest`. @@ -536,7 +555,7 @@ def link(path, dest, replace=False): traceback.format_exc()) -def hardlink(path, dest, replace=False): +def hardlink(path: bytes, dest: bytes, replace: bool = False): """Create a hard link from path to `dest`. Raises an OSError if `dest` already exists, unless `replace` is True. Does nothing if `path` == `dest`. @@ -560,7 +579,12 @@ def hardlink(path, dest, replace=False): traceback.format_exc()) -def reflink(path, dest, replace=False, fallback=False): +def reflink( + path: bytes, + dest: bytes, + replace: bool = False, + fallback: bool = False, +): """Create a reflink from `dest` to `path`. Raise an `OSError` if `dest` already exists, unless `replace` is @@ -589,7 +613,7 @@ def reflink(path, dest, replace=False, fallback=False): 'link', (path, dest), traceback.format_exc()) -def unique_path(path): +def unique_path(path: bytes) -> bytes: """Returns a version of ``path`` that does not exist on the filesystem. Specifically, if ``path` itself already exists, then something unique is appended to the path. @@ -616,7 +640,7 @@ def unique_path(path): # Unix. They are forbidden here because they cause problems on Samba # shares, which are sufficiently common as to cause frequent problems. # https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx -CHAR_REPLACE = [ +CHAR_REPLACE: List[Tuple[Pattern, str]] = [ (re.compile(r'[\\/]'), '_'), # / and \ -- forbidden everywhere. (re.compile(r'^\.'), '_'), # Leading dot (hidden files on Unix). (re.compile(r'[\x00-\x1f]'), ''), # Control characters. @@ -626,7 +650,10 @@ CHAR_REPLACE = [ ] -def sanitize_path(path, replacements=None): +def sanitize_path( + path: str, + replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]] = None, +) -> str: """Takes a path (as a Unicode string) and makes sure that it is legal. Returns a new path. Only works with fragments; won't work reliably on Windows when a path begins with a drive letter. Path @@ -647,7 +674,7 @@ def sanitize_path(path, replacements=None): return os.path.join(*comps) -def truncate_path(path, length=MAX_FILENAME_LENGTH): +def truncate_path(path: AnyStr, length: int = MAX_FILENAME_LENGTH) -> AnyStr: """Given a bytestring path or a Unicode path fragment, truncate the components to a legal length. In the last component, the extension is preserved. @@ -664,7 +691,13 @@ def truncate_path(path, length=MAX_FILENAME_LENGTH): return os.path.join(*out) -def _legalize_stage(path, replacements, length, extension, fragment): +def _legalize_stage( + path: str, + replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]], + length: int, + extension: str, + fragment: bool, +) -> Tuple[Bytes_or_String, bool]: """Perform a single round of path legalization steps (sanitation/replacement, encoding from Unicode to bytes, extension-appending, and truncation). Return the path (Unicode if @@ -676,7 +709,7 @@ def _legalize_stage(path, replacements, length, extension, fragment): # Encode for the filesystem. if not fragment: - path = bytestring_path(path) + path = bytestring_path(path) # type: ignore # Preserve extension. path += extension.lower() @@ -688,7 +721,13 @@ def _legalize_stage(path, replacements, length, extension, fragment): return path, path != pre_truncate_path -def legalize_path(path, replacements, length, extension, fragment): +def legalize_path( + path: str, + replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]], + length: int, + extension: bytes, + fragment: bool, +) -> Tuple[Union[Bytes_or_String, bool]]: """Given a path-like Unicode string, produce a legal path. Return the path and a flag indicating whether some replacements had to be ignored (see below). @@ -736,7 +775,7 @@ def legalize_path(path, replacements, length, extension, fragment): return second_stage_path, retruncated -def py3_path(path): +def py3_path(path: AnyStr) -> str: """Convert a bytestring path to Unicode. This helps deal with APIs on Python 3 that *only* accept Unicode @@ -751,12 +790,12 @@ def py3_path(path): return os.fsdecode(path) -def str2bool(value): +def str2bool(value: str) -> bool: """Returns a boolean reflecting a human-entered string.""" return value.lower() in ('yes', '1', 'true', 't', 'y') -def as_string(value): +def as_string(value: Any) -> str: """Convert a value to a Unicode object for matching with a query. None becomes the empty string. Bytestrings are silently decoded. """ @@ -770,7 +809,7 @@ def as_string(value): return str(value) -def plurality(objs): +def plurality(objs: Sequence[T]) -> T: """Given a sequence of hashble objects, returns the object that is most common in the set and the its number of appearance. The sequence must contain at least one object. @@ -781,7 +820,7 @@ def plurality(objs): return c.most_common(1)[0] -def cpu_count(): +def cpu_count() -> int: """Return the number of hardware thread contexts (cores or SMT threads) in the system. """ @@ -812,13 +851,12 @@ def cpu_count(): return 1 -def convert_command_args(args): +def convert_command_args(args: List[bytes]) -> List[str]: """Convert command arguments, which may either be `bytes` or `str` - objects, to uniformly surrogate-escaped strings. - """ + objects, to uniformly surrogate-escaped strings. """ assert isinstance(args, list) - def convert(arg): + def convert(arg) -> str: if isinstance(arg, bytes): return os.fsdecode(arg) return arg @@ -830,7 +868,10 @@ def convert_command_args(args): CommandOutput = namedtuple("CommandOutput", ("stdout", "stderr")) -def command_output(cmd, shell=False): +def command_output( + cmd: List[Bytes_or_String], + shell: bool = False, +) -> CommandOutput: """Runs the command and returns its output after it has exited. Returns a CommandOutput. The attributes ``stdout`` and ``stderr`` contain @@ -870,7 +911,7 @@ def command_output(cmd, shell=False): return CommandOutput(stdout, stderr) -def max_filename_length(path, limit=MAX_FILENAME_LENGTH): +def max_filename_length(path: AnyStr, limit=MAX_FILENAME_LENGTH) -> int: """Attempt to determine the maximum filename length for the filesystem containing `path`. If the value is greater than `limit`, then `limit` is used instead (to prevent errors when a filesystem @@ -887,7 +928,7 @@ def max_filename_length(path, limit=MAX_FILENAME_LENGTH): return limit -def open_anything(): +def open_anything() -> str: """Return the system command that dispatches execution to the correct program. """ @@ -901,7 +942,7 @@ def open_anything(): return base_cmd -def editor_command(): +def editor_command() -> str: """Get a command for opening a text file. Use the `EDITOR` environment variable by default. If it is not @@ -914,7 +955,7 @@ def editor_command(): return open_anything() -def interactive_open(targets, command): +def interactive_open(targets: Sequence[str], command: str): """Open the files in `targets` by `exec`ing a new `command`, given as a Unicode string. (The new program takes over, and Python execution ends: this does not fork a subprocess.) @@ -936,7 +977,7 @@ def interactive_open(targets, command): return os.execlp(*args) -def case_sensitive(path): +def case_sensitive(path: bytes) -> bool: """Check whether the filesystem at the given path is case sensitive. To work best, the path should point to a file or a directory. If the path @@ -984,7 +1025,7 @@ def case_sensitive(path): return not os.path.samefile(lower_sys, upper_sys) -def raw_seconds_short(string): +def raw_seconds_short(string: str) -> float: """Formats a human-readable M:SS string as a float (number of seconds). Raises ValueError if the conversion cannot take place due to `string` not @@ -997,7 +1038,7 @@ def raw_seconds_short(string): return float(minutes * 60 + seconds) -def asciify_path(path, sep_replace): +def asciify_path(path: str, sep_replace: str) -> str: """Decodes all unicode characters in a path into ASCII equivalents. Substitutions are provided by the unidecode module. Path separators in the @@ -1010,7 +1051,7 @@ def asciify_path(path, sep_replace): # if this platform has an os.altsep, change it to os.sep. if os.altsep: path = path.replace(os.altsep, os.sep) - path_components = path.split(os.sep) + path_components: List[Bytes_or_String] = path.split(os.sep) for index, item in enumerate(path_components): path_components[index] = unidecode(item).replace(os.sep, sep_replace) if os.altsep: @@ -1021,7 +1062,7 @@ def asciify_path(path, sep_replace): return os.sep.join(path_components) -def par_map(transform, items): +def par_map(transform: Callable, items: Iterable): """Apply the function `transform` to all the elements in the iterable `items`, like `map(transform, items)` but with no return value. @@ -1035,7 +1076,7 @@ def par_map(transform, items): pool.join() -def lazy_property(func): +def lazy_property(func: Callable) -> Callable: """A decorator that creates a lazily evaluated property. On first access, the property is assigned the return value of `func`. This first value is stored, so that future accesses do not have to evaluate `func` again. diff --git a/setup.py b/setup.py index 26cebaa87..729d5003f 100755 --- a/setup.py +++ b/setup.py @@ -92,6 +92,7 @@ setup( 'confuse>=1.5.0', 'munkres>=1.0.0', 'jellyfish', + 'typing_extensions', ] + ( # Support for ANSI console colors on Windows. ['colorama'] if (sys.platform == 'win32') else []