mirror of
https://github.com/beetbox/beets.git
synced 2025-12-14 20:43:41 +01:00
Merge pull request #4495 from Serene-Arc/dbcore_typing
This commit is contained in:
commit
f68ff90899
6 changed files with 395 additions and 231 deletions
|
|
@ -15,6 +15,7 @@
|
|||
"""The central Model and Database constructs for DBCore.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import os
|
||||
import re
|
||||
|
|
@ -22,6 +23,10 @@ from collections import defaultdict
|
|||
import threading
|
||||
import sqlite3
|
||||
import contextlib
|
||||
from sqlite3 import Connection
|
||||
from types import TracebackType
|
||||
from typing import Iterable, Type, List, Tuple, Optional, Union, \
|
||||
Dict, Any, Generator, Iterator, Callable
|
||||
|
||||
from unidecode import unidecode
|
||||
|
||||
|
|
@ -29,9 +34,15 @@ import beets
|
|||
from beets.util import functemplate
|
||||
from beets.util import py3_path
|
||||
from beets.dbcore import types
|
||||
from .query import MatchQuery, NullSort, TrueQuery, AndQuery
|
||||
from .query import MatchQuery, NullSort, TrueQuery, AndQuery, Query, \
|
||||
FieldQuery, Sort
|
||||
from collections.abc import Mapping
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from beets.library import LibModel
|
||||
from ..util.functemplate import Template
|
||||
|
||||
|
||||
class DBAccessError(Exception):
|
||||
"""The SQLite database became inaccessible.
|
||||
|
|
@ -58,7 +69,12 @@ class FormattedMapping(Mapping):
|
|||
|
||||
ALL_KEYS = '*'
|
||||
|
||||
def __init__(self, model, included_keys=ALL_KEYS, for_path=False):
|
||||
def __init__(
|
||||
self,
|
||||
model: 'Model',
|
||||
included_keys: str = ALL_KEYS,
|
||||
for_path: bool = False,
|
||||
):
|
||||
self.for_path = for_path
|
||||
self.model = model
|
||||
if included_keys == self.ALL_KEYS:
|
||||
|
|
@ -73,10 +89,10 @@ class FormattedMapping(Mapping):
|
|||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterable[str]:
|
||||
return iter(self.model_keys)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.model_keys)
|
||||
|
||||
def get(self, key, default=None):
|
||||
|
|
@ -107,7 +123,7 @@ class LazyConvertDict:
|
|||
"""Lazily convert types for attributes fetched from the database
|
||||
"""
|
||||
|
||||
def __init__(self, model_cls):
|
||||
def __init__(self, model_cls: 'Model'):
|
||||
"""Initialize the object empty
|
||||
"""
|
||||
self.data = {}
|
||||
|
|
@ -148,12 +164,12 @@ class LazyConvertDict:
|
|||
if key in self.data:
|
||||
del self.data[key]
|
||||
|
||||
def keys(self):
|
||||
def keys(self) -> List[str]:
|
||||
"""Get a list of available field names for this object.
|
||||
"""
|
||||
return list(self._converted.keys()) + list(self.data.keys())
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> 'LazyConvertDict':
|
||||
"""Create a copy of the object.
|
||||
"""
|
||||
new = self.__class__(self.model_cls)
|
||||
|
|
@ -169,7 +185,7 @@ class LazyConvertDict:
|
|||
for key, value in values.items():
|
||||
self[key] = value
|
||||
|
||||
def items(self):
|
||||
def items(self) -> Iterable[Tuple[str, Any]]:
|
||||
"""Iterate over (key, value) pairs that this object contains.
|
||||
Computed fields are not included.
|
||||
"""
|
||||
|
|
@ -185,12 +201,12 @@ class LazyConvertDict:
|
|||
else:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key) -> bool:
|
||||
"""Determine whether `key` is an attribute on this object.
|
||||
"""
|
||||
return key in self.keys()
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterable[str]:
|
||||
"""Iterate over the available field names (excluding computed
|
||||
fields).
|
||||
"""
|
||||
|
|
@ -269,14 +285,14 @@ class Model:
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def _getters(cls):
|
||||
def _getters(cls: Type['Model']):
|
||||
"""Return a mapping from field names to getter functions.
|
||||
"""
|
||||
# We could cache this if it becomes a performance problem to
|
||||
# gather the getter mapping every time.
|
||||
raise NotImplementedError()
|
||||
|
||||
def _template_funcs(self):
|
||||
def _template_funcs(self) -> Mapping[str, Callable[[str], str]]:
|
||||
"""Return a mapping from function names to text-transformer
|
||||
functions.
|
||||
"""
|
||||
|
|
@ -285,7 +301,7 @@ class Model:
|
|||
|
||||
# Basic operation.
|
||||
|
||||
def __init__(self, db=None, **values):
|
||||
def __init__(self, db: Optional['Database'] = None, **values):
|
||||
"""Create a new object with an optional Database association and
|
||||
initial field values.
|
||||
"""
|
||||
|
|
@ -299,7 +315,12 @@ class Model:
|
|||
self.clear_dirty()
|
||||
|
||||
@classmethod
|
||||
def _awaken(cls, db=None, fixed_values={}, flex_values={}):
|
||||
def _awaken(
|
||||
cls: Type['Model'],
|
||||
db: 'Database' = None,
|
||||
fixed_values: Mapping = {},
|
||||
flex_values: Mapping = {},
|
||||
) -> 'Model':
|
||||
"""Create an object with values drawn from the database.
|
||||
|
||||
This is a performance optimization: the checks involved with
|
||||
|
|
@ -312,7 +333,7 @@ class Model:
|
|||
|
||||
return obj
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return '{}({})'.format(
|
||||
type(self).__name__,
|
||||
', '.join(f'{k}={v!r}' for k, v in dict(self).items()),
|
||||
|
|
@ -326,7 +347,7 @@ class Model:
|
|||
if self._db:
|
||||
self._revision = self._db.revision
|
||||
|
||||
def _check_db(self, need_id=True):
|
||||
def _check_db(self, need_id: bool = True):
|
||||
"""Ensure that this object is associated with a database row: it
|
||||
has a reference to a database (`_db`) and an id. A ValueError
|
||||
exception is raised otherwise.
|
||||
|
|
@ -338,7 +359,7 @@ class Model:
|
|||
if need_id and not self.id:
|
||||
raise ValueError('{} has no id'.format(type(self).__name__))
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> 'Model':
|
||||
"""Create a copy of the model object.
|
||||
|
||||
The field values and other state is duplicated, but the new copy
|
||||
|
|
@ -356,7 +377,7 @@ class Model:
|
|||
# Essential field accessors.
|
||||
|
||||
@classmethod
|
||||
def _type(cls, key):
|
||||
def _type(cls, key) -> types.Type:
|
||||
"""Get the type of a field, a `Type` instance.
|
||||
|
||||
If the field has no explicit type, it is given the base `Type`,
|
||||
|
|
@ -364,7 +385,7 @@ class Model:
|
|||
"""
|
||||
return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT
|
||||
|
||||
def _get(self, key, default=None, raise_=False):
|
||||
def _get(self, key, default: bool = None, raise_: bool = False):
|
||||
"""Get the value for a field, or `default`. Alternatively,
|
||||
raise a KeyError if the field is not available.
|
||||
"""
|
||||
|
|
@ -431,7 +452,7 @@ class Model:
|
|||
else:
|
||||
raise KeyError(f'no such field {key}')
|
||||
|
||||
def keys(self, computed=False):
|
||||
def keys(self, computed: bool = False):
|
||||
"""Get a list of available field names for this object. The
|
||||
`computed` parameter controls whether computed (plugin-provided)
|
||||
fields are included in the key list.
|
||||
|
|
@ -457,19 +478,19 @@ class Model:
|
|||
for key, value in values.items():
|
||||
self[key] = value
|
||||
|
||||
def items(self):
|
||||
def items(self) -> Iterator[Tuple[str, Any]]:
|
||||
"""Iterate over (key, value) pairs that this object contains.
|
||||
Computed fields are not included.
|
||||
"""
|
||||
for key in self:
|
||||
yield key, self[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key) -> bool:
|
||||
"""Determine whether `key` is an attribute on this object.
|
||||
"""
|
||||
return key in self.keys(computed=True)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterable[str]:
|
||||
"""Iterate over the available field names (excluding computed
|
||||
fields).
|
||||
"""
|
||||
|
|
@ -500,7 +521,7 @@ class Model:
|
|||
|
||||
# Database interaction (CRUD methods).
|
||||
|
||||
def store(self, fields=None):
|
||||
def store(self, fields: bool = None):
|
||||
"""Save the object's metadata into the library database.
|
||||
:param fields: the fields to be stored. If not specified, all fields
|
||||
will be.
|
||||
|
|
@ -581,7 +602,7 @@ class Model:
|
|||
(self.id,)
|
||||
)
|
||||
|
||||
def add(self, db=None):
|
||||
def add(self, db: Optional['Database'] = None):
|
||||
"""Add the object to the library database. This object must be
|
||||
associated with a database; you can provide one via the `db`
|
||||
parameter or use the currently associated database.
|
||||
|
|
@ -610,13 +631,21 @@ class Model:
|
|||
|
||||
_formatter = FormattedMapping
|
||||
|
||||
def formatted(self, included_keys=_formatter.ALL_KEYS, for_path=False):
|
||||
def formatted(
|
||||
self,
|
||||
included_keys: str = _formatter.ALL_KEYS,
|
||||
for_path: bool = False,
|
||||
):
|
||||
"""Get a mapping containing all values on this object formatted
|
||||
as human-readable unicode strings.
|
||||
"""
|
||||
return self._formatter(self, included_keys, for_path)
|
||||
|
||||
def evaluate_template(self, template, for_path=False):
|
||||
def evaluate_template(
|
||||
self,
|
||||
template: Union[str, Template],
|
||||
for_path: bool = False,
|
||||
) -> str:
|
||||
"""Evaluate a template (a string or a `Template` object) using
|
||||
the object's fields. If `for_path` is true, then no new path
|
||||
separators will be added to the template.
|
||||
|
|
@ -630,7 +659,7 @@ class Model:
|
|||
# Parsing.
|
||||
|
||||
@classmethod
|
||||
def _parse(cls, key, string):
|
||||
def _parse(cls, key, string: str) -> Any:
|
||||
"""Parse a string as a value for the given key.
|
||||
"""
|
||||
if not isinstance(string, str):
|
||||
|
|
@ -638,7 +667,7 @@ class Model:
|
|||
|
||||
return cls._type(key).parse(string)
|
||||
|
||||
def set_parse(self, key, string):
|
||||
def set_parse(self, key, string: str):
|
||||
"""Set the object's key to a value represented by a string.
|
||||
"""
|
||||
self[key] = self._parse(key, string)
|
||||
|
|
@ -646,12 +675,21 @@ class Model:
|
|||
# Convenient queries.
|
||||
|
||||
@classmethod
|
||||
def field_query(cls, field, pattern, query_cls=MatchQuery):
|
||||
def field_query(
|
||||
cls,
|
||||
field,
|
||||
pattern,
|
||||
query_cls: Type[FieldQuery] = MatchQuery,
|
||||
) -> FieldQuery:
|
||||
"""Get a `FieldQuery` for this model."""
|
||||
return query_cls(field, pattern, field in cls._fields)
|
||||
|
||||
@classmethod
|
||||
def all_fields_query(cls, pats, query_cls=MatchQuery):
|
||||
def all_fields_query(
|
||||
cls: Type['Model'],
|
||||
pats: Mapping,
|
||||
query_cls: Type[FieldQuery] = MatchQuery,
|
||||
):
|
||||
"""Get a query that matches many fields with different patterns.
|
||||
|
||||
`pats` should be a mapping from field names to patterns. The
|
||||
|
|
@ -670,8 +708,15 @@ class Results:
|
|||
constructs LibModel objects that reflect database rows.
|
||||
"""
|
||||
|
||||
def __init__(self, model_class, rows, db, flex_rows,
|
||||
query=None, sort=None):
|
||||
def __init__(
|
||||
self,
|
||||
model_class: Type['LibModel'],
|
||||
rows: List[Mapping],
|
||||
db: 'Database',
|
||||
flex_rows,
|
||||
query: Optional[FieldQuery] = None,
|
||||
sort=None,
|
||||
):
|
||||
"""Create a result set that will construct objects of type
|
||||
`model_class`.
|
||||
|
||||
|
|
@ -703,7 +748,7 @@ class Results:
|
|||
# consumed.
|
||||
self._objects = []
|
||||
|
||||
def _get_objects(self):
|
||||
def _get_objects(self) -> Iterable[Model]:
|
||||
"""Construct and generate Model objects for they query. The
|
||||
objects are returned in the order emitted from the database; no
|
||||
slow sort is applied.
|
||||
|
|
@ -738,7 +783,7 @@ class Results:
|
|||
yield obj
|
||||
break
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterable[Model]:
|
||||
"""Construct and generate Model objects for all matching
|
||||
objects, in sorted order.
|
||||
"""
|
||||
|
|
@ -751,7 +796,7 @@ class Results:
|
|||
# Objects are pre-sorted (i.e., by the database).
|
||||
return self._get_objects()
|
||||
|
||||
def _get_indexed_flex_attrs(self):
|
||||
def _get_indexed_flex_attrs(self) -> Mapping:
|
||||
""" Index flexible attributes by the entity id they belong to
|
||||
"""
|
||||
flex_values = {}
|
||||
|
|
@ -763,7 +808,7 @@ class Results:
|
|||
|
||||
return flex_values
|
||||
|
||||
def _make_model(self, row, flex_values={}):
|
||||
def _make_model(self, row, flex_values: Dict = {}) -> Model:
|
||||
""" Create a Model object for the given row
|
||||
"""
|
||||
cols = dict(row)
|
||||
|
|
@ -774,7 +819,7 @@ class Results:
|
|||
obj = self.model_class._awaken(self.db, values, flex_values)
|
||||
return obj
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
"""Get the number of matching objects.
|
||||
"""
|
||||
if not self._rows:
|
||||
|
|
@ -792,12 +837,12 @@ class Results:
|
|||
# A fast query. Just count the rows.
|
||||
return self._row_count
|
||||
|
||||
def __nonzero__(self):
|
||||
def __nonzero__(self) -> bool:
|
||||
"""Does this result contain any objects?
|
||||
"""
|
||||
return self.__bool__()
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
"""Does this result contain any objects?
|
||||
"""
|
||||
return bool(len(self))
|
||||
|
|
@ -819,7 +864,7 @@ class Results:
|
|||
except StopIteration:
|
||||
raise IndexError(f'result index {n} out of range')
|
||||
|
||||
def get(self):
|
||||
def get(self) -> Optional[Model]:
|
||||
"""Return the first matching object, or None if no objects
|
||||
match.
|
||||
"""
|
||||
|
|
@ -840,10 +885,10 @@ class Transaction:
|
|||
current transaction.
|
||||
"""
|
||||
|
||||
def __init__(self, db):
|
||||
def __init__(self, db: 'Database'):
|
||||
self.db = db
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> 'Transaction':
|
||||
"""Begin a transaction. This transaction may be created while
|
||||
another is active in a different thread.
|
||||
"""
|
||||
|
|
@ -856,7 +901,12 @@ class Transaction:
|
|||
self.db._db_lock.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Type[Exception],
|
||||
exc_value: Exception,
|
||||
traceback: TracebackType,
|
||||
):
|
||||
"""Complete a transaction. This must be the most recently
|
||||
entered but not yet exited transaction. If it is the last active
|
||||
transaction, the database updates are committed.
|
||||
|
|
@ -872,14 +922,14 @@ class Transaction:
|
|||
self._mutated = False
|
||||
self.db._db_lock.release()
|
||||
|
||||
def query(self, statement, subvals=()):
|
||||
def query(self, statement: str, subvals: Iterable = ()) -> List:
|
||||
"""Execute an SQL statement with substitution values and return
|
||||
a list of rows from the database.
|
||||
"""
|
||||
cursor = self.db._connection().execute(statement, subvals)
|
||||
return cursor.fetchall()
|
||||
|
||||
def mutate(self, statement, subvals=()):
|
||||
def mutate(self, statement: str, subvals: Iterable = ()) -> Any:
|
||||
"""Execute an SQL statement with substitution values and return
|
||||
the row ID of the last affected row.
|
||||
"""
|
||||
|
|
@ -898,7 +948,7 @@ class Transaction:
|
|||
self._mutated = True
|
||||
return cursor.lastrowid
|
||||
|
||||
def script(self, statements):
|
||||
def script(self, statements: str):
|
||||
"""Execute a string containing multiple SQL statements."""
|
||||
# We don't know whether this mutates, but quite likely it does.
|
||||
self._mutated = True
|
||||
|
|
@ -922,7 +972,7 @@ class Database:
|
|||
data is written in a transaction.
|
||||
"""
|
||||
|
||||
def __init__(self, path, timeout=5.0):
|
||||
def __init__(self, path, timeout: float = 5.0):
|
||||
if sqlite3.threadsafety == 0:
|
||||
raise RuntimeError(
|
||||
"sqlite3 must be compiled with multi-threading support"
|
||||
|
|
@ -956,7 +1006,7 @@ class Database:
|
|||
|
||||
# Primitive access control: connections and transactions.
|
||||
|
||||
def _connection(self):
|
||||
def _connection(self) -> Connection:
|
||||
"""Get a SQLite connection object to the underlying database.
|
||||
One connection object is created per thread.
|
||||
"""
|
||||
|
|
@ -969,7 +1019,7 @@ class Database:
|
|||
self._connections[thread_id] = conn
|
||||
return conn
|
||||
|
||||
def _create_connection(self):
|
||||
def _create_connection(self) -> Connection:
|
||||
"""Create a SQLite connection to the underlying database.
|
||||
|
||||
Makes a new connection every time. If you need to configure the
|
||||
|
|
@ -1019,7 +1069,7 @@ class Database:
|
|||
conn.close()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _tx_stack(self):
|
||||
def _tx_stack(self) -> Generator[List, None, None]:
|
||||
"""A context manager providing access to the current thread's
|
||||
transaction stack. The context manager synchronizes access to
|
||||
the stack map. Transactions should never migrate across threads.
|
||||
|
|
@ -1028,7 +1078,7 @@ class Database:
|
|||
with self._shared_map_lock:
|
||||
yield self._tx_stacks[thread_id]
|
||||
|
||||
def transaction(self):
|
||||
def transaction(self) -> Transaction:
|
||||
"""Get a :class:`Transaction` object for interacting directly
|
||||
with the underlying SQLite database.
|
||||
"""
|
||||
|
|
@ -1048,7 +1098,7 @@ class Database:
|
|||
|
||||
# Schema setup and migration.
|
||||
|
||||
def _make_table(self, table, fields):
|
||||
def _make_table(self, table: str, fields: Mapping[str, types.Type]):
|
||||
"""Set up the schema of the database. `fields` is a mapping
|
||||
from field names to `Type`s. Columns are added if necessary.
|
||||
"""
|
||||
|
|
@ -1083,7 +1133,7 @@ class Database:
|
|||
with self.transaction() as tx:
|
||||
tx.script(setup_sql)
|
||||
|
||||
def _make_attribute_table(self, flex_table):
|
||||
def _make_attribute_table(self, flex_table: str):
|
||||
"""Create a table and associated index for flexible attributes
|
||||
for the given entity (if they don't exist).
|
||||
"""
|
||||
|
|
@ -1101,7 +1151,12 @@ class Database:
|
|||
|
||||
# Querying.
|
||||
|
||||
def _fetch(self, model_cls, query=None, sort=None):
|
||||
def _fetch(
|
||||
self,
|
||||
model_cls: Type['LibModel'],
|
||||
query: Optional[Query] = None,
|
||||
sort: Optional[Sort] = None,
|
||||
) -> Results:
|
||||
"""Fetch the objects of type `model_cls` matching the given
|
||||
query. The query may be given as a string, string sequence, a
|
||||
Query object, or None (to fetch everything). `sort` is an
|
||||
|
|
@ -1141,7 +1196,7 @@ class Database:
|
|||
sort if sort.is_slow() else None, # Slow sort component.
|
||||
)
|
||||
|
||||
def _get(self, model_cls, id):
|
||||
def _get(self, model_cls: Union[Type[Model], Type[LibModel]], id) -> Model:
|
||||
"""Get a Model object by its id or None if the id does not
|
||||
exist.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -15,13 +15,24 @@
|
|||
"""The Query type hierarchy for DBCore.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import re
|
||||
from operator import mul
|
||||
from typing import Union, Tuple, List, Optional, Pattern, Any, Type, Iterator,\
|
||||
Collection, MutableMapping, Sequence
|
||||
|
||||
from beets import util
|
||||
from datetime import datetime, timedelta
|
||||
import unicodedata
|
||||
from functools import reduce
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from beets.library import Item
|
||||
from beets.dbcore import Model
|
||||
|
||||
|
||||
class ParsingError(ValueError):
|
||||
"""Abstract class for any unparseable user-requested album/query
|
||||
|
|
@ -60,7 +71,7 @@ class Query:
|
|||
"""An abstract class representing a query into the item database.
|
||||
"""
|
||||
|
||||
def clause(self):
|
||||
def clause(self) -> Tuple[None, Tuple]:
|
||||
"""Generate an SQLite expression implementing the query.
|
||||
|
||||
Return (clause, subvals) where clause is a valid sqlite
|
||||
|
|
@ -69,19 +80,19 @@ class Query:
|
|||
"""
|
||||
return None, ()
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item: Item):
|
||||
"""Check whether this query matches a given Item. Can be used to
|
||||
perform queries on arbitrary sets of Items.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return type(self) == type(other)
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
|
|
@ -93,12 +104,12 @@ class FieldQuery(Query):
|
|||
same matching functionality in SQLite.
|
||||
"""
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
def __init__(self, field: str, pattern: Optional[str], fast: bool = True):
|
||||
self.field = field
|
||||
self.pattern = pattern
|
||||
self.fast = fast
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Union[None, Tuple]:
|
||||
return None, ()
|
||||
|
||||
def clause(self):
|
||||
|
|
@ -109,51 +120,51 @@ class FieldQuery(Query):
|
|||
return None, ()
|
||||
|
||||
@classmethod
|
||||
def value_match(cls, pattern, value):
|
||||
def value_match(cls, pattern: str, value: str):
|
||||
"""Determine whether the value matches the pattern. Both
|
||||
arguments are strings.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item: Model):
|
||||
return self.value_match(self.pattern, item.get(self.field))
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return ("{0.__class__.__name__}({0.field!r}, {0.pattern!r}, "
|
||||
"{0.fast})".format(self))
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return super().__eq__(other) and \
|
||||
self.field == other.field and self.pattern == other.pattern
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.field, hash(self.pattern)))
|
||||
|
||||
|
||||
class MatchQuery(FieldQuery):
|
||||
"""A query that looks for exact matches in an item field."""
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Tuple[str, List[str]]:
|
||||
return self.field + " = ?", [self.pattern]
|
||||
|
||||
@classmethod
|
||||
def value_match(cls, pattern, value):
|
||||
def value_match(cls, pattern: str, value: str) -> bool:
|
||||
return pattern == value
|
||||
|
||||
|
||||
class NoneQuery(FieldQuery):
|
||||
"""A query that checks whether a field is null."""
|
||||
|
||||
def __init__(self, field, fast=True):
|
||||
def __init__(self, field, fast: bool = True):
|
||||
super().__init__(field, None, fast)
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Tuple[str, Tuple]:
|
||||
return self.field + " IS NULL", ()
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item: 'Item') -> bool:
|
||||
return item.get(self.field) is None
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self)
|
||||
|
||||
|
||||
|
|
@ -163,14 +174,18 @@ class StringFieldQuery(FieldQuery):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def value_match(cls, pattern, value):
|
||||
def value_match(cls, pattern: str, value: Any):
|
||||
"""Determine whether the value matches the pattern. The value
|
||||
may have any type.
|
||||
"""
|
||||
return cls.string_match(pattern, util.as_string(value))
|
||||
|
||||
@classmethod
|
||||
def string_match(cls, pattern, value):
|
||||
def string_match(
|
||||
cls,
|
||||
pattern: str,
|
||||
value: str,
|
||||
) -> bool:
|
||||
"""Determine whether the value matches the pattern. Both
|
||||
arguments are strings. Subclasses implement this method.
|
||||
"""
|
||||
|
|
@ -180,7 +195,7 @@ class StringFieldQuery(FieldQuery):
|
|||
class StringQuery(StringFieldQuery):
|
||||
"""A query that matches a whole string in a specific item field."""
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Tuple[str, List[str]]:
|
||||
search = (self.pattern
|
||||
.replace('\\', '\\\\')
|
||||
.replace('%', '\\%')
|
||||
|
|
@ -190,14 +205,14 @@ class StringQuery(StringFieldQuery):
|
|||
return clause, subvals
|
||||
|
||||
@classmethod
|
||||
def string_match(cls, pattern, value):
|
||||
def string_match(cls, pattern: str, value: str) -> bool:
|
||||
return pattern.lower() == value.lower()
|
||||
|
||||
|
||||
class SubstringQuery(StringFieldQuery):
|
||||
"""A query that matches a substring in a specific item field."""
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Tuple[str, List[str]]:
|
||||
pattern = (self.pattern
|
||||
.replace('\\', '\\\\')
|
||||
.replace('%', '\\%')
|
||||
|
|
@ -208,7 +223,7 @@ class SubstringQuery(StringFieldQuery):
|
|||
return clause, subvals
|
||||
|
||||
@classmethod
|
||||
def string_match(cls, pattern, value):
|
||||
def string_match(cls, pattern: str, value: str) -> bool:
|
||||
return pattern.lower() in value.lower()
|
||||
|
||||
|
||||
|
|
@ -220,7 +235,7 @@ class RegexpQuery(StringFieldQuery):
|
|||
expression.
|
||||
"""
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
def __init__(self, field: str, pattern: str, fast: bool = True):
|
||||
super().__init__(field, pattern, fast)
|
||||
pattern = self._normalize(pattern)
|
||||
try:
|
||||
|
|
@ -235,14 +250,14 @@ class RegexpQuery(StringFieldQuery):
|
|||
return f" regexp({self.field}, ?)", [self.pattern.pattern]
|
||||
|
||||
@staticmethod
|
||||
def _normalize(s):
|
||||
def _normalize(s: str) -> str:
|
||||
"""Normalize a Unicode string's representation (used on both
|
||||
patterns and matched values).
|
||||
"""
|
||||
return unicodedata.normalize('NFC', s)
|
||||
|
||||
@classmethod
|
||||
def string_match(cls, pattern, value):
|
||||
def string_match(cls, pattern: Pattern, value: str) -> bool:
|
||||
return pattern.search(cls._normalize(value)) is not None
|
||||
|
||||
|
||||
|
|
@ -251,7 +266,12 @@ class BooleanQuery(MatchQuery):
|
|||
string reflecting a boolean.
|
||||
"""
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
def __init__(
|
||||
self,
|
||||
field: str,
|
||||
pattern: Union[bool, str],
|
||||
fast: bool = True,
|
||||
):
|
||||
super().__init__(field, pattern, fast)
|
||||
if isinstance(pattern, str):
|
||||
self.pattern = util.str2bool(pattern)
|
||||
|
|
@ -265,7 +285,7 @@ class BytesQuery(MatchQuery):
|
|||
`MatchQuery` when matching on BLOB values.
|
||||
"""
|
||||
|
||||
def __init__(self, field, pattern):
|
||||
def __init__(self, field: str, pattern: Union[bytes, str, memoryview]):
|
||||
super().__init__(field, pattern)
|
||||
|
||||
# Use a buffer/memoryview representation of the pattern for SQLite
|
||||
|
|
@ -279,7 +299,7 @@ class BytesQuery(MatchQuery):
|
|||
self.buf_pattern = self.pattern
|
||||
self.pattern = bytes(self.pattern)
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Tuple[str, List[memoryview]]:
|
||||
return self.field + " = ?", [self.buf_pattern]
|
||||
|
||||
|
||||
|
|
@ -292,7 +312,7 @@ class NumericQuery(FieldQuery):
|
|||
a float.
|
||||
"""
|
||||
|
||||
def _convert(self, s):
|
||||
def _convert(self, s: str) -> Union[float, int, None]:
|
||||
"""Convert a string to a numeric type (float or int).
|
||||
|
||||
Return None if `s` is empty.
|
||||
|
|
@ -309,7 +329,7 @@ class NumericQuery(FieldQuery):
|
|||
except ValueError:
|
||||
raise InvalidQueryArgumentValueError(s, "an int or a float")
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
def __init__(self, field: str, pattern: str, fast: bool = True):
|
||||
super().__init__(field, pattern, fast)
|
||||
|
||||
parts = pattern.split('..', 1)
|
||||
|
|
@ -324,7 +344,7 @@ class NumericQuery(FieldQuery):
|
|||
self.rangemin = self._convert(parts[0])
|
||||
self.rangemax = self._convert(parts[1])
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item: 'Item') -> bool:
|
||||
if self.field not in item:
|
||||
return False
|
||||
value = item[self.field]
|
||||
|
|
@ -340,7 +360,7 @@ class NumericQuery(FieldQuery):
|
|||
return False
|
||||
return True
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Tuple[str, Tuple]:
|
||||
if self.point is not None:
|
||||
return self.field + '=?', (self.point,)
|
||||
else:
|
||||
|
|
@ -360,24 +380,27 @@ class CollectionQuery(Query):
|
|||
indexed like a list to access the sub-queries.
|
||||
"""
|
||||
|
||||
def __init__(self, subqueries=()):
|
||||
def __init__(self, subqueries: Sequence = ()):
|
||||
self.subqueries = subqueries
|
||||
|
||||
# Act like a sequence.
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.subqueries)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.subqueries[key]
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator:
|
||||
return iter(self.subqueries)
|
||||
|
||||
def __contains__(self, item):
|
||||
def __contains__(self, item) -> bool:
|
||||
return item in self.subqueries
|
||||
|
||||
def clause_with_joiner(self, joiner):
|
||||
def clause_with_joiner(
|
||||
self,
|
||||
joiner: str,
|
||||
) -> Tuple[Optional[str], Collection]:
|
||||
"""Return a clause created by joining together the clauses of
|
||||
all subqueries with the string joiner (padded by spaces).
|
||||
"""
|
||||
|
|
@ -393,14 +416,14 @@ class CollectionQuery(Query):
|
|||
clause = (' ' + joiner + ' ').join(clause_parts)
|
||||
return clause, subvals
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "{0.__class__.__name__}({0.subqueries!r})".format(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return super().__eq__(other) and \
|
||||
self.subqueries == other.subqueries
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
"""Since subqueries are mutable, this object should not be hashable.
|
||||
However and for conveniences purposes, it can be hashed.
|
||||
"""
|
||||
|
|
@ -413,7 +436,7 @@ class AnyFieldQuery(CollectionQuery):
|
|||
constructor.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern, fields, cls):
|
||||
def __init__(self, pattern, fields, cls: Type[FieldQuery]):
|
||||
self.pattern = pattern
|
||||
self.fields = fields
|
||||
self.query_class = cls
|
||||
|
|
@ -421,26 +444,27 @@ class AnyFieldQuery(CollectionQuery):
|
|||
subqueries = []
|
||||
for field in self.fields:
|
||||
subqueries.append(cls(field, pattern, True))
|
||||
# TYPING ERROR
|
||||
super().__init__(subqueries)
|
||||
|
||||
def clause(self):
|
||||
def clause(self) -> Tuple[Union[str, None], Collection]:
|
||||
return self.clause_with_joiner('or')
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item: 'Item') -> bool:
|
||||
for subq in self.subqueries:
|
||||
if subq.match(item):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return ("{0.__class__.__name__}({0.pattern!r}, {0.fields!r}, "
|
||||
"{0.query_class.__name__})".format(self))
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return super().__eq__(other) and \
|
||||
self.query_class == other.query_class
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.pattern, tuple(self.fields), self.query_class))
|
||||
|
||||
|
||||
|
|
@ -448,6 +472,7 @@ class MutableCollectionQuery(CollectionQuery):
|
|||
"""A collection query whose subqueries may be modified after the
|
||||
query is initialized.
|
||||
"""
|
||||
subqueries: MutableMapping
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.subqueries[key] = value
|
||||
|
|
@ -459,20 +484,20 @@ class MutableCollectionQuery(CollectionQuery):
|
|||
class AndQuery(MutableCollectionQuery):
|
||||
"""A conjunction of a list of other queries."""
|
||||
|
||||
def clause(self):
|
||||
def clause(self) -> Tuple[Union[str, None], Collection]:
|
||||
return self.clause_with_joiner('and')
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item) -> bool:
|
||||
return all(q.match(item) for q in self.subqueries)
|
||||
|
||||
|
||||
class OrQuery(MutableCollectionQuery):
|
||||
"""A conjunction of a list of other queries."""
|
||||
|
||||
def clause(self):
|
||||
def clause(self) -> Tuple[Union[str, None], Collection]:
|
||||
return self.clause_with_joiner('or')
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item) -> bool:
|
||||
return any(q.match(item) for q in self.subqueries)
|
||||
|
||||
|
||||
|
|
@ -493,43 +518,43 @@ class NotQuery(Query):
|
|||
# is handled by match() for slow queries.
|
||||
return clause, subvals
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item) -> bool:
|
||||
return not self.subquery.match(item)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "{0.__class__.__name__}({0.subquery!r})".format(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return super().__eq__(other) and \
|
||||
self.subquery == other.subquery
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(('not', hash(self.subquery)))
|
||||
|
||||
|
||||
class TrueQuery(Query):
|
||||
"""A query that always matches."""
|
||||
|
||||
def clause(self):
|
||||
def clause(self) -> Tuple[Union[str, None], Collection]:
|
||||
return '1', ()
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class FalseQuery(Query):
|
||||
"""A query that never matches."""
|
||||
|
||||
def clause(self):
|
||||
def clause(self) -> Tuple[Union[str, None], Collection]:
|
||||
return '0', ()
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# Time/date queries.
|
||||
|
||||
def _parse_periods(pattern):
|
||||
def _parse_periods(pattern: str) -> Tuple['Period', 'Period']:
|
||||
"""Parse a string containing two dates separated by two dots (..).
|
||||
Return a pair of `Period` objects.
|
||||
"""
|
||||
|
|
@ -563,7 +588,7 @@ class Period:
|
|||
relative_re = '(?P<sign>[+|-]?)(?P<quantity>[0-9]+)' + \
|
||||
'(?P<timespan>[y|m|w|d])'
|
||||
|
||||
def __init__(self, date, precision):
|
||||
def __init__(self, date: datetime, precision: str):
|
||||
"""Create a period with the given date (a `datetime` object) and
|
||||
precision (a string, one of "year", "month", "day", "hour", "minute",
|
||||
or "second").
|
||||
|
|
@ -574,7 +599,7 @@ class Period:
|
|||
self.precision = precision
|
||||
|
||||
@classmethod
|
||||
def parse(cls, string):
|
||||
def parse(cls: Type['Period'], string: str) -> Optional['Period']:
|
||||
"""Parse a date and return a `Period` object or `None` if the
|
||||
string is empty, or raise an InvalidQueryArgumentValueError if
|
||||
the string cannot be parsed to a date.
|
||||
|
|
@ -591,7 +616,8 @@ class Period:
|
|||
and a "year" is exactly 365 days.
|
||||
"""
|
||||
|
||||
def find_date_and_format(string):
|
||||
def find_date_and_format(string: str) -> \
|
||||
Union[Tuple[None, None], Tuple[datetime, int]]:
|
||||
for ord, format in enumerate(cls.date_formats):
|
||||
for format_option in format:
|
||||
try:
|
||||
|
|
@ -628,7 +654,7 @@ class Period:
|
|||
precision = cls.precisions[ordinal]
|
||||
return cls(date, precision)
|
||||
|
||||
def open_right_endpoint(self):
|
||||
def open_right_endpoint(self) -> datetime:
|
||||
"""Based on the precision, convert the period to a precise
|
||||
`datetime` for use as a right endpoint in a right-open interval.
|
||||
"""
|
||||
|
|
@ -660,7 +686,7 @@ class DateInterval:
|
|||
A right endpoint of None means towards infinity.
|
||||
"""
|
||||
|
||||
def __init__(self, start, end):
|
||||
def __init__(self, start: Optional[datetime], end: Optional[datetime]):
|
||||
if start is not None and end is not None and not start < end:
|
||||
raise ValueError("start date {} is not before end date {}"
|
||||
.format(start, end))
|
||||
|
|
@ -668,21 +694,21 @@ class DateInterval:
|
|||
self.end = end
|
||||
|
||||
@classmethod
|
||||
def from_periods(cls, start, end):
|
||||
def from_periods(cls, start: Period, end: Period) -> 'DateInterval':
|
||||
"""Create an interval with two Periods as the endpoints.
|
||||
"""
|
||||
end_date = end.open_right_endpoint() if end is not None else None
|
||||
start_date = start.date if start is not None else None
|
||||
return cls(start_date, end_date)
|
||||
|
||||
def contains(self, date):
|
||||
def contains(self, date: datetime) -> bool:
|
||||
if self.start is not None and date < self.start:
|
||||
return False
|
||||
if self.end is not None and date >= self.end:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f'[{self.start}, {self.end})'
|
||||
|
||||
|
||||
|
|
@ -696,12 +722,12 @@ class DateQuery(FieldQuery):
|
|||
using an ellipsis interval syntax similar to that of NumericQuery.
|
||||
"""
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
def __init__(self, field, pattern, fast: bool = True):
|
||||
super().__init__(field, pattern, fast)
|
||||
start, end = _parse_periods(pattern)
|
||||
self.interval = DateInterval.from_periods(start, end)
|
||||
|
||||
def match(self, item):
|
||||
def match(self, item: 'Item') -> bool:
|
||||
if self.field not in item:
|
||||
return False
|
||||
timestamp = float(item[self.field])
|
||||
|
|
@ -710,7 +736,7 @@ class DateQuery(FieldQuery):
|
|||
|
||||
_clause_tmpl = "{0} {1} ?"
|
||||
|
||||
def col_clause(self):
|
||||
def col_clause(self) -> Tuple[Union[str, None], Collection]:
|
||||
clause_parts = []
|
||||
subvals = []
|
||||
|
||||
|
|
@ -742,7 +768,7 @@ class DurationQuery(NumericQuery):
|
|||
or M:SS time interval.
|
||||
"""
|
||||
|
||||
def _convert(self, s):
|
||||
def _convert(self, s: str) -> Optional[float]:
|
||||
"""Convert a M:SS or numeric string to a float.
|
||||
|
||||
Return None if `s` is empty.
|
||||
|
|
@ -768,27 +794,27 @@ class Sort:
|
|||
the item database.
|
||||
"""
|
||||
|
||||
def order_clause(self):
|
||||
def order_clause(self) -> None:
|
||||
"""Generates a SQL fragment to be used in a ORDER BY clause, or
|
||||
None if no fragment is used (i.e., this is a slow sort).
|
||||
"""
|
||||
return None
|
||||
|
||||
def sort(self, items):
|
||||
def sort(self, items: List) -> List:
|
||||
"""Sort the list of objects and return a list.
|
||||
"""
|
||||
return sorted(items)
|
||||
|
||||
def is_slow(self):
|
||||
def is_slow(self) -> bool:
|
||||
"""Indicate whether this query is *slow*, meaning that it cannot
|
||||
be executed in SQL and must be executed in Python.
|
||||
"""
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return 0
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return type(self) == type(other)
|
||||
|
||||
|
||||
|
|
@ -796,13 +822,13 @@ class MultipleSort(Sort):
|
|||
"""Sort that encapsulates multiple sub-sorts.
|
||||
"""
|
||||
|
||||
def __init__(self, sorts=None):
|
||||
def __init__(self, sorts: Optional[List[Sort]] = None):
|
||||
self.sorts = sorts or []
|
||||
|
||||
def add_sort(self, sort):
|
||||
def add_sort(self, sort: Sort):
|
||||
self.sorts.append(sort)
|
||||
|
||||
def _sql_sorts(self):
|
||||
def _sql_sorts(self) -> List[Sort]:
|
||||
"""Return the list of sub-sorts for which we can be (at least
|
||||
partially) fast.
|
||||
|
||||
|
|
@ -819,15 +845,16 @@ class MultipleSort(Sort):
|
|||
sql_sorts.reverse()
|
||||
return sql_sorts
|
||||
|
||||
def order_clause(self):
|
||||
def order_clause(self) -> str:
|
||||
order_strings = []
|
||||
for sort in self._sql_sorts():
|
||||
order = sort.order_clause()
|
||||
order_strings.append(order)
|
||||
|
||||
# TYPING ERROR
|
||||
return ", ".join(order_strings)
|
||||
|
||||
def is_slow(self):
|
||||
def is_slow(self) -> bool:
|
||||
for sort in self.sorts:
|
||||
if sort.is_slow():
|
||||
return True
|
||||
|
|
@ -865,17 +892,22 @@ class FieldSort(Sort):
|
|||
any kind).
|
||||
"""
|
||||
|
||||
def __init__(self, field, ascending=True, case_insensitive=True):
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
ascending: bool = True,
|
||||
case_insensitive: bool = True,
|
||||
):
|
||||
self.field = field
|
||||
self.ascending = ascending
|
||||
self.case_insensitive = case_insensitive
|
||||
|
||||
def sort(self, objs):
|
||||
def sort(self, objs: Collection):
|
||||
# TODO: Conversion and null-detection here. In Python 3,
|
||||
# comparisons with None fail. We should also support flexible
|
||||
# attributes with different types without falling over.
|
||||
|
||||
def key(item):
|
||||
def key(item: 'Item'):
|
||||
field_val = item.get(self.field, '')
|
||||
if self.case_insensitive and isinstance(field_val, str):
|
||||
field_val = field_val.lower()
|
||||
|
|
@ -883,17 +915,17 @@ class FieldSort(Sort):
|
|||
|
||||
return sorted(objs, key=key, reverse=not self.ascending)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return '<{}: {}{}>'.format(
|
||||
type(self).__name__,
|
||||
self.field,
|
||||
'+' if self.ascending else '-',
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.field, self.ascending))
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return super().__eq__(other) and \
|
||||
self.field == other.field and \
|
||||
self.ascending == other.ascending
|
||||
|
|
@ -903,7 +935,7 @@ class FixedFieldSort(FieldSort):
|
|||
"""Sort object to sort on a fixed field.
|
||||
"""
|
||||
|
||||
def order_clause(self):
|
||||
def order_clause(self) -> str:
|
||||
order = "ASC" if self.ascending else "DESC"
|
||||
if self.case_insensitive:
|
||||
field = '(CASE ' \
|
||||
|
|
@ -920,24 +952,24 @@ class SlowFieldSort(FieldSort):
|
|||
i.e., a computed or flexible field.
|
||||
"""
|
||||
|
||||
def is_slow(self):
|
||||
def is_slow(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class NullSort(Sort):
|
||||
"""No sorting. Leave results unsorted."""
|
||||
|
||||
def sort(self, items):
|
||||
def sort(self, items: List) -> List:
|
||||
return items
|
||||
|
||||
def __nonzero__(self):
|
||||
def __nonzero__(self) -> bool:
|
||||
return self.__bool__()
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
return False
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return type(self) == type(other) or other is None
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return 0
|
||||
|
|
|
|||
|
|
@ -17,7 +17,11 @@
|
|||
|
||||
import re
|
||||
import itertools
|
||||
from . import query
|
||||
from typing import Dict, Type, Tuple, Optional, Collection, List, \
|
||||
Sequence
|
||||
|
||||
from . import query, Model
|
||||
from .query import Sort
|
||||
|
||||
PARSE_QUERY_PART_REGEX = re.compile(
|
||||
# Non-capturing optional segment for the keyword.
|
||||
|
|
@ -34,8 +38,12 @@ PARSE_QUERY_PART_REGEX = re.compile(
|
|||
)
|
||||
|
||||
|
||||
def parse_query_part(part, query_classes={}, prefixes={},
|
||||
default_class=query.SubstringQuery):
|
||||
def parse_query_part(
|
||||
part: str,
|
||||
query_classes: Dict = {},
|
||||
prefixes: Dict = {},
|
||||
default_class: Type[query.SubstringQuery] = query.SubstringQuery,
|
||||
) -> Tuple[Optional[str], str, Type[query.Query], bool]:
|
||||
"""Parse a single *query part*, which is a chunk of a complete query
|
||||
string representing a single criterion.
|
||||
|
||||
|
|
@ -100,7 +108,11 @@ def parse_query_part(part, query_classes={}, prefixes={},
|
|||
return key, term, query_class, negate
|
||||
|
||||
|
||||
def construct_query_part(model_cls, prefixes, query_part):
|
||||
def construct_query_part(
|
||||
model_cls: Type[Model],
|
||||
prefixes: Dict,
|
||||
query_part: str,
|
||||
) -> query.Query:
|
||||
"""Parse a *query part* string and return a :class:`Query` object.
|
||||
|
||||
:param model_cls: The :class:`Model` class that this is a query for.
|
||||
|
|
@ -158,7 +170,13 @@ def construct_query_part(model_cls, prefixes, query_part):
|
|||
return out_query
|
||||
|
||||
|
||||
def query_from_strings(query_cls, model_cls, prefixes, query_parts):
|
||||
# TYPING ERROR
|
||||
def query_from_strings(
|
||||
query_cls: Type[query.Query],
|
||||
model_cls: Type[Model],
|
||||
prefixes: Dict,
|
||||
query_parts: Collection[str],
|
||||
) -> query.Query:
|
||||
"""Creates a collection query of type `query_cls` from a list of
|
||||
strings in the format used by parse_query_part. `model_cls`
|
||||
determines how queries are constructed from strings.
|
||||
|
|
@ -171,7 +189,11 @@ def query_from_strings(query_cls, model_cls, prefixes, query_parts):
|
|||
return query_cls(subqueries)
|
||||
|
||||
|
||||
def construct_sort_part(model_cls, part, case_insensitive=True):
|
||||
def construct_sort_part(
|
||||
model_cls: Type[Model],
|
||||
part: str,
|
||||
case_insensitive: bool = True,
|
||||
) -> Sort:
|
||||
"""Create a `Sort` from a single string criterion.
|
||||
|
||||
`model_cls` is the `Model` being queried. `part` is a single string
|
||||
|
|
@ -197,7 +219,11 @@ def construct_sort_part(model_cls, part, case_insensitive=True):
|
|||
return sort
|
||||
|
||||
|
||||
def sort_from_strings(model_cls, sort_parts, case_insensitive=True):
|
||||
def sort_from_strings(
|
||||
model_cls: Type[Model],
|
||||
sort_parts: Sequence[str],
|
||||
case_insensitive: bool = True,
|
||||
) -> Sort:
|
||||
"""Create a `Sort` from a list of sort criteria (strings).
|
||||
"""
|
||||
if not sort_parts:
|
||||
|
|
@ -212,8 +238,12 @@ def sort_from_strings(model_cls, sort_parts, case_insensitive=True):
|
|||
return sort
|
||||
|
||||
|
||||
def parse_sorted_query(model_cls, parts, prefixes={},
|
||||
case_insensitive=True):
|
||||
def parse_sorted_query(
|
||||
model_cls: Type[Model],
|
||||
parts: List[str],
|
||||
prefixes: Dict = {},
|
||||
case_insensitive: bool = True,
|
||||
) -> Tuple[query.Query, Sort]:
|
||||
"""Given a list of strings, create the `Query` and `Sort` that they
|
||||
represent.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
"""Representation of type information for DBCore model fields.
|
||||
"""
|
||||
|
||||
from typing import Union, Any, Callable
|
||||
from . import query
|
||||
from beets.util import str2bool
|
||||
|
||||
|
|
@ -35,7 +36,7 @@ class Type:
|
|||
"""The `Query` subclass to be used when querying the field.
|
||||
"""
|
||||
|
||||
model_type = str
|
||||
model_type: Callable[[Any], str] = str
|
||||
"""The Python type that is used to represent the value in the model.
|
||||
|
||||
The model is guaranteed to return a value of this type if the field
|
||||
|
|
@ -44,12 +45,12 @@ class Type:
|
|||
"""
|
||||
|
||||
@property
|
||||
def null(self):
|
||||
def null(self) -> model_type:
|
||||
"""The value to be exposed when the underlying value is None.
|
||||
"""
|
||||
return self.model_type()
|
||||
|
||||
def format(self, value):
|
||||
def format(self, value: model_type) -> str:
|
||||
"""Given a value of this type, produce a Unicode string
|
||||
representing the value. This is used in template evaluation.
|
||||
"""
|
||||
|
|
@ -63,7 +64,7 @@ class Type:
|
|||
|
||||
return str(value)
|
||||
|
||||
def parse(self, string):
|
||||
def parse(self, string: str) -> model_type:
|
||||
"""Parse a (possibly human-written) string and return the
|
||||
indicated value of this type.
|
||||
"""
|
||||
|
|
@ -72,11 +73,12 @@ class Type:
|
|||
except ValueError:
|
||||
return self.null
|
||||
|
||||
def normalize(self, value):
|
||||
def normalize(self, value: Union[None, int, float, bytes]) -> model_type:
|
||||
"""Given a value that will be assigned into a field of this
|
||||
type, normalize the value to have the appropriate type. This
|
||||
base implementation only reinterprets `None`.
|
||||
"""
|
||||
# TYPING ERROR
|
||||
if value is None:
|
||||
return self.null
|
||||
else:
|
||||
|
|
@ -84,7 +86,10 @@ class Type:
|
|||
# `self.model_type(value)`
|
||||
return value
|
||||
|
||||
def from_sql(self, sql_value):
|
||||
def from_sql(
|
||||
self,
|
||||
sql_value: Union[None, int, float, str, bytes],
|
||||
) -> model_type:
|
||||
"""Receives the value stored in the SQL backend and return the
|
||||
value to be stored in the model.
|
||||
|
||||
|
|
@ -105,7 +110,7 @@ class Type:
|
|||
else:
|
||||
return self.normalize(sql_value)
|
||||
|
||||
def to_sql(self, model_value):
|
||||
def to_sql(self, model_value: Any) -> Union[None, int, float, str, bytes]:
|
||||
"""Convert a value as stored in the model object to a value used
|
||||
by the database adapter.
|
||||
"""
|
||||
|
|
@ -125,7 +130,7 @@ class Integer(Type):
|
|||
query = query.NumericQuery
|
||||
model_type = int
|
||||
|
||||
def normalize(self, value):
|
||||
def normalize(self, value: str) -> Union[int, str]:
|
||||
try:
|
||||
return self.model_type(round(float(value)))
|
||||
except ValueError:
|
||||
|
|
@ -138,10 +143,10 @@ class PaddedInt(Integer):
|
|||
"""An integer field that is formatted with a given number of digits,
|
||||
padded with zeroes.
|
||||
"""
|
||||
def __init__(self, digits):
|
||||
def __init__(self, digits: int):
|
||||
self.digits = digits
|
||||
|
||||
def format(self, value):
|
||||
def format(self, value: int) -> str:
|
||||
return '{0:0{1}d}'.format(value or 0, self.digits)
|
||||
|
||||
|
||||
|
|
@ -155,11 +160,11 @@ class ScaledInt(Integer):
|
|||
"""An integer whose formatting operation scales the number by a
|
||||
constant and adds a suffix. Good for units with large magnitudes.
|
||||
"""
|
||||
def __init__(self, unit, suffix=''):
|
||||
def __init__(self, unit: int, suffix: str = ''):
|
||||
self.unit = unit
|
||||
self.suffix = suffix
|
||||
|
||||
def format(self, value):
|
||||
def format(self, value: int) -> str:
|
||||
return '{}{}'.format((value or 0) // self.unit, self.suffix)
|
||||
|
||||
|
||||
|
|
@ -169,7 +174,7 @@ class Id(Integer):
|
|||
"""
|
||||
null = None
|
||||
|
||||
def __init__(self, primary=True):
|
||||
def __init__(self, primary: bool = True):
|
||||
if primary:
|
||||
self.sql = 'INTEGER PRIMARY KEY'
|
||||
|
||||
|
|
@ -182,10 +187,10 @@ class Float(Type):
|
|||
query = query.NumericQuery
|
||||
model_type = float
|
||||
|
||||
def __init__(self, digits=1):
|
||||
def __init__(self, digits: int = 1):
|
||||
self.digits = digits
|
||||
|
||||
def format(self, value):
|
||||
def format(self, value: float) -> str:
|
||||
return '{0:.{1}f}'.format(value or 0, self.digits)
|
||||
|
||||
|
||||
|
|
@ -201,7 +206,7 @@ class String(Type):
|
|||
sql = 'TEXT'
|
||||
query = query.SubstringQuery
|
||||
|
||||
def normalize(self, value):
|
||||
def normalize(self, value: str) -> str:
|
||||
if value is None:
|
||||
return self.null
|
||||
else:
|
||||
|
|
@ -236,10 +241,10 @@ class Boolean(Type):
|
|||
query = query.BooleanQuery
|
||||
model_type = bool
|
||||
|
||||
def format(self, value):
|
||||
def format(self, value: bool) -> str:
|
||||
return str(bool(value))
|
||||
|
||||
def parse(self, string):
|
||||
def parse(self, string: str) -> bool:
|
||||
return str2bool(string)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -23,11 +23,17 @@ import shutil
|
|||
import fnmatch
|
||||
import functools
|
||||
from collections import Counter, namedtuple
|
||||
from logging import Logger
|
||||
from multiprocessing.pool import ThreadPool
|
||||
import traceback
|
||||
import subprocess
|
||||
import platform
|
||||
import shlex
|
||||
from typing import Callable, List, Optional, Sequence, Pattern, \
|
||||
Tuple, MutableSequence, AnyStr, TypeVar, Generator, Any, \
|
||||
Iterable, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from beets.util import hidden
|
||||
from unidecode import unidecode
|
||||
from enum import Enum
|
||||
|
|
@ -35,6 +41,8 @@ from enum import Enum
|
|||
|
||||
MAX_FILENAME_LENGTH = 200
|
||||
WINDOWS_MAGIC_PREFIX = '\\\\?\\'
|
||||
T = TypeVar('T')
|
||||
Bytes_or_String: TypeAlias = Union[str, bytes]
|
||||
|
||||
|
||||
class HumanReadableException(Exception):
|
||||
|
|
@ -135,7 +143,7 @@ class MoveOperation(Enum):
|
|||
REFLINK_AUTO = 5
|
||||
|
||||
|
||||
def normpath(path):
|
||||
def normpath(path: bytes) -> bytes:
|
||||
"""Provide the canonical form of the path suitable for storing in
|
||||
the database.
|
||||
"""
|
||||
|
|
@ -144,11 +152,11 @@ def normpath(path):
|
|||
return bytestring_path(path)
|
||||
|
||||
|
||||
def ancestry(path):
|
||||
def ancestry(path: bytes) -> List[str]:
|
||||
"""Return a list consisting of path's parent directory, its
|
||||
grandparent, and so on. For instance:
|
||||
|
||||
>>> ancestry('/a/b/c')
|
||||
>>> ancestry(b'/a/b/c')
|
||||
['/', '/a', '/a/b']
|
||||
|
||||
The argument should *not* be the result of a call to `syspath`.
|
||||
|
|
@ -168,7 +176,12 @@ def ancestry(path):
|
|||
return out
|
||||
|
||||
|
||||
def sorted_walk(path, ignore=(), ignore_hidden=False, logger=None):
|
||||
def sorted_walk(
|
||||
path: AnyStr,
|
||||
ignore: Sequence = (),
|
||||
ignore_hidden: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
) -> Generator[Tuple, None, None]:
|
||||
"""Like `os.walk`, but yields things in case-insensitive sorted,
|
||||
breadth-first order. Directory and file names matching any glob
|
||||
pattern in `ignore` are skipped. If `logger` is provided, then
|
||||
|
|
@ -225,14 +238,14 @@ def sorted_walk(path, ignore=(), ignore_hidden=False, logger=None):
|
|||
yield from sorted_walk(cur, ignore, ignore_hidden, logger)
|
||||
|
||||
|
||||
def path_as_posix(path):
|
||||
def path_as_posix(path: bytes) -> bytes:
|
||||
"""Return the string representation of the path with forward (/)
|
||||
slashes.
|
||||
"""
|
||||
return path.replace(b'\\', b'/')
|
||||
|
||||
|
||||
def mkdirall(path):
|
||||
def mkdirall(path: bytes):
|
||||
"""Make all the enclosing directories of path (like mkdir -p on the
|
||||
parent).
|
||||
"""
|
||||
|
|
@ -245,7 +258,7 @@ def mkdirall(path):
|
|||
traceback.format_exc())
|
||||
|
||||
|
||||
def fnmatch_all(names, patterns):
|
||||
def fnmatch_all(names: Sequence[bytes], patterns: Sequence[bytes]) -> bool:
|
||||
"""Determine whether all strings in `names` match at least one of
|
||||
the `patterns`, which should be shell glob expressions.
|
||||
"""
|
||||
|
|
@ -260,7 +273,11 @@ def fnmatch_all(names, patterns):
|
|||
return True
|
||||
|
||||
|
||||
def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
|
||||
def prune_dirs(
|
||||
path: str,
|
||||
root: Optional[Bytes_or_String] = None,
|
||||
clutter: Sequence[str] = ('.DS_Store', 'Thumbs.db'),
|
||||
):
|
||||
"""If path is an empty directory, then remove it. Recursively remove
|
||||
path's ancestry up to root (which is never removed) where there are
|
||||
empty directories. If path is not contained in root, then nothing is
|
||||
|
|
@ -291,7 +308,7 @@ def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
|
|||
if not os.path.exists(directory):
|
||||
# Directory gone already.
|
||||
continue
|
||||
clutter = [bytestring_path(c) for c in clutter]
|
||||
clutter: List[bytes] = [bytestring_path(c) for c in clutter]
|
||||
match_paths = [bytestring_path(d) for d in os.listdir(directory)]
|
||||
try:
|
||||
if fnmatch_all(match_paths, clutter):
|
||||
|
|
@ -303,10 +320,10 @@ def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
|
|||
break
|
||||
|
||||
|
||||
def components(path):
|
||||
def components(path: AnyStr) -> MutableSequence[AnyStr]:
|
||||
"""Return a list of the path components in path. For instance:
|
||||
|
||||
>>> components('/a/b/c')
|
||||
>>> components(b'/a/b/c')
|
||||
['a', 'b', 'c']
|
||||
|
||||
The argument should *not* be the result of a call to `syspath`.
|
||||
|
|
@ -327,14 +344,14 @@ def components(path):
|
|||
return comps
|
||||
|
||||
|
||||
def arg_encoding():
|
||||
def arg_encoding() -> str:
|
||||
"""Get the encoding for command-line arguments (and other OS
|
||||
locale-sensitive strings).
|
||||
"""
|
||||
return sys.getfilesystemencoding()
|
||||
|
||||
|
||||
def _fsencoding():
|
||||
def _fsencoding() -> str:
|
||||
"""Get the system's filesystem encoding. On Windows, this is always
|
||||
UTF-8 (not MBCS).
|
||||
"""
|
||||
|
|
@ -349,9 +366,10 @@ def _fsencoding():
|
|||
return encoding
|
||||
|
||||
|
||||
def bytestring_path(path):
|
||||
def bytestring_path(path: Bytes_or_String) -> bytes:
|
||||
"""Given a path, which is either a bytes or a unicode, returns a str
|
||||
path (ensuring that we never deal with Unicode pathnames).
|
||||
path (ensuring that we never deal with Unicode pathnames). Path should be
|
||||
bytes but has safeguards for strings to be converted.
|
||||
"""
|
||||
# Pass through bytestrings.
|
||||
if isinstance(path, bytes):
|
||||
|
|
@ -370,10 +388,10 @@ def bytestring_path(path):
|
|||
return path.encode('utf-8')
|
||||
|
||||
|
||||
PATH_SEP = bytestring_path(os.sep)
|
||||
PATH_SEP: bytes = bytestring_path(os.sep)
|
||||
|
||||
|
||||
def displayable_path(path, separator='; '):
|
||||
def displayable_path(path: bytes, separator: str = '; ') -> str:
|
||||
"""Attempts to decode a bytestring path to a unicode object for the
|
||||
purpose of displaying it to the user. If the `path` argument is a
|
||||
list or a tuple, the elements are joined with `separator`.
|
||||
|
|
@ -392,7 +410,7 @@ def displayable_path(path, separator='; '):
|
|||
return path.decode('utf-8', 'ignore')
|
||||
|
||||
|
||||
def syspath(path, prefix=True):
|
||||
def syspath(path: bytes, prefix: bool = True) -> Bytes_or_String:
|
||||
"""Convert a path for use by the operating system. In particular,
|
||||
paths on Windows must receive a magic prefix and must be converted
|
||||
to Unicode before they are sent to the OS. To disable the magic
|
||||
|
|
@ -412,6 +430,7 @@ def syspath(path, prefix=True):
|
|||
except UnicodeError:
|
||||
# The encoding should always be MBCS, Windows' broken
|
||||
# Unicode representation.
|
||||
assert isinstance(path, bytes)
|
||||
encoding = sys.getfilesystemencoding() or sys.getdefaultencoding()
|
||||
path = path.decode(encoding, 'replace')
|
||||
|
||||
|
|
@ -426,14 +445,14 @@ def syspath(path, prefix=True):
|
|||
return path
|
||||
|
||||
|
||||
def samefile(p1, p2):
|
||||
def samefile(p1: bytes, p2: bytes) -> bool:
|
||||
"""Safer equality for paths."""
|
||||
if p1 == p2:
|
||||
return True
|
||||
return shutil._samefile(syspath(p1), syspath(p2))
|
||||
|
||||
|
||||
def remove(path, soft=True):
|
||||
def remove(path: bytes, soft: bool = True):
|
||||
"""Remove the file. If `soft`, then no error will be raised if the
|
||||
file does not exist.
|
||||
"""
|
||||
|
|
@ -446,7 +465,7 @@ def remove(path, soft=True):
|
|||
raise FilesystemError(exc, 'delete', (path,), traceback.format_exc())
|
||||
|
||||
|
||||
def copy(path, dest, replace=False):
|
||||
def copy(path: bytes, dest: bytes, replace: bool = False):
|
||||
"""Copy a plain file. Permissions are not copied. If `dest` already
|
||||
exists, raises a FilesystemError unless `replace` is True. Has no
|
||||
effect if `path` is the same as `dest`. Paths are translated to
|
||||
|
|
@ -465,7 +484,7 @@ def copy(path, dest, replace=False):
|
|||
traceback.format_exc())
|
||||
|
||||
|
||||
def move(path, dest, replace=False):
|
||||
def move(path: bytes, dest: bytes, replace: bool = False):
|
||||
"""Rename a file. `dest` may not be a directory. If `dest` already
|
||||
exists, raises an OSError unless `replace` is True. Has no effect if
|
||||
`path` is the same as `dest`. If the paths are on different
|
||||
|
|
@ -515,7 +534,7 @@ def move(path, dest, replace=False):
|
|||
os.remove(tmp)
|
||||
|
||||
|
||||
def link(path, dest, replace=False):
|
||||
def link(path: bytes, dest: bytes, replace: bool = False):
|
||||
"""Create a symbolic link from path to `dest`. Raises an OSError if
|
||||
`dest` already exists, unless `replace` is True. Does nothing if
|
||||
`path` == `dest`.
|
||||
|
|
@ -536,7 +555,7 @@ def link(path, dest, replace=False):
|
|||
traceback.format_exc())
|
||||
|
||||
|
||||
def hardlink(path, dest, replace=False):
|
||||
def hardlink(path: bytes, dest: bytes, replace: bool = False):
|
||||
"""Create a hard link from path to `dest`. Raises an OSError if
|
||||
`dest` already exists, unless `replace` is True. Does nothing if
|
||||
`path` == `dest`.
|
||||
|
|
@ -560,7 +579,12 @@ def hardlink(path, dest, replace=False):
|
|||
traceback.format_exc())
|
||||
|
||||
|
||||
def reflink(path, dest, replace=False, fallback=False):
|
||||
def reflink(
|
||||
path: bytes,
|
||||
dest: bytes,
|
||||
replace: bool = False,
|
||||
fallback: bool = False,
|
||||
):
|
||||
"""Create a reflink from `dest` to `path`.
|
||||
|
||||
Raise an `OSError` if `dest` already exists, unless `replace` is
|
||||
|
|
@ -589,7 +613,7 @@ def reflink(path, dest, replace=False, fallback=False):
|
|||
'link', (path, dest), traceback.format_exc())
|
||||
|
||||
|
||||
def unique_path(path):
|
||||
def unique_path(path: bytes) -> bytes:
|
||||
"""Returns a version of ``path`` that does not exist on the
|
||||
filesystem. Specifically, if ``path` itself already exists, then
|
||||
something unique is appended to the path.
|
||||
|
|
@ -616,7 +640,7 @@ def unique_path(path):
|
|||
# Unix. They are forbidden here because they cause problems on Samba
|
||||
# shares, which are sufficiently common as to cause frequent problems.
|
||||
# https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
|
||||
CHAR_REPLACE = [
|
||||
CHAR_REPLACE: List[Tuple[Pattern, str]] = [
|
||||
(re.compile(r'[\\/]'), '_'), # / and \ -- forbidden everywhere.
|
||||
(re.compile(r'^\.'), '_'), # Leading dot (hidden files on Unix).
|
||||
(re.compile(r'[\x00-\x1f]'), ''), # Control characters.
|
||||
|
|
@ -626,7 +650,10 @@ CHAR_REPLACE = [
|
|||
]
|
||||
|
||||
|
||||
def sanitize_path(path, replacements=None):
|
||||
def sanitize_path(
|
||||
path: str,
|
||||
replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]] = None,
|
||||
) -> str:
|
||||
"""Takes a path (as a Unicode string) and makes sure that it is
|
||||
legal. Returns a new path. Only works with fragments; won't work
|
||||
reliably on Windows when a path begins with a drive letter. Path
|
||||
|
|
@ -647,7 +674,7 @@ def sanitize_path(path, replacements=None):
|
|||
return os.path.join(*comps)
|
||||
|
||||
|
||||
def truncate_path(path, length=MAX_FILENAME_LENGTH):
|
||||
def truncate_path(path: AnyStr, length: int = MAX_FILENAME_LENGTH) -> AnyStr:
|
||||
"""Given a bytestring path or a Unicode path fragment, truncate the
|
||||
components to a legal length. In the last component, the extension
|
||||
is preserved.
|
||||
|
|
@ -664,7 +691,13 @@ def truncate_path(path, length=MAX_FILENAME_LENGTH):
|
|||
return os.path.join(*out)
|
||||
|
||||
|
||||
def _legalize_stage(path, replacements, length, extension, fragment):
|
||||
def _legalize_stage(
|
||||
path: str,
|
||||
replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]],
|
||||
length: int,
|
||||
extension: str,
|
||||
fragment: bool,
|
||||
) -> Tuple[Bytes_or_String, bool]:
|
||||
"""Perform a single round of path legalization steps
|
||||
(sanitation/replacement, encoding from Unicode to bytes,
|
||||
extension-appending, and truncation). Return the path (Unicode if
|
||||
|
|
@ -676,7 +709,7 @@ def _legalize_stage(path, replacements, length, extension, fragment):
|
|||
|
||||
# Encode for the filesystem.
|
||||
if not fragment:
|
||||
path = bytestring_path(path)
|
||||
path = bytestring_path(path) # type: ignore
|
||||
|
||||
# Preserve extension.
|
||||
path += extension.lower()
|
||||
|
|
@ -688,7 +721,13 @@ def _legalize_stage(path, replacements, length, extension, fragment):
|
|||
return path, path != pre_truncate_path
|
||||
|
||||
|
||||
def legalize_path(path, replacements, length, extension, fragment):
|
||||
def legalize_path(
|
||||
path: str,
|
||||
replacements: Optional[Sequence[Sequence[Union[Pattern, str]]]],
|
||||
length: int,
|
||||
extension: bytes,
|
||||
fragment: bool,
|
||||
) -> Tuple[Union[Bytes_or_String, bool]]:
|
||||
"""Given a path-like Unicode string, produce a legal path. Return
|
||||
the path and a flag indicating whether some replacements had to be
|
||||
ignored (see below).
|
||||
|
|
@ -736,7 +775,7 @@ def legalize_path(path, replacements, length, extension, fragment):
|
|||
return second_stage_path, retruncated
|
||||
|
||||
|
||||
def py3_path(path):
|
||||
def py3_path(path: AnyStr) -> str:
|
||||
"""Convert a bytestring path to Unicode.
|
||||
|
||||
This helps deal with APIs on Python 3 that *only* accept Unicode
|
||||
|
|
@ -751,12 +790,12 @@ def py3_path(path):
|
|||
return os.fsdecode(path)
|
||||
|
||||
|
||||
def str2bool(value):
|
||||
def str2bool(value: str) -> bool:
|
||||
"""Returns a boolean reflecting a human-entered string."""
|
||||
return value.lower() in ('yes', '1', 'true', 't', 'y')
|
||||
|
||||
|
||||
def as_string(value):
|
||||
def as_string(value: Any) -> str:
|
||||
"""Convert a value to a Unicode object for matching with a query.
|
||||
None becomes the empty string. Bytestrings are silently decoded.
|
||||
"""
|
||||
|
|
@ -770,7 +809,7 @@ def as_string(value):
|
|||
return str(value)
|
||||
|
||||
|
||||
def plurality(objs):
|
||||
def plurality(objs: Sequence[T]) -> T:
|
||||
"""Given a sequence of hashble objects, returns the object that
|
||||
is most common in the set and the its number of appearance. The
|
||||
sequence must contain at least one object.
|
||||
|
|
@ -781,7 +820,7 @@ def plurality(objs):
|
|||
return c.most_common(1)[0]
|
||||
|
||||
|
||||
def cpu_count():
|
||||
def cpu_count() -> int:
|
||||
"""Return the number of hardware thread contexts (cores or SMT
|
||||
threads) in the system.
|
||||
"""
|
||||
|
|
@ -812,13 +851,12 @@ def cpu_count():
|
|||
return 1
|
||||
|
||||
|
||||
def convert_command_args(args):
|
||||
def convert_command_args(args: List[bytes]) -> List[str]:
|
||||
"""Convert command arguments, which may either be `bytes` or `str`
|
||||
objects, to uniformly surrogate-escaped strings.
|
||||
"""
|
||||
objects, to uniformly surrogate-escaped strings. """
|
||||
assert isinstance(args, list)
|
||||
|
||||
def convert(arg):
|
||||
def convert(arg) -> str:
|
||||
if isinstance(arg, bytes):
|
||||
return os.fsdecode(arg)
|
||||
return arg
|
||||
|
|
@ -830,7 +868,10 @@ def convert_command_args(args):
|
|||
CommandOutput = namedtuple("CommandOutput", ("stdout", "stderr"))
|
||||
|
||||
|
||||
def command_output(cmd, shell=False):
|
||||
def command_output(
|
||||
cmd: List[Bytes_or_String],
|
||||
shell: bool = False,
|
||||
) -> CommandOutput:
|
||||
"""Runs the command and returns its output after it has exited.
|
||||
|
||||
Returns a CommandOutput. The attributes ``stdout`` and ``stderr`` contain
|
||||
|
|
@ -870,7 +911,7 @@ def command_output(cmd, shell=False):
|
|||
return CommandOutput(stdout, stderr)
|
||||
|
||||
|
||||
def max_filename_length(path, limit=MAX_FILENAME_LENGTH):
|
||||
def max_filename_length(path: AnyStr, limit=MAX_FILENAME_LENGTH) -> int:
|
||||
"""Attempt to determine the maximum filename length for the
|
||||
filesystem containing `path`. If the value is greater than `limit`,
|
||||
then `limit` is used instead (to prevent errors when a filesystem
|
||||
|
|
@ -887,7 +928,7 @@ def max_filename_length(path, limit=MAX_FILENAME_LENGTH):
|
|||
return limit
|
||||
|
||||
|
||||
def open_anything():
|
||||
def open_anything() -> str:
|
||||
"""Return the system command that dispatches execution to the correct
|
||||
program.
|
||||
"""
|
||||
|
|
@ -901,7 +942,7 @@ def open_anything():
|
|||
return base_cmd
|
||||
|
||||
|
||||
def editor_command():
|
||||
def editor_command() -> str:
|
||||
"""Get a command for opening a text file.
|
||||
|
||||
Use the `EDITOR` environment variable by default. If it is not
|
||||
|
|
@ -914,7 +955,7 @@ def editor_command():
|
|||
return open_anything()
|
||||
|
||||
|
||||
def interactive_open(targets, command):
|
||||
def interactive_open(targets: Sequence[str], command: str):
|
||||
"""Open the files in `targets` by `exec`ing a new `command`, given
|
||||
as a Unicode string. (The new program takes over, and Python
|
||||
execution ends: this does not fork a subprocess.)
|
||||
|
|
@ -936,7 +977,7 @@ def interactive_open(targets, command):
|
|||
return os.execlp(*args)
|
||||
|
||||
|
||||
def case_sensitive(path):
|
||||
def case_sensitive(path: bytes) -> bool:
|
||||
"""Check whether the filesystem at the given path is case sensitive.
|
||||
|
||||
To work best, the path should point to a file or a directory. If the path
|
||||
|
|
@ -984,7 +1025,7 @@ def case_sensitive(path):
|
|||
return not os.path.samefile(lower_sys, upper_sys)
|
||||
|
||||
|
||||
def raw_seconds_short(string):
|
||||
def raw_seconds_short(string: str) -> float:
|
||||
"""Formats a human-readable M:SS string as a float (number of seconds).
|
||||
|
||||
Raises ValueError if the conversion cannot take place due to `string` not
|
||||
|
|
@ -997,7 +1038,7 @@ def raw_seconds_short(string):
|
|||
return float(minutes * 60 + seconds)
|
||||
|
||||
|
||||
def asciify_path(path, sep_replace):
|
||||
def asciify_path(path: str, sep_replace: str) -> str:
|
||||
"""Decodes all unicode characters in a path into ASCII equivalents.
|
||||
|
||||
Substitutions are provided by the unidecode module. Path separators in the
|
||||
|
|
@ -1010,7 +1051,7 @@ def asciify_path(path, sep_replace):
|
|||
# if this platform has an os.altsep, change it to os.sep.
|
||||
if os.altsep:
|
||||
path = path.replace(os.altsep, os.sep)
|
||||
path_components = path.split(os.sep)
|
||||
path_components: List[Bytes_or_String] = path.split(os.sep)
|
||||
for index, item in enumerate(path_components):
|
||||
path_components[index] = unidecode(item).replace(os.sep, sep_replace)
|
||||
if os.altsep:
|
||||
|
|
@ -1021,7 +1062,7 @@ def asciify_path(path, sep_replace):
|
|||
return os.sep.join(path_components)
|
||||
|
||||
|
||||
def par_map(transform, items):
|
||||
def par_map(transform: Callable, items: Iterable):
|
||||
"""Apply the function `transform` to all the elements in the
|
||||
iterable `items`, like `map(transform, items)` but with no return
|
||||
value.
|
||||
|
|
@ -1035,7 +1076,7 @@ def par_map(transform, items):
|
|||
pool.join()
|
||||
|
||||
|
||||
def lazy_property(func):
|
||||
def lazy_property(func: Callable) -> Callable:
|
||||
"""A decorator that creates a lazily evaluated property. On first access,
|
||||
the property is assigned the return value of `func`. This first value is
|
||||
stored, so that future accesses do not have to evaluate `func` again.
|
||||
|
|
|
|||
1
setup.py
1
setup.py
|
|
@ -92,6 +92,7 @@ setup(
|
|||
'confuse>=1.5.0',
|
||||
'munkres>=1.0.0',
|
||||
'jellyfish',
|
||||
'typing_extensions',
|
||||
] + (
|
||||
# Support for ANSI console colors on Windows.
|
||||
['colorama'] if (sys.platform == 'win32') else []
|
||||
|
|
|
|||
Loading…
Reference in a new issue