import os import shutil import json import asyncio from datetime import datetime from typing import List from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect from fastapi.middleware import Middleware from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from .rag import ChatPDF middleware = [ Middleware( CORSMiddleware, allow_origins=["*"], allow_methods=['*'], allow_headers=['*'] ) ] app = FastAPI(middleware=middleware) files_dir = os.path.expanduser("~/wtp_be_files/") session_assistant = ChatPDF() class ConnectionManager: def __init__(self): self.active_connections: List[WebSocket] = [] async def connect(self, websocket: WebSocket): await websocket.accept() self.active_connections.append(websocket) def disconnect(self, websocket: WebSocket): self.active_connections.remove(websocket) async def send_personal_message(self, message: str, websocket: WebSocket): await websocket.send_text(message) async def broadcast(self, message: str): for connection in self.active_connections: await connection.send_text(message) manager = ConnectionManager() @app.websocket("/ws/{client_id}") async def websocket_endpoint(websocket: WebSocket, client_id: int): await manager.connect(websocket) now = datetime.now() current_time = now.strftime("%H:%M") try: while True: data = await websocket.receive_text() data = data.strip() if len(data) > 0: if not session_assistant.pdf_count > 0: message = {"time":current_time,"clientId":client_id,"message":"Please, add a PDF document first."} await manager.send_personal_message(json.dumps(message), websocket) else: print("FETCHING STREAM") streaming_response = session_assistant.ask(data) print("STARTING STREAM") for text in streaming_response.response_gen: message = {"time":current_time,"clientId":client_id,"message":text} # await manager.broadcast(json.dumps(message)) await manager.send_personal_message(json.dumps(message), websocket) print("ENDING STREAM") except WebSocketDisconnect: manager.disconnect(websocket) message = {"time":current_time,"clientId":client_id,"message":"Offline"} await manager.broadcast(json.dumps(message)) async def astreamer(generator): try: print("streaming........") for i in generator: print(i) yield (i) await asyncio.sleep(.1) except asyncio.CancelledError as e: yield ('cancelled') @app.get("/query") async def process_input(text: str): if text and len(text.strip()) > 0: text = text.strip() streaming_response = session_assistant.ask(text) return StreamingResponse(astreamer(streaming_response.response_gen), media_type='text/event-stream') @app.post("/upload") def upload(files: list[UploadFile]): try: os.makedirs(files_dir) for file in files: try: path = f"{files_dir}/{file.filename}" file.file.seek(0) with open(path, 'wb') as destination: shutil.copyfileobj(file.file, destination) finally: file.file.close() finally: session_assistant.ingest(files_dir) shutil.rmtree(files_dir) return "Files inserted!" @app.get("/clear") def ping(): session_assistant.clear() return "All files have been cleared." @app.get("/") def ping(): return "Pong!"