Spaces:
Paused
Paused
import aiohttp | |
import json | |
import urllib | |
from networks import ( | |
ChathubRequestPayloadConstructor, | |
ConversationRequestHeadersConstructor, | |
MessageParser, | |
OpenaiStreamOutputer, | |
) | |
from conversations import ConversationStyle | |
from utils.logger import logger | |
from utils.enver import enver | |
class ConversationConnector: | |
""" | |
Input params: | |
- `sec_access_token`, `client_id`, `conversation_id` | |
- Generated by `ConversationCreator` | |
- `invocation_id` (int): | |
- For 1st request, this value must be `0`. | |
- For all requests after, any integer is valid. | |
- To make it simple, use `1` for all requests after the 1st one. | |
""" | |
def __init__( | |
self, | |
conversation_style: str = ConversationStyle.PRECISE.value, | |
sec_access_token: str = "", | |
client_id: str = "", | |
conversation_id: str = "", | |
invocation_id: int = 0, | |
cookies={}, | |
): | |
conversation_style_enum_values = [ | |
style.value for style in ConversationStyle.__members__.values() | |
] | |
if conversation_style.lower() not in conversation_style_enum_values: | |
self.conversation_style = ConversationStyle.PRECISE.value | |
else: | |
self.conversation_style = conversation_style.lower() | |
print(f"Model: [{self.conversation_style}]") | |
self.sec_access_token = sec_access_token | |
self.client_id = client_id | |
self.conversation_id = conversation_id | |
self.invocation_id = invocation_id | |
self.cookies = cookies | |
async def wss_send(self, message): | |
serialized_websocket_message = json.dumps(message, ensure_ascii=False) + "\x1e" | |
await self.wss.send_str(serialized_websocket_message) | |
async def init_handshake(self): | |
await self.wss_send({"protocol": "json", "version": 1}) | |
await self.wss.receive_str() | |
await self.wss_send({"type": 6}) | |
async def connect(self): | |
self.quotelized_sec_access_token = urllib.parse.quote(self.sec_access_token) | |
self.ws_url = ( | |
f"wss://sydney.bing.com/sydney/ChatHub" | |
f"?sec_access_token={self.quotelized_sec_access_token}" | |
) | |
self.aiohttp_session = aiohttp.ClientSession(cookies=self.cookies) | |
headers_constructor = ConversationRequestHeadersConstructor() | |
enver.set_envs(proxies=True) | |
self.wss = await self.aiohttp_session.ws_connect( | |
self.ws_url, | |
headers=headers_constructor.request_headers, | |
proxy=enver.proxy, | |
) | |
await self.init_handshake() | |
async def send_chathub_request(self, prompt: str, system_prompt: str = None): | |
payload_constructor = ChathubRequestPayloadConstructor( | |
prompt=prompt, | |
conversation_style=self.conversation_style, | |
client_id=self.client_id, | |
conversation_id=self.conversation_id, | |
invocation_id=self.invocation_id, | |
system_prompt=system_prompt, | |
) | |
self.connect_request_payload = payload_constructor.request_payload | |
await self.wss_send(self.connect_request_payload) | |
async def stream_chat( | |
self, prompt: str = "", system_prompt: str = None, yield_output=False | |
): | |
await self.connect() | |
await self.send_chathub_request(prompt=prompt, system_prompt=system_prompt) | |
message_parser = MessageParser(outputer=OpenaiStreamOutputer()) | |
has_output_role_message = False | |
if yield_output and not has_output_role_message: | |
has_output_role_message = True | |
yield message_parser.outputer.output(content="", content_type="Role") | |
while not self.wss.closed: | |
response_lines_str = await self.wss.receive_str() | |
if isinstance(response_lines_str, str): | |
response_lines = response_lines_str.split("\x1e") | |
else: | |
continue | |
for line in response_lines: | |
if not line: | |
continue | |
data = json.loads(line) | |
# Stream: Meaningful Messages | |
if data.get("type") == 1: | |
if yield_output: | |
output = message_parser.parse(data, return_output=True) | |
if isinstance(output, list): | |
for item in output: | |
yield item | |
else: | |
if output: | |
yield output | |
else: | |
message_parser.parse(data) | |
# Stream: List of all messages in the whole conversation | |
elif data.get("type") == 2: | |
if data.get("item"): | |
# item = data.get("item") | |
# logger.note("\n[Saving chat messages ...]") | |
pass | |
# Stream: End of Conversation | |
elif data.get("type") == 3: | |
finished_str = "\n[Finished]" | |
logger.success(finished_str) | |
self.invocation_id += 1 | |
await self.wss.close() | |
await self.aiohttp_session.close() | |
if yield_output: | |
yield message_parser.outputer.output( | |
content=finished_str, content_type="Finished" | |
) | |
break | |
# Stream: Heartbeat Signal | |
elif data.get("type") == 6: | |
continue | |
# Stream: Not Implemented | |
else: | |
continue | |