import asyncio import datetime from typing import Dict, List from fastapi import WebSocket from backend.report_type import BasicReport, DetailedReport from backend.chat import ChatAgentWithMemory from gpt_researcher.utils.enum import ReportType, Tone from multi_agents.main import run_research_task from gpt_researcher.actions import stream_output # Import stream_output from backend.server.server_utils import CustomLogsHandler class WebSocketManager: """Manage websockets""" def __init__(self): """Initialize the WebSocketManager class.""" self.active_connections: List[WebSocket] = [] self.sender_tasks: Dict[WebSocket, asyncio.Task] = {} self.message_queues: Dict[WebSocket, asyncio.Queue] = {} self.chat_agent = None async def start_sender(self, websocket: WebSocket): """Start the sender task.""" queue = self.message_queues.get(websocket) if not queue: return while True: message = await queue.get() if websocket in self.active_connections: try: if message == "ping": await websocket.send_text("pong") else: await websocket.send_text(message) except: break else: break async def connect(self, websocket: WebSocket): """Connect a websocket.""" await websocket.accept() self.active_connections.append(websocket) self.message_queues[websocket] = asyncio.Queue() self.sender_tasks[websocket] = asyncio.create_task( self.start_sender(websocket)) async def disconnect(self, websocket: WebSocket): """Disconnect a websocket.""" if websocket in self.active_connections: self.active_connections.remove(websocket) self.sender_tasks[websocket].cancel() await self.message_queues[websocket].put(None) del self.sender_tasks[websocket] del self.message_queues[websocket] async def start_streaming(self, task, report_type, report_source, source_urls, document_urls, tone, websocket, headers=None): """Start streaming the output.""" tone = Tone[tone] # add customized JSON config file path here config_path = "default" report = await run_agent(task, report_type, report_source, source_urls, document_urls, tone, websocket, headers = headers, config_path = config_path) #Create new Chat Agent whenever a new report is written self.chat_agent = ChatAgentWithMemory(report, config_path, headers) return report async def chat(self, message, websocket): """Chat with the agent based message diff""" if self.chat_agent: await self.chat_agent.chat(message, websocket) else: await websocket.send_json({"type": "chat", "content": "Knowledge empty, please run the research first to obtain knowledge"}) async def run_agent(task, report_type, report_source, source_urls, document_urls, tone: Tone, websocket, headers=None, config_path=""): """Run the agent.""" start_time = datetime.datetime.now() # Create logs handler for this research task logs_handler = CustomLogsHandler(websocket, task) # Initialize researcher based on report type if report_type == "multi_agents": report = await run_research_task( query=task, websocket=logs_handler, # Use logs_handler instead of raw websocket stream_output=stream_output, tone=tone, headers=headers ) report = report.get("report", "") elif report_type == ReportType.DetailedReport.value: researcher = DetailedReport( query=task, report_type=report_type, report_source=report_source, source_urls=source_urls, document_urls=document_urls, tone=tone, config_path=config_path, websocket=logs_handler, # Use logs_handler instead of raw websocket headers=headers ) report = await researcher.run() else: researcher = BasicReport( query=task, report_type=report_type, report_source=report_source, source_urls=source_urls, document_urls=document_urls, tone=tone, config_path=config_path, websocket=logs_handler, # Use logs_handler instead of raw websocket headers=headers ) report = await researcher.run() return report