mirror of
https://github.com/beetbox/beets.git
synced 2025-12-07 17:16:07 +01:00
169 lines
5.2 KiB
Python
169 lines
5.2 KiB
Python
"""
|
|
We found that using the beets library models can sometimes be frustrating because of
|
|
missing typehints for the model attributes and methods.
|
|
|
|
This script does the following in order:
|
|
- generates or overwrite the current models.pyi stub files.
|
|
- injects type hints for the __getitem__ method depending on the defined fields
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import re
|
|
import subprocess
|
|
from pathlib import Path
|
|
from typing import Union, get_args, get_origin
|
|
|
|
from beets.dbcore.types import get_type_parameters
|
|
from beets.library import Album, Item, LibModel
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def overload_template(key: str, return_type: type) -> list[str]:
|
|
"""Generate an overload for __getitem__ with a specific key and return type."""
|
|
type_str = type2str(return_type)
|
|
return [
|
|
"@overload",
|
|
f"def __getitem__(self, key: {key}) -> {type_str}: ...",
|
|
]
|
|
|
|
|
|
def type2str(tp: type) -> str:
|
|
"""Convert a Python type into a PEP 604 style string for stubs."""
|
|
# Handle NoneType first since it doesn't have __name__
|
|
if tp is None or tp is type(None):
|
|
return "None"
|
|
if tp is object:
|
|
return "Any"
|
|
|
|
origin = get_origin(tp)
|
|
if origin is None:
|
|
# It's a simple type
|
|
return tp.__name__
|
|
elif hasattr(origin, "__name__"):
|
|
# It's a regular generic type like list, dict, etc.
|
|
args = get_args(tp)
|
|
if args:
|
|
args_str = ", ".join(type2str(arg) for arg in args)
|
|
return f"{origin.__name__}[{args_str}]"
|
|
else:
|
|
return origin.__name__
|
|
else:
|
|
# Handle special forms like Union, Optional, etc.
|
|
args = get_args(tp)
|
|
if origin is Union:
|
|
return " | ".join(type2str(arg) for arg in args)
|
|
else:
|
|
# For other special forms, fall back to string representation
|
|
return str(tp)
|
|
|
|
|
|
def generate_overloads(model: type[LibModel]) -> list[str]:
|
|
"""Generate overloads for __getitem__ based on Item._fields."""
|
|
lines: list[str] = []
|
|
count = 0
|
|
for name, field_type in model._fields.items():
|
|
return_type, null_type = get_type_parameters(field_type)
|
|
lines.extend(
|
|
overload_template(
|
|
f"Literal['{name}']", Union[return_type, null_type]
|
|
)
|
|
)
|
|
count += 1
|
|
|
|
# Default str
|
|
lines.extend(overload_template("str", object))
|
|
|
|
log.info(f"Generated {count} overloads for {model.__name__}")
|
|
return lines
|
|
|
|
|
|
def inject_overloads(stub_path: Path, model: type[LibModel]) -> None:
|
|
"""Insert generated overloads into the class definition in a .pyi file."""
|
|
text = stub_path.read_text()
|
|
|
|
class_name = model.__name__
|
|
log.info(f"Injecting overloads for {class_name} into {stub_path}")
|
|
|
|
# Find the class definition
|
|
class_pattern = rf"^(class {class_name}\(.*\):)"
|
|
match = re.search(class_pattern, text, flags=re.MULTILINE)
|
|
if not match:
|
|
raise RuntimeError(f"Class {class_name} not found in {stub_path}")
|
|
|
|
# Where to insert
|
|
insert_pos = match.end()
|
|
|
|
# Prepare overload block and indent
|
|
overloads = generate_overloads(model)
|
|
overload_text = "\n".join(f" {line}" for line in overloads)
|
|
|
|
# Insert after class line
|
|
new_text = text[:insert_pos] + "\n" + overload_text + text[insert_pos:]
|
|
|
|
# Write result
|
|
stub_path.write_text(new_text)
|
|
log.info(f"Injected overloads into {stub_path}")
|
|
|
|
|
|
def run_stubgen(module: str, out_dir: Path) -> Path:
|
|
"""Run stubgen for a module and return the generated pyi path."""
|
|
subprocess.run(
|
|
["stubgen", "-m", module, "--include-private", "-o", str(out_dir)],
|
|
check=True,
|
|
)
|
|
# Figure out the generated file path
|
|
pyi_path = out_dir / Path(module.replace(".", "/") + ".pyi")
|
|
if not pyi_path.exists():
|
|
raise FileNotFoundError(f"Stubgen did not generate {pyi_path}")
|
|
return pyi_path
|
|
|
|
|
|
def format_file(stub_path: Path) -> None:
|
|
"""Run ruff fix on the generated stub file."""
|
|
subprocess.run(
|
|
[
|
|
"ruff",
|
|
"check",
|
|
str(stub_path),
|
|
"--fix",
|
|
"--unsafe-fixes",
|
|
"--silent",
|
|
],
|
|
check=True,
|
|
)
|
|
|
|
|
|
def ensure_imports(stub_path: Path) -> None:
|
|
"""Ensure multiple imports are present in the generated stub file."""
|
|
text = stub_path.read_text()
|
|
|
|
if "from typing import Literal" not in text:
|
|
insert_pos = text.find(
|
|
"from typing"
|
|
) # Attempt to find the first typing import
|
|
if insert_pos == -1:
|
|
# No existing import, add at the top of the file
|
|
insert_pos = 0
|
|
else:
|
|
# Otherwise, find the position after the first import line
|
|
insert_pos = text.find("\n", insert_pos) + 1
|
|
|
|
# Add Literal import
|
|
new_text = (
|
|
text[:insert_pos]
|
|
+ "from typing import Literal, Any, overload\n"
|
|
+ text[insert_pos:]
|
|
)
|
|
stub_path.write_text(new_text)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
out_dir = Path(__file__).parent.parent
|
|
file = run_stubgen("beets.library.models", out_dir)
|
|
ensure_imports(file)
|
|
inject_overloads(file, Item)
|
|
inject_overloads(file, Album)
|