Spaces:
Running
Running
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license | |
from typing import Dict | |
import dns.exception | |
# pylint: disable=unused-import | |
from dns._asyncbackend import ( # noqa: F401 lgtm[py/unused-import] | |
Backend, | |
DatagramSocket, | |
Socket, | |
StreamSocket, | |
) | |
# pylint: enable=unused-import | |
_default_backend = None | |
_backends: Dict[str, Backend] = {} | |
# Allow sniffio import to be disabled for testing purposes | |
_no_sniffio = False | |
class AsyncLibraryNotFoundError(dns.exception.DNSException): | |
pass | |
def get_backend(name: str) -> Backend: | |
"""Get the specified asynchronous backend. | |
*name*, a ``str``, the name of the backend. Currently the "trio" | |
and "asyncio" backends are available. | |
Raises NotImplementedError if an unknown backend name is specified. | |
""" | |
# pylint: disable=import-outside-toplevel,redefined-outer-name | |
backend = _backends.get(name) | |
if backend: | |
return backend | |
if name == "trio": | |
import dns._trio_backend | |
backend = dns._trio_backend.Backend() | |
elif name == "asyncio": | |
import dns._asyncio_backend | |
backend = dns._asyncio_backend.Backend() | |
else: | |
raise NotImplementedError(f"unimplemented async backend {name}") | |
_backends[name] = backend | |
return backend | |
def sniff() -> str: | |
"""Attempt to determine the in-use asynchronous I/O library by using | |
the ``sniffio`` module if it is available. | |
Returns the name of the library, or raises AsyncLibraryNotFoundError | |
if the library cannot be determined. | |
""" | |
# pylint: disable=import-outside-toplevel | |
try: | |
if _no_sniffio: | |
raise ImportError | |
import sniffio | |
try: | |
return sniffio.current_async_library() | |
except sniffio.AsyncLibraryNotFoundError: | |
raise AsyncLibraryNotFoundError("sniffio cannot determine async library") | |
except ImportError: | |
import asyncio | |
try: | |
asyncio.get_running_loop() | |
return "asyncio" | |
except RuntimeError: | |
raise AsyncLibraryNotFoundError("no async library detected") | |
def get_default_backend() -> Backend: | |
"""Get the default backend, initializing it if necessary.""" | |
if _default_backend: | |
return _default_backend | |
return set_default_backend(sniff()) | |
def set_default_backend(name: str) -> Backend: | |
"""Set the default backend. | |
It's not normally necessary to call this method, as | |
``get_default_backend()`` will initialize the backend | |
appropriately in many cases. If ``sniffio`` is not installed, or | |
in testing situations, this function allows the backend to be set | |
explicitly. | |
""" | |
global _default_backend | |
_default_backend = get_backend(name) | |
return _default_backend | |