Spaces:
Running
Running
from __future__ import annotations | |
import sys | |
import threading | |
from collections.abc import Awaitable, Callable, Generator | |
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait | |
from contextlib import AbstractContextManager, contextmanager | |
from dataclasses import dataclass, field | |
from inspect import isawaitable | |
from types import TracebackType | |
from typing import ( | |
Any, | |
AsyncContextManager, | |
ContextManager, | |
Generic, | |
Iterable, | |
TypeVar, | |
cast, | |
overload, | |
) | |
from ._core import _eventloop | |
from ._core._eventloop import get_async_backend, get_cancelled_exc_class, threadlocals | |
from ._core._synchronization import Event | |
from ._core._tasks import CancelScope, create_task_group | |
from .abc import AsyncBackend | |
from .abc._tasks import TaskStatus | |
if sys.version_info >= (3, 11): | |
from typing import TypeVarTuple, Unpack | |
else: | |
from typing_extensions import TypeVarTuple, Unpack | |
T_Retval = TypeVar("T_Retval") | |
T_co = TypeVar("T_co", covariant=True) | |
PosArgsT = TypeVarTuple("PosArgsT") | |
def run( | |
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT] | |
) -> T_Retval: | |
""" | |
Call a coroutine function from a worker thread. | |
:param func: a coroutine function | |
:param args: positional arguments for the callable | |
:return: the return value of the coroutine function | |
""" | |
try: | |
async_backend = threadlocals.current_async_backend | |
token = threadlocals.current_token | |
except AttributeError: | |
raise RuntimeError( | |
"This function can only be run from an AnyIO worker thread" | |
) from None | |
return async_backend.run_async_from_thread(func, args, token=token) | |
def run_sync( | |
func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] | |
) -> T_Retval: | |
""" | |
Call a function in the event loop thread from a worker thread. | |
:param func: a callable | |
:param args: positional arguments for the callable | |
:return: the return value of the callable | |
""" | |
try: | |
async_backend = threadlocals.current_async_backend | |
token = threadlocals.current_token | |
except AttributeError: | |
raise RuntimeError( | |
"This function can only be run from an AnyIO worker thread" | |
) from None | |
return async_backend.run_sync_from_thread(func, args, token=token) | |
class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager): | |
_enter_future: Future[T_co] | |
_exit_future: Future[bool | None] | |
_exit_event: Event | |
_exit_exc_info: tuple[ | |
type[BaseException] | None, BaseException | None, TracebackType | None | |
] = (None, None, None) | |
def __init__(self, async_cm: AsyncContextManager[T_co], portal: BlockingPortal): | |
self._async_cm = async_cm | |
self._portal = portal | |
async def run_async_cm(self) -> bool | None: | |
try: | |
self._exit_event = Event() | |
value = await self._async_cm.__aenter__() | |
except BaseException as exc: | |
self._enter_future.set_exception(exc) | |
raise | |
else: | |
self._enter_future.set_result(value) | |
try: | |
# Wait for the sync context manager to exit. | |
# This next statement can raise `get_cancelled_exc_class()` if | |
# something went wrong in a task group in this async context | |
# manager. | |
await self._exit_event.wait() | |
finally: | |
# In case of cancellation, it could be that we end up here before | |
# `_BlockingAsyncContextManager.__exit__` is called, and an | |
# `_exit_exc_info` has been set. | |
result = await self._async_cm.__aexit__(*self._exit_exc_info) | |
return result | |
def __enter__(self) -> T_co: | |
self._enter_future = Future() | |
self._exit_future = self._portal.start_task_soon(self.run_async_cm) | |
return self._enter_future.result() | |
def __exit__( | |
self, | |
__exc_type: type[BaseException] | None, | |
__exc_value: BaseException | None, | |
__traceback: TracebackType | None, | |
) -> bool | None: | |
self._exit_exc_info = __exc_type, __exc_value, __traceback | |
self._portal.call(self._exit_event.set) | |
return self._exit_future.result() | |
class _BlockingPortalTaskStatus(TaskStatus): | |
def __init__(self, future: Future): | |
self._future = future | |
def started(self, value: object = None) -> None: | |
self._future.set_result(value) | |
class BlockingPortal: | |
"""An object that lets external threads run code in an asynchronous event loop.""" | |
def __new__(cls) -> BlockingPortal: | |
return get_async_backend().create_blocking_portal() | |
def __init__(self) -> None: | |
self._event_loop_thread_id: int | None = threading.get_ident() | |
self._stop_event = Event() | |
self._task_group = create_task_group() | |
self._cancelled_exc_class = get_cancelled_exc_class() | |
async def __aenter__(self) -> BlockingPortal: | |
await self._task_group.__aenter__() | |
return self | |
async def __aexit__( | |
self, | |
exc_type: type[BaseException] | None, | |
exc_val: BaseException | None, | |
exc_tb: TracebackType | None, | |
) -> bool | None: | |
await self.stop() | |
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) | |
def _check_running(self) -> None: | |
if self._event_loop_thread_id is None: | |
raise RuntimeError("This portal is not running") | |
if self._event_loop_thread_id == threading.get_ident(): | |
raise RuntimeError( | |
"This method cannot be called from the event loop thread" | |
) | |
async def sleep_until_stopped(self) -> None: | |
"""Sleep until :meth:`stop` is called.""" | |
await self._stop_event.wait() | |
async def stop(self, cancel_remaining: bool = False) -> None: | |
""" | |
Signal the portal to shut down. | |
This marks the portal as no longer accepting new calls and exits from | |
:meth:`sleep_until_stopped`. | |
:param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False`` | |
to let them finish before returning | |
""" | |
self._event_loop_thread_id = None | |
self._stop_event.set() | |
if cancel_remaining: | |
self._task_group.cancel_scope.cancel() | |
async def _call_func( | |
self, | |
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], | |
args: tuple[Unpack[PosArgsT]], | |
kwargs: dict[str, Any], | |
future: Future[T_Retval], | |
) -> None: | |
def callback(f: Future[T_Retval]) -> None: | |
if f.cancelled() and self._event_loop_thread_id not in ( | |
None, | |
threading.get_ident(), | |
): | |
self.call(scope.cancel) | |
try: | |
retval_or_awaitable = func(*args, **kwargs) | |
if isawaitable(retval_or_awaitable): | |
with CancelScope() as scope: | |
if future.cancelled(): | |
scope.cancel() | |
else: | |
future.add_done_callback(callback) | |
retval = await retval_or_awaitable | |
else: | |
retval = retval_or_awaitable | |
except self._cancelled_exc_class: | |
future.cancel() | |
future.set_running_or_notify_cancel() | |
except BaseException as exc: | |
if not future.cancelled(): | |
future.set_exception(exc) | |
# Let base exceptions fall through | |
if not isinstance(exc, Exception): | |
raise | |
else: | |
if not future.cancelled(): | |
future.set_result(retval) | |
finally: | |
scope = None # type: ignore[assignment] | |
def _spawn_task_from_thread( | |
self, | |
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], | |
args: tuple[Unpack[PosArgsT]], | |
kwargs: dict[str, Any], | |
name: object, | |
future: Future[T_Retval], | |
) -> None: | |
""" | |
Spawn a new task using the given callable. | |
Implementors must ensure that the future is resolved when the task finishes. | |
:param func: a callable | |
:param args: positional arguments to be passed to the callable | |
:param kwargs: keyword arguments to be passed to the callable | |
:param name: name of the task (will be coerced to a string if not ``None``) | |
:param future: a future that will resolve to the return value of the callable, | |
or the exception raised during its execution | |
""" | |
raise NotImplementedError | |
def call( | |
self, | |
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], | |
*args: Unpack[PosArgsT], | |
) -> T_Retval: ... | |
def call( | |
self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] | |
) -> T_Retval: ... | |
def call( | |
self, | |
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], | |
*args: Unpack[PosArgsT], | |
) -> T_Retval: | |
""" | |
Call the given function in the event loop thread. | |
If the callable returns a coroutine object, it is awaited on. | |
:param func: any callable | |
:raises RuntimeError: if the portal is not running or if this method is called | |
from within the event loop thread | |
""" | |
return cast(T_Retval, self.start_task_soon(func, *args).result()) | |
def start_task_soon( | |
self, | |
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], | |
*args: Unpack[PosArgsT], | |
name: object = None, | |
) -> Future[T_Retval]: ... | |
def start_task_soon( | |
self, | |
func: Callable[[Unpack[PosArgsT]], T_Retval], | |
*args: Unpack[PosArgsT], | |
name: object = None, | |
) -> Future[T_Retval]: ... | |
def start_task_soon( | |
self, | |
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], | |
*args: Unpack[PosArgsT], | |
name: object = None, | |
) -> Future[T_Retval]: | |
""" | |
Start a task in the portal's task group. | |
The task will be run inside a cancel scope which can be cancelled by cancelling | |
the returned future. | |
:param func: the target function | |
:param args: positional arguments passed to ``func`` | |
:param name: name of the task (will be coerced to a string if not ``None``) | |
:return: a future that resolves with the return value of the callable if the | |
task completes successfully, or with the exception raised in the task | |
:raises RuntimeError: if the portal is not running or if this method is called | |
from within the event loop thread | |
:rtype: concurrent.futures.Future[T_Retval] | |
.. versionadded:: 3.0 | |
""" | |
self._check_running() | |
f: Future[T_Retval] = Future() | |
self._spawn_task_from_thread(func, args, {}, name, f) | |
return f | |
def start_task( | |
self, | |
func: Callable[..., Awaitable[T_Retval]], | |
*args: object, | |
name: object = None, | |
) -> tuple[Future[T_Retval], Any]: | |
""" | |
Start a task in the portal's task group and wait until it signals for readiness. | |
This method works the same way as :meth:`.abc.TaskGroup.start`. | |
:param func: the target function | |
:param args: positional arguments passed to ``func`` | |
:param name: name of the task (will be coerced to a string if not ``None``) | |
:return: a tuple of (future, task_status_value) where the ``task_status_value`` | |
is the value passed to ``task_status.started()`` from within the target | |
function | |
:rtype: tuple[concurrent.futures.Future[T_Retval], Any] | |
.. versionadded:: 3.0 | |
""" | |
def task_done(future: Future[T_Retval]) -> None: | |
if not task_status_future.done(): | |
if future.cancelled(): | |
task_status_future.cancel() | |
elif future.exception(): | |
task_status_future.set_exception(future.exception()) | |
else: | |
exc = RuntimeError( | |
"Task exited without calling task_status.started()" | |
) | |
task_status_future.set_exception(exc) | |
self._check_running() | |
task_status_future: Future = Future() | |
task_status = _BlockingPortalTaskStatus(task_status_future) | |
f: Future = Future() | |
f.add_done_callback(task_done) | |
self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f) | |
return f, task_status_future.result() | |
def wrap_async_context_manager( | |
self, cm: AsyncContextManager[T_co] | |
) -> ContextManager[T_co]: | |
""" | |
Wrap an async context manager as a synchronous context manager via this portal. | |
Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping | |
in the middle until the synchronous context manager exits. | |
:param cm: an asynchronous context manager | |
:return: a synchronous context manager | |
.. versionadded:: 2.1 | |
""" | |
return _BlockingAsyncContextManager(cm, self) | |
class BlockingPortalProvider: | |
""" | |
A manager for a blocking portal. Used as a context manager. The first thread to | |
enter this context manager causes a blocking portal to be started with the specific | |
parameters, and the last thread to exit causes the portal to be shut down. Thus, | |
there will be exactly one blocking portal running in this context as long as at | |
least one thread has entered this context manager. | |
The parameters are the same as for :func:`~anyio.run`. | |
:param backend: name of the backend | |
:param backend_options: backend options | |
.. versionadded:: 4.4 | |
""" | |
backend: str = "asyncio" | |
backend_options: dict[str, Any] | None = None | |
_lock: threading.Lock = field(init=False, default_factory=threading.Lock) | |
_leases: int = field(init=False, default=0) | |
_portal: BlockingPortal = field(init=False) | |
_portal_cm: AbstractContextManager[BlockingPortal] | None = field( | |
init=False, default=None | |
) | |
def __enter__(self) -> BlockingPortal: | |
with self._lock: | |
if self._portal_cm is None: | |
self._portal_cm = start_blocking_portal( | |
self.backend, self.backend_options | |
) | |
self._portal = self._portal_cm.__enter__() | |
self._leases += 1 | |
return self._portal | |
def __exit__( | |
self, | |
exc_type: type[BaseException] | None, | |
exc_val: BaseException | None, | |
exc_tb: TracebackType | None, | |
) -> None: | |
portal_cm: AbstractContextManager[BlockingPortal] | None = None | |
with self._lock: | |
assert self._portal_cm | |
assert self._leases > 0 | |
self._leases -= 1 | |
if not self._leases: | |
portal_cm = self._portal_cm | |
self._portal_cm = None | |
del self._portal | |
if portal_cm: | |
portal_cm.__exit__(None, None, None) | |
def start_blocking_portal( | |
backend: str = "asyncio", backend_options: dict[str, Any] | None = None | |
) -> Generator[BlockingPortal, Any, None]: | |
""" | |
Start a new event loop in a new thread and run a blocking portal in its main task. | |
The parameters are the same as for :func:`~anyio.run`. | |
:param backend: name of the backend | |
:param backend_options: backend options | |
:return: a context manager that yields a blocking portal | |
.. versionchanged:: 3.0 | |
Usage as a context manager is now required. | |
""" | |
async def run_portal() -> None: | |
async with BlockingPortal() as portal_: | |
if future.set_running_or_notify_cancel(): | |
future.set_result(portal_) | |
await portal_.sleep_until_stopped() | |
future: Future[BlockingPortal] = Future() | |
with ThreadPoolExecutor(1) as executor: | |
run_future = executor.submit( | |
_eventloop.run, # type: ignore[arg-type] | |
run_portal, | |
backend=backend, | |
backend_options=backend_options, | |
) | |
try: | |
wait( | |
cast(Iterable[Future], [run_future, future]), | |
return_when=FIRST_COMPLETED, | |
) | |
except BaseException: | |
future.cancel() | |
run_future.cancel() | |
raise | |
if future.done(): | |
portal = future.result() | |
cancel_remaining_tasks = False | |
try: | |
yield portal | |
except BaseException: | |
cancel_remaining_tasks = True | |
raise | |
finally: | |
try: | |
portal.call(portal.stop, cancel_remaining_tasks) | |
except RuntimeError: | |
pass | |
run_future.result() | |
def check_cancelled() -> None: | |
""" | |
Check if the cancel scope of the host task's running the current worker thread has | |
been cancelled. | |
If the host task's current cancel scope has indeed been cancelled, the | |
backend-specific cancellation exception will be raised. | |
:raises RuntimeError: if the current thread was not spawned by | |
:func:`.to_thread.run_sync` | |
""" | |
try: | |
async_backend: AsyncBackend = threadlocals.current_async_backend | |
except AttributeError: | |
raise RuntimeError( | |
"This function can only be run from an AnyIO worker thread" | |
) from None | |
async_backend.check_cancelled() | |