Spaces:
Sleeping
Sleeping
"""Implementation of the WebSocket protocol. | |
`WebSockets <http://dev.w3.org/html5/websockets/>`_ allow for bidirectional | |
communication between the browser and server. WebSockets are supported in the | |
current versions of all major browsers. | |
This module implements the final version of the WebSocket protocol as | |
defined in `RFC 6455 <http://tools.ietf.org/html/rfc6455>`_. | |
.. versionchanged:: 4.0 | |
Removed support for the draft 76 protocol version. | |
""" | |
import abc | |
import asyncio | |
import base64 | |
import hashlib | |
import os | |
import sys | |
import struct | |
import tornado | |
from urllib.parse import urlparse | |
import warnings | |
import zlib | |
from tornado.concurrent import Future, future_set_result_unless_cancelled | |
from tornado.escape import utf8, native_str, to_unicode | |
from tornado import gen, httpclient, httputil | |
from tornado.ioloop import IOLoop, PeriodicCallback | |
from tornado.iostream import StreamClosedError, IOStream | |
from tornado.log import gen_log, app_log | |
from tornado.netutil import Resolver | |
from tornado import simple_httpclient | |
from tornado.queues import Queue | |
from tornado.tcpclient import TCPClient | |
from tornado.util import _websocket_mask | |
from typing import ( | |
TYPE_CHECKING, | |
cast, | |
Any, | |
Optional, | |
Dict, | |
Union, | |
List, | |
Awaitable, | |
Callable, | |
Tuple, | |
Type, | |
) | |
from types import TracebackType | |
if TYPE_CHECKING: | |
from typing_extensions import Protocol | |
# The zlib compressor types aren't actually exposed anywhere | |
# publicly, so declare protocols for the portions we use. | |
class _Compressor(Protocol): | |
def compress(self, data: bytes) -> bytes: | |
pass | |
def flush(self, mode: int) -> bytes: | |
pass | |
class _Decompressor(Protocol): | |
unconsumed_tail = b"" # type: bytes | |
def decompress(self, data: bytes, max_length: int) -> bytes: | |
pass | |
class _WebSocketDelegate(Protocol): | |
# The common base interface implemented by WebSocketHandler on | |
# the server side and WebSocketClientConnection on the client | |
# side. | |
def on_ws_connection_close( | |
self, close_code: Optional[int] = None, close_reason: Optional[str] = None | |
) -> None: | |
pass | |
def on_message(self, message: Union[str, bytes]) -> Optional["Awaitable[None]"]: | |
pass | |
def on_ping(self, data: bytes) -> None: | |
pass | |
def on_pong(self, data: bytes) -> None: | |
pass | |
def log_exception( | |
self, | |
typ: Optional[Type[BaseException]], | |
value: Optional[BaseException], | |
tb: Optional[TracebackType], | |
) -> None: | |
pass | |
_default_max_message_size = 10 * 1024 * 1024 | |
class WebSocketError(Exception): | |
pass | |
class WebSocketClosedError(WebSocketError): | |
"""Raised by operations on a closed connection. | |
.. versionadded:: 3.2 | |
""" | |
pass | |
class _DecompressTooLargeError(Exception): | |
pass | |
class _WebSocketParams(object): | |
def __init__( | |
self, | |
ping_interval: Optional[float] = None, | |
ping_timeout: Optional[float] = None, | |
max_message_size: int = _default_max_message_size, | |
compression_options: Optional[Dict[str, Any]] = None, | |
) -> None: | |
self.ping_interval = ping_interval | |
self.ping_timeout = ping_timeout | |
self.max_message_size = max_message_size | |
self.compression_options = compression_options | |
class WebSocketHandler(tornado.web.RequestHandler): | |
"""Subclass this class to create a basic WebSocket handler. | |
Override `on_message` to handle incoming messages, and use | |
`write_message` to send messages to the client. You can also | |
override `open` and `on_close` to handle opened and closed | |
connections. | |
Custom upgrade response headers can be sent by overriding | |
`~tornado.web.RequestHandler.set_default_headers` or | |
`~tornado.web.RequestHandler.prepare`. | |
See http://dev.w3.org/html5/websockets/ for details on the | |
JavaScript interface. The protocol is specified at | |
http://tools.ietf.org/html/rfc6455. | |
Here is an example WebSocket handler that echos back all received messages | |
back to the client: | |
.. testcode:: | |
class EchoWebSocket(tornado.websocket.WebSocketHandler): | |
def open(self): | |
print("WebSocket opened") | |
def on_message(self, message): | |
self.write_message(u"You said: " + message) | |
def on_close(self): | |
print("WebSocket closed") | |
.. testoutput:: | |
:hide: | |
WebSockets are not standard HTTP connections. The "handshake" is | |
HTTP, but after the handshake, the protocol is | |
message-based. Consequently, most of the Tornado HTTP facilities | |
are not available in handlers of this type. The only communication | |
methods available to you are `write_message()`, `ping()`, and | |
`close()`. Likewise, your request handler class should implement | |
`open()` method rather than ``get()`` or ``post()``. | |
If you map the handler above to ``/websocket`` in your application, you can | |
invoke it in JavaScript with:: | |
var ws = new WebSocket("ws://localhost:8888/websocket"); | |
ws.onopen = function() { | |
ws.send("Hello, world"); | |
}; | |
ws.onmessage = function (evt) { | |
alert(evt.data); | |
}; | |
This script pops up an alert box that says "You said: Hello, world". | |
Web browsers allow any site to open a websocket connection to any other, | |
instead of using the same-origin policy that governs other network | |
access from JavaScript. This can be surprising and is a potential | |
security hole, so since Tornado 4.0 `WebSocketHandler` requires | |
applications that wish to receive cross-origin websockets to opt in | |
by overriding the `~WebSocketHandler.check_origin` method (see that | |
method's docs for details). Failure to do so is the most likely | |
cause of 403 errors when making a websocket connection. | |
When using a secure websocket connection (``wss://``) with a self-signed | |
certificate, the connection from a browser may fail because it wants | |
to show the "accept this certificate" dialog but has nowhere to show it. | |
You must first visit a regular HTML page using the same certificate | |
to accept it before the websocket connection will succeed. | |
If the application setting ``websocket_ping_interval`` has a non-zero | |
value, a ping will be sent periodically, and the connection will be | |
closed if a response is not received before the ``websocket_ping_timeout``. | |
Messages larger than the ``websocket_max_message_size`` application setting | |
(default 10MiB) will not be accepted. | |
.. versionchanged:: 4.5 | |
Added ``websocket_ping_interval``, ``websocket_ping_timeout``, and | |
``websocket_max_message_size``. | |
""" | |
def __init__( | |
self, | |
application: tornado.web.Application, | |
request: httputil.HTTPServerRequest, | |
**kwargs: Any | |
) -> None: | |
super().__init__(application, request, **kwargs) | |
self.ws_connection = None # type: Optional[WebSocketProtocol] | |
self.close_code = None # type: Optional[int] | |
self.close_reason = None # type: Optional[str] | |
self._on_close_called = False | |
async def get(self, *args: Any, **kwargs: Any) -> None: | |
self.open_args = args | |
self.open_kwargs = kwargs | |
# Upgrade header should be present and should be equal to WebSocket | |
if self.request.headers.get("Upgrade", "").lower() != "websocket": | |
self.set_status(400) | |
log_msg = 'Can "Upgrade" only to "WebSocket".' | |
self.finish(log_msg) | |
gen_log.debug(log_msg) | |
return | |
# Connection header should be upgrade. | |
# Some proxy servers/load balancers | |
# might mess with it. | |
headers = self.request.headers | |
connection = map( | |
lambda s: s.strip().lower(), headers.get("Connection", "").split(",") | |
) | |
if "upgrade" not in connection: | |
self.set_status(400) | |
log_msg = '"Connection" must be "Upgrade".' | |
self.finish(log_msg) | |
gen_log.debug(log_msg) | |
return | |
# Handle WebSocket Origin naming convention differences | |
# The difference between version 8 and 13 is that in 8 the | |
# client sends a "Sec-Websocket-Origin" header and in 13 it's | |
# simply "Origin". | |
if "Origin" in self.request.headers: | |
origin = self.request.headers.get("Origin") | |
else: | |
origin = self.request.headers.get("Sec-Websocket-Origin", None) | |
# If there was an origin header, check to make sure it matches | |
# according to check_origin. When the origin is None, we assume it | |
# did not come from a browser and that it can be passed on. | |
if origin is not None and not self.check_origin(origin): | |
self.set_status(403) | |
log_msg = "Cross origin websockets not allowed" | |
self.finish(log_msg) | |
gen_log.debug(log_msg) | |
return | |
self.ws_connection = self.get_websocket_protocol() | |
if self.ws_connection: | |
await self.ws_connection.accept_connection(self) | |
else: | |
self.set_status(426, "Upgrade Required") | |
self.set_header("Sec-WebSocket-Version", "7, 8, 13") | |
def ping_interval(self) -> Optional[float]: | |
"""The interval for websocket keep-alive pings. | |
Set websocket_ping_interval = 0 to disable pings. | |
""" | |
return self.settings.get("websocket_ping_interval", None) | |
def ping_timeout(self) -> Optional[float]: | |
"""If no ping is received in this many seconds, | |
close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). | |
Default is max of 3 pings or 30 seconds. | |
""" | |
return self.settings.get("websocket_ping_timeout", None) | |
def max_message_size(self) -> int: | |
"""Maximum allowed message size. | |
If the remote peer sends a message larger than this, the connection | |
will be closed. | |
Default is 10MiB. | |
""" | |
return self.settings.get( | |
"websocket_max_message_size", _default_max_message_size | |
) | |
def write_message( | |
self, message: Union[bytes, str, Dict[str, Any]], binary: bool = False | |
) -> "Future[None]": | |
"""Sends the given message to the client of this Web Socket. | |
The message may be either a string or a dict (which will be | |
encoded as json). If the ``binary`` argument is false, the | |
message will be sent as utf8; in binary mode any byte string | |
is allowed. | |
If the connection is already closed, raises `WebSocketClosedError`. | |
Returns a `.Future` which can be used for flow control. | |
.. versionchanged:: 3.2 | |
`WebSocketClosedError` was added (previously a closed connection | |
would raise an `AttributeError`) | |
.. versionchanged:: 4.3 | |
Returns a `.Future` which can be used for flow control. | |
.. versionchanged:: 5.0 | |
Consistently raises `WebSocketClosedError`. Previously could | |
sometimes raise `.StreamClosedError`. | |
""" | |
if self.ws_connection is None or self.ws_connection.is_closing(): | |
raise WebSocketClosedError() | |
if isinstance(message, dict): | |
message = tornado.escape.json_encode(message) | |
return self.ws_connection.write_message(message, binary=binary) | |
def select_subprotocol(self, subprotocols: List[str]) -> Optional[str]: | |
"""Override to implement subprotocol negotiation. | |
``subprotocols`` is a list of strings identifying the | |
subprotocols proposed by the client. This method may be | |
overridden to return one of those strings to select it, or | |
``None`` to not select a subprotocol. | |
Failure to select a subprotocol does not automatically abort | |
the connection, although clients may close the connection if | |
none of their proposed subprotocols was selected. | |
The list may be empty, in which case this method must return | |
None. This method is always called exactly once even if no | |
subprotocols were proposed so that the handler can be advised | |
of this fact. | |
.. versionchanged:: 5.1 | |
Previously, this method was called with a list containing | |
an empty string instead of an empty list if no subprotocols | |
were proposed by the client. | |
""" | |
return None | |
def selected_subprotocol(self) -> Optional[str]: | |
"""The subprotocol returned by `select_subprotocol`. | |
.. versionadded:: 5.1 | |
""" | |
assert self.ws_connection is not None | |
return self.ws_connection.selected_subprotocol | |
def get_compression_options(self) -> Optional[Dict[str, Any]]: | |
"""Override to return compression options for the connection. | |
If this method returns None (the default), compression will | |
be disabled. If it returns a dict (even an empty one), it | |
will be enabled. The contents of the dict may be used to | |
control the following compression options: | |
``compression_level`` specifies the compression level. | |
``mem_level`` specifies the amount of memory used for the internal compression state. | |
These parameters are documented in details here: | |
https://docs.python.org/3.6/library/zlib.html#zlib.compressobj | |
.. versionadded:: 4.1 | |
.. versionchanged:: 4.5 | |
Added ``compression_level`` and ``mem_level``. | |
""" | |
# TODO: Add wbits option. | |
return None | |
def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: | |
"""Invoked when a new WebSocket is opened. | |
The arguments to `open` are extracted from the `tornado.web.URLSpec` | |
regular expression, just like the arguments to | |
`tornado.web.RequestHandler.get`. | |
`open` may be a coroutine. `on_message` will not be called until | |
`open` has returned. | |
.. versionchanged:: 5.1 | |
``open`` may be a coroutine. | |
""" | |
pass | |
def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]: | |
"""Handle incoming messages on the WebSocket | |
This method must be overridden. | |
.. versionchanged:: 4.5 | |
``on_message`` can be a coroutine. | |
""" | |
raise NotImplementedError | |
def ping(self, data: Union[str, bytes] = b"") -> None: | |
"""Send ping frame to the remote end. | |
The data argument allows a small amount of data (up to 125 | |
bytes) to be sent as a part of the ping message. Note that not | |
all websocket implementations expose this data to | |
applications. | |
Consider using the ``websocket_ping_interval`` application | |
setting instead of sending pings manually. | |
.. versionchanged:: 5.1 | |
The data argument is now optional. | |
""" | |
data = utf8(data) | |
if self.ws_connection is None or self.ws_connection.is_closing(): | |
raise WebSocketClosedError() | |
self.ws_connection.write_ping(data) | |
def on_pong(self, data: bytes) -> None: | |
"""Invoked when the response to a ping frame is received.""" | |
pass | |
def on_ping(self, data: bytes) -> None: | |
"""Invoked when the a ping frame is received.""" | |
pass | |
def on_close(self) -> None: | |
"""Invoked when the WebSocket is closed. | |
If the connection was closed cleanly and a status code or reason | |
phrase was supplied, these values will be available as the attributes | |
``self.close_code`` and ``self.close_reason``. | |
.. versionchanged:: 4.0 | |
Added ``close_code`` and ``close_reason`` attributes. | |
""" | |
pass | |
def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: | |
"""Closes this Web Socket. | |
Once the close handshake is successful the socket will be closed. | |
``code`` may be a numeric status code, taken from the values | |
defined in `RFC 6455 section 7.4.1 | |
<https://tools.ietf.org/html/rfc6455#section-7.4.1>`_. | |
``reason`` may be a textual message about why the connection is | |
closing. These values are made available to the client, but are | |
not otherwise interpreted by the websocket protocol. | |
.. versionchanged:: 4.0 | |
Added the ``code`` and ``reason`` arguments. | |
""" | |
if self.ws_connection: | |
self.ws_connection.close(code, reason) | |
self.ws_connection = None | |
def check_origin(self, origin: str) -> bool: | |
"""Override to enable support for allowing alternate origins. | |
The ``origin`` argument is the value of the ``Origin`` HTTP | |
header, the url responsible for initiating this request. This | |
method is not called for clients that do not send this header; | |
such requests are always allowed (because all browsers that | |
implement WebSockets support this header, and non-browser | |
clients do not have the same cross-site security concerns). | |
Should return ``True`` to accept the request or ``False`` to | |
reject it. By default, rejects all requests with an origin on | |
a host other than this one. | |
This is a security protection against cross site scripting attacks on | |
browsers, since WebSockets are allowed to bypass the usual same-origin | |
policies and don't use CORS headers. | |
.. warning:: | |
This is an important security measure; don't disable it | |
without understanding the security implications. In | |
particular, if your authentication is cookie-based, you | |
must either restrict the origins allowed by | |
``check_origin()`` or implement your own XSRF-like | |
protection for websocket connections. See `these | |
<https://www.christian-schneider.net/CrossSiteWebSocketHijacking.html>`_ | |
`articles | |
<https://devcenter.heroku.com/articles/websocket-security>`_ | |
for more. | |
To accept all cross-origin traffic (which was the default prior to | |
Tornado 4.0), simply override this method to always return ``True``:: | |
def check_origin(self, origin): | |
return True | |
To allow connections from any subdomain of your site, you might | |
do something like:: | |
def check_origin(self, origin): | |
parsed_origin = urllib.parse.urlparse(origin) | |
return parsed_origin.netloc.endswith(".mydomain.com") | |
.. versionadded:: 4.0 | |
""" | |
parsed_origin = urlparse(origin) | |
origin = parsed_origin.netloc | |
origin = origin.lower() | |
host = self.request.headers.get("Host") | |
# Check to see that origin matches host directly, including ports | |
return origin == host | |
def set_nodelay(self, value: bool) -> None: | |
"""Set the no-delay flag for this stream. | |
By default, small messages may be delayed and/or combined to minimize | |
the number of packets sent. This can sometimes cause 200-500ms delays | |
due to the interaction between Nagle's algorithm and TCP delayed | |
ACKs. To reduce this delay (at the expense of possibly increasing | |
bandwidth usage), call ``self.set_nodelay(True)`` once the websocket | |
connection is established. | |
See `.BaseIOStream.set_nodelay` for additional details. | |
.. versionadded:: 3.1 | |
""" | |
assert self.ws_connection is not None | |
self.ws_connection.set_nodelay(value) | |
def on_connection_close(self) -> None: | |
if self.ws_connection: | |
self.ws_connection.on_connection_close() | |
self.ws_connection = None | |
if not self._on_close_called: | |
self._on_close_called = True | |
self.on_close() | |
self._break_cycles() | |
def on_ws_connection_close( | |
self, close_code: Optional[int] = None, close_reason: Optional[str] = None | |
) -> None: | |
self.close_code = close_code | |
self.close_reason = close_reason | |
self.on_connection_close() | |
def _break_cycles(self) -> None: | |
# WebSocketHandlers call finish() early, but we don't want to | |
# break up reference cycles (which makes it impossible to call | |
# self.render_string) until after we've really closed the | |
# connection (if it was established in the first place, | |
# indicated by status code 101). | |
if self.get_status() != 101 or self._on_close_called: | |
super()._break_cycles() | |
def get_websocket_protocol(self) -> Optional["WebSocketProtocol"]: | |
websocket_version = self.request.headers.get("Sec-WebSocket-Version") | |
if websocket_version in ("7", "8", "13"): | |
params = _WebSocketParams( | |
ping_interval=self.ping_interval, | |
ping_timeout=self.ping_timeout, | |
max_message_size=self.max_message_size, | |
compression_options=self.get_compression_options(), | |
) | |
return WebSocketProtocol13(self, False, params) | |
return None | |
def _detach_stream(self) -> IOStream: | |
# disable non-WS methods | |
for method in [ | |
"write", | |
"redirect", | |
"set_header", | |
"set_cookie", | |
"set_status", | |
"flush", | |
"finish", | |
]: | |
setattr(self, method, _raise_not_supported_for_websockets) | |
return self.detach() | |
def _raise_not_supported_for_websockets(*args: Any, **kwargs: Any) -> None: | |
raise RuntimeError("Method not supported for Web Sockets") | |
class WebSocketProtocol(abc.ABC): | |
"""Base class for WebSocket protocol versions.""" | |
def __init__(self, handler: "_WebSocketDelegate") -> None: | |
self.handler = handler | |
self.stream = None # type: Optional[IOStream] | |
self.client_terminated = False | |
self.server_terminated = False | |
def _run_callback( | |
self, callback: Callable, *args: Any, **kwargs: Any | |
) -> "Optional[Future[Any]]": | |
"""Runs the given callback with exception handling. | |
If the callback is a coroutine, returns its Future. On error, aborts the | |
websocket connection and returns None. | |
""" | |
try: | |
result = callback(*args, **kwargs) | |
except Exception: | |
self.handler.log_exception(*sys.exc_info()) | |
self._abort() | |
return None | |
else: | |
if result is not None: | |
result = gen.convert_yielded(result) | |
assert self.stream is not None | |
self.stream.io_loop.add_future(result, lambda f: f.result()) | |
return result | |
def on_connection_close(self) -> None: | |
self._abort() | |
def _abort(self) -> None: | |
"""Instantly aborts the WebSocket connection by closing the socket""" | |
self.client_terminated = True | |
self.server_terminated = True | |
if self.stream is not None: | |
self.stream.close() # forcibly tear down the connection | |
self.close() # let the subclass cleanup | |
def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: | |
raise NotImplementedError() | |
def is_closing(self) -> bool: | |
raise NotImplementedError() | |
async def accept_connection(self, handler: WebSocketHandler) -> None: | |
raise NotImplementedError() | |
def write_message( | |
self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False | |
) -> "Future[None]": | |
raise NotImplementedError() | |
def selected_subprotocol(self) -> Optional[str]: | |
raise NotImplementedError() | |
def write_ping(self, data: bytes) -> None: | |
raise NotImplementedError() | |
# The entry points below are used by WebSocketClientConnection, | |
# which was introduced after we only supported a single version of | |
# WebSocketProtocol. The WebSocketProtocol/WebSocketProtocol13 | |
# boundary is currently pretty ad-hoc. | |
def _process_server_headers( | |
self, key: Union[str, bytes], headers: httputil.HTTPHeaders | |
) -> None: | |
raise NotImplementedError() | |
def start_pinging(self) -> None: | |
raise NotImplementedError() | |
async def _receive_frame_loop(self) -> None: | |
raise NotImplementedError() | |
def set_nodelay(self, x: bool) -> None: | |
raise NotImplementedError() | |
class _PerMessageDeflateCompressor(object): | |
def __init__( | |
self, | |
persistent: bool, | |
max_wbits: Optional[int], | |
compression_options: Optional[Dict[str, Any]] = None, | |
) -> None: | |
if max_wbits is None: | |
max_wbits = zlib.MAX_WBITS | |
# There is no symbolic constant for the minimum wbits value. | |
if not (8 <= max_wbits <= zlib.MAX_WBITS): | |
raise ValueError( | |
"Invalid max_wbits value %r; allowed range 8-%d", | |
max_wbits, | |
zlib.MAX_WBITS, | |
) | |
self._max_wbits = max_wbits | |
if ( | |
compression_options is None | |
or "compression_level" not in compression_options | |
): | |
self._compression_level = tornado.web.GZipContentEncoding.GZIP_LEVEL | |
else: | |
self._compression_level = compression_options["compression_level"] | |
if compression_options is None or "mem_level" not in compression_options: | |
self._mem_level = 8 | |
else: | |
self._mem_level = compression_options["mem_level"] | |
if persistent: | |
self._compressor = self._create_compressor() # type: Optional[_Compressor] | |
else: | |
self._compressor = None | |
def _create_compressor(self) -> "_Compressor": | |
return zlib.compressobj( | |
self._compression_level, zlib.DEFLATED, -self._max_wbits, self._mem_level | |
) | |
def compress(self, data: bytes) -> bytes: | |
compressor = self._compressor or self._create_compressor() | |
data = compressor.compress(data) + compressor.flush(zlib.Z_SYNC_FLUSH) | |
assert data.endswith(b"\x00\x00\xff\xff") | |
return data[:-4] | |
class _PerMessageDeflateDecompressor(object): | |
def __init__( | |
self, | |
persistent: bool, | |
max_wbits: Optional[int], | |
max_message_size: int, | |
compression_options: Optional[Dict[str, Any]] = None, | |
) -> None: | |
self._max_message_size = max_message_size | |
if max_wbits is None: | |
max_wbits = zlib.MAX_WBITS | |
if not (8 <= max_wbits <= zlib.MAX_WBITS): | |
raise ValueError( | |
"Invalid max_wbits value %r; allowed range 8-%d", | |
max_wbits, | |
zlib.MAX_WBITS, | |
) | |
self._max_wbits = max_wbits | |
if persistent: | |
self._decompressor = ( | |
self._create_decompressor() | |
) # type: Optional[_Decompressor] | |
else: | |
self._decompressor = None | |
def _create_decompressor(self) -> "_Decompressor": | |
return zlib.decompressobj(-self._max_wbits) | |
def decompress(self, data: bytes) -> bytes: | |
decompressor = self._decompressor or self._create_decompressor() | |
result = decompressor.decompress( | |
data + b"\x00\x00\xff\xff", self._max_message_size | |
) | |
if decompressor.unconsumed_tail: | |
raise _DecompressTooLargeError() | |
return result | |
class WebSocketProtocol13(WebSocketProtocol): | |
"""Implementation of the WebSocket protocol from RFC 6455. | |
This class supports versions 7 and 8 of the protocol in addition to the | |
final version 13. | |
""" | |
# Bit masks for the first byte of a frame. | |
FIN = 0x80 | |
RSV1 = 0x40 | |
RSV2 = 0x20 | |
RSV3 = 0x10 | |
RSV_MASK = RSV1 | RSV2 | RSV3 | |
OPCODE_MASK = 0x0F | |
stream = None # type: IOStream | |
def __init__( | |
self, | |
handler: "_WebSocketDelegate", | |
mask_outgoing: bool, | |
params: _WebSocketParams, | |
) -> None: | |
WebSocketProtocol.__init__(self, handler) | |
self.mask_outgoing = mask_outgoing | |
self.params = params | |
self._final_frame = False | |
self._frame_opcode = None | |
self._masked_frame = None | |
self._frame_mask = None # type: Optional[bytes] | |
self._frame_length = None | |
self._fragmented_message_buffer = None # type: Optional[bytearray] | |
self._fragmented_message_opcode = None | |
self._waiting = None # type: object | |
self._compression_options = params.compression_options | |
self._decompressor = None # type: Optional[_PerMessageDeflateDecompressor] | |
self._compressor = None # type: Optional[_PerMessageDeflateCompressor] | |
self._frame_compressed = None # type: Optional[bool] | |
# The total uncompressed size of all messages received or sent. | |
# Unicode messages are encoded to utf8. | |
# Only for testing; subject to change. | |
self._message_bytes_in = 0 | |
self._message_bytes_out = 0 | |
# The total size of all packets received or sent. Includes | |
# the effect of compression, frame overhead, and control frames. | |
self._wire_bytes_in = 0 | |
self._wire_bytes_out = 0 | |
self.ping_callback = None # type: Optional[PeriodicCallback] | |
self.last_ping = 0.0 | |
self.last_pong = 0.0 | |
self.close_code = None # type: Optional[int] | |
self.close_reason = None # type: Optional[str] | |
# Use a property for this to satisfy the abc. | |
def selected_subprotocol(self) -> Optional[str]: | |
return self._selected_subprotocol | |
def selected_subprotocol(self, value: Optional[str]) -> None: | |
self._selected_subprotocol = value | |
async def accept_connection(self, handler: WebSocketHandler) -> None: | |
try: | |
self._handle_websocket_headers(handler) | |
except ValueError: | |
handler.set_status(400) | |
log_msg = "Missing/Invalid WebSocket headers" | |
handler.finish(log_msg) | |
gen_log.debug(log_msg) | |
return | |
try: | |
await self._accept_connection(handler) | |
except asyncio.CancelledError: | |
self._abort() | |
return | |
except ValueError: | |
gen_log.debug("Malformed WebSocket request received", exc_info=True) | |
self._abort() | |
return | |
def _handle_websocket_headers(self, handler: WebSocketHandler) -> None: | |
"""Verifies all invariant- and required headers | |
If a header is missing or have an incorrect value ValueError will be | |
raised | |
""" | |
fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version") | |
if not all(map(lambda f: handler.request.headers.get(f), fields)): | |
raise ValueError("Missing/Invalid WebSocket headers") | |
def compute_accept_value(key: Union[str, bytes]) -> str: | |
"""Computes the value for the Sec-WebSocket-Accept header, | |
given the value for Sec-WebSocket-Key. | |
""" | |
sha1 = hashlib.sha1() | |
sha1.update(utf8(key)) | |
sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") # Magic value | |
return native_str(base64.b64encode(sha1.digest())) | |
def _challenge_response(self, handler: WebSocketHandler) -> str: | |
return WebSocketProtocol13.compute_accept_value( | |
cast(str, handler.request.headers.get("Sec-Websocket-Key")) | |
) | |
async def _accept_connection(self, handler: WebSocketHandler) -> None: | |
subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol") | |
if subprotocol_header: | |
subprotocols = [s.strip() for s in subprotocol_header.split(",")] | |
else: | |
subprotocols = [] | |
self.selected_subprotocol = handler.select_subprotocol(subprotocols) | |
if self.selected_subprotocol: | |
assert self.selected_subprotocol in subprotocols | |
handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol) | |
extensions = self._parse_extensions_header(handler.request.headers) | |
for ext in extensions: | |
if ext[0] == "permessage-deflate" and self._compression_options is not None: | |
# TODO: negotiate parameters if compression_options | |
# specifies limits. | |
self._create_compressors("server", ext[1], self._compression_options) | |
if ( | |
"client_max_window_bits" in ext[1] | |
and ext[1]["client_max_window_bits"] is None | |
): | |
# Don't echo an offered client_max_window_bits | |
# parameter with no value. | |
del ext[1]["client_max_window_bits"] | |
handler.set_header( | |
"Sec-WebSocket-Extensions", | |
httputil._encode_header("permessage-deflate", ext[1]), | |
) | |
break | |
handler.clear_header("Content-Type") | |
handler.set_status(101) | |
handler.set_header("Upgrade", "websocket") | |
handler.set_header("Connection", "Upgrade") | |
handler.set_header("Sec-WebSocket-Accept", self._challenge_response(handler)) | |
handler.finish() | |
self.stream = handler._detach_stream() | |
self.start_pinging() | |
try: | |
open_result = handler.open(*handler.open_args, **handler.open_kwargs) | |
if open_result is not None: | |
await open_result | |
except Exception: | |
handler.log_exception(*sys.exc_info()) | |
self._abort() | |
return | |
await self._receive_frame_loop() | |
def _parse_extensions_header( | |
self, headers: httputil.HTTPHeaders | |
) -> List[Tuple[str, Dict[str, str]]]: | |
extensions = headers.get("Sec-WebSocket-Extensions", "") | |
if extensions: | |
return [httputil._parse_header(e.strip()) for e in extensions.split(",")] | |
return [] | |
def _process_server_headers( | |
self, key: Union[str, bytes], headers: httputil.HTTPHeaders | |
) -> None: | |
"""Process the headers sent by the server to this client connection. | |
'key' is the websocket handshake challenge/response key. | |
""" | |
assert headers["Upgrade"].lower() == "websocket" | |
assert headers["Connection"].lower() == "upgrade" | |
accept = self.compute_accept_value(key) | |
assert headers["Sec-Websocket-Accept"] == accept | |
extensions = self._parse_extensions_header(headers) | |
for ext in extensions: | |
if ext[0] == "permessage-deflate" and self._compression_options is not None: | |
self._create_compressors("client", ext[1]) | |
else: | |
raise ValueError("unsupported extension %r", ext) | |
self.selected_subprotocol = headers.get("Sec-WebSocket-Protocol", None) | |
def _get_compressor_options( | |
self, | |
side: str, | |
agreed_parameters: Dict[str, Any], | |
compression_options: Optional[Dict[str, Any]] = None, | |
) -> Dict[str, Any]: | |
"""Converts a websocket agreed_parameters set to keyword arguments | |
for our compressor objects. | |
""" | |
options = dict( | |
persistent=(side + "_no_context_takeover") not in agreed_parameters | |
) # type: Dict[str, Any] | |
wbits_header = agreed_parameters.get(side + "_max_window_bits", None) | |
if wbits_header is None: | |
options["max_wbits"] = zlib.MAX_WBITS | |
else: | |
options["max_wbits"] = int(wbits_header) | |
options["compression_options"] = compression_options | |
return options | |
def _create_compressors( | |
self, | |
side: str, | |
agreed_parameters: Dict[str, Any], | |
compression_options: Optional[Dict[str, Any]] = None, | |
) -> None: | |
# TODO: handle invalid parameters gracefully | |
allowed_keys = set( | |
[ | |
"server_no_context_takeover", | |
"client_no_context_takeover", | |
"server_max_window_bits", | |
"client_max_window_bits", | |
] | |
) | |
for key in agreed_parameters: | |
if key not in allowed_keys: | |
raise ValueError("unsupported compression parameter %r" % key) | |
other_side = "client" if (side == "server") else "server" | |
self._compressor = _PerMessageDeflateCompressor( | |
**self._get_compressor_options(side, agreed_parameters, compression_options) | |
) | |
self._decompressor = _PerMessageDeflateDecompressor( | |
max_message_size=self.params.max_message_size, | |
**self._get_compressor_options( | |
other_side, agreed_parameters, compression_options | |
) | |
) | |
def _write_frame( | |
self, fin: bool, opcode: int, data: bytes, flags: int = 0 | |
) -> "Future[None]": | |
data_len = len(data) | |
if opcode & 0x8: | |
# All control frames MUST have a payload length of 125 | |
# bytes or less and MUST NOT be fragmented. | |
if not fin: | |
raise ValueError("control frames may not be fragmented") | |
if data_len > 125: | |
raise ValueError("control frame payloads may not exceed 125 bytes") | |
if fin: | |
finbit = self.FIN | |
else: | |
finbit = 0 | |
frame = struct.pack("B", finbit | opcode | flags) | |
if self.mask_outgoing: | |
mask_bit = 0x80 | |
else: | |
mask_bit = 0 | |
if data_len < 126: | |
frame += struct.pack("B", data_len | mask_bit) | |
elif data_len <= 0xFFFF: | |
frame += struct.pack("!BH", 126 | mask_bit, data_len) | |
else: | |
frame += struct.pack("!BQ", 127 | mask_bit, data_len) | |
if self.mask_outgoing: | |
mask = os.urandom(4) | |
data = mask + _websocket_mask(mask, data) | |
frame += data | |
self._wire_bytes_out += len(frame) | |
return self.stream.write(frame) | |
def write_message( | |
self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False | |
) -> "Future[None]": | |
"""Sends the given message to the client of this Web Socket.""" | |
if binary: | |
opcode = 0x2 | |
else: | |
opcode = 0x1 | |
if isinstance(message, dict): | |
message = tornado.escape.json_encode(message) | |
message = tornado.escape.utf8(message) | |
assert isinstance(message, bytes) | |
self._message_bytes_out += len(message) | |
flags = 0 | |
if self._compressor: | |
message = self._compressor.compress(message) | |
flags |= self.RSV1 | |
# For historical reasons, write methods in Tornado operate in a semi-synchronous | |
# mode in which awaiting the Future they return is optional (But errors can | |
# still be raised). This requires us to go through an awkward dance here | |
# to transform the errors that may be returned while presenting the same | |
# semi-synchronous interface. | |
try: | |
fut = self._write_frame(True, opcode, message, flags=flags) | |
except StreamClosedError: | |
raise WebSocketClosedError() | |
async def wrapper() -> None: | |
try: | |
await fut | |
except StreamClosedError: | |
raise WebSocketClosedError() | |
return asyncio.ensure_future(wrapper()) | |
def write_ping(self, data: bytes) -> None: | |
"""Send ping frame.""" | |
assert isinstance(data, bytes) | |
self._write_frame(True, 0x9, data) | |
async def _receive_frame_loop(self) -> None: | |
try: | |
while not self.client_terminated: | |
await self._receive_frame() | |
except StreamClosedError: | |
self._abort() | |
self.handler.on_ws_connection_close(self.close_code, self.close_reason) | |
async def _read_bytes(self, n: int) -> bytes: | |
data = await self.stream.read_bytes(n) | |
self._wire_bytes_in += n | |
return data | |
async def _receive_frame(self) -> None: | |
# Read the frame header. | |
data = await self._read_bytes(2) | |
header, mask_payloadlen = struct.unpack("BB", data) | |
is_final_frame = header & self.FIN | |
reserved_bits = header & self.RSV_MASK | |
opcode = header & self.OPCODE_MASK | |
opcode_is_control = opcode & 0x8 | |
if self._decompressor is not None and opcode != 0: | |
# Compression flag is present in the first frame's header, | |
# but we can't decompress until we have all the frames of | |
# the message. | |
self._frame_compressed = bool(reserved_bits & self.RSV1) | |
reserved_bits &= ~self.RSV1 | |
if reserved_bits: | |
# client is using as-yet-undefined extensions; abort | |
self._abort() | |
return | |
is_masked = bool(mask_payloadlen & 0x80) | |
payloadlen = mask_payloadlen & 0x7F | |
# Parse and validate the length. | |
if opcode_is_control and payloadlen >= 126: | |
# control frames must have payload < 126 | |
self._abort() | |
return | |
if payloadlen < 126: | |
self._frame_length = payloadlen | |
elif payloadlen == 126: | |
data = await self._read_bytes(2) | |
payloadlen = struct.unpack("!H", data)[0] | |
elif payloadlen == 127: | |
data = await self._read_bytes(8) | |
payloadlen = struct.unpack("!Q", data)[0] | |
new_len = payloadlen | |
if self._fragmented_message_buffer is not None: | |
new_len += len(self._fragmented_message_buffer) | |
if new_len > self.params.max_message_size: | |
self.close(1009, "message too big") | |
self._abort() | |
return | |
# Read the payload, unmasking if necessary. | |
if is_masked: | |
self._frame_mask = await self._read_bytes(4) | |
data = await self._read_bytes(payloadlen) | |
if is_masked: | |
assert self._frame_mask is not None | |
data = _websocket_mask(self._frame_mask, data) | |
# Decide what to do with this frame. | |
if opcode_is_control: | |
# control frames may be interleaved with a series of fragmented | |
# data frames, so control frames must not interact with | |
# self._fragmented_* | |
if not is_final_frame: | |
# control frames must not be fragmented | |
self._abort() | |
return | |
elif opcode == 0: # continuation frame | |
if self._fragmented_message_buffer is None: | |
# nothing to continue | |
self._abort() | |
return | |
self._fragmented_message_buffer.extend(data) | |
if is_final_frame: | |
opcode = self._fragmented_message_opcode | |
data = bytes(self._fragmented_message_buffer) | |
self._fragmented_message_buffer = None | |
else: # start of new data message | |
if self._fragmented_message_buffer is not None: | |
# can't start new message until the old one is finished | |
self._abort() | |
return | |
if not is_final_frame: | |
self._fragmented_message_opcode = opcode | |
self._fragmented_message_buffer = bytearray(data) | |
if is_final_frame: | |
handled_future = self._handle_message(opcode, data) | |
if handled_future is not None: | |
await handled_future | |
def _handle_message(self, opcode: int, data: bytes) -> "Optional[Future[None]]": | |
"""Execute on_message, returning its Future if it is a coroutine.""" | |
if self.client_terminated: | |
return None | |
if self._frame_compressed: | |
assert self._decompressor is not None | |
try: | |
data = self._decompressor.decompress(data) | |
except _DecompressTooLargeError: | |
self.close(1009, "message too big after decompression") | |
self._abort() | |
return None | |
if opcode == 0x1: | |
# UTF-8 data | |
self._message_bytes_in += len(data) | |
try: | |
decoded = data.decode("utf-8") | |
except UnicodeDecodeError: | |
self._abort() | |
return None | |
return self._run_callback(self.handler.on_message, decoded) | |
elif opcode == 0x2: | |
# Binary data | |
self._message_bytes_in += len(data) | |
return self._run_callback(self.handler.on_message, data) | |
elif opcode == 0x8: | |
# Close | |
self.client_terminated = True | |
if len(data) >= 2: | |
self.close_code = struct.unpack(">H", data[:2])[0] | |
if len(data) > 2: | |
self.close_reason = to_unicode(data[2:]) | |
# Echo the received close code, if any (RFC 6455 section 5.5.1). | |
self.close(self.close_code) | |
elif opcode == 0x9: | |
# Ping | |
try: | |
self._write_frame(True, 0xA, data) | |
except StreamClosedError: | |
self._abort() | |
self._run_callback(self.handler.on_ping, data) | |
elif opcode == 0xA: | |
# Pong | |
self.last_pong = IOLoop.current().time() | |
return self._run_callback(self.handler.on_pong, data) | |
else: | |
self._abort() | |
return None | |
def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: | |
"""Closes the WebSocket connection.""" | |
if not self.server_terminated: | |
if not self.stream.closed(): | |
if code is None and reason is not None: | |
code = 1000 # "normal closure" status code | |
if code is None: | |
close_data = b"" | |
else: | |
close_data = struct.pack(">H", code) | |
if reason is not None: | |
close_data += utf8(reason) | |
try: | |
self._write_frame(True, 0x8, close_data) | |
except StreamClosedError: | |
self._abort() | |
self.server_terminated = True | |
if self.client_terminated: | |
if self._waiting is not None: | |
self.stream.io_loop.remove_timeout(self._waiting) | |
self._waiting = None | |
self.stream.close() | |
elif self._waiting is None: | |
# Give the client a few seconds to complete a clean shutdown, | |
# otherwise just close the connection. | |
self._waiting = self.stream.io_loop.add_timeout( | |
self.stream.io_loop.time() + 5, self._abort | |
) | |
if self.ping_callback: | |
self.ping_callback.stop() | |
self.ping_callback = None | |
def is_closing(self) -> bool: | |
"""Return ``True`` if this connection is closing. | |
The connection is considered closing if either side has | |
initiated its closing handshake or if the stream has been | |
shut down uncleanly. | |
""" | |
return self.stream.closed() or self.client_terminated or self.server_terminated | |
def ping_interval(self) -> Optional[float]: | |
interval = self.params.ping_interval | |
if interval is not None: | |
return interval | |
return 0 | |
def ping_timeout(self) -> Optional[float]: | |
timeout = self.params.ping_timeout | |
if timeout is not None: | |
return timeout | |
assert self.ping_interval is not None | |
return max(3 * self.ping_interval, 30) | |
def start_pinging(self) -> None: | |
"""Start sending periodic pings to keep the connection alive""" | |
assert self.ping_interval is not None | |
if self.ping_interval > 0: | |
self.last_ping = self.last_pong = IOLoop.current().time() | |
self.ping_callback = PeriodicCallback( | |
self.periodic_ping, self.ping_interval * 1000 | |
) | |
self.ping_callback.start() | |
def periodic_ping(self) -> None: | |
"""Send a ping to keep the websocket alive | |
Called periodically if the websocket_ping_interval is set and non-zero. | |
""" | |
if self.is_closing() and self.ping_callback is not None: | |
self.ping_callback.stop() | |
return | |
# Check for timeout on pong. Make sure that we really have | |
# sent a recent ping in case the machine with both server and | |
# client has been suspended since the last ping. | |
now = IOLoop.current().time() | |
since_last_pong = now - self.last_pong | |
since_last_ping = now - self.last_ping | |
assert self.ping_interval is not None | |
assert self.ping_timeout is not None | |
if ( | |
since_last_ping < 2 * self.ping_interval | |
and since_last_pong > self.ping_timeout | |
): | |
self.close() | |
return | |
self.write_ping(b"") | |
self.last_ping = now | |
def set_nodelay(self, x: bool) -> None: | |
self.stream.set_nodelay(x) | |
class WebSocketClientConnection(simple_httpclient._HTTPConnection): | |
"""WebSocket client connection. | |
This class should not be instantiated directly; use the | |
`websocket_connect` function instead. | |
""" | |
protocol = None # type: WebSocketProtocol | |
def __init__( | |
self, | |
request: httpclient.HTTPRequest, | |
on_message_callback: Optional[Callable[[Union[None, str, bytes]], None]] = None, | |
compression_options: Optional[Dict[str, Any]] = None, | |
ping_interval: Optional[float] = None, | |
ping_timeout: Optional[float] = None, | |
max_message_size: int = _default_max_message_size, | |
subprotocols: Optional[List[str]] = None, | |
resolver: Optional[Resolver] = None, | |
) -> None: | |
self.connect_future = Future() # type: Future[WebSocketClientConnection] | |
self.read_queue = Queue(1) # type: Queue[Union[None, str, bytes]] | |
self.key = base64.b64encode(os.urandom(16)) | |
self._on_message_callback = on_message_callback | |
self.close_code = None # type: Optional[int] | |
self.close_reason = None # type: Optional[str] | |
self.params = _WebSocketParams( | |
ping_interval=ping_interval, | |
ping_timeout=ping_timeout, | |
max_message_size=max_message_size, | |
compression_options=compression_options, | |
) | |
scheme, sep, rest = request.url.partition(":") | |
scheme = {"ws": "http", "wss": "https"}[scheme] | |
request.url = scheme + sep + rest | |
request.headers.update( | |
{ | |
"Upgrade": "websocket", | |
"Connection": "Upgrade", | |
"Sec-WebSocket-Key": self.key, | |
"Sec-WebSocket-Version": "13", | |
} | |
) | |
if subprotocols is not None: | |
request.headers["Sec-WebSocket-Protocol"] = ",".join(subprotocols) | |
if compression_options is not None: | |
# Always offer to let the server set our max_wbits (and even though | |
# we don't offer it, we will accept a client_no_context_takeover | |
# from the server). | |
# TODO: set server parameters for deflate extension | |
# if requested in self.compression_options. | |
request.headers["Sec-WebSocket-Extensions"] = ( | |
"permessage-deflate; client_max_window_bits" | |
) | |
# Websocket connection is currently unable to follow redirects | |
request.follow_redirects = False | |
self.tcp_client = TCPClient(resolver=resolver) | |
super().__init__( | |
None, | |
request, | |
lambda: None, | |
self._on_http_response, | |
104857600, | |
self.tcp_client, | |
65536, | |
104857600, | |
) | |
def __del__(self) -> None: | |
if self.protocol is not None: | |
# Unclosed client connections can sometimes log "task was destroyed but | |
# was pending" warnings if shutdown strikes at the wrong time (such as | |
# while a ping is being processed due to ping_interval). Log our own | |
# warning to make it a little more deterministic (although it's still | |
# dependent on GC timing). | |
warnings.warn("Unclosed WebSocketClientConnection", ResourceWarning) | |
def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: | |
"""Closes the websocket connection. | |
``code`` and ``reason`` are documented under | |
`WebSocketHandler.close`. | |
.. versionadded:: 3.2 | |
.. versionchanged:: 4.0 | |
Added the ``code`` and ``reason`` arguments. | |
""" | |
if self.protocol is not None: | |
self.protocol.close(code, reason) | |
self.protocol = None # type: ignore | |
def on_connection_close(self) -> None: | |
if not self.connect_future.done(): | |
self.connect_future.set_exception(StreamClosedError()) | |
self._on_message(None) | |
self.tcp_client.close() | |
super().on_connection_close() | |
def on_ws_connection_close( | |
self, close_code: Optional[int] = None, close_reason: Optional[str] = None | |
) -> None: | |
self.close_code = close_code | |
self.close_reason = close_reason | |
self.on_connection_close() | |
def _on_http_response(self, response: httpclient.HTTPResponse) -> None: | |
if not self.connect_future.done(): | |
if response.error: | |
self.connect_future.set_exception(response.error) | |
else: | |
self.connect_future.set_exception( | |
WebSocketError("Non-websocket response") | |
) | |
async def headers_received( | |
self, | |
start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], | |
headers: httputil.HTTPHeaders, | |
) -> None: | |
assert isinstance(start_line, httputil.ResponseStartLine) | |
if start_line.code != 101: | |
await super().headers_received(start_line, headers) | |
return | |
if self._timeout is not None: | |
self.io_loop.remove_timeout(self._timeout) | |
self._timeout = None | |
self.headers = headers | |
self.protocol = self.get_websocket_protocol() | |
self.protocol._process_server_headers(self.key, self.headers) | |
self.protocol.stream = self.connection.detach() | |
IOLoop.current().add_callback(self.protocol._receive_frame_loop) | |
self.protocol.start_pinging() | |
# Once we've taken over the connection, clear the final callback | |
# we set on the http request. This deactivates the error handling | |
# in simple_httpclient that would otherwise interfere with our | |
# ability to see exceptions. | |
self.final_callback = None # type: ignore | |
future_set_result_unless_cancelled(self.connect_future, self) | |
def write_message( | |
self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False | |
) -> "Future[None]": | |
"""Sends a message to the WebSocket server. | |
If the stream is closed, raises `WebSocketClosedError`. | |
Returns a `.Future` which can be used for flow control. | |
.. versionchanged:: 5.0 | |
Exception raised on a closed stream changed from `.StreamClosedError` | |
to `WebSocketClosedError`. | |
""" | |
if self.protocol is None: | |
raise WebSocketClosedError("Client connection has been closed") | |
return self.protocol.write_message(message, binary=binary) | |
def read_message( | |
self, | |
callback: Optional[Callable[["Future[Union[None, str, bytes]]"], None]] = None, | |
) -> Awaitable[Union[None, str, bytes]]: | |
"""Reads a message from the WebSocket server. | |
If on_message_callback was specified at WebSocket | |
initialization, this function will never return messages | |
Returns a future whose result is the message, or None | |
if the connection is closed. If a callback argument | |
is given it will be called with the future when it is | |
ready. | |
""" | |
awaitable = self.read_queue.get() | |
if callback is not None: | |
self.io_loop.add_future(asyncio.ensure_future(awaitable), callback) | |
return awaitable | |
def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]: | |
return self._on_message(message) | |
def _on_message( | |
self, message: Union[None, str, bytes] | |
) -> Optional[Awaitable[None]]: | |
if self._on_message_callback: | |
self._on_message_callback(message) | |
return None | |
else: | |
return self.read_queue.put(message) | |
def ping(self, data: bytes = b"") -> None: | |
"""Send ping frame to the remote end. | |
The data argument allows a small amount of data (up to 125 | |
bytes) to be sent as a part of the ping message. Note that not | |
all websocket implementations expose this data to | |
applications. | |
Consider using the ``ping_interval`` argument to | |
`websocket_connect` instead of sending pings manually. | |
.. versionadded:: 5.1 | |
""" | |
data = utf8(data) | |
if self.protocol is None: | |
raise WebSocketClosedError() | |
self.protocol.write_ping(data) | |
def on_pong(self, data: bytes) -> None: | |
pass | |
def on_ping(self, data: bytes) -> None: | |
pass | |
def get_websocket_protocol(self) -> WebSocketProtocol: | |
return WebSocketProtocol13(self, mask_outgoing=True, params=self.params) | |
def selected_subprotocol(self) -> Optional[str]: | |
"""The subprotocol selected by the server. | |
.. versionadded:: 5.1 | |
""" | |
return self.protocol.selected_subprotocol | |
def log_exception( | |
self, | |
typ: "Optional[Type[BaseException]]", | |
value: Optional[BaseException], | |
tb: Optional[TracebackType], | |
) -> None: | |
assert typ is not None | |
assert value is not None | |
app_log.error("Uncaught exception %s", value, exc_info=(typ, value, tb)) | |
def websocket_connect( | |
url: Union[str, httpclient.HTTPRequest], | |
callback: Optional[Callable[["Future[WebSocketClientConnection]"], None]] = None, | |
connect_timeout: Optional[float] = None, | |
on_message_callback: Optional[Callable[[Union[None, str, bytes]], None]] = None, | |
compression_options: Optional[Dict[str, Any]] = None, | |
ping_interval: Optional[float] = None, | |
ping_timeout: Optional[float] = None, | |
max_message_size: int = _default_max_message_size, | |
subprotocols: Optional[List[str]] = None, | |
resolver: Optional[Resolver] = None, | |
) -> "Awaitable[WebSocketClientConnection]": | |
"""Client-side websocket support. | |
Takes a url and returns a Future whose result is a | |
`WebSocketClientConnection`. | |
``compression_options`` is interpreted in the same way as the | |
return value of `.WebSocketHandler.get_compression_options`. | |
The connection supports two styles of operation. In the coroutine | |
style, the application typically calls | |
`~.WebSocketClientConnection.read_message` in a loop:: | |
conn = yield websocket_connect(url) | |
while True: | |
msg = yield conn.read_message() | |
if msg is None: break | |
# Do something with msg | |
In the callback style, pass an ``on_message_callback`` to | |
``websocket_connect``. In both styles, a message of ``None`` | |
indicates that the connection has been closed. | |
``subprotocols`` may be a list of strings specifying proposed | |
subprotocols. The selected protocol may be found on the | |
``selected_subprotocol`` attribute of the connection object | |
when the connection is complete. | |
.. versionchanged:: 3.2 | |
Also accepts ``HTTPRequest`` objects in place of urls. | |
.. versionchanged:: 4.1 | |
Added ``compression_options`` and ``on_message_callback``. | |
.. versionchanged:: 4.5 | |
Added the ``ping_interval``, ``ping_timeout``, and ``max_message_size`` | |
arguments, which have the same meaning as in `WebSocketHandler`. | |
.. versionchanged:: 5.0 | |
The ``io_loop`` argument (deprecated since version 4.1) has been removed. | |
.. versionchanged:: 5.1 | |
Added the ``subprotocols`` argument. | |
.. versionchanged:: 6.3 | |
Added the ``resolver`` argument. | |
""" | |
if isinstance(url, httpclient.HTTPRequest): | |
assert connect_timeout is None | |
request = url | |
# Copy and convert the headers dict/object (see comments in | |
# AsyncHTTPClient.fetch) | |
request.headers = httputil.HTTPHeaders(request.headers) | |
else: | |
request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout) | |
request = cast( | |
httpclient.HTTPRequest, | |
httpclient._RequestProxy(request, httpclient.HTTPRequest._DEFAULTS), | |
) | |
conn = WebSocketClientConnection( | |
request, | |
on_message_callback=on_message_callback, | |
compression_options=compression_options, | |
ping_interval=ping_interval, | |
ping_timeout=ping_timeout, | |
max_message_size=max_message_size, | |
subprotocols=subprotocols, | |
resolver=resolver, | |
) | |
if callback is not None: | |
IOLoop.current().add_future(conn.connect_future, callback) | |
return conn.connect_future | |