Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import io | |
import socket | |
import ssl | |
import typing | |
from ..exceptions import ProxySchemeUnsupported | |
if typing.TYPE_CHECKING: | |
from typing_extensions import Literal | |
from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT | |
_SelfT = typing.TypeVar("_SelfT", bound="SSLTransport") | |
_WriteBuffer = typing.Union[bytearray, memoryview] | |
_ReturnValue = typing.TypeVar("_ReturnValue") | |
SSL_BLOCKSIZE = 16384 | |
class SSLTransport: | |
""" | |
The SSLTransport wraps an existing socket and establishes an SSL connection. | |
Contrary to Python's implementation of SSLSocket, it allows you to chain | |
multiple TLS connections together. It's particularly useful if you need to | |
implement TLS within TLS. | |
The class supports most of the socket API operations. | |
""" | |
def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None: | |
""" | |
Raises a ProxySchemeUnsupported if the provided ssl_context can't be used | |
for TLS in TLS. | |
The only requirement is that the ssl_context provides the 'wrap_bio' | |
methods. | |
""" | |
if not hasattr(ssl_context, "wrap_bio"): | |
raise ProxySchemeUnsupported( | |
"TLS in TLS requires SSLContext.wrap_bio() which isn't " | |
"available on non-native SSLContext" | |
) | |
def __init__( | |
self, | |
socket: socket.socket, | |
ssl_context: ssl.SSLContext, | |
server_hostname: str | None = None, | |
suppress_ragged_eofs: bool = True, | |
) -> None: | |
""" | |
Create an SSLTransport around socket using the provided ssl_context. | |
""" | |
self.incoming = ssl.MemoryBIO() | |
self.outgoing = ssl.MemoryBIO() | |
self.suppress_ragged_eofs = suppress_ragged_eofs | |
self.socket = socket | |
self.sslobj = ssl_context.wrap_bio( | |
self.incoming, self.outgoing, server_hostname=server_hostname | |
) | |
# Perform initial handshake. | |
self._ssl_io_loop(self.sslobj.do_handshake) | |
def __enter__(self: _SelfT) -> _SelfT: | |
return self | |
def __exit__(self, *_: typing.Any) -> None: | |
self.close() | |
def fileno(self) -> int: | |
return self.socket.fileno() | |
def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes: | |
return self._wrap_ssl_read(len, buffer) | |
def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes: | |
if flags != 0: | |
raise ValueError("non-zero flags not allowed in calls to recv") | |
return self._wrap_ssl_read(buflen) | |
def recv_into( | |
self, | |
buffer: _WriteBuffer, | |
nbytes: int | None = None, | |
flags: int = 0, | |
) -> None | int | bytes: | |
if flags != 0: | |
raise ValueError("non-zero flags not allowed in calls to recv_into") | |
if nbytes is None: | |
nbytes = len(buffer) | |
return self.read(nbytes, buffer) | |
def sendall(self, data: bytes, flags: int = 0) -> None: | |
if flags != 0: | |
raise ValueError("non-zero flags not allowed in calls to sendall") | |
count = 0 | |
with memoryview(data) as view, view.cast("B") as byte_view: | |
amount = len(byte_view) | |
while count < amount: | |
v = self.send(byte_view[count:]) | |
count += v | |
def send(self, data: bytes, flags: int = 0) -> int: | |
if flags != 0: | |
raise ValueError("non-zero flags not allowed in calls to send") | |
return self._ssl_io_loop(self.sslobj.write, data) | |
def makefile( | |
self, | |
mode: str, | |
buffering: int | None = None, | |
*, | |
encoding: str | None = None, | |
errors: str | None = None, | |
newline: str | None = None, | |
) -> typing.BinaryIO | typing.TextIO | socket.SocketIO: | |
""" | |
Python's httpclient uses makefile and buffered io when reading HTTP | |
messages and we need to support it. | |
This is unfortunately a copy and paste of socket.py makefile with small | |
changes to point to the socket directly. | |
""" | |
if not set(mode) <= {"r", "w", "b"}: | |
raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)") | |
writing = "w" in mode | |
reading = "r" in mode or not writing | |
assert reading or writing | |
binary = "b" in mode | |
rawmode = "" | |
if reading: | |
rawmode += "r" | |
if writing: | |
rawmode += "w" | |
raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type] | |
self.socket._io_refs += 1 # type: ignore[attr-defined] | |
if buffering is None: | |
buffering = -1 | |
if buffering < 0: | |
buffering = io.DEFAULT_BUFFER_SIZE | |
if buffering == 0: | |
if not binary: | |
raise ValueError("unbuffered streams must be binary") | |
return raw | |
buffer: typing.BinaryIO | |
if reading and writing: | |
buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment] | |
elif reading: | |
buffer = io.BufferedReader(raw, buffering) | |
else: | |
assert writing | |
buffer = io.BufferedWriter(raw, buffering) | |
if binary: | |
return buffer | |
text = io.TextIOWrapper(buffer, encoding, errors, newline) | |
text.mode = mode # type: ignore[misc] | |
return text | |
def unwrap(self) -> None: | |
self._ssl_io_loop(self.sslobj.unwrap) | |
def close(self) -> None: | |
self.socket.close() | |
def getpeercert( | |
self, binary_form: Literal[False] = ... | |
) -> _TYPE_PEER_CERT_RET_DICT | None: | |
... | |
def getpeercert(self, binary_form: Literal[True]) -> bytes | None: | |
... | |
def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET: | |
return self.sslobj.getpeercert(binary_form) # type: ignore[return-value] | |
def version(self) -> str | None: | |
return self.sslobj.version() | |
def cipher(self) -> tuple[str, str, int] | None: | |
return self.sslobj.cipher() | |
def selected_alpn_protocol(self) -> str | None: | |
return self.sslobj.selected_alpn_protocol() | |
def selected_npn_protocol(self) -> str | None: | |
return self.sslobj.selected_npn_protocol() | |
def shared_ciphers(self) -> list[tuple[str, str, int]] | None: | |
return self.sslobj.shared_ciphers() | |
def compression(self) -> str | None: | |
return self.sslobj.compression() | |
def settimeout(self, value: float | None) -> None: | |
self.socket.settimeout(value) | |
def gettimeout(self) -> float | None: | |
return self.socket.gettimeout() | |
def _decref_socketios(self) -> None: | |
self.socket._decref_socketios() # type: ignore[attr-defined] | |
def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes: | |
try: | |
return self._ssl_io_loop(self.sslobj.read, len, buffer) | |
except ssl.SSLError as e: | |
if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: | |
return 0 # eof, return 0. | |
else: | |
raise | |
# func is sslobj.do_handshake or sslobj.unwrap | |
def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: | |
... | |
# func is sslobj.write, arg1 is data | |
def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: | |
... | |
# func is sslobj.read, arg1 is len, arg2 is buffer | |
def _ssl_io_loop( | |
self, | |
func: typing.Callable[[int, bytearray | None], bytes], | |
arg1: int, | |
arg2: bytearray | None, | |
) -> bytes: | |
... | |
def _ssl_io_loop( | |
self, | |
func: typing.Callable[..., _ReturnValue], | |
arg1: None | bytes | int = None, | |
arg2: bytearray | None = None, | |
) -> _ReturnValue: | |
"""Performs an I/O loop between incoming/outgoing and the socket.""" | |
should_loop = True | |
ret = None | |
while should_loop: | |
errno = None | |
try: | |
if arg1 is None and arg2 is None: | |
ret = func() | |
elif arg2 is None: | |
ret = func(arg1) | |
else: | |
ret = func(arg1, arg2) | |
except ssl.SSLError as e: | |
if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): | |
# WANT_READ, and WANT_WRITE are expected, others are not. | |
raise e | |
errno = e.errno | |
buf = self.outgoing.read() | |
self.socket.sendall(buf) | |
if errno is None: | |
should_loop = False | |
elif errno == ssl.SSL_ERROR_WANT_READ: | |
buf = self.socket.recv(SSL_BLOCKSIZE) | |
if buf: | |
self.incoming.write(buf) | |
else: | |
self.incoming.write_eof() | |
return typing.cast(_ReturnValue, ret) | |