Spaces:
Sleeping
Sleeping
import asyncio | |
import contextlib | |
import functools | |
import socket | |
import traceback | |
import typing | |
import unittest | |
from tornado.concurrent import Future | |
from tornado import gen | |
from tornado.httpclient import HTTPError, HTTPRequest | |
from tornado.locks import Event | |
from tornado.log import gen_log, app_log | |
from tornado.netutil import Resolver | |
from tornado.simple_httpclient import SimpleAsyncHTTPClient | |
from tornado.template import DictLoader | |
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog | |
from tornado.web import Application, RequestHandler | |
try: | |
import tornado.websocket # noqa: F401 | |
from tornado.util import _websocket_mask_python | |
except ImportError: | |
# The unittest module presents misleading errors on ImportError | |
# (it acts as if websocket_test could not be found, hiding the underlying | |
# error). If we get an ImportError here (which could happen due to | |
# TORNADO_EXTENSION=1), print some extra information before failing. | |
traceback.print_exc() | |
raise | |
from tornado.websocket import ( | |
WebSocketHandler, | |
websocket_connect, | |
WebSocketError, | |
WebSocketClosedError, | |
) | |
try: | |
from tornado import speedups | |
except ImportError: | |
speedups = None # type: ignore | |
class TestWebSocketHandler(WebSocketHandler): | |
"""Base class for testing handlers that exposes the on_close event. | |
This allows for tests to see the close code and reason on the | |
server side. | |
""" | |
def initialize(self, close_future=None, compression_options=None): | |
self.close_future = close_future | |
self.compression_options = compression_options | |
def get_compression_options(self): | |
return self.compression_options | |
def on_close(self): | |
if self.close_future is not None: | |
self.close_future.set_result((self.close_code, self.close_reason)) | |
class EchoHandler(TestWebSocketHandler): | |
def on_message(self, message): | |
try: | |
yield self.write_message(message, isinstance(message, bytes)) | |
except asyncio.CancelledError: | |
pass | |
except WebSocketClosedError: | |
pass | |
class ErrorInOnMessageHandler(TestWebSocketHandler): | |
def on_message(self, message): | |
1 / 0 | |
class HeaderHandler(TestWebSocketHandler): | |
def open(self): | |
methods_to_test = [ | |
functools.partial(self.write, "This should not work"), | |
functools.partial(self.redirect, "http://localhost/elsewhere"), | |
functools.partial(self.set_header, "X-Test", ""), | |
functools.partial(self.set_cookie, "Chocolate", "Chip"), | |
functools.partial(self.set_status, 503), | |
self.flush, | |
self.finish, | |
] | |
for method in methods_to_test: | |
try: | |
# In a websocket context, many RequestHandler methods | |
# raise RuntimeErrors. | |
method() # type: ignore | |
raise Exception("did not get expected exception") | |
except RuntimeError: | |
pass | |
self.write_message(self.request.headers.get("X-Test", "")) | |
class HeaderEchoHandler(TestWebSocketHandler): | |
def set_default_headers(self): | |
self.set_header("X-Extra-Response-Header", "Extra-Response-Value") | |
def prepare(self): | |
for k, v in self.request.headers.get_all(): | |
if k.lower().startswith("x-test"): | |
self.set_header(k, v) | |
class NonWebSocketHandler(RequestHandler): | |
def get(self): | |
self.write("ok") | |
class RedirectHandler(RequestHandler): | |
def get(self): | |
self.redirect("/echo") | |
class CloseReasonHandler(TestWebSocketHandler): | |
def open(self): | |
self.on_close_called = False | |
self.close(1001, "goodbye") | |
class AsyncPrepareHandler(TestWebSocketHandler): | |
def prepare(self): | |
yield gen.moment | |
def on_message(self, message): | |
self.write_message(message) | |
class PathArgsHandler(TestWebSocketHandler): | |
def open(self, arg): | |
self.write_message(arg) | |
class CoroutineOnMessageHandler(TestWebSocketHandler): | |
def initialize(self, **kwargs): | |
super().initialize(**kwargs) | |
self.sleeping = 0 | |
def on_message(self, message): | |
if self.sleeping > 0: | |
self.write_message("another coroutine is already sleeping") | |
self.sleeping += 1 | |
yield gen.sleep(0.01) | |
self.sleeping -= 1 | |
self.write_message(message) | |
class RenderMessageHandler(TestWebSocketHandler): | |
def on_message(self, message): | |
self.write_message(self.render_string("message.html", message=message)) | |
class SubprotocolHandler(TestWebSocketHandler): | |
def initialize(self, **kwargs): | |
super().initialize(**kwargs) | |
self.select_subprotocol_called = False | |
def select_subprotocol(self, subprotocols): | |
if self.select_subprotocol_called: | |
raise Exception("select_subprotocol called twice") | |
self.select_subprotocol_called = True | |
if "goodproto" in subprotocols: | |
return "goodproto" | |
return None | |
def open(self): | |
if not self.select_subprotocol_called: | |
raise Exception("select_subprotocol not called") | |
self.write_message("subprotocol=%s" % self.selected_subprotocol) | |
class OpenCoroutineHandler(TestWebSocketHandler): | |
def initialize(self, test, **kwargs): | |
super().initialize(**kwargs) | |
self.test = test | |
self.open_finished = False | |
def open(self): | |
yield self.test.message_sent.wait() | |
yield gen.sleep(0.010) | |
self.open_finished = True | |
def on_message(self, message): | |
if not self.open_finished: | |
raise Exception("on_message called before open finished") | |
self.write_message("ok") | |
class ErrorInOpenHandler(TestWebSocketHandler): | |
def open(self): | |
raise Exception("boom") | |
class ErrorInAsyncOpenHandler(TestWebSocketHandler): | |
async def open(self): | |
await asyncio.sleep(0) | |
raise Exception("boom") | |
class NoDelayHandler(TestWebSocketHandler): | |
def open(self): | |
self.set_nodelay(True) | |
self.write_message("hello") | |
class WebSocketBaseTestCase(AsyncHTTPTestCase): | |
def setUp(self): | |
super().setUp() | |
self.conns_to_close = [] | |
def tearDown(self): | |
for conn in self.conns_to_close: | |
conn.close() | |
super().tearDown() | |
def ws_connect(self, path, **kwargs): | |
ws = yield websocket_connect( | |
"ws://127.0.0.1:%d%s" % (self.get_http_port(), path), **kwargs | |
) | |
self.conns_to_close.append(ws) | |
raise gen.Return(ws) | |
class WebSocketTest(WebSocketBaseTestCase): | |
def get_app(self): | |
self.close_future = Future() # type: Future[None] | |
return Application( | |
[ | |
("/echo", EchoHandler, dict(close_future=self.close_future)), | |
("/non_ws", NonWebSocketHandler), | |
("/redirect", RedirectHandler), | |
("/header", HeaderHandler, dict(close_future=self.close_future)), | |
( | |
"/header_echo", | |
HeaderEchoHandler, | |
dict(close_future=self.close_future), | |
), | |
( | |
"/close_reason", | |
CloseReasonHandler, | |
dict(close_future=self.close_future), | |
), | |
( | |
"/error_in_on_message", | |
ErrorInOnMessageHandler, | |
dict(close_future=self.close_future), | |
), | |
( | |
"/async_prepare", | |
AsyncPrepareHandler, | |
dict(close_future=self.close_future), | |
), | |
( | |
"/path_args/(.*)", | |
PathArgsHandler, | |
dict(close_future=self.close_future), | |
), | |
( | |
"/coroutine", | |
CoroutineOnMessageHandler, | |
dict(close_future=self.close_future), | |
), | |
("/render", RenderMessageHandler, dict(close_future=self.close_future)), | |
( | |
"/subprotocol", | |
SubprotocolHandler, | |
dict(close_future=self.close_future), | |
), | |
( | |
"/open_coroutine", | |
OpenCoroutineHandler, | |
dict(close_future=self.close_future, test=self), | |
), | |
("/error_in_open", ErrorInOpenHandler), | |
("/error_in_async_open", ErrorInAsyncOpenHandler), | |
("/nodelay", NoDelayHandler), | |
], | |
template_loader=DictLoader({"message.html": "<b>{{ message }}</b>"}), | |
) | |
def get_http_client(self): | |
# These tests require HTTP/1; force the use of SimpleAsyncHTTPClient. | |
return SimpleAsyncHTTPClient() | |
def tearDown(self): | |
super().tearDown() | |
RequestHandler._template_loaders.clear() | |
def test_http_request(self): | |
# WS server, HTTP client. | |
response = self.fetch("/echo") | |
self.assertEqual(response.code, 400) | |
def test_missing_websocket_key(self): | |
response = self.fetch( | |
"/echo", | |
headers={ | |
"Connection": "Upgrade", | |
"Upgrade": "WebSocket", | |
"Sec-WebSocket-Version": "13", | |
}, | |
) | |
self.assertEqual(response.code, 400) | |
def test_bad_websocket_version(self): | |
response = self.fetch( | |
"/echo", | |
headers={ | |
"Connection": "Upgrade", | |
"Upgrade": "WebSocket", | |
"Sec-WebSocket-Version": "12", | |
}, | |
) | |
self.assertEqual(response.code, 426) | |
def test_websocket_gen(self): | |
ws = yield self.ws_connect("/echo") | |
yield ws.write_message("hello") | |
response = yield ws.read_message() | |
self.assertEqual(response, "hello") | |
def test_websocket_callbacks(self): | |
websocket_connect( | |
"ws://127.0.0.1:%d/echo" % self.get_http_port(), callback=self.stop | |
) | |
ws = self.wait().result() | |
ws.write_message("hello") | |
ws.read_message(self.stop) | |
response = self.wait().result() | |
self.assertEqual(response, "hello") | |
self.close_future.add_done_callback(lambda f: self.stop()) | |
ws.close() | |
self.wait() | |
def test_binary_message(self): | |
ws = yield self.ws_connect("/echo") | |
ws.write_message(b"hello \xe9", binary=True) | |
response = yield ws.read_message() | |
self.assertEqual(response, b"hello \xe9") | |
def test_unicode_message(self): | |
ws = yield self.ws_connect("/echo") | |
ws.write_message("hello \u00e9") | |
response = yield ws.read_message() | |
self.assertEqual(response, "hello \u00e9") | |
def test_error_in_closed_client_write_message(self): | |
ws = yield self.ws_connect("/echo") | |
ws.close() | |
with self.assertRaises(WebSocketClosedError): | |
ws.write_message("hello \u00e9") | |
def test_render_message(self): | |
ws = yield self.ws_connect("/render") | |
ws.write_message("hello") | |
response = yield ws.read_message() | |
self.assertEqual(response, "<b>hello</b>") | |
def test_error_in_on_message(self): | |
ws = yield self.ws_connect("/error_in_on_message") | |
ws.write_message("hello") | |
with ExpectLog(app_log, "Uncaught exception"): | |
response = yield ws.read_message() | |
self.assertIs(response, None) | |
def test_websocket_http_fail(self): | |
with self.assertRaises(HTTPError) as cm: | |
yield self.ws_connect("/notfound") | |
self.assertEqual(cm.exception.code, 404) | |
def test_websocket_http_success(self): | |
with self.assertRaises(WebSocketError): | |
yield self.ws_connect("/non_ws") | |
def test_websocket_http_redirect(self): | |
with self.assertRaises(HTTPError): | |
yield self.ws_connect("/redirect") | |
def test_websocket_network_fail(self): | |
sock, port = bind_unused_port() | |
sock.close() | |
with self.assertRaises(IOError): | |
with ExpectLog(gen_log, ".*", required=False): | |
yield websocket_connect( | |
"ws://127.0.0.1:%d/" % port, connect_timeout=3600 | |
) | |
def test_websocket_close_buffered_data(self): | |
with contextlib.closing( | |
(yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port())) | |
) as ws: | |
ws.write_message("hello") | |
ws.write_message("world") | |
# Close the underlying stream. | |
ws.stream.close() | |
def test_websocket_headers(self): | |
# Ensure that arbitrary headers can be passed through websocket_connect. | |
with contextlib.closing( | |
( | |
yield websocket_connect( | |
HTTPRequest( | |
"ws://127.0.0.1:%d/header" % self.get_http_port(), | |
headers={"X-Test": "hello"}, | |
) | |
) | |
) | |
) as ws: | |
response = yield ws.read_message() | |
self.assertEqual(response, "hello") | |
def test_websocket_header_echo(self): | |
# Ensure that headers can be returned in the response. | |
# Specifically, that arbitrary headers passed through websocket_connect | |
# can be returned. | |
with contextlib.closing( | |
( | |
yield websocket_connect( | |
HTTPRequest( | |
"ws://127.0.0.1:%d/header_echo" % self.get_http_port(), | |
headers={"X-Test-Hello": "hello"}, | |
) | |
) | |
) | |
) as ws: | |
self.assertEqual(ws.headers.get("X-Test-Hello"), "hello") | |
self.assertEqual( | |
ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value" | |
) | |
def test_server_close_reason(self): | |
ws = yield self.ws_connect("/close_reason") | |
msg = yield ws.read_message() | |
# A message of None means the other side closed the connection. | |
self.assertIs(msg, None) | |
self.assertEqual(ws.close_code, 1001) | |
self.assertEqual(ws.close_reason, "goodbye") | |
# The on_close callback is called no matter which side closed. | |
code, reason = yield self.close_future | |
# The client echoed the close code it received to the server, | |
# so the server's close code (returned via close_future) is | |
# the same. | |
self.assertEqual(code, 1001) | |
def test_client_close_reason(self): | |
ws = yield self.ws_connect("/echo") | |
ws.close(1001, "goodbye") | |
code, reason = yield self.close_future | |
self.assertEqual(code, 1001) | |
self.assertEqual(reason, "goodbye") | |
def test_write_after_close(self): | |
ws = yield self.ws_connect("/close_reason") | |
msg = yield ws.read_message() | |
self.assertIs(msg, None) | |
with self.assertRaises(WebSocketClosedError): | |
ws.write_message("hello") | |
def test_async_prepare(self): | |
# Previously, an async prepare method triggered a bug that would | |
# result in a timeout on test shutdown (and a memory leak). | |
ws = yield self.ws_connect("/async_prepare") | |
ws.write_message("hello") | |
res = yield ws.read_message() | |
self.assertEqual(res, "hello") | |
def test_path_args(self): | |
ws = yield self.ws_connect("/path_args/hello") | |
res = yield ws.read_message() | |
self.assertEqual(res, "hello") | |
def test_coroutine(self): | |
ws = yield self.ws_connect("/coroutine") | |
# Send both messages immediately, coroutine must process one at a time. | |
yield ws.write_message("hello1") | |
yield ws.write_message("hello2") | |
res = yield ws.read_message() | |
self.assertEqual(res, "hello1") | |
res = yield ws.read_message() | |
self.assertEqual(res, "hello2") | |
def test_check_origin_valid_no_path(self): | |
port = self.get_http_port() | |
url = "ws://127.0.0.1:%d/echo" % port | |
headers = {"Origin": "http://127.0.0.1:%d" % port} | |
with contextlib.closing( | |
(yield websocket_connect(HTTPRequest(url, headers=headers))) | |
) as ws: | |
ws.write_message("hello") | |
response = yield ws.read_message() | |
self.assertEqual(response, "hello") | |
def test_check_origin_valid_with_path(self): | |
port = self.get_http_port() | |
url = "ws://127.0.0.1:%d/echo" % port | |
headers = {"Origin": "http://127.0.0.1:%d/something" % port} | |
with contextlib.closing( | |
(yield websocket_connect(HTTPRequest(url, headers=headers))) | |
) as ws: | |
ws.write_message("hello") | |
response = yield ws.read_message() | |
self.assertEqual(response, "hello") | |
def test_check_origin_invalid_partial_url(self): | |
port = self.get_http_port() | |
url = "ws://127.0.0.1:%d/echo" % port | |
headers = {"Origin": "127.0.0.1:%d" % port} | |
with self.assertRaises(HTTPError) as cm: | |
yield websocket_connect(HTTPRequest(url, headers=headers)) | |
self.assertEqual(cm.exception.code, 403) | |
def test_check_origin_invalid(self): | |
port = self.get_http_port() | |
url = "ws://127.0.0.1:%d/echo" % port | |
# Host is 127.0.0.1, which should not be accessible from some other | |
# domain | |
headers = {"Origin": "http://somewhereelse.com"} | |
with self.assertRaises(HTTPError) as cm: | |
yield websocket_connect(HTTPRequest(url, headers=headers)) | |
self.assertEqual(cm.exception.code, 403) | |
def test_check_origin_invalid_subdomains(self): | |
port = self.get_http_port() | |
# CaresResolver may return ipv6-only results for localhost, but our | |
# server is only running on ipv4. Test for this edge case and skip | |
# the test if it happens. | |
addrinfo = yield Resolver().resolve("localhost", port) | |
families = set(addr[0] for addr in addrinfo) | |
if socket.AF_INET not in families: | |
self.skipTest("localhost does not resolve to ipv4") | |
return | |
url = "ws://localhost:%d/echo" % port | |
# Subdomains should be disallowed by default. If we could pass a | |
# resolver to websocket_connect we could test sibling domains as well. | |
headers = {"Origin": "http://subtenant.localhost"} | |
with self.assertRaises(HTTPError) as cm: | |
yield websocket_connect(HTTPRequest(url, headers=headers)) | |
self.assertEqual(cm.exception.code, 403) | |
def test_subprotocols(self): | |
ws = yield self.ws_connect( | |
"/subprotocol", subprotocols=["badproto", "goodproto"] | |
) | |
self.assertEqual(ws.selected_subprotocol, "goodproto") | |
res = yield ws.read_message() | |
self.assertEqual(res, "subprotocol=goodproto") | |
def test_subprotocols_not_offered(self): | |
ws = yield self.ws_connect("/subprotocol") | |
self.assertIs(ws.selected_subprotocol, None) | |
res = yield ws.read_message() | |
self.assertEqual(res, "subprotocol=None") | |
def test_open_coroutine(self): | |
self.message_sent = Event() | |
ws = yield self.ws_connect("/open_coroutine") | |
yield ws.write_message("hello") | |
self.message_sent.set() | |
res = yield ws.read_message() | |
self.assertEqual(res, "ok") | |
def test_error_in_open(self): | |
with ExpectLog(app_log, "Uncaught exception"): | |
ws = yield self.ws_connect("/error_in_open") | |
res = yield ws.read_message() | |
self.assertIsNone(res) | |
def test_error_in_async_open(self): | |
with ExpectLog(app_log, "Uncaught exception"): | |
ws = yield self.ws_connect("/error_in_async_open") | |
res = yield ws.read_message() | |
self.assertIsNone(res) | |
def test_nodelay(self): | |
ws = yield self.ws_connect("/nodelay") | |
res = yield ws.read_message() | |
self.assertEqual(res, "hello") | |
class NativeCoroutineOnMessageHandler(TestWebSocketHandler): | |
def initialize(self, **kwargs): | |
super().initialize(**kwargs) | |
self.sleeping = 0 | |
async def on_message(self, message): | |
if self.sleeping > 0: | |
self.write_message("another coroutine is already sleeping") | |
self.sleeping += 1 | |
await gen.sleep(0.01) | |
self.sleeping -= 1 | |
self.write_message(message) | |
class WebSocketNativeCoroutineTest(WebSocketBaseTestCase): | |
def get_app(self): | |
return Application([("/native", NativeCoroutineOnMessageHandler)]) | |
def test_native_coroutine(self): | |
ws = yield self.ws_connect("/native") | |
# Send both messages immediately, coroutine must process one at a time. | |
yield ws.write_message("hello1") | |
yield ws.write_message("hello2") | |
res = yield ws.read_message() | |
self.assertEqual(res, "hello1") | |
res = yield ws.read_message() | |
self.assertEqual(res, "hello2") | |
class CompressionTestMixin(object): | |
MESSAGE = "Hello world. Testing 123 123" | |
def get_app(self): | |
class LimitedHandler(TestWebSocketHandler): | |
def max_message_size(self): | |
return 1024 | |
def on_message(self, message): | |
self.write_message(str(len(message))) | |
return Application( | |
[ | |
( | |
"/echo", | |
EchoHandler, | |
dict(compression_options=self.get_server_compression_options()), | |
), | |
( | |
"/limited", | |
LimitedHandler, | |
dict(compression_options=self.get_server_compression_options()), | |
), | |
] | |
) | |
def get_server_compression_options(self): | |
return None | |
def get_client_compression_options(self): | |
return None | |
def verify_wire_bytes(self, bytes_in: int, bytes_out: int) -> None: | |
raise NotImplementedError() | |
def test_message_sizes(self: typing.Any): | |
ws = yield self.ws_connect( | |
"/echo", compression_options=self.get_client_compression_options() | |
) | |
# Send the same message three times so we can measure the | |
# effect of the context_takeover options. | |
for i in range(3): | |
ws.write_message(self.MESSAGE) | |
response = yield ws.read_message() | |
self.assertEqual(response, self.MESSAGE) | |
self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3) | |
self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3) | |
self.verify_wire_bytes(ws.protocol._wire_bytes_in, ws.protocol._wire_bytes_out) | |
def test_size_limit(self: typing.Any): | |
ws = yield self.ws_connect( | |
"/limited", compression_options=self.get_client_compression_options() | |
) | |
# Small messages pass through. | |
ws.write_message("a" * 128) | |
response = yield ws.read_message() | |
self.assertEqual(response, "128") | |
# This message is too big after decompression, but it compresses | |
# down to a size that will pass the initial checks. | |
ws.write_message("a" * 2048) | |
response = yield ws.read_message() | |
self.assertIsNone(response) | |
class UncompressedTestMixin(CompressionTestMixin): | |
"""Specialization of CompressionTestMixin when we expect no compression.""" | |
def verify_wire_bytes(self: typing.Any, bytes_in, bytes_out): | |
# Bytes out includes the 4-byte mask key per message. | |
self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6)) | |
self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2)) | |
class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase): | |
pass | |
# If only one side tries to compress, the extension is not negotiated. | |
class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase): | |
def get_server_compression_options(self): | |
return {} | |
class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase): | |
def get_client_compression_options(self): | |
return {} | |
class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase): | |
def get_server_compression_options(self): | |
return {} | |
def get_client_compression_options(self): | |
return {} | |
def verify_wire_bytes(self, bytes_in, bytes_out): | |
self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6)) | |
self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2)) | |
# Bytes out includes the 4 bytes mask key per message. | |
self.assertEqual(bytes_out, bytes_in + 12) | |
class MaskFunctionMixin(object): | |
# Subclasses should define self.mask(mask, data) | |
def mask(self, mask: bytes, data: bytes) -> bytes: | |
raise NotImplementedError() | |
def test_mask(self: typing.Any): | |
self.assertEqual(self.mask(b"abcd", b""), b"") | |
self.assertEqual(self.mask(b"abcd", b"b"), b"\x03") | |
self.assertEqual(self.mask(b"abcd", b"54321"), b"TVPVP") | |
self.assertEqual(self.mask(b"ZXCV", b"98765432"), b"c`t`olpd") | |
# Include test cases with \x00 bytes (to ensure that the C | |
# extension isn't depending on null-terminated strings) and | |
# bytes with the high bit set (to smoke out signedness issues). | |
self.assertEqual( | |
self.mask(b"\x00\x01\x02\x03", b"\xff\xfb\xfd\xfc\xfe\xfa"), | |
b"\xff\xfa\xff\xff\xfe\xfb", | |
) | |
self.assertEqual( | |
self.mask(b"\xff\xfb\xfd\xfc", b"\x00\x01\x02\x03\x04\x05"), | |
b"\xff\xfa\xff\xff\xfb\xfe", | |
) | |
class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase): | |
def mask(self, mask, data): | |
return _websocket_mask_python(mask, data) | |
class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase): | |
def mask(self, mask, data): | |
return speedups.websocket_mask(mask, data) | |
class ServerPeriodicPingTest(WebSocketBaseTestCase): | |
def get_app(self): | |
class PingHandler(TestWebSocketHandler): | |
def on_pong(self, data): | |
self.write_message("got pong") | |
return Application([("/", PingHandler)], websocket_ping_interval=0.01) | |
def test_server_ping(self): | |
ws = yield self.ws_connect("/") | |
for i in range(3): | |
response = yield ws.read_message() | |
self.assertEqual(response, "got pong") | |
# TODO: test that the connection gets closed if ping responses stop. | |
class ClientPeriodicPingTest(WebSocketBaseTestCase): | |
def get_app(self): | |
class PingHandler(TestWebSocketHandler): | |
def on_ping(self, data): | |
self.write_message("got ping") | |
return Application([("/", PingHandler)]) | |
def test_client_ping(self): | |
ws = yield self.ws_connect("/", ping_interval=0.01) | |
for i in range(3): | |
response = yield ws.read_message() | |
self.assertEqual(response, "got ping") | |
# TODO: test that the connection gets closed if ping responses stop. | |
ws.close() | |
class ManualPingTest(WebSocketBaseTestCase): | |
def get_app(self): | |
class PingHandler(TestWebSocketHandler): | |
def on_ping(self, data): | |
self.write_message(data, binary=isinstance(data, bytes)) | |
return Application([("/", PingHandler)]) | |
def test_manual_ping(self): | |
ws = yield self.ws_connect("/") | |
self.assertRaises(ValueError, ws.ping, "a" * 126) | |
ws.ping("hello") | |
resp = yield ws.read_message() | |
# on_ping always sees bytes. | |
self.assertEqual(resp, b"hello") | |
ws.ping(b"binary hello") | |
resp = yield ws.read_message() | |
self.assertEqual(resp, b"binary hello") | |
class MaxMessageSizeTest(WebSocketBaseTestCase): | |
def get_app(self): | |
return Application([("/", EchoHandler)], websocket_max_message_size=1024) | |
def test_large_message(self): | |
ws = yield self.ws_connect("/") | |
# Write a message that is allowed. | |
msg = "a" * 1024 | |
ws.write_message(msg) | |
resp = yield ws.read_message() | |
self.assertEqual(resp, msg) | |
# Write a message that is too large. | |
ws.write_message(msg + "b") | |
resp = yield ws.read_message() | |
# A message of None means the other side closed the connection. | |
self.assertIs(resp, None) | |
self.assertEqual(ws.close_code, 1009) | |
self.assertEqual(ws.close_reason, "message too big") | |
# TODO: Needs tests of messages split over multiple | |
# continuation frames. | |