Removed Optional and Union and resolved a minor mypy shadowing issue.

This commit is contained in:
Sebastian Mohr 2025-02-09 20:23:19 +01:00
parent fdf7afbfe3
commit c17a774dd6

View file

@ -30,11 +30,12 @@ up a bottleneck stage by dividing its work among multiple threads.
To do so, pass an iterable of coroutines to the Pipeline constructor
in place of any single coroutine.
"""
from __future__ import annotations
import queue
import sys
from threading import Lock, Thread
from typing import Callable, Generator, Optional, Union
from typing import Callable, Generator
from typing_extensions import TypeVar, TypeVarTuple, Unpack
@ -152,15 +153,20 @@ def multiple(messages):
return MultiMessage(messages)
A = TypeVarTuple("A")
T = TypeVar("T")
# Arguments of the function (omitting the task)
Args = TypeVarTuple("Args")
# Task as an additional argument to the function
Task = TypeVar("Task")
# Return type of the function (should normally be task but sadly
# we cant enforce this with the current stage functions without
# a refactor)
R = TypeVar("R")
def stage(
func: Callable[
[Unpack[A], T],
Optional[R],
[Unpack[Args], Task],
R | None,
],
):
"""Decorate a function to become a simple stage.
@ -176,8 +182,8 @@ def stage(
[3, 4, 5]
"""
def coro(*args: Unpack[A]) -> Generator[Union[R, T, None], T, R]:
task = None
def coro(*args: Unpack[Args]) -> Generator[R | Task | None, Task, None]:
task: R | Task | None = None
while True:
task = yield task
task = func(*(args + (task,)))
@ -185,7 +191,7 @@ def stage(
return coro
def mutator_stage(func: Callable[[Unpack[A], T], R]):
def mutator_stage(func: Callable[[Unpack[Args], Task], R]):
"""Decorate a function that manipulates items in a coroutine to
become a simple stage.
@ -200,7 +206,7 @@ def mutator_stage(func: Callable[[Unpack[A], T], R]):
[{'x': True}, {'a': False, 'x': True}]
"""
def coro(*args: Unpack[A]) -> Generator[Union[T, None], T, None]:
def coro(*args: Unpack[Args]) -> Generator[Task | None, Task, None]:
task = None
while True:
task = yield task
@ -419,9 +425,7 @@ class Pipeline:
for i in range(1, queue_count):
for coro in self.stages[i]:
threads.append(
MiddlePipelineThread(
coro, queues[i - 1], queues[i], threads
)
MiddlePipelineThread(coro, queues[i - 1], queues[i], threads)
)
# Last stage.