Merge pull request #4495 from Serene-Arc/dbcore_typing

This commit is contained in:
Serene 2023-06-02 16:23:31 +10:00 committed by GitHub
commit f68ff90899
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 395 additions and 231 deletions

View file

@ -15,6 +15,7 @@
"""The central Model and Database constructs for DBCore.
"""
from __future__ import annotations
import time
import os
import re
@ -22,6 +23,10 @@ from collections import defaultdict
import threading
import sqlite3
import contextlib
from sqlite3 import Connection
from types import TracebackType
from typing import Iterable, Type, List, Tuple, Optional, Union, \
Dict, Any, Generator, Iterator, Callable
from unidecode import unidecode
@ -29,9 +34,15 @@ import beets
from beets.util import functemplate
from beets.util import py3_path
from beets.dbcore import types
from .query import MatchQuery, NullSort, TrueQuery, AndQuery
from .query import MatchQuery, NullSort, TrueQuery, AndQuery, Query, \
FieldQuery, Sort
from collections.abc import Mapping
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from beets.library import LibModel
from ..util.functemplate import Template
class DBAccessError(Exception):
"""The SQLite database became inaccessible.
@ -58,7 +69,12 @@ class FormattedMapping(Mapping):
ALL_KEYS = '*'
def __init__(self, model, included_keys=ALL_KEYS, for_path=False):
def __init__(
self,
model: 'Model',
included_keys: str = ALL_KEYS,
for_path: bool = False,
):
self.for_path = for_path
self.model = model
if included_keys == self.ALL_KEYS:
@ -73,10 +89,10 @@ class FormattedMapping(Mapping):
else:
raise KeyError(key)
def __iter__(self):
def __iter__(self) -> Iterable[str]:
return iter(self.model_keys)
def __len__(self):
def __len__(self) -> int:
return len(self.model_keys)
def get(self, key, default=None):
@ -107,7 +123,7 @@ class LazyConvertDict:
"""Lazily convert types for attributes fetched from the database
"""
def __init__(self, model_cls):
def __init__(self, model_cls: 'Model'):
"""Initialize the object empty
"""
self.data = {}
@ -148,12 +164,12 @@ class LazyConvertDict:
if key in self.data:
del self.data[key]
def keys(self):
def keys(self) -> List[str]:
"""Get a list of available field names for this object.
"""
return list(self._converted.keys()) + list(self.data.keys())
def copy(self):
def copy(self) -> 'LazyConvertDict':
"""Create a copy of the object.
"""
new = self.__class__(self.model_cls)
@ -169,7 +185,7 @@ class LazyConvertDict:
for key, value in values.items():
self[key] = value
def items(self):
def items(self) -> Iterable[Tuple[str, Any]]:
"""Iterate over (key, value) pairs that this object contains.
Computed fields are not included.
"""
@ -185,12 +201,12 @@ class LazyConvertDict:
else:
return default
def __contains__(self, key):
def __contains__(self, key) -> bool:
"""Determine whether `key` is an attribute on this object.
"""
return key in self.keys()
def __iter__(self):
def __iter__(self) -> Iterable[str]:
"""Iterate over the available field names (excluding computed
fields).
"""
@ -269,14 +285,14 @@ class Model:
"""
@classmethod
def _getters(cls):
def _getters(cls: Type['Model']):
"""Return a mapping from field names to getter functions.
"""
# We could cache this if it becomes a performance problem to
# gather the getter mapping every time.
raise NotImplementedError()
def _template_funcs(self):
def _template_funcs(self) -> Mapping[str, Callable[[str], str]]:
"""Return a mapping from function names to text-transformer
functions.
"""
@ -285,7 +301,7 @@ class Model:
# Basic operation.
def __init__(self, db=None, **values):
def __init__(self, db: Optional['Database'] = None, **values):
"""Create a new object with an optional Database association and
initial field values.
"""
@ -299,7 +315,12 @@ class Model:
self.clear_dirty()
@classmethod
def _awaken(cls, db=None, fixed_values={}, flex_values={}):
def _awaken(
cls: Type['Model'],
db: 'Database' = None,
fixed_values: Mapping = {},
flex_values: Mapping = {},
) -> 'Model':
"""Create an object with values drawn from the database.
This is a performance optimization: the checks involved with
@ -312,7 +333,7 @@ class Model:
return obj
def __repr__(self):
def __repr__(self) -> str:
return '{}({})'.format(
type(self).__name__,
', '.join(f'{k}={v!r}' for k, v in dict(self).items()),
@ -326,7 +347,7 @@ class Model:
if self._db:
self._revision = self._db.revision
def _check_db(self, need_id=True):
def _check_db(self, need_id: bool = True):
"""Ensure that this object is associated with a database row: it
has a reference to a database (`_db`) and an id. A ValueError
exception is raised otherwise.
@ -338,7 +359,7 @@ class Model:
if need_id and not self.id:
raise ValueError('{} has no id'.format(type(self).__name__))
def copy(self):
def copy(self) -> 'Model':
"""Create a copy of the model object.
The field values and other state is duplicated, but the new copy
@ -356,7 +377,7 @@ class Model:
# Essential field accessors.
@classmethod
def _type(cls, key):
def _type(cls, key) -> types.Type:
"""Get the type of a field, a `Type` instance.
If the field has no explicit type, it is given the base `Type`,
@ -364,7 +385,7 @@ class Model:
"""
return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT
def _get(self, key, default=None, raise_=False):
def _get(self, key, default: bool = None, raise_: bool = False):
"""Get the value for a field, or `default`. Alternatively,
raise a KeyError if the field is not available.
"""
@ -431,7 +452,7 @@ class Model:
else:
raise KeyError(f'no such field {key}')
def keys(self, computed=False):
def keys(self, computed: bool = False):
"""Get a list of available field names for this object. The
`computed` parameter controls whether computed (plugin-provided)
fields are included in the key list.
@ -457,19 +478,19 @@ class Model:
for key, value in values.items():
self[key] = value
def items(self):
def items(self) -> Iterator[Tuple[str, Any]]:
"""Iterate over (key, value) pairs that this object contains.
Computed fields are not included.
"""
for key in self:
yield key, self[key]
def __contains__(self, key):
def __contains__(self, key) -> bool:
"""Determine whether `key` is an attribute on this object.
"""
return key in self.keys(computed=True)
def __iter__(self):
def __iter__(self) -> Iterable[str]:
"""Iterate over the available field names (excluding computed
fields).
"""
@ -500,7 +521,7 @@ class Model:
# Database interaction (CRUD methods).
def store(self, fields=None):
def store(self, fields: bool = None):
"""Save the object's metadata into the library database.
:param fields: the fields to be stored. If not specified, all fields
will be.
@ -581,7 +602,7 @@ class Model:
(self.id,)
)
def add(self, db=None):
def add(self, db: Optional['Database'] = None):
"""Add the object to the library database. This object must be
associated with a database; you can provide one via the `db`
parameter or use the currently associated database.
@ -610,13 +631,21 @@ class Model:
_formatter = FormattedMapping
def formatted(self, included_keys=_formatter.ALL_KEYS, for_path=False):
def formatted(
self,
included_keys: str = _formatter.ALL_KEYS,
for_path: bool = False,
):
"""Get a mapping containing all values on this object formatted
as human-readable unicode strings.
"""
return self._formatter(self, included_keys, for_path)
def evaluate_template(self, template, for_path=False):
def evaluate_template(
self,
template: Union[str, Template],
for_path: bool = False,
) -> str:
"""Evaluate a template (a string or a `Template` object) using
the object's fields. If `for_path` is true, then no new path
separators will be added to the template.
@ -630,7 +659,7 @@ class Model:
# Parsing.
@classmethod
def _parse(cls, key, string):
def _parse(cls, key, string: str) -> Any:
"""Parse a string as a value for the given key.
"""
if not isinstance(string, str):
@ -638,7 +667,7 @@ class Model:
return cls._type(key).parse(string)
def set_parse(self, key, string):
def set_parse(self, key, string: str):
"""Set the object's key to a value represented by a string.
"""
self[key] = self._parse(key, string)
@ -646,12 +675,21 @@ class Model:
# Convenient queries.
@classmethod
def field_query(cls, field, pattern, query_cls=MatchQuery):
def field_query(
cls,
field,
pattern,
query_cls: Type[FieldQuery] = MatchQuery,
) -> FieldQuery:
"""Get a `FieldQuery` for this model."""
return query_cls(field, pattern, field in cls._fields)
@classmethod
def all_fields_query(cls, pats, query_cls=MatchQuery):
def all_fields_query(
cls: Type['Model'],
pats: Mapping,
query_cls: Type[FieldQuery] = MatchQuery,
):
"""Get a query that matches many fields with different patterns.
`pats` should be a mapping from field names to patterns. The
@ -670,8 +708,15 @@ class Results:
constructs LibModel objects that reflect database rows.
"""
def __init__(self, model_class, rows, db, flex_rows,
query=None, sort=None):
def __init__(
self,
model_class: Type['LibModel'],
rows: List[Mapping],
db: 'Database',
flex_rows,
query: Optional[FieldQuery] = None,
sort=None,
):
"""Create a result set that will construct objects of type
`model_class`.
@ -703,7 +748,7 @@ class Results:
# consumed.
self._objects = []
def _get_objects(self):
def _get_objects(self) -> Iterable[Model]:
"""Construct and generate Model objects for they query. The
objects are returned in the order emitted from the database; no
slow sort is applied.
@ -738,7 +783,7 @@ class Results:
yield obj
break
def __iter__(self):
def __iter__(self) -> Iterable[Model]:
"""Construct and generate Model objects for all matching
objects, in sorted order.
"""
@ -751,7 +796,7 @@ class Results:
# Objects are pre-sorted (i.e., by the database).
return self._get_objects()
def _get_indexed_flex_attrs(self):
def _get_indexed_flex_attrs(self) -> Mapping:
""" Index flexible attributes by the entity id they belong to
"""
flex_values = {}
@ -763,7 +808,7 @@ class Results:
return flex_values
def _make_model(self, row, flex_values={}):
def _make_model(self, row, flex_values: Dict = {}) -> Model:
""" Create a Model object for the given row
"""
cols = dict(row)
@ -774,7 +819,7 @@ class Results:
obj = self.model_class._awaken(self.db, values, flex_values)
return obj
def __len__(self):
def __len__(self) -> int:
"""Get the number of matching objects.
"""
if not self._rows:
@ -792,12 +837,12 @@ class Results:
# A fast query. Just count the rows.
return self._row_count
def __nonzero__(self):
def __nonzero__(self) -> bool:
"""Does this result contain any objects?
"""
return self.__bool__()
def __bool__(self):
def __bool__(self) -> bool:
"""Does this result contain any objects?
"""
return bool(len(self))
@ -819,7 +864,7 @@ class Results:
except StopIteration:
raise IndexError(f'result index {n} out of range')
def get(self):
def get(self) -> Optional[Model]:
"""Return the first matching object, or None if no objects
match.
"""
@ -840,10 +885,10 @@ class Transaction:
current transaction.
"""
def __init__(self, db):
def __init__(self, db: 'Database'):
self.db = db
def __enter__(self):
def __enter__(self) -> 'Transaction':
"""Begin a transaction. This transaction may be created while
another is active in a different thread.
"""
@ -856,7 +901,12 @@ class Transaction:
self.db._db_lock.acquire()
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Type[Exception],
exc_value: Exception,
traceback: TracebackType,
):
"""Complete a transaction. This must be the most recently
entered but not yet exited transaction. If it is the last active
transaction, the database updates are committed.
@ -872,14 +922,14 @@ class Transaction:
self._mutated = False
self.db._db_lock.release()
def query(self, statement, subvals=()):
def query(self, statement: str, subvals: Iterable = ()) -> List:
"""Execute an SQL statement with substitution values and return
a list of rows from the database.
"""
cursor = self.db._connection().execute(statement, subvals)
return cursor.fetchall()
def mutate(self, statement, subvals=()):
def mutate(self, statement: str, subvals: Iterable = ()) -> Any:
"""Execute an SQL statement with substitution values and return
the row ID of the last affected row.
"""
@ -898,7 +948,7 @@ class Transaction:
self._mutated = True
return cursor.lastrowid
def script(self, statements):
def script(self, statements: str):
"""Execute a string containing multiple SQL statements."""
# We don't know whether this mutates, but quite likely it does.
self._mutated = True
@ -922,7 +972,7 @@ class Database:
data is written in a transaction.
"""
def __init__(self, path, timeout=5.0):
def __init__(self, path, timeout: float = 5.0):
if sqlite3.threadsafety == 0:
raise RuntimeError(
"sqlite3 must be compiled with multi-threading support"
@ -956,7 +1006,7 @@ class Database:
# Primitive access control: connections and transactions.
def _connection(self):
def _connection(self) -> Connection:
"""Get a SQLite connection object to the underlying database.
One connection object is created per thread.
"""
@ -969,7 +1019,7 @@ class Database:
self._connections[thread_id] = conn
return conn
def _create_connection(self):
def _create_connection(self) -> Connection:
"""Create a SQLite connection to the underlying database.
Makes a new connection every time. If you need to configure the
@ -1019,7 +1069,7 @@ class Database:
conn.close()
@contextlib.contextmanager
def _tx_stack(self):
def _tx_stack(self) -> Generator[List, None, None]:
"""A context manager providing access to the current thread's
transaction stack. The context manager synchronizes access to
the stack map. Transactions should never migrate across threads.
@ -1028,7 +1078,7 @@ class Database:
with self._shared_map_lock:
yield self._tx_stacks[thread_id]
def transaction(self):
def transaction(self) -> Transaction:
"""Get a :class:`Transaction` object for interacting directly
with the underlying SQLite database.
"""
@ -1048,7 +1098,7 @@ class Database:
# Schema setup and migration.
def _make_table(self, table, fields):
def _make_table(self, table: str, fields: Mapping[str, types.Type]):
"""Set up the schema of the database. `fields` is a mapping
from field names to `Type`s. Columns are added if necessary.
"""
@ -1083,7 +1133,7 @@ class Database:
with self.transaction() as tx:
tx.script(setup_sql)
def _make_attribute_table(self, flex_table):
def _make_attribute_table(self, flex_table: str):
"""Create a table and associated index for flexible attributes
for the given entity (if they don't exist).
"""
@ -1101,7 +1151,12 @@ class Database:
# Querying.
def _fetch(self, model_cls, query=None, sort=None):
def _fetch(
self,
model_cls: Type['LibModel'],
query: Optional[Query] = None,
sort: Optional[Sort] = None,
) -> Results:
"""Fetch the objects of type `model_cls` matching the given
query. The query may be given as a string, string sequence, a
Query object, or None (to fetch everything). `sort` is an
@ -1141,7 +1196,7 @@ class Database:
sort if sort.is_slow() else None, # Slow sort component.
)
def _get(self, model_cls, id):
def _get(self, model_cls: Union[Type[Model], Type[LibModel]], id) -> Model:
"""Get a Model object by its id or None if the id does not
exist.
"""

View file

@ -15,13 +15,24 @@
"""The Query type hierarchy for DBCore.
"""
from __future__ import annotations
import re
from operator import mul
from typing import Union, Tuple, List, Optional, Pattern, Any, Type, Iterator,\
Collection, MutableMapping, Sequence
from beets import util
from datetime import datetime, timedelta
import unicodedata
from functools import reduce
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from beets.library import Item
from beets.dbcore import Model
class ParsingError(ValueError):
"""Abstract class for any unparseable user-requested album/query
@ -60,7 +71,7 @@ class Query:
"""An abstract class representing a query into the item database.
"""
def clause(self):
def clause(self) -> Tuple[None, Tuple]:
"""Generate an SQLite expression implementing the query.
Return (clause, subvals) where clause is a valid sqlite
@ -69,19 +80,19 @@ class Query:
"""
return None, ()
def match(self, item):
def match(self, item: Item):
"""Check whether this query matches a given Item. Can be used to
perform queries on arbitrary sets of Items.
"""
raise NotImplementedError
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
def __eq__(self, other):
def __eq__(self, other) -> bool:
return type(self) == type(other)
def __hash__(self):
def __hash__(self) -> int:
return 0
@ -93,12 +104,12 @@ class FieldQuery(Query):
same matching functionality in SQLite.
"""
def __init__(self, field, pattern, fast=True):
def __init__(self, field: str, pattern: Optional[str], fast: bool = True):
self.field = field
self.pattern = pattern
self.fast = fast
def col_clause(self):
def col_clause(self) -> Union[None, Tuple]:
return None, ()
def clause(self):
@ -109,51 +120,51 @@ class FieldQuery(Query):
return None, ()
@classmethod
def value_match(cls, pattern, value):
def value_match(cls, pattern: str, value: str):
"""Determine whether the value matches the pattern. Both
arguments are strings.
"""
raise NotImplementedError()
def match(self, item):
def match(self, item: Model):
return self.value_match(self.pattern, item.get(self.field))
def __repr__(self):
def __repr__(self) -> str:
return ("{0.__class__.__name__}({0.field!r}, {0.pattern!r}, "
"{0.fast})".format(self))
def __eq__(self, other):
def __eq__(self, other) -> bool:
return super().__eq__(other) and \
self.field == other.field and self.pattern == other.pattern
def __hash__(self):
def __hash__(self) -> int:
return hash((self.field, hash(self.pattern)))
class MatchQuery(FieldQuery):
"""A query that looks for exact matches in an item field."""
def col_clause(self):
def col_clause(self) -> Tuple[str, List[str]]:
return self.field + " = ?", [self.pattern]
@classmethod
def value_match(cls, pattern, value):
def value_match(cls, pattern: str, value: str) -> bool:
return pattern == value
class NoneQuery(FieldQuery):
"""A query that checks whether a field is null."""
def __init__(self, field, fast=True):
def __init__(self, field, fast: bool = True):
super().__init__(field, None, fast)
def col_clause(self):
def col_clause(self) -> Tuple[str, Tuple]:
return self.field + " IS NULL", ()
def match(self, item):
def match(self, item: 'Item') -> bool:
return item.get(self.field) is None
def __repr__(self):
def __repr__(self) -> str:
return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self)
@ -163,14 +174,18 @@ class StringFieldQuery(FieldQuery):
"""
@classmethod
def value_match(cls, pattern, value):
def value_match(cls, pattern: str, value: Any):
"""Determine whether the value matches the pattern. The value
may have any type.
"""
return cls.string_match(pattern, util.as_string(value))
@classmethod
def string_match(cls, pattern, value):
def string_match(
cls,
pattern: str,
value: str,
) -> bool:
"""Determine whether the value matches the pattern. Both
arguments are strings. Subclasses implement this method.
"""
@ -180,7 +195,7 @@ class StringFieldQuery(FieldQuery):
class StringQuery(StringFieldQuery):
"""A query that matches a whole string in a specific item field."""
def col_clause(self):
def col_clause(self) -> Tuple[str, List[str]]:
search = (self.pattern
.replace('\\', '\\\\')
.replace('%', '\\%')
@ -190,14 +205,14 @@ class StringQuery(StringFieldQuery):
return clause, subvals
@classmethod
def string_match(cls, pattern, value):
def string_match(cls, pattern: str, value: str) -> bool:
return pattern.lower() == value.lower()
class SubstringQuery(StringFieldQuery):
"""A query that matches a substring in a specific item field."""
def col_clause(self):
def col_clause(self) -> Tuple[str, List[str]]:
pattern = (self.pattern
.replace('\\', '\\\\')
.replace('%', '\\%')
@ -208,7 +223,7 @@ class SubstringQuery(StringFieldQuery):
return clause, subvals
@classmethod
def string_match(cls, pattern, value):
def string_match(cls, pattern: str, value: str) -> bool:
return pattern.lower() in value.lower()
@ -220,7 +235,7 @@ class RegexpQuery(StringFieldQuery):
expression.
"""
def __init__(self, field, pattern, fast=True):
def __init__(self, field: str, pattern: str, fast: bool = True):
super().__init__(field, pattern, fast)
pattern = self._normalize(pattern)
try:
@ -235,14 +250,14 @@ class RegexpQuery(StringFieldQuery):
return f" regexp({self.field}, ?)", [self.pattern.pattern]
@staticmethod
def _normalize(s):
def _normalize(s: str) -> str:
"""Normalize a Unicode string's representation (used on both
patterns and matched values).
"""
return unicodedata.normalize('NFC', s)
@classmethod
def string_match(cls, pattern, value):
def string_match(cls, pattern: Pattern, value: str) -> bool:
return pattern.search(cls._normalize(value)) is not None
@ -251,7 +266,12 @@ class BooleanQuery(MatchQuery):
string reflecting a boolean.
"""
def __init__(self, field, pattern, fast=True):
def __init__(
self,
field: str,
pattern: Union[bool, str],
fast: bool = True,
):
super().__init__(field, pattern, fast)
if isinstance(pattern, str):
self.pattern = util.str2bool(pattern)
@ -265,7 +285,7 @@ class BytesQuery(MatchQuery):
`MatchQuery` when matching on BLOB values.
"""
def __init__(self, field, pattern):
def __init__(self, field: str, pattern: Union[bytes, str, memoryview]):
super().__init__(field, pattern)
# Use a buffer/memoryview representation of the pattern for SQLite
@ -279,7 +299,7 @@ class BytesQuery(MatchQuery):
self.buf_pattern = self.pattern
self.pattern = bytes(self.pattern)
def col_clause(self):
def col_clause(self) -> Tuple[str, List[memoryview]]:
return self.field + " = ?", [self.buf_pattern]
@ -292,7 +312,7 @@ class NumericQuery(FieldQuery):
a float.
"""
def _convert(self, s):
def _convert(self, s: str) -> Union[float, int, None]:
"""Convert a string to a numeric type (float or int).
Return None if `s` is empty.
@ -309,7 +329,7 @@ class NumericQuery(FieldQuery):
except ValueError:
raise InvalidQueryArgumentValueError(s, "an int or a float")
def __init__(self, field, pattern, fast=True):
def __init__(self, field: str, pattern: str, fast: bool = True):
super().__init__(field, pattern, fast)
parts = pattern.split('..', 1)
@ -324,7 +344,7 @@ class NumericQuery(FieldQuery):
self.rangemin = self._convert(parts[0])
self.rangemax = self._convert(parts[1])
def match(self, item):
def match(self, item: 'Item') -> bool:
if self.field not in item:
return False
value = item[self.field]
@ -340,7 +360,7 @@ class NumericQuery(FieldQuery):
return False
return True
def col_clause(self):
def col_clause(self) -> Tuple[str, Tuple]:
if self.point is not None:
return self.field + '=?', (self.point,)
else:
@ -360,24 +380,27 @@ class CollectionQuery(Query):
indexed like a list to access the sub-queries.
"""
def __init__(self, subqueries=()):
def __init__(self, subqueries: Sequence = ()):
self.subqueries = subqueries
# Act like a sequence.
def __len__(self):
def __len__(self) -> int:
return len(self.subqueries)
def __getitem__(self, key):
return self.subqueries[key]
def __iter__(self):
def __iter__(self) -> Iterator:
return iter(self.subqueries)
def __contains__(self, item):
def __contains__(self, item) -> bool:
return item in self.subqueries
def clause_with_joiner(self, joiner):
def clause_with_joiner(
self,
joiner: str,
) -> Tuple[Optional[str], Collection]:
"""Return a clause created by joining together the clauses of
all subqueries with the string joiner (padded by spaces).
"""
@ -393,14 +416,14 @@ class CollectionQuery(Query):
clause = (' ' + joiner + ' ').join(clause_parts)
return clause, subvals
def __repr__(self):
def __repr__(self) -> str:
return "{0.__class__.__name__}({0.subqueries!r})".format(self)
def __eq__(self, other):
def __eq__(self, other) -> bool:
return super().__eq__(other) and \
self.subqueries == other.subqueries
def __hash__(self):
def __hash__(self) -> int:
"""Since subqueries are mutable, this object should not be hashable.
However and for conveniences purposes, it can be hashed.
"""
@ -413,7 +436,7 @@ class AnyFieldQuery(CollectionQuery):
constructor.
"""
def __init__(self, pattern, fields, cls):
def __init__(self, pattern, fields, cls: Type[FieldQuery]):
self.pattern = pattern
self.fields = fields
self.query_class = cls
@ -421,26 +444,27 @@ class AnyFieldQuery(CollectionQuery):
subqueries = []
for field in self.fields:
subqueries.append(cls(field, pattern, True))
# TYPING ERROR
super().__init__(subqueries)
def clause(self):
def clause(self) -> Tuple[Union[str, None], Collection]:
return self.clause_with_joiner('or')
def match(self, item):
def match(self, item: 'Item') -> bool:
for subq in self.subqueries:
if subq.match(item):
return True
return False
def __repr__(self):
def __repr__(self) -> str:
return ("{0.__class__.__name__}({0.pattern!r}, {0.fields!r}, "
"{0.query_class.__name__})".format(self))
def __eq__(self, other):
def __eq__(self, other) -> bool:
return super().__eq__(other) and \
self.query_class == other.query_class
def __hash__(self):
def __hash__(self) -> int:
return hash((self.pattern, tuple(self.fields), self.query_class))
@ -448,6 +472,7 @@ class MutableCollectionQuery(CollectionQuery):
"""A collection query whose subqueries may be modified after the
query is initialized.
"""
subqueries: MutableMapping
def __setitem__(self, key, value):
self.subqueries[key] = value
@ -459,20 +484,20 @@ class MutableCollectionQuery(CollectionQuery):
class AndQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
def clause(self):
def clause(self) -> Tuple[Union[str, None], Collection]:
return self.clause_with_joiner('and')
def match(self, item):
def match(self, item) -> bool:
return all(q.match(item) for q in self.subqueries)
class OrQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
def clause(self):
def clause(self) -> Tuple[Union[str, None], Collection]:
return self.clause_with_joiner('or')
def match(self, item):
def match(self, item) -> bool:
return any(q.match(item) for q in self.subqueries)
@ -493,43 +518,43 @@ class NotQuery(Query):
# is handled by match() for slow queries.
return clause, subvals
def match(self, item):
def match(self, item) -> bool:
return not self.subquery.match(item)
def __repr__(self):
def __repr__(self) -> str:
return "{0.__class__.__name__}({0.subquery!r})".format(self)
def __eq__(self, other):
def __eq__(self, other) -> bool:
return super().__eq__(other) and \
self.subquery == other.subquery
def __hash__(self):
def __hash__(self) -> int:
return hash(('not', hash(self.subquery)))
class TrueQuery(Query):
"""A query that always matches."""
def clause(self):
def clause(self) -> Tuple[Union[str, None], Collection]:
return '1', ()
def match(self, item):
def match(self, item) -> bool:
return True
class FalseQuery(Query):
"""A query that never matches."""
def clause(self):
def clause(self) -> Tuple[Union[str, None], Collection]:
return '0', ()
def match(self, item):
def match(self, item) -> bool:
return False
# Time/date queries.
def _parse_periods(pattern):
def _parse_periods(pattern: str) -> Tuple['Period', 'Period']:
"""Parse a string containing two dates separated by two dots (..).
Return a pair of `Period` objects.
"""
@ -563,7 +588,7 @@ class Period:
relative_re = '(?P<sign>[+|-]?)(?P<quantity>[0-9]+)' + \
'(?P<timespan>[y|m|w|d])'
def __init__(self, date, precision):
def __init__(self, date: datetime, precision: str):
"""Create a period with the given date (a `datetime` object) and
precision (a string, one of "year", "month", "day", "hour", "minute",
or "second").
@ -574,7 +599,7 @@ class Period:
self.precision = precision
@classmethod
def parse(cls, string):
def parse(cls: Type['Period'], string: str) -> Optional['Period']:
"""Parse a date and return a `Period` object or `None` if the
string is empty, or raise an InvalidQueryArgumentValueError if
the string cannot be parsed to a date.
@ -591,7 +616,8 @@ class Period:
and a "year" is exactly 365 days.
"""
def find_date_and_format(string):
def find_date_and_format(string: str) -> \
Union[Tuple[None, None], Tuple[datetime, int]]:
for ord, format in enumerate(cls.date_formats):
for format_option in format:
try:
@ -628,7 +654,7 @@ class Period:
precision = cls.precisions[ordinal]
return cls(date, precision)
def open_right_endpoint(self):
def open_right_endpoint(self) -> datetime:
"""Based on the precision, convert the period to a precise
`datetime` for use as a right endpoint in a right-open interval.
"""
@ -660,7 +686,7 @@ class DateInterval:
A right endpoint of None means towards infinity.
"""
def __init__(self, start, end):
def __init__(self, start: Optional[datetime], end: Optional[datetime]):
if start is not None and end is not None and not start < end:
raise ValueError("start date {} is not before end date {}"
.format(start, end))
@ -668,21 +694,21 @@ class DateInterval:
self.end = end
@classmethod
def from_periods(cls, start, end):
def from_periods(cls, start: Period, end: Period) -> 'DateInterval':
"""Create an interval with two Periods as the endpoints.
"""
end_date = end.open_right_endpoint() if end is not None else None
start_date = start.date if start is not None else None
return cls(start_date, end_date)
def contains(self, date):
def contains(self, date: datetime) -> bool:
if self.start is not None and date < self.start:
return False
if self.end is not None and date >= self.end:
return False
return True
def __str__(self):
def __str__(self) -> str:
return f'[{self.start}, {self.end})'
@ -696,12 +722,12 @@ class DateQuery(FieldQuery):
using an ellipsis interval syntax similar to that of NumericQuery.
"""
def __init__(self, field, pattern, fast=True):
def __init__(self, field, pattern, fast: bool = True):
super().__init__(field, pattern, fast)
start, end = _parse_periods(pattern)
self.interval = DateInterval.from_periods(start, end)
def match(self, item):
def match(self, item: 'Item') -> bool:
if self.field not in item:
return False
timestamp = float(item[self.field])
@ -710,7 +736,7 @@ class DateQuery(FieldQuery):
_clause_tmpl = "{0} {1} ?"
def col_clause(self):
def col_clause(self) -> Tuple[Union[str, None], Collection]:
clause_parts = []
subvals = []
@ -742,7 +768,7 @@ class DurationQuery(NumericQuery):
or M:SS time interval.
"""
def _convert(self, s):
def _convert(self, s: str) -> Optional[float]:
"""Convert a M:SS or numeric string to a float.
Return None if `s` is empty.
@ -768,27 +794,27 @@ class Sort:
the item database.
"""
def order_clause(self):
def order_clause(self) -> None:
"""Generates a SQL fragment to be used in a ORDER BY clause, or
None if no fragment is used (i.e., this is a slow sort).
"""
return None
def sort(self, items):
def sort(self, items: List) -> List:
"""Sort the list of objects and return a list.
"""
return sorted(items)
def is_slow(self):
def is_slow(self) -> bool:
"""Indicate whether this query is *slow*, meaning that it cannot
be executed in SQL and must be executed in Python.
"""
return False
def __hash__(self):
def __hash__(self) -> int:
return 0
def __eq__(self, other):
def __eq__(self, other) -> bool:
return type(self) == type(other)
@ -796,13 +822,13 @@ class MultipleSort(Sort):
"""Sort that encapsulates multiple sub-sorts.
"""
def __init__(self, sorts=None):
def __init__(self, sorts: Optional[List[Sort]] = None):
self.sorts = sorts or []
def add_sort(self, sort):
def add_sort(self, sort: Sort):
self.sorts.append(sort)
def _sql_sorts(self):
def _sql_sorts(self) -> List[Sort]:
"""Return the list of sub-sorts for which we can be (at least
partially) fast.
@ -819,15 +845,16 @@ class MultipleSort(Sort):
sql_sorts.reverse()
return sql_sorts
def order_clause(self):
def order_clause(self) -> str:
order_strings = []
for sort in self._sql_sorts():
order = sort.order_clause()
order_strings.append(order)
# TYPING ERROR
return ", ".join(order_strings)
def is_slow(self):
def is_slow(self) -> bool:
for sort in self.sorts:
if sort.is_slow():
return True
@ -865,17 +892,22 @@ class FieldSort(Sort):
any kind).
"""
def __init__(self, field, ascending=True, case_insensitive=True):
def __init__(
self,
field,
ascending: bool = True,
case_insensitive: bool = True,
):
self.field = field
self.ascending = ascending
self.case_insensitive = case_insensitive
def sort(self, objs):
def sort(self, objs: Collection):
# TODO: Conversion and null-detection here. In Python 3,
# comparisons with None fail. We should also support flexible
# attributes with different types without falling over.
def key(item):
def key(item: 'Item'):
field_val = item.get(self.field, '')
if self.case_insensitive and isinstance(field_val, str):
field_val = field_val.lower()
@ -883,17 +915,17 @@ class FieldSort(Sort):
return sorted(objs, key=key, reverse=not self.ascending)
def __repr__(self):
def __repr__(self) -> str:
return '<{}: {}{}>'.format(
type(self).__name__,
self.field,
'+' if self.ascending else '-',
)
def __hash__(self):
def __hash__(self) -> int:
return hash((self.field, self.ascending))
def __eq__(self, other):
def __eq__(self, other) -> bool:
return super().__eq__(other) and \
self.field == other.field and \
self.ascending == other.ascending
@ -903,7 +935,7 @@ class FixedFieldSort(FieldSort):
"""Sort object to sort on a fixed field.
"""
def order_clause(self):
def order_clause(self) -> str:
order = "ASC" if self.ascending else "DESC"
if self.case_insensitive:
field = '(CASE ' \
@ -920,24 +952,24 @@ class SlowFieldSort(FieldSort):
i.e., a computed or flexible field.
"""
def is_slow(self):
def is_slow(self) -> bool:
return True
class NullSort(Sort):
"""No sorting. Leave results unsorted."""
def sort(self, items):
def sort(self, items: List) -> List:
return items
def __nonzero__(self):
def __nonzero__(self) -> bool:
return self.__bool__()
def __bool__(self):
def __bool__(self) -> bool:
return False
def __eq__(self, other):
def __eq__(self, other) -> bool:
return type(self) == type(other) or other is None
def __hash__(self):
def __hash__(self) -> int:
return 0

View file

@ -17,7 +17,11 @@
import re
import itertools
from . import query
from typing import Dict, Type, Tuple, Optional, Collection, List, \
Sequence
from . import query, Model
from .query import Sort
PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
@ -34,8 +38,12 @@ PARSE_QUERY_PART_REGEX = re.compile(
)
def parse_query_part(part, query_classes={}, prefixes={},
default_class=query.SubstringQuery):
def parse_query_part(
part: str,
query_classes: Dict = {},
prefixes: Dict = {},
default_class: Type[query.SubstringQuery] = query.SubstringQuery,
) -> Tuple[Optional[str], str, Type[query.Query], bool]:
"""Parse a single *query part*, which is a chunk of a complete query
string representing a single criterion.
@ -100,7 +108,11 @@ def parse_query_part(part, query_classes={}, prefixes={},
return key, term, query_class, negate
def construct_query_part(model_cls, prefixes, query_part):
def construct_query_part(
model_cls: Type[Model],
prefixes: Dict,
query_part: str,
) -> query.Query:
"""Parse a *query part* string and return a :class:`Query` object.
:param model_cls: The :class:`Model` class that this is a query for.
@ -158,7 +170,13 @@ def construct_query_part(model_cls, prefixes, query_part):
return out_query
def query_from_strings(query_cls, model_cls, prefixes, query_parts):
# TYPING ERROR
def query_from_strings(
query_cls: Type[query.Query],
model_cls: Type[Model],
prefixes: Dict,
query_parts: Collection[str],
) -> query.Query:
"""Creates a collection query of type `query_cls` from a list of
strings in the format used by parse_query_part. `model_cls`
determines how queries are constructed from strings.
@ -171,7 +189,11 @@ def query_from_strings(query_cls, model_cls, prefixes, query_parts):
return query_cls(subqueries)
def construct_sort_part(model_cls, part, case_insensitive=True):
def construct_sort_part(
model_cls: Type[Model],
part: str,
case_insensitive: bool = True,
) -> Sort:
"""Create a `Sort` from a single string criterion.
`model_cls` is the `Model` being queried. `part` is a single string
@ -197,7 +219,11 @@ def construct_sort_part(model_cls, part, case_insensitive=True):
return sort
def sort_from_strings(model_cls, sort_parts, case_insensitive=True):
def sort_from_strings(
model_cls: Type[Model],
sort_parts: Sequence[str],
case_insensitive: bool = True,
) -> Sort:
"""Create a `Sort` from a list of sort criteria (strings).
"""
if not sort_parts:
@ -212,8 +238,12 @@ def sort_from_strings(model_cls, sort_parts, case_insensitive=True):
return sort
def parse_sorted_query(model_cls, parts, prefixes={},
case_insensitive=True):
def parse_sorted_query(
model_cls: Type[Model],
parts: List[str],
prefixes: Dict = {},
case_insensitive: bool = True,
) -> Tuple[query.Query, Sort]:
"""Given a list of strings, create the `Query` and `Sort` that they
represent.
"""

View file

@ -15,6 +15,7 @@
"""Representation of type information for DBCore model fields.
"""
from typing import Union, Any, Callable
from . import query
from beets.util import str2bool
@ -35,7 +36,7 @@ class Type:
"""The `Query` subclass to be used when querying the field.
"""
model_type = str
model_type: Callable[[Any], str] = str
"""The Python type that is used to represent the value in the model.
The model is guaranteed to return a value of this type if the field
@ -44,12 +45,12 @@ class Type:
"""
@property
def null(self):
def null(self) -> model_type:
"""The value to be exposed when the underlying value is None.
"""
return self.model_type()
def format(self, value):
def format(self, value: model_type) -> str:
"""Given a value of this type, produce a Unicode string
representing the value. This is used in template evaluation.
"""
@ -63,7 +64,7 @@ class Type:
return str(value)
def parse(self, string):
def parse(self, string: str) -> model_type:
"""Parse a (possibly human-written) string and return the
indicated value of this type.
"""
@ -72,11 +73,12 @@ class Type:
except ValueError:
return self.null
def normalize(self, value):
def normalize(self, value: Union[None, int, float, bytes]) -> model_type:
"""Given a value that will be assigned into a field of this
type, normalize the value to have the appropriate type. This
base implementation only reinterprets `None`.
"""
# TYPING ERROR
if value is None:
return self.null
else:
@ -84,7 +86,10 @@ class Type:
# `self.model_type(value)`
return value
def from_sql(self, sql_value):
def from_sql(
self,
sql_value: Union[None, int, float, str, bytes],
) -> model_type:
"""Receives the value stored in the SQL backend and return the
value to be stored in the model.
@ -105,7 +110,7 @@ class Type:
else:
return self.normalize(sql_value)
def to_sql(self, model_value):
def to_sql(self, model_value: Any) -> Union[None, int, float, str, bytes]:
"""Convert a value as stored in the model object to a value used
by the database adapter.
"""
@ -125,7 +130,7 @@ class Integer(Type):
query = query.NumericQuery
model_type = int
def normalize(self, value):
def normalize(self, value: str) -> Union[int, str]:
try:
return self.model_type(round(float(value)))
except ValueError:
@ -138,10 +143,10 @@ class PaddedInt(Integer):
"""An integer field that is formatted with a given number of digits,
padded with zeroes.
"""
def __init__(self, digits):
def __init__(self, digits: int):
self.digits = digits
def format(self, value):
def format(self, value: int) -> str:
return '{0:0{1}d}'.format(value or 0, self.digits)
@ -155,11 +160,11 @@ class ScaledInt(Integer):
"""An integer whose formatting operation scales the number by a
constant and adds a suffix. Good for units with large magnitudes.
"""
def __init__(self, unit, suffix=''):
def __init__(self, unit: int, suffix: str = ''):
self.unit = unit
self.suffix = suffix
def format(self, value):
def format(self, value: int) -> str:
return '{}{}'.format((value or 0) // self.unit, self.suffix)
@ -169,7 +174,7 @@ class Id(Integer):
"""
null = None
def __init__(self, primary=True):
def __init__(self, primary: bool = True):
if primary:
self.sql = 'INTEGER PRIMARY KEY'
@ -182,10 +187,10 @@ class Float(Type):
query = query.NumericQuery
model_type = float
def __init__(self, digits=1):
def __init__(self, digits: int = 1):
self.digits = digits
def format(self, value):
def format(self, value: float) -> str:
return '{0:.{1}f}'.format(value or 0, self.digits)
@ -201,7 +206,7 @@ class String(Type):
sql = 'TEXT'
query = query.SubstringQuery
def normalize(self, value):
def normalize(self, value: str) -> str:
if value is None:
return self.null
else:
@ -236,10 +241,10 @@ class Boolean(Type):
query = query.BooleanQuery
model_type = bool
def format(self, value):
def format(self, value: bool) -> str:
return str(bool(value))
def parse(self, string):
def parse(self, string: str) -> bool:
return str2bool(string)

View file

@ -23,11 +23,17 @@ import shutil
import fnmatch
import functools
from collections import Counter, namedtuple
from logging import Logger
from multiprocessing.pool import ThreadPool
import traceback
import subprocess
import platform
import shlex
from typing import Callable, List, Optional, Sequence, Pattern, \
Tuple, MutableSequence, AnyStr, TypeVar, Generator, Any, \
Iterable, Union
from typing_extensions import TypeAlias
from beets.util import hidden
from unidecode import unidecode
from enum import Enum
@ -35,6 +41,8 @@ from enum import Enum
MAX_FILENAME_LENGTH = 200
WINDOWS_MAGIC_PREFIX = '\\\\?\\'
T = TypeVar('T')
Bytes_or_String: TypeAlias = Union[str, bytes]
class HumanReadableException(Exception):
@ -135,7 +143,7 @@ class MoveOperation(Enum):
REFLINK_AUTO = 5
def normpath(path):
def normpath(path: bytes) -> bytes:
"""Provide the canonical form of the path suitable for storing in
the database.
"""
@ -144,11 +152,11 @@ def normpath(path):
return bytestring_path(path)
def ancestry(path):
def ancestry(path: bytes) -> List[str]:
"""Return a list consisting of path's parent directory, its
grandparent, and so on. For instance:
>>> ancestry('/a/b/c')
>>> ancestry(b'/a/b/c')
['/', '/a', '/a/b']
The argument should *not* be the result of a call to `syspath`.
@ -168,7 +176,12 @@ def ancestry(path):
return out
def sorted_walk(path, ignore=(), ignore_hidden=False, logger=None):
def sorted_walk(
path: AnyStr,
ignore: Sequence = (),
ignore_hidden: bool = False,
logger: Optional[Logger] = None,
) -> Generator[Tuple, None, None]:
"""Like `os.walk`, but yields things in case-insensitive sorted,
breadth-first order. Directory and file names matching any glob
pattern in `ignore` are skipped. If `logger` is provided, then
@ -225,14 +238,14 @@ def sorted_walk(path, ignore=(), ignore_hidden=False, logger=None):
yield from sorted_walk(cur, ignore, ignore_hidden, logger)
def path_as_posix(path):
def path_as_posix(path: bytes) -> bytes:
"""Return the string representation of the path with forward (/)
slashes.
"""
return path.replace(b'\\', b'/')
def mkdirall(path):
def mkdirall(path: bytes):
"""Make all the enclosing directories of path (like mkdir -p on the
parent).
"""
@ -245,7 +258,7 @@ def mkdirall(path):
traceback.format_exc())
def fnmatch_all(names, patterns):
def fnmatch_all(names: Sequence[bytes], patterns: Sequence[bytes]) -> bool:
"""Determine whether all strings in `names` match at least one of
the `patterns`, which should be shell glob expressions.
"""
@ -260,7 +273,11 @@ def fnmatch_all(names, patterns):
return True
def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
def prune_dirs(
path: str,
root: Optional[Bytes_or_String] = None,
clutter: Sequence[str] = ('.DS_Store', 'Thumbs.db'),
):
"""If path is an empty directory, then remove it. Recursively remove
path's ancestry up to root (which is never removed) where there are
empty directories. If path is not contained in root, then nothing is
@ -291,7 +308,7 @@ def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
if not os.path.exists(directory):
# Directory gone already.
continue
clutter = [bytestring_path(c) for c in clutter]
clutter: List[bytes] = [bytestring_path(c) for c in clutter]
match_paths = [bytestring_path(d) for d in os.listdir(directory)]
try:
if fnmatch_all(match_paths, clutter):
@ -303,10 +320,10 @@ def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
break
def components(path):
def components(path: AnyStr) -> MutableSequence[AnyStr]:
"""Return a list of the path components in path. For instance:
>>> components('/a/b/c')
>>> components(b'/a/b/c')
['a', 'b', 'c']
The argument should *not* be the result of a call to `syspath`.
@ -327,14 +344,14 @@ def components(path):
return comps
def arg_encoding():
def arg_encoding() -> str:
"""Get the encoding for command-line arguments (and other OS
locale-sensitive strings).
"""
return sys.getfilesystemencoding()
def _fsencoding():
def _fsencoding() -> str:
"""Get the system's filesystem encoding. On Windows, this is always
UTF-8 (not MBCS).
"""
@ -349,9 +366,10 @@ def _fsencoding():
return encoding
def bytestring_path(path):
def bytestring_path(path: Bytes_or_String) -> bytes:
"""Given a path, which is either a bytes or a unicode, returns a str
path (ensuring that we never deal with Unicode pathnames).
path (ensuring that we never deal with Unicode pathnames). Path should be
bytes but has safeguards for strings to be converted.
"""
# Pass through bytestrings.
if isinstance(path, bytes):
@ -370,10 +388,10 @@ def bytestring_path(path):
return path.encode('utf-8')
PATH_SEP = bytestring_path(os.sep)
PATH_SEP: bytes = bytestring_path(os.sep)
def displayable_path(path, separator='; '):
def displayable_path(path: bytes, separator: str = '; ') -> str:
"""Attempts to decode a bytestring path to a unicode object for the
purpose of displaying it to the user. If the `path` argument is a
list or a tuple, the elements are joined with `separator`.
@ -392,7 +410,7 @@ def displayable_path(path, separator='; '):
return path.decode('utf-8', 'ignore')
def syspath(path, prefix=True):
def syspath(path: bytes, prefix: bool = True) -> Bytes_or_String:
"""Convert a path for use by the operating system. In particular,
paths on Windows must receive a magic prefix and must be converted
to Unicode before they are sent to the OS. To disable the magic
@ -412,6 +430,7 @@ def syspath(path, prefix=True):
except UnicodeError:
# The encoding should always be MBCS, Windows' broken
# Unicode representation.
assert isinstance(path, bytes)
encoding = sys.getfilesystemencoding() or sys.getdefaultencoding()
path = path.decode(encoding, 'replace')
@ -426,14 +445,14 @@ def syspath(path, prefix=True):
return path
def samefile(p1, p2):
def samefile(p1: bytes, p2: bytes) -> bool:
"""Safer equality for paths."""
if p1 == p2:
return True
return shutil._samefile(syspath(p1), syspath(p2))
def remove(path, soft=True):
def remove(path: bytes, soft: bool = True):
"""Remove the file. If `soft`, then no error will be raised if the
file does not exist.
"""
@ -446,7 +465,7 @@ def remove(path, soft=True):
raise FilesystemError(exc, 'delete', (path,), traceback.format_exc())
def copy(path, dest, replace=False):
def copy(path: bytes, dest: bytes, replace: bool = False):
"""Copy a plain file. Permissions are not copied. If `dest` already
exists, raises a FilesystemError unless `replace` is True. Has no
effect if `path` is the same as `dest`. Paths are translated to
@ -465,7 +484,7 @@ def copy(path, dest, replace=False):
traceback.format_exc())
def move(path, dest, replace=False):
def move(path: bytes, dest: bytes, replace: bool = False):
"""Rename a file. `dest` may not be a directory. If `dest` already
exists, raises an OSError unless `replace` is True. Has no effect if
`path` is the same as `dest`. If the paths are on different
@ -515,7 +534,7 @@ def move(path, dest, replace=False):
os.remove(tmp)
def link(path, dest, replace=False):
def link(path: bytes, dest: bytes, replace: bool = False):
"""Create a symbolic link from path to `dest`. Raises an OSError if
`dest` already exists, unless `replace` is True. Does nothing if
`path` == `dest`.
@ -536,7 +555,7 @@ def link(path, dest, replace=False):
traceback.format_exc())
def hardlink(path, dest, replace=False):
def hardlink(path: bytes, dest: bytes, replace: bool = False):
"""Create a hard link from path to `dest`. Raises an OSError if
`dest` already exists, unless `replace` is True. Does nothing if
`path` == `dest`.
@ -560,7 +579,12 @@ def hardlink(path, dest, replace=False):
traceback.format_exc())
def reflink(path, dest, replace=False, fallback=False):
def reflink(
path: bytes,
dest: bytes,
replace: bool = False,
fallback: bool = False,
):
"""Create a reflink from `dest` to `path`.
Raise an `OSError` if `dest` already exists, unless `replace` is
@ -589,7 +613,7 @@ def reflink(path, dest, replace=False, fallback=False):
'link', (path, dest), traceback.format_exc())
def unique_path(path):
def unique_path(path: bytes) -> bytes:
"""Returns a version of ``path`` that does not exist on the
filesystem. Specifically, if ``path` itself already exists, then
something unique is appended to the path.
@ -616,7 +640,7 @@ def unique_path(path):
# Unix. They are forbidden here because they cause problems on Samba
# shares, which are sufficiently common as to cause frequent problems.
# https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
CHAR_REPLACE = [
CHAR_REPLACE: List[Tuple[Pattern, str]] = [
(re.compile(r'[\\/]'), '_'), # / and \ -- forbidden everywhere.
(re.compile(r'^\.'), '_'), # Leading dot (hidden files on Unix).
(re.compile(r'[\x00-\x1f]'), ''), # Control characters.
@ -626,7 +650,10 @@ CHAR_REPLACE = [
]
def sanitize_path(path, replacements=None):
def sanitize_path(
path: str,
replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]] = None,
) -> str:
"""Takes a path (as a Unicode string) and makes sure that it is
legal. Returns a new path. Only works with fragments; won't work
reliably on Windows when a path begins with a drive letter. Path
@ -647,7 +674,7 @@ def sanitize_path(path, replacements=None):
return os.path.join(*comps)
def truncate_path(path, length=MAX_FILENAME_LENGTH):
def truncate_path(path: AnyStr, length: int = MAX_FILENAME_LENGTH) -> AnyStr:
"""Given a bytestring path or a Unicode path fragment, truncate the
components to a legal length. In the last component, the extension
is preserved.
@ -664,7 +691,13 @@ def truncate_path(path, length=MAX_FILENAME_LENGTH):
return os.path.join(*out)
def _legalize_stage(path, replacements, length, extension, fragment):
def _legalize_stage(
path: str,
replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]],
length: int,
extension: str,
fragment: bool,
) -> Tuple[Bytes_or_String, bool]:
"""Perform a single round of path legalization steps
(sanitation/replacement, encoding from Unicode to bytes,
extension-appending, and truncation). Return the path (Unicode if
@ -676,7 +709,7 @@ def _legalize_stage(path, replacements, length, extension, fragment):
# Encode for the filesystem.
if not fragment:
path = bytestring_path(path)
path = bytestring_path(path) # type: ignore
# Preserve extension.
path += extension.lower()
@ -688,7 +721,13 @@ def _legalize_stage(path, replacements, length, extension, fragment):
return path, path != pre_truncate_path
def legalize_path(path, replacements, length, extension, fragment):
def legalize_path(
path: str,
replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]],
length: int,
extension: bytes,
fragment: bool,
) -> Tuple[Union[Bytes_or_String, bool]]:
"""Given a path-like Unicode string, produce a legal path. Return
the path and a flag indicating whether some replacements had to be
ignored (see below).
@ -736,7 +775,7 @@ def legalize_path(path, replacements, length, extension, fragment):
return second_stage_path, retruncated
def py3_path(path):
def py3_path(path: AnyStr) -> str:
"""Convert a bytestring path to Unicode.
This helps deal with APIs on Python 3 that *only* accept Unicode
@ -751,12 +790,12 @@ def py3_path(path):
return os.fsdecode(path)
def str2bool(value):
def str2bool(value: str) -> bool:
"""Returns a boolean reflecting a human-entered string."""
return value.lower() in ('yes', '1', 'true', 't', 'y')
def as_string(value):
def as_string(value: Any) -> str:
"""Convert a value to a Unicode object for matching with a query.
None becomes the empty string. Bytestrings are silently decoded.
"""
@ -770,7 +809,7 @@ def as_string(value):
return str(value)
def plurality(objs):
def plurality(objs: Sequence[T]) -> T:
"""Given a sequence of hashble objects, returns the object that
is most common in the set and the its number of appearance. The
sequence must contain at least one object.
@ -781,7 +820,7 @@ def plurality(objs):
return c.most_common(1)[0]
def cpu_count():
def cpu_count() -> int:
"""Return the number of hardware thread contexts (cores or SMT
threads) in the system.
"""
@ -812,13 +851,12 @@ def cpu_count():
return 1
def convert_command_args(args):
def convert_command_args(args: List[bytes]) -> List[str]:
"""Convert command arguments, which may either be `bytes` or `str`
objects, to uniformly surrogate-escaped strings.
"""
objects, to uniformly surrogate-escaped strings. """
assert isinstance(args, list)
def convert(arg):
def convert(arg) -> str:
if isinstance(arg, bytes):
return os.fsdecode(arg)
return arg
@ -830,7 +868,10 @@ def convert_command_args(args):
CommandOutput = namedtuple("CommandOutput", ("stdout", "stderr"))
def command_output(cmd, shell=False):
def command_output(
cmd: List[Bytes_or_String],
shell: bool = False,
) -> CommandOutput:
"""Runs the command and returns its output after it has exited.
Returns a CommandOutput. The attributes ``stdout`` and ``stderr`` contain
@ -870,7 +911,7 @@ def command_output(cmd, shell=False):
return CommandOutput(stdout, stderr)
def max_filename_length(path, limit=MAX_FILENAME_LENGTH):
def max_filename_length(path: AnyStr, limit=MAX_FILENAME_LENGTH) -> int:
"""Attempt to determine the maximum filename length for the
filesystem containing `path`. If the value is greater than `limit`,
then `limit` is used instead (to prevent errors when a filesystem
@ -887,7 +928,7 @@ def max_filename_length(path, limit=MAX_FILENAME_LENGTH):
return limit
def open_anything():
def open_anything() -> str:
"""Return the system command that dispatches execution to the correct
program.
"""
@ -901,7 +942,7 @@ def open_anything():
return base_cmd
def editor_command():
def editor_command() -> str:
"""Get a command for opening a text file.
Use the `EDITOR` environment variable by default. If it is not
@ -914,7 +955,7 @@ def editor_command():
return open_anything()
def interactive_open(targets, command):
def interactive_open(targets: Sequence[str], command: str):
"""Open the files in `targets` by `exec`ing a new `command`, given
as a Unicode string. (The new program takes over, and Python
execution ends: this does not fork a subprocess.)
@ -936,7 +977,7 @@ def interactive_open(targets, command):
return os.execlp(*args)
def case_sensitive(path):
def case_sensitive(path: bytes) -> bool:
"""Check whether the filesystem at the given path is case sensitive.
To work best, the path should point to a file or a directory. If the path
@ -984,7 +1025,7 @@ def case_sensitive(path):
return not os.path.samefile(lower_sys, upper_sys)
def raw_seconds_short(string):
def raw_seconds_short(string: str) -> float:
"""Formats a human-readable M:SS string as a float (number of seconds).
Raises ValueError if the conversion cannot take place due to `string` not
@ -997,7 +1038,7 @@ def raw_seconds_short(string):
return float(minutes * 60 + seconds)
def asciify_path(path, sep_replace):
def asciify_path(path: str, sep_replace: str) -> str:
"""Decodes all unicode characters in a path into ASCII equivalents.
Substitutions are provided by the unidecode module. Path separators in the
@ -1010,7 +1051,7 @@ def asciify_path(path, sep_replace):
# if this platform has an os.altsep, change it to os.sep.
if os.altsep:
path = path.replace(os.altsep, os.sep)
path_components = path.split(os.sep)
path_components: List[Bytes_or_String] = path.split(os.sep)
for index, item in enumerate(path_components):
path_components[index] = unidecode(item).replace(os.sep, sep_replace)
if os.altsep:
@ -1021,7 +1062,7 @@ def asciify_path(path, sep_replace):
return os.sep.join(path_components)
def par_map(transform, items):
def par_map(transform: Callable, items: Iterable):
"""Apply the function `transform` to all the elements in the
iterable `items`, like `map(transform, items)` but with no return
value.
@ -1035,7 +1076,7 @@ def par_map(transform, items):
pool.join()
def lazy_property(func):
def lazy_property(func: Callable) -> Callable:
"""A decorator that creates a lazily evaluated property. On first access,
the property is assigned the return value of `func`. This first value is
stored, so that future accesses do not have to evaluate `func` again.

View file

@ -92,6 +92,7 @@ setup(
'confuse>=1.5.0',
'munkres>=1.0.0',
'jellyfish',
'typing_extensions',
] + (
# Support for ANSI console colors on Windows.
['colorama'] if (sys.platform == 'win32') else []