Spaces:
Running
Running
import os | |
import shutil | |
import json | |
import asyncio | |
from datetime import datetime | |
from typing import List | |
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect | |
from fastapi.middleware import Middleware | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse | |
from .rag import ChatPDF | |
middleware = [ | |
Middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=['*'], | |
allow_headers=['*'] | |
) | |
] | |
app = FastAPI(middleware=middleware) | |
files_dir = os.path.expanduser("~/wtp_be_files/") | |
session_assistant = ChatPDF() | |
class ConnectionManager: | |
def __init__(self): | |
self.active_connections: List[WebSocket] = [] | |
async def connect(self, websocket: WebSocket): | |
await websocket.accept() | |
self.active_connections.append(websocket) | |
def disconnect(self, websocket: WebSocket): | |
self.active_connections.remove(websocket) | |
async def send_personal_message(self, message: str, websocket: WebSocket): | |
await websocket.send_text(message) | |
async def broadcast(self, message: str): | |
for connection in self.active_connections: | |
await connection.send_text(message) | |
manager = ConnectionManager() | |
async def websocket_endpoint(websocket: WebSocket, client_id: int): | |
await manager.connect(websocket) | |
now = datetime.now() | |
current_time = now.strftime("%H:%M") | |
try: | |
while True: | |
data = await websocket.receive_text() | |
data = data.strip() | |
if len(data) > 0: | |
if not session_assistant.pdf_count > 0: | |
message = {"time":current_time,"clientId":client_id,"message":"Please, add a PDF document first."} | |
await manager.send_personal_message(json.dumps(message), websocket) | |
else: | |
print("FETCHING STREAM") | |
streaming_response = session_assistant.ask(data) | |
print("STARTING STREAM") | |
for text in streaming_response.response_gen: | |
message = {"time":current_time,"clientId":client_id,"message":text} | |
# await manager.broadcast(json.dumps(message)) | |
await manager.send_personal_message(json.dumps(message), websocket) | |
print("ENDING STREAM") | |
except WebSocketDisconnect: | |
manager.disconnect(websocket) | |
message = {"time":current_time,"clientId":client_id,"message":"Offline"} | |
await manager.broadcast(json.dumps(message)) | |
async def astreamer(generator): | |
try: | |
print("streaming........") | |
for i in generator: | |
print(i) | |
yield (i) | |
await asyncio.sleep(.1) | |
except asyncio.CancelledError as e: | |
yield ('cancelled') | |
async def process_input(text: str): | |
if text and len(text.strip()) > 0: | |
text = text.strip() | |
streaming_response = session_assistant.ask(text) | |
return StreamingResponse(astreamer(streaming_response.response_gen), media_type='text/event-stream') | |
def upload(files: list[UploadFile]): | |
try: | |
os.makedirs(files_dir) | |
for file in files: | |
try: | |
path = f"{files_dir}/{file.filename}" | |
file.file.seek(0) | |
with open(path, 'wb') as destination: | |
shutil.copyfileobj(file.file, destination) | |
finally: | |
file.file.close() | |
finally: | |
session_assistant.ingest(files_dir) | |
shutil.rmtree(files_dir) | |
return "Files inserted!" | |
def ping(): | |
session_assistant.clear() | |
return "All files have been cleared." | |
def ping(): | |
return "Pong!" | |