From a84b3542f9753bbecab51b642fb0cdeaa5218969 Mon Sep 17 00:00:00 2001 From: wisp3rwind <17089248+wisp3rwind@users.noreply.github.com> Date: Thu, 23 Feb 2023 23:16:22 +0100 Subject: [PATCH] typing: corrections for dbcore/types.py tricky... - the only way I found to express the concept of the "associated type" (in Rust lingo) model_type was by making Type generic over its value and null types. - in addition, the class hierarchy of Integer and Float types had to be modified, since previously some of them would have conflicting null types relative to their super class (this required a change to the edit plugin; hopefully no more breakage is caused by these changes) - don't import the query module, but only the relevant Query's to avoid confusing the module query and the class variable query --- beets/dbcore/types.py | 142 ++++++++++++++++++++++++++++++------------ beetsplug/edit.py | 2 +- 2 files changed, 103 insertions(+), 41 deletions(-) diff --git a/beets/dbcore/types.py b/beets/dbcore/types.py index ac8dd762b..24cee4a1f 100644 --- a/beets/dbcore/types.py +++ b/beets/dbcore/types.py @@ -15,28 +15,47 @@ """Representation of type information for DBCore model fields. """ -from typing import Union, Any, Callable -from . import query +from abc import ABC +import typing +from typing import Any, cast, Generic, List, TypeVar, Union +from .query import BooleanQuery, FieldQuery, NumericQuery, SubstringQuery from beets.util import str2bool # Abstract base. -class Type: + +class ModelType(Protocol): + """Protocol that specifies the required constructor for model types, i.e. + a function that takes any argument and attempts to parse it to the given + type. + """ + def __init__(self, value: Any = None): + ... + + +# Generic type variables, used for the value type T and null type N (if +# nullable, else T and N are set to the same type for the concrete subclasses +# of Type). +N = TypeVar("N") +T = TypeVar("T", bound=ModelType) + + +class Type(ABC, Generic[T, N]): """An object encapsulating the type of a model field. Includes information about how to store, query, format, and parse a given field. """ - sql = 'TEXT' + sql: str = 'TEXT' """The SQLite column type for the value. """ - query = query.SubstringQuery + query: typing.Type[FieldQuery] = SubstringQuery """The `Query` subclass to be used when querying the field. """ - model_type: Callable[[Any], str] = str + model_type: typing.Type[T] """The Python type that is used to represent the value in the model. The model is guaranteed to return a value of this type if the field @@ -45,12 +64,15 @@ class Type: """ @property - def null(self) -> model_type: + def null(self) -> N: """The value to be exposed when the underlying value is None. """ - return self.model_type() + # Note that this default implementation only makes sense for T = N. + # It would be better to implement `null()` only in subclasses, or + # have a field null_type similar to `model_type` and use that here. + return cast(N, self.model_type()) - def format(self, value: model_type) -> str: + def format(self, value: Union[N, T]) -> str: """Given a value of this type, produce a Unicode string representing the value. This is used in template evaluation. """ @@ -58,13 +80,13 @@ class Type: value = self.null # `self.null` might be `None` if value is None: - value = '' - if isinstance(value, bytes): - value = value.decode('utf-8', 'ignore') + return '' + elif isinstance(value, bytes): + return value.decode('utf-8', 'ignore') + else: + return str(value) - return str(value) - - def parse(self, string: str) -> model_type: + def parse(self, string: str) -> Union[T, N]: """Parse a (possibly human-written) string and return the indicated value of this type. """ @@ -73,7 +95,7 @@ class Type: except ValueError: return self.null - def normalize(self, value: Union[None, int, float, bytes]) -> model_type: + def normalize(self, value: Any) -> Union[T, N]: """Given a value that will be assigned into a field of this type, normalize the value to have the appropriate type. This base implementation only reinterprets `None`. @@ -84,12 +106,12 @@ class Type: else: # TODO This should eventually be replaced by # `self.model_type(value)` - return value + return cast(T, value) def from_sql( self, sql_value: Union[None, int, float, str, bytes], - ) -> model_type: + ) -> Union[T, N]: """Receives the value stored in the SQL backend and return the value to be stored in the model. @@ -119,18 +141,22 @@ class Type: # Reusable types. -class Default(Type): - null = None +class Default(Type[str, None]): + model_type = str + + @property + def null(self): + return None -class Integer(Type): +class BaseInteger(Type[int, N]): """A basic integer type. """ sql = 'INTEGER' - query = query.NumericQuery + query = NumericQuery model_type = int - def normalize(self, value: str) -> Union[int, str]: + def normalize(self, value: Any) -> Union[int, N]: try: return self.model_type(round(float(value))) except ValueError: @@ -139,21 +165,39 @@ class Integer(Type): return self.null -class PaddedInt(Integer): +class Integer(BaseInteger[int]): + @property + def null(self) -> int: + return 0 + + +class NullInteger(BaseInteger[None]): + @property + def null(self) -> None: + return None + + +class BasePaddedInt(BaseInteger[N]): """An integer field that is formatted with a given number of digits, padded with zeroes. """ def __init__(self, digits: int): self.digits = digits - def format(self, value: int) -> str: + def format(self, value: Union[int, N]) -> str: return '{0:0{1}d}'.format(value or 0, self.digits) -class NullPaddedInt(PaddedInt): - """Same as `PaddedInt`, but does not normalize `None` to `0.0`. +class PaddedInt(BasePaddedInt[int]): + pass + + +class NullPaddedInt(BasePaddedInt[None]): + """Same as `PaddedInt`, but does not normalize `None` to `0`. """ - null = None + @property + def null(self) -> None: + return None class ScaledInt(Integer): @@ -168,52 +212,70 @@ class ScaledInt(Integer): return '{}{}'.format((value or 0) // self.unit, self.suffix) -class Id(Integer): +class Id(NullInteger): """An integer used as the row id or a foreign key in a SQLite table. This type is nullable: None values are not translated to zero. """ - null = None + @property + def null(self) -> None: + return None def __init__(self, primary: bool = True): if primary: self.sql = 'INTEGER PRIMARY KEY' -class Float(Type): +class BaseFloat(Type[float, N]): """A basic floating-point type. The `digits` parameter specifies how many decimal places to use in the human-readable representation. """ sql = 'REAL' - query = query.NumericQuery + query = NumericQuery model_type = float def __init__(self, digits: int = 1): self.digits = digits - def format(self, value: float) -> str: + def format(self, value: Union[float, N]) -> str: return '{0:.{1}f}'.format(value or 0, self.digits) -class NullFloat(Float): +class Float(BaseFloat[float]): + """Floating-point type that normalizes `None` to `0.0`. + """ + @property + def null(self) -> float: + return 0.0 + + +class NullFloat(BaseFloat[None]): """Same as `Float`, but does not normalize `None` to `0.0`. """ - null = None + @property + def null(self) -> None: + return None -class String(Type): +class BaseString(Type[T, N]): """A Unicode string type. """ sql = 'TEXT' - query = query.SubstringQuery + query = SubstringQuery - def normalize(self, value: str) -> str: + def normalize(self, value: Any) -> Union[T, N]: if value is None: return self.null else: return self.model_type(value) -class DelimitedString(String): +class String(BaseString[str, Any]): + """A Unicode string type. + """ + model_type = str + + +class DelimitedString(BaseString[List[str], List[str]]): """A list of Unicode strings, represented in-database by a single string containing delimiter-separated values. """ @@ -238,7 +300,7 @@ class Boolean(Type): """A boolean type. """ sql = 'INTEGER' - query = query.BooleanQuery + query = BooleanQuery model_type = bool def format(self, value: bool) -> str: diff --git a/beetsplug/edit.py b/beetsplug/edit.py index 6f03fa4d8..6cd0c0df5 100644 --- a/beetsplug/edit.py +++ b/beetsplug/edit.py @@ -31,7 +31,7 @@ import shlex # These "safe" types can avoid the format/parse cycle that most fields go # through: they are safe to edit with native YAML types. -SAFE_TYPES = (types.Float, types.Integer, types.Boolean) +SAFE_TYPES = (types.BaseFloat, types.BaseInteger, types.Boolean) class ParseError(Exception):