generate stubs

This commit is contained in:
Sebastian Mohr 2025-08-21 19:18:23 +02:00
parent b7091bf120
commit 59021add83
3 changed files with 266 additions and 1 deletions

View file

@ -18,9 +18,19 @@ from __future__ import annotations
import re
import time
from types import GenericAlias
import typing
from abc import ABC
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from typing import (
TYPE_CHECKING,
Any,
Generic,
TypeVar,
cast,
get_args,
get_origin,
get_type_hints,
)
import beets
from beets import util
@ -454,6 +464,58 @@ class DurationType(Float):
return self.null
def get_type_parameters(instance: Type) -> tuple[type, type]:
"""
Extract T and N types from a Type instance.
Returns:
Tuple of (T_type, N_type) or (object, object) if cannot determine
"""
n_type: type
t_type: type
try:
# Get T type from model_type attribute or default to object
t_type = getattr(instance, "model_type", object)
if hasattr(t_type, "__origin__") and get_origin(t_type) is not None:
# This is already a generic type, keep it as-is
pass
else:
# For non-generic types, ensure we have the actual type
t_type = type(t_type) if not isinstance(t_type, type) else t_type
# Get N type from null attribute with proper handling
if hasattr(instance, "null"):
null_value = instance.null
if null_value is None:
n_type = type(None)
else:
n_type = type(null_value)
# For generic types in null values, try to infer from class definition
if isinstance(null_value, list) and hasattr(
instance.__class__, "__orig_bases__"
):
# Try to extract the generic type from class inheritance
for base in getattr(
instance.__class__, "__orig_bases__", []
):
if get_origin(base) is not None:
base_args = get_args(base)
if len(base_args) >= 2 and base_args[1] is not Any:
n_type = base_args[1]
break
else:
# Default N type is the same as T for non-nullable types
n_type = t_type
return t_type, n_type
except (AttributeError, TypeError):
pass
# Final fallback: return object types
return object, object
# Shared instances of common types.
DEFAULT = Default()
INTEGER = Integer()

View file

@ -0,0 +1,169 @@
"""
We found that using the beets library models can sometimes be frustrating because of
missing typehints for the model attributes and methods.
This script does the following in order:
- generates or overwrite the current models.pyi stub files.
- injects type hints for the __getitem__ method depending on the defined fields
"""
from __future__ import annotations
import logging
import re
import subprocess
from pathlib import Path
from typing import Union, get_args, get_origin
from beets.dbcore.types import get_type_parameters
from beets.library import Album, Item, LibModel
log = logging.getLogger(__name__)
def overload_template(key: str, return_type: type) -> list[str]:
"""Generate an overload for __getitem__ with a specific key and return type."""
type_str = type2str(return_type)
return [
"@overload",
f"def __getitem__(self, key: {key}) -> {type_str}: ...",
]
def type2str(tp: type) -> str:
"""Convert a Python type into a PEP 604 style string for stubs."""
# Handle NoneType first since it doesn't have __name__
if tp is None or tp is type(None):
return "None"
if tp is object:
return "Any"
origin = get_origin(tp)
if origin is None:
# It's a simple type
return tp.__name__
elif hasattr(origin, "__name__"):
# It's a regular generic type like list, dict, etc.
args = get_args(tp)
if args:
args_str = ", ".join(type2str(arg) for arg in args)
return f"{origin.__name__}[{args_str}]"
else:
return origin.__name__
else:
# Handle special forms like Union, Optional, etc.
args = get_args(tp)
if origin is Union:
return " | ".join(type2str(arg) for arg in args)
else:
# For other special forms, fall back to string representation
return str(tp)
def generate_overloads(model: type[LibModel]) -> list[str]:
"""Generate overloads for __getitem__ based on Item._fields."""
lines: list[str] = []
count = 0
for name, field_type in model._fields.items():
return_type, null_type = get_type_parameters(field_type)
lines.extend(
overload_template(
f"Literal['{name}']", Union[return_type, null_type]
)
)
count += 1
# Default str
lines.extend(overload_template("str", object))
log.info(f"Generated {count} overloads for {model.__name__}")
return lines
def inject_overloads(stub_path: Path, model: type[LibModel]) -> None:
"""Insert generated overloads into the class definition in a .pyi file."""
text = stub_path.read_text()
class_name = model.__name__
log.info(f"Injecting overloads for {class_name} into {stub_path}")
# Find the class definition
class_pattern = rf"^(class {class_name}\(.*\):)"
match = re.search(class_pattern, text, flags=re.MULTILINE)
if not match:
raise RuntimeError(f"Class {class_name} not found in {stub_path}")
# Where to insert
insert_pos = match.end()
# Prepare overload block and indent
overloads = generate_overloads(model)
overload_text = "\n".join(f" {line}" for line in overloads)
# Insert after class line
new_text = text[:insert_pos] + "\n" + overload_text + text[insert_pos:]
# Write result
stub_path.write_text(new_text)
log.info(f"Injected overloads into {stub_path}")
def run_stubgen(module: str, out_dir: Path) -> Path:
"""Run stubgen for a module and return the generated pyi path."""
subprocess.run(
["stubgen", "-m", module, "--include-private", "-o", str(out_dir)],
check=True,
)
# Figure out the generated file path
pyi_path = out_dir / Path(module.replace(".", "/") + ".pyi")
if not pyi_path.exists():
raise FileNotFoundError(f"Stubgen did not generate {pyi_path}")
return pyi_path
def format_file(stub_path: Path) -> None:
"""Run ruff fix on the generated stub file."""
subprocess.run(
[
"ruff",
"check",
str(stub_path),
"--fix",
"--unsafe-fixes",
"--silent",
],
check=True,
)
def ensure_imports(stub_path: Path) -> None:
"""Ensure multiple imports are present in the generated stub file."""
text = stub_path.read_text()
if "from typing import Literal" not in text:
insert_pos = text.find(
"from typing"
) # Attempt to find the first typing import
if insert_pos == -1:
# No existing import, add at the top of the file
insert_pos = 0
else:
# Otherwise, find the position after the first import line
insert_pos = text.find("\n", insert_pos) + 1
# Add Literal import
new_text = (
text[:insert_pos]
+ "from typing import Literal, Any, overload\n"
+ text[insert_pos:]
)
stub_path.write_text(new_text)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
out_dir = Path(__file__).parent.parent
file = run_stubgen("beets.library.models", out_dir)
ensure_imports(file)
inject_overloads(file, Item)
inject_overloads(file, Album)

View file

@ -1,4 +1,7 @@
import time
from tkinter import N
import pytest
import beets
from beets.dbcore import types
@ -56,3 +59,34 @@ def test_durationtype():
beets.config["format_raw_length"] = True
assert 61.23 == t.format(61.23)
assert 3601.23 == t.format(3601.23)
@pytest.mark.parametrize(
"type, type_params",
[
(types.DEFAULT, (str, type(None))),
# ints
(types.Integer(), (int, int)),
(types.NullInteger(), (int, type(None))),
(types.PaddedInt(2), (int, int)),
(types.NullPaddedInt(3), (int, type(None))),
(types.ScaledInt(2, "foo"), (int, int)),
(types.Id(), (int, type(None))),
# floats
(types.Float(), (float, float)),
(types.NullFloat(), (float, type(None))),
(types.DateType(), (float, float)),
(types.DurationType(), (float, float)),
# Strings
(types.String(), (str, str)),
(types.DelimitedString(","), (list[str], list[str])),
(types.MusicalKey(), (str, type(None))),
# Other
(types.Boolean(), (bool, bool)),
# Paths
(types.PathType(), (bytes, bytes)),
(types.NullPathType(), (bytes, type(None))),
],
)
def test_get_type_parameters(type: types.Type, type_params):
assert type_params == types.get_type_parameters(type)