diff --git a/beets/dbcore/types.py b/beets/dbcore/types.py index 1b8434a0b..f03731a55 100644 --- a/beets/dbcore/types.py +++ b/beets/dbcore/types.py @@ -18,9 +18,19 @@ from __future__ import annotations import re import time +from types import GenericAlias import typing from abc import ABC -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from typing import ( + TYPE_CHECKING, + Any, + Generic, + TypeVar, + cast, + get_args, + get_origin, + get_type_hints, +) import beets from beets import util @@ -454,6 +464,58 @@ class DurationType(Float): return self.null +def get_type_parameters(instance: Type) -> tuple[type, type]: + """ + Extract T and N types from a Type instance. + + Returns: + Tuple of (T_type, N_type) or (object, object) if cannot determine + """ + n_type: type + t_type: type + try: + # Get T type from model_type attribute or default to object + t_type = getattr(instance, "model_type", object) + + if hasattr(t_type, "__origin__") and get_origin(t_type) is not None: + # This is already a generic type, keep it as-is + pass + else: + # For non-generic types, ensure we have the actual type + t_type = type(t_type) if not isinstance(t_type, type) else t_type + + # Get N type from null attribute with proper handling + if hasattr(instance, "null"): + null_value = instance.null + if null_value is None: + n_type = type(None) + else: + n_type = type(null_value) + # For generic types in null values, try to infer from class definition + if isinstance(null_value, list) and hasattr( + instance.__class__, "__orig_bases__" + ): + # Try to extract the generic type from class inheritance + for base in getattr( + instance.__class__, "__orig_bases__", [] + ): + if get_origin(base) is not None: + base_args = get_args(base) + if len(base_args) >= 2 and base_args[1] is not Any: + n_type = base_args[1] + break + else: + # Default N type is the same as T for non-nullable types + n_type = t_type + + return t_type, n_type + except (AttributeError, TypeError): + pass + + # Final fallback: return object types + return object, object + + # Shared instances of common types. DEFAULT = Default() INTEGER = Integer() diff --git a/scripts/generate_model_stubs.py b/scripts/generate_model_stubs.py new file mode 100644 index 000000000..aa8702878 --- /dev/null +++ b/scripts/generate_model_stubs.py @@ -0,0 +1,169 @@ +""" +We found that using the beets library models can sometimes be frustrating because of +missing typehints for the model attributes and methods. + +This script does the following in order: +- generates or overwrite the current models.pyi stub files. +- injects type hints for the __getitem__ method depending on the defined fields +""" + +from __future__ import annotations + +import logging +import re +import subprocess +from pathlib import Path +from typing import Union, get_args, get_origin + +from beets.dbcore.types import get_type_parameters +from beets.library import Album, Item, LibModel + +log = logging.getLogger(__name__) + + +def overload_template(key: str, return_type: type) -> list[str]: + """Generate an overload for __getitem__ with a specific key and return type.""" + type_str = type2str(return_type) + return [ + "@overload", + f"def __getitem__(self, key: {key}) -> {type_str}: ...", + ] + + +def type2str(tp: type) -> str: + """Convert a Python type into a PEP 604 style string for stubs.""" + # Handle NoneType first since it doesn't have __name__ + if tp is None or tp is type(None): + return "None" + if tp is object: + return "Any" + + origin = get_origin(tp) + if origin is None: + # It's a simple type + return tp.__name__ + elif hasattr(origin, "__name__"): + # It's a regular generic type like list, dict, etc. + args = get_args(tp) + if args: + args_str = ", ".join(type2str(arg) for arg in args) + return f"{origin.__name__}[{args_str}]" + else: + return origin.__name__ + else: + # Handle special forms like Union, Optional, etc. + args = get_args(tp) + if origin is Union: + return " | ".join(type2str(arg) for arg in args) + else: + # For other special forms, fall back to string representation + return str(tp) + + +def generate_overloads(model: type[LibModel]) -> list[str]: + """Generate overloads for __getitem__ based on Item._fields.""" + lines: list[str] = [] + count = 0 + for name, field_type in model._fields.items(): + return_type, null_type = get_type_parameters(field_type) + lines.extend( + overload_template( + f"Literal['{name}']", Union[return_type, null_type] + ) + ) + count += 1 + + # Default str + lines.extend(overload_template("str", object)) + + log.info(f"Generated {count} overloads for {model.__name__}") + return lines + + +def inject_overloads(stub_path: Path, model: type[LibModel]) -> None: + """Insert generated overloads into the class definition in a .pyi file.""" + text = stub_path.read_text() + + class_name = model.__name__ + log.info(f"Injecting overloads for {class_name} into {stub_path}") + + # Find the class definition + class_pattern = rf"^(class {class_name}\(.*\):)" + match = re.search(class_pattern, text, flags=re.MULTILINE) + if not match: + raise RuntimeError(f"Class {class_name} not found in {stub_path}") + + # Where to insert + insert_pos = match.end() + + # Prepare overload block and indent + overloads = generate_overloads(model) + overload_text = "\n".join(f" {line}" for line in overloads) + + # Insert after class line + new_text = text[:insert_pos] + "\n" + overload_text + text[insert_pos:] + + # Write result + stub_path.write_text(new_text) + log.info(f"Injected overloads into {stub_path}") + + +def run_stubgen(module: str, out_dir: Path) -> Path: + """Run stubgen for a module and return the generated pyi path.""" + subprocess.run( + ["stubgen", "-m", module, "--include-private", "-o", str(out_dir)], + check=True, + ) + # Figure out the generated file path + pyi_path = out_dir / Path(module.replace(".", "/") + ".pyi") + if not pyi_path.exists(): + raise FileNotFoundError(f"Stubgen did not generate {pyi_path}") + return pyi_path + + +def format_file(stub_path: Path) -> None: + """Run ruff fix on the generated stub file.""" + subprocess.run( + [ + "ruff", + "check", + str(stub_path), + "--fix", + "--unsafe-fixes", + "--silent", + ], + check=True, + ) + + +def ensure_imports(stub_path: Path) -> None: + """Ensure multiple imports are present in the generated stub file.""" + text = stub_path.read_text() + + if "from typing import Literal" not in text: + insert_pos = text.find( + "from typing" + ) # Attempt to find the first typing import + if insert_pos == -1: + # No existing import, add at the top of the file + insert_pos = 0 + else: + # Otherwise, find the position after the first import line + insert_pos = text.find("\n", insert_pos) + 1 + + # Add Literal import + new_text = ( + text[:insert_pos] + + "from typing import Literal, Any, overload\n" + + text[insert_pos:] + ) + stub_path.write_text(new_text) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + out_dir = Path(__file__).parent.parent + file = run_stubgen("beets.library.models", out_dir) + ensure_imports(file) + inject_overloads(file, Item) + inject_overloads(file, Album) diff --git a/test/test_types.py b/test/test_types.py index 6727917d8..457f5297f 100644 --- a/test/test_types.py +++ b/test/test_types.py @@ -1,4 +1,7 @@ import time +from tkinter import N + +import pytest import beets from beets.dbcore import types @@ -56,3 +59,34 @@ def test_durationtype(): beets.config["format_raw_length"] = True assert 61.23 == t.format(61.23) assert 3601.23 == t.format(3601.23) + + +@pytest.mark.parametrize( + "type, type_params", + [ + (types.DEFAULT, (str, type(None))), + # ints + (types.Integer(), (int, int)), + (types.NullInteger(), (int, type(None))), + (types.PaddedInt(2), (int, int)), + (types.NullPaddedInt(3), (int, type(None))), + (types.ScaledInt(2, "foo"), (int, int)), + (types.Id(), (int, type(None))), + # floats + (types.Float(), (float, float)), + (types.NullFloat(), (float, type(None))), + (types.DateType(), (float, float)), + (types.DurationType(), (float, float)), + # Strings + (types.String(), (str, str)), + (types.DelimitedString(","), (list[str], list[str])), + (types.MusicalKey(), (str, type(None))), + # Other + (types.Boolean(), (bool, bool)), + # Paths + (types.PathType(), (bytes, bytes)), + (types.NullPathType(), (bytes, type(None))), + ], +) +def test_get_type_parameters(type: types.Type, type_params): + assert type_params == types.get_type_parameters(type)