make cached_classproperty more robust

This commit is contained in:
Konstantin 2025-10-19 23:07:48 +02:00
parent b924dfcd8c
commit 34ac256d5f
2 changed files with 52 additions and 23 deletions

View file

@ -26,8 +26,10 @@ import shutil
import subprocess
import sys
import tempfile
import threading
import traceback
import warnings
import weakref
from collections import Counter
from collections.abc import Sequence
from contextlib import suppress
@ -1061,24 +1063,12 @@ class cached_classproperty(Generic[T]):
instance properties, this operates on the class rather than instances.
"""
cache: ClassVar[dict[tuple[type[object], str], object]] = {}
_cache: ClassVar[
weakref.WeakKeyDictionary[type[object], dict[str, object]]
] = weakref.WeakKeyDictionary()
_lock: ClassVar[threading.RLock] = threading.RLock()
name: str | None = None
name: str = ""
# Ideally, we would like to use `Callable[[type[T]], Any]` here,
# however, `mypy` is unable to see this as a **class** property, and thinks
# that this callable receives an **instance** of the object, failing the
# type check, for example:
# >>> class Album:
# >>> @cached_classproperty
# >>> def foo(cls):
# >>> reveal_type(cls) # mypy: revealed type is "Album"
# >>> return cls.bar
#
# Argument 1 to "cached_classproperty" has incompatible type
# "Callable[[Album], ...]"; expected "Callable[[type[Album]], ...]"
#
# Therefore, we just use `Any` here, which is not ideal, but works.
def __init__(self, getter: Callable[..., T]) -> None:
"""Initialize the descriptor with the property getter function."""
self.getter: Callable[..., T] = getter
@ -1089,11 +1079,50 @@ class cached_classproperty(Generic[T]):
def __get__(self, instance: object, owner: type[object]) -> T:
"""Compute and cache if needed, and return the property value."""
key: tuple[type[object], str] = owner, self.name
if key not in self.cache:
self.cache[key] = self.getter(owner)
if self.name is None:
raise RuntimeError(
f"{self.__class__.__name__} was not properly initialized. " # noqa: ISC003
+ "__set_name__ was never called. This usually happens when "
+ "the descriptor is used outside of a class definition."
)
return cast(T, self.cache[key])
# First check without lock for performance
class_cache: dict[str, object] | None = self._cache.get(owner)
if class_cache is not None:
try:
# We know this is safe because we only put T values in the cache
return cast(T, class_cache[self.name])
except KeyError:
...
# Compute and cache with lock
with self._lock:
# Double-check inside lock
class_cache = self._cache.setdefault(owner, {})
try:
return cast(T, class_cache[self.name])
except KeyError:
...
# Compute and cache new value
value: T = self.getter(owner)
class_cache[self.name] = value
return value
@classmethod
def clear_cache(
cls, owner: type[object] | None = None, name: str | None = None
) -> None:
"""Clear cache for specific class/property or entire cache."""
if owner is None:
cls._cache.clear()
elif name is None:
keys_to_remove = [k for k in cls._cache.keys() if k[0] == owner]
for key in keys_to_remove:
del cls._cache[key]
else:
_ = cls._cache.pop(owner, None)
class LazySharedInstance(Generic[T]):

View file

@ -51,5 +51,5 @@ def pytest_assertrepr_compare(op, left, right):
@pytest.fixture(autouse=True)
def clear_cached_classproperty():
cached_classproperty.cache.clear()
def clear_cached_classproperty() -> None:
cached_classproperty.clear_cache()