diff --git a/beets/util/pipeline.py b/beets/util/pipeline.py index d797dea13..3a31175b6 100644 --- a/beets/util/pipeline.py +++ b/beets/util/pipeline.py @@ -30,11 +30,12 @@ up a bottleneck stage by dividing its work among multiple threads. To do so, pass an iterable of coroutines to the Pipeline constructor in place of any single coroutine. """ +from __future__ import annotations import queue import sys from threading import Lock, Thread -from typing import Callable, Generator, Optional, Union +from typing import Callable, Generator from typing_extensions import TypeVar, TypeVarTuple, Unpack @@ -152,15 +153,20 @@ def multiple(messages): return MultiMessage(messages) -A = TypeVarTuple("A") -T = TypeVar("T") +# Arguments of the function (omitting the task) +Args = TypeVarTuple("Args") +# Task as an additional argument to the function +Task = TypeVar("Task") +# Return type of the function (should normally be task but sadly +# we cant enforce this with the current stage functions without +# a refactor) R = TypeVar("R") def stage( func: Callable[ - [Unpack[A], T], - Optional[R], + [Unpack[Args], Task], + R | None, ], ): """Decorate a function to become a simple stage. @@ -176,8 +182,8 @@ def stage( [3, 4, 5] """ - def coro(*args: Unpack[A]) -> Generator[Union[R, T, None], T, R]: - task = None + def coro(*args: Unpack[Args]) -> Generator[R | Task | None, Task, None]: + task: R | Task | None = None while True: task = yield task task = func(*(args + (task,))) @@ -185,7 +191,7 @@ def stage( return coro -def mutator_stage(func: Callable[[Unpack[A], T], R]): +def mutator_stage(func: Callable[[Unpack[Args], Task], R]): """Decorate a function that manipulates items in a coroutine to become a simple stage. @@ -200,7 +206,7 @@ def mutator_stage(func: Callable[[Unpack[A], T], R]): [{'x': True}, {'a': False, 'x': True}] """ - def coro(*args: Unpack[A]) -> Generator[Union[T, None], T, None]: + def coro(*args: Unpack[Args]) -> Generator[Task | None, Task, None]: task = None while True: task = yield task @@ -419,9 +425,7 @@ class Pipeline: for i in range(1, queue_count): for coro in self.stages[i]: threads.append( - MiddlePipelineThread( - coro, queues[i - 1], queues[i], threads - ) + MiddlePipelineThread(coro, queues[i - 1], queues[i], threads) ) # Last stage.