beets/scripts/generate_model_stubs.py
2025-08-21 19:18:23 +02:00

169 lines
5.2 KiB
Python

"""
We found that using the beets library models can sometimes be frustrating because of
missing typehints for the model attributes and methods.
This script does the following in order:
- generates or overwrite the current models.pyi stub files.
- injects type hints for the __getitem__ method depending on the defined fields
"""
from __future__ import annotations
import logging
import re
import subprocess
from pathlib import Path
from typing import Union, get_args, get_origin
from beets.dbcore.types import get_type_parameters
from beets.library import Album, Item, LibModel
log = logging.getLogger(__name__)
def overload_template(key: str, return_type: type) -> list[str]:
"""Generate an overload for __getitem__ with a specific key and return type."""
type_str = type2str(return_type)
return [
"@overload",
f"def __getitem__(self, key: {key}) -> {type_str}: ...",
]
def type2str(tp: type) -> str:
"""Convert a Python type into a PEP 604 style string for stubs."""
# Handle NoneType first since it doesn't have __name__
if tp is None or tp is type(None):
return "None"
if tp is object:
return "Any"
origin = get_origin(tp)
if origin is None:
# It's a simple type
return tp.__name__
elif hasattr(origin, "__name__"):
# It's a regular generic type like list, dict, etc.
args = get_args(tp)
if args:
args_str = ", ".join(type2str(arg) for arg in args)
return f"{origin.__name__}[{args_str}]"
else:
return origin.__name__
else:
# Handle special forms like Union, Optional, etc.
args = get_args(tp)
if origin is Union:
return " | ".join(type2str(arg) for arg in args)
else:
# For other special forms, fall back to string representation
return str(tp)
def generate_overloads(model: type[LibModel]) -> list[str]:
"""Generate overloads for __getitem__ based on Item._fields."""
lines: list[str] = []
count = 0
for name, field_type in model._fields.items():
return_type, null_type = get_type_parameters(field_type)
lines.extend(
overload_template(
f"Literal['{name}']", Union[return_type, null_type]
)
)
count += 1
# Default str
lines.extend(overload_template("str", object))
log.info(f"Generated {count} overloads for {model.__name__}")
return lines
def inject_overloads(stub_path: Path, model: type[LibModel]) -> None:
"""Insert generated overloads into the class definition in a .pyi file."""
text = stub_path.read_text()
class_name = model.__name__
log.info(f"Injecting overloads for {class_name} into {stub_path}")
# Find the class definition
class_pattern = rf"^(class {class_name}\(.*\):)"
match = re.search(class_pattern, text, flags=re.MULTILINE)
if not match:
raise RuntimeError(f"Class {class_name} not found in {stub_path}")
# Where to insert
insert_pos = match.end()
# Prepare overload block and indent
overloads = generate_overloads(model)
overload_text = "\n".join(f" {line}" for line in overloads)
# Insert after class line
new_text = text[:insert_pos] + "\n" + overload_text + text[insert_pos:]
# Write result
stub_path.write_text(new_text)
log.info(f"Injected overloads into {stub_path}")
def run_stubgen(module: str, out_dir: Path) -> Path:
"""Run stubgen for a module and return the generated pyi path."""
subprocess.run(
["stubgen", "-m", module, "--include-private", "-o", str(out_dir)],
check=True,
)
# Figure out the generated file path
pyi_path = out_dir / Path(module.replace(".", "/") + ".pyi")
if not pyi_path.exists():
raise FileNotFoundError(f"Stubgen did not generate {pyi_path}")
return pyi_path
def format_file(stub_path: Path) -> None:
"""Run ruff fix on the generated stub file."""
subprocess.run(
[
"ruff",
"check",
str(stub_path),
"--fix",
"--unsafe-fixes",
"--silent",
],
check=True,
)
def ensure_imports(stub_path: Path) -> None:
"""Ensure multiple imports are present in the generated stub file."""
text = stub_path.read_text()
if "from typing import Literal" not in text:
insert_pos = text.find(
"from typing"
) # Attempt to find the first typing import
if insert_pos == -1:
# No existing import, add at the top of the file
insert_pos = 0
else:
# Otherwise, find the position after the first import line
insert_pos = text.find("\n", insert_pos) + 1
# Add Literal import
new_text = (
text[:insert_pos]
+ "from typing import Literal, Any, overload\n"
+ text[insert_pos:]
)
stub_path.write_text(new_text)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
out_dir = Path(__file__).parent.parent
file = run_stubgen("beets.library.models", out_dir)
ensure_imports(file)
inject_overloads(file, Item)
inject_overloads(file, Album)