mirror of
https://github.com/beetbox/beets.git
synced 2026-02-16 12:24:53 +01:00
generate stubs
This commit is contained in:
parent
b7091bf120
commit
59021add83
3 changed files with 266 additions and 1 deletions
|
|
@ -18,9 +18,19 @@ from __future__ import annotations
|
|||
|
||||
import re
|
||||
import time
|
||||
from types import GenericAlias
|
||||
import typing
|
||||
from abc import ABC
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
TypeVar,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
import beets
|
||||
from beets import util
|
||||
|
|
@ -454,6 +464,58 @@ class DurationType(Float):
|
|||
return self.null
|
||||
|
||||
|
||||
def get_type_parameters(instance: Type) -> tuple[type, type]:
|
||||
"""
|
||||
Extract T and N types from a Type instance.
|
||||
|
||||
Returns:
|
||||
Tuple of (T_type, N_type) or (object, object) if cannot determine
|
||||
"""
|
||||
n_type: type
|
||||
t_type: type
|
||||
try:
|
||||
# Get T type from model_type attribute or default to object
|
||||
t_type = getattr(instance, "model_type", object)
|
||||
|
||||
if hasattr(t_type, "__origin__") and get_origin(t_type) is not None:
|
||||
# This is already a generic type, keep it as-is
|
||||
pass
|
||||
else:
|
||||
# For non-generic types, ensure we have the actual type
|
||||
t_type = type(t_type) if not isinstance(t_type, type) else t_type
|
||||
|
||||
# Get N type from null attribute with proper handling
|
||||
if hasattr(instance, "null"):
|
||||
null_value = instance.null
|
||||
if null_value is None:
|
||||
n_type = type(None)
|
||||
else:
|
||||
n_type = type(null_value)
|
||||
# For generic types in null values, try to infer from class definition
|
||||
if isinstance(null_value, list) and hasattr(
|
||||
instance.__class__, "__orig_bases__"
|
||||
):
|
||||
# Try to extract the generic type from class inheritance
|
||||
for base in getattr(
|
||||
instance.__class__, "__orig_bases__", []
|
||||
):
|
||||
if get_origin(base) is not None:
|
||||
base_args = get_args(base)
|
||||
if len(base_args) >= 2 and base_args[1] is not Any:
|
||||
n_type = base_args[1]
|
||||
break
|
||||
else:
|
||||
# Default N type is the same as T for non-nullable types
|
||||
n_type = t_type
|
||||
|
||||
return t_type, n_type
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
# Final fallback: return object types
|
||||
return object, object
|
||||
|
||||
|
||||
# Shared instances of common types.
|
||||
DEFAULT = Default()
|
||||
INTEGER = Integer()
|
||||
|
|
|
|||
169
scripts/generate_model_stubs.py
Normal file
169
scripts/generate_model_stubs.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
"""
|
||||
We found that using the beets library models can sometimes be frustrating because of
|
||||
missing typehints for the model attributes and methods.
|
||||
|
||||
This script does the following in order:
|
||||
- generates or overwrite the current models.pyi stub files.
|
||||
- injects type hints for the __getitem__ method depending on the defined fields
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Union, get_args, get_origin
|
||||
|
||||
from beets.dbcore.types import get_type_parameters
|
||||
from beets.library import Album, Item, LibModel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def overload_template(key: str, return_type: type) -> list[str]:
|
||||
"""Generate an overload for __getitem__ with a specific key and return type."""
|
||||
type_str = type2str(return_type)
|
||||
return [
|
||||
"@overload",
|
||||
f"def __getitem__(self, key: {key}) -> {type_str}: ...",
|
||||
]
|
||||
|
||||
|
||||
def type2str(tp: type) -> str:
|
||||
"""Convert a Python type into a PEP 604 style string for stubs."""
|
||||
# Handle NoneType first since it doesn't have __name__
|
||||
if tp is None or tp is type(None):
|
||||
return "None"
|
||||
if tp is object:
|
||||
return "Any"
|
||||
|
||||
origin = get_origin(tp)
|
||||
if origin is None:
|
||||
# It's a simple type
|
||||
return tp.__name__
|
||||
elif hasattr(origin, "__name__"):
|
||||
# It's a regular generic type like list, dict, etc.
|
||||
args = get_args(tp)
|
||||
if args:
|
||||
args_str = ", ".join(type2str(arg) for arg in args)
|
||||
return f"{origin.__name__}[{args_str}]"
|
||||
else:
|
||||
return origin.__name__
|
||||
else:
|
||||
# Handle special forms like Union, Optional, etc.
|
||||
args = get_args(tp)
|
||||
if origin is Union:
|
||||
return " | ".join(type2str(arg) for arg in args)
|
||||
else:
|
||||
# For other special forms, fall back to string representation
|
||||
return str(tp)
|
||||
|
||||
|
||||
def generate_overloads(model: type[LibModel]) -> list[str]:
|
||||
"""Generate overloads for __getitem__ based on Item._fields."""
|
||||
lines: list[str] = []
|
||||
count = 0
|
||||
for name, field_type in model._fields.items():
|
||||
return_type, null_type = get_type_parameters(field_type)
|
||||
lines.extend(
|
||||
overload_template(
|
||||
f"Literal['{name}']", Union[return_type, null_type]
|
||||
)
|
||||
)
|
||||
count += 1
|
||||
|
||||
# Default str
|
||||
lines.extend(overload_template("str", object))
|
||||
|
||||
log.info(f"Generated {count} overloads for {model.__name__}")
|
||||
return lines
|
||||
|
||||
|
||||
def inject_overloads(stub_path: Path, model: type[LibModel]) -> None:
|
||||
"""Insert generated overloads into the class definition in a .pyi file."""
|
||||
text = stub_path.read_text()
|
||||
|
||||
class_name = model.__name__
|
||||
log.info(f"Injecting overloads for {class_name} into {stub_path}")
|
||||
|
||||
# Find the class definition
|
||||
class_pattern = rf"^(class {class_name}\(.*\):)"
|
||||
match = re.search(class_pattern, text, flags=re.MULTILINE)
|
||||
if not match:
|
||||
raise RuntimeError(f"Class {class_name} not found in {stub_path}")
|
||||
|
||||
# Where to insert
|
||||
insert_pos = match.end()
|
||||
|
||||
# Prepare overload block and indent
|
||||
overloads = generate_overloads(model)
|
||||
overload_text = "\n".join(f" {line}" for line in overloads)
|
||||
|
||||
# Insert after class line
|
||||
new_text = text[:insert_pos] + "\n" + overload_text + text[insert_pos:]
|
||||
|
||||
# Write result
|
||||
stub_path.write_text(new_text)
|
||||
log.info(f"Injected overloads into {stub_path}")
|
||||
|
||||
|
||||
def run_stubgen(module: str, out_dir: Path) -> Path:
|
||||
"""Run stubgen for a module and return the generated pyi path."""
|
||||
subprocess.run(
|
||||
["stubgen", "-m", module, "--include-private", "-o", str(out_dir)],
|
||||
check=True,
|
||||
)
|
||||
# Figure out the generated file path
|
||||
pyi_path = out_dir / Path(module.replace(".", "/") + ".pyi")
|
||||
if not pyi_path.exists():
|
||||
raise FileNotFoundError(f"Stubgen did not generate {pyi_path}")
|
||||
return pyi_path
|
||||
|
||||
|
||||
def format_file(stub_path: Path) -> None:
|
||||
"""Run ruff fix on the generated stub file."""
|
||||
subprocess.run(
|
||||
[
|
||||
"ruff",
|
||||
"check",
|
||||
str(stub_path),
|
||||
"--fix",
|
||||
"--unsafe-fixes",
|
||||
"--silent",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
def ensure_imports(stub_path: Path) -> None:
|
||||
"""Ensure multiple imports are present in the generated stub file."""
|
||||
text = stub_path.read_text()
|
||||
|
||||
if "from typing import Literal" not in text:
|
||||
insert_pos = text.find(
|
||||
"from typing"
|
||||
) # Attempt to find the first typing import
|
||||
if insert_pos == -1:
|
||||
# No existing import, add at the top of the file
|
||||
insert_pos = 0
|
||||
else:
|
||||
# Otherwise, find the position after the first import line
|
||||
insert_pos = text.find("\n", insert_pos) + 1
|
||||
|
||||
# Add Literal import
|
||||
new_text = (
|
||||
text[:insert_pos]
|
||||
+ "from typing import Literal, Any, overload\n"
|
||||
+ text[insert_pos:]
|
||||
)
|
||||
stub_path.write_text(new_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
out_dir = Path(__file__).parent.parent
|
||||
file = run_stubgen("beets.library.models", out_dir)
|
||||
ensure_imports(file)
|
||||
inject_overloads(file, Item)
|
||||
inject_overloads(file, Album)
|
||||
|
|
@ -1,4 +1,7 @@
|
|||
import time
|
||||
from tkinter import N
|
||||
|
||||
import pytest
|
||||
|
||||
import beets
|
||||
from beets.dbcore import types
|
||||
|
|
@ -56,3 +59,34 @@ def test_durationtype():
|
|||
beets.config["format_raw_length"] = True
|
||||
assert 61.23 == t.format(61.23)
|
||||
assert 3601.23 == t.format(3601.23)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"type, type_params",
|
||||
[
|
||||
(types.DEFAULT, (str, type(None))),
|
||||
# ints
|
||||
(types.Integer(), (int, int)),
|
||||
(types.NullInteger(), (int, type(None))),
|
||||
(types.PaddedInt(2), (int, int)),
|
||||
(types.NullPaddedInt(3), (int, type(None))),
|
||||
(types.ScaledInt(2, "foo"), (int, int)),
|
||||
(types.Id(), (int, type(None))),
|
||||
# floats
|
||||
(types.Float(), (float, float)),
|
||||
(types.NullFloat(), (float, type(None))),
|
||||
(types.DateType(), (float, float)),
|
||||
(types.DurationType(), (float, float)),
|
||||
# Strings
|
||||
(types.String(), (str, str)),
|
||||
(types.DelimitedString(","), (list[str], list[str])),
|
||||
(types.MusicalKey(), (str, type(None))),
|
||||
# Other
|
||||
(types.Boolean(), (bool, bool)),
|
||||
# Paths
|
||||
(types.PathType(), (bytes, bytes)),
|
||||
(types.NullPathType(), (bytes, type(None))),
|
||||
],
|
||||
)
|
||||
def test_get_type_parameters(type: types.Type, type_params):
|
||||
assert type_params == types.get_type_parameters(type)
|
||||
|
|
|
|||
Loading…
Reference in a new issue