Added type hints for pipeline stage decorators

This commit is contained in:
Sebastian Mohr 2025-02-01 15:08:41 +01:00
parent c83f2e4e71
commit 09b15aaf52

View file

@ -31,9 +31,13 @@ To do so, pass an iterable of coroutines to the Pipeline constructor
in place of any single coroutine.
"""
import queue
import sys
from threading import Lock, Thread
from typing import Callable, Generator, Optional, Union
from typing_extensions import TypeVar, TypeVarTuple, Unpack
BUBBLE = "__PIPELINE_BUBBLE__"
POISON = "__PIPELINE_POISON__"
@ -149,7 +153,17 @@ def multiple(messages):
return MultiMessage(messages)
def stage(func):
A = TypeVarTuple("A")
T = TypeVar("T")
R = TypeVar("R")
def stage(
func: Callable[
[Unpack[A], T],
R,
],
):
"""Decorate a function to become a simple stage.
>>> @stage
@ -163,7 +177,7 @@ def stage(func):
[3, 4, 5]
"""
def coro(*args):
def coro(*args: Unpack[A]) -> Generator[Union[R, T, None], T, None]:
task = None
while True:
task = yield task
@ -172,7 +186,7 @@ def stage(func):
return coro
def mutator_stage(func):
def mutator_stage(func: Callable[[Unpack[A], T], None]):
"""Decorate a function that manipulates items in a coroutine to
become a simple stage.
@ -187,7 +201,7 @@ def mutator_stage(func):
[{'x': True}, {'a': False, 'x': True}]
"""
def coro(*args):
def coro(*args: Unpack[A]) -> Generator[Optional[T], T, None]:
task = None
while True:
task = yield task
@ -406,9 +420,7 @@ class Pipeline:
for i in range(1, queue_count):
for coro in self.stages[i]:
threads.append(
MiddlePipelineThread(
coro, queues[i - 1], queues[i], threads
)
MiddlePipelineThread(coro, queues[i - 1], queues[i], threads)
)
# Last stage.