import os import logging import pathlib import time import re from typing import List from fastapi import FastAPI, Request, Depends from fastapi.middleware import Middleware from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from .rag import ChatPDF middleware = [ Middleware( CORSMiddleware, allow_origins=["*"], allow_methods=['*'], allow_headers=['*'] ) ] app = FastAPI(middleware=middleware) files_dir = os.path.expanduser("~/wtp_be_files/") os.makedirs(files_dir) session_assistant = ChatPDF() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def astreamer(generator): t0 = time.time() for i in generator: logger.info(f"Chunk being yielded (time {int((time.time()-t0)*1000)}ms)") yield i logger.info(f"Over (time {int((time.time()-t0)*1000)}ms)") @app.get("/query") async def process_input(text: str): generator = None if text and len(text.strip()) > 0: if session_assistant.pdf_count > 0: text = text.strip() streaming_response = session_assistant.ask(text) generator = streaming_response.response_gen else: message = "Please add a PDF document first." generator = re.split(r'(\s)', message) else: message = "The provided query is empty." generator = re.split(r'(\s)', message) return StreamingResponse(astreamer(generator), media_type='text/event-stream') async def parse_body(request: Request): data: bytes = await request.body() return data @app.post("/upload") def upload(data: bytes = Depends(parse_body)): print("Data : " + data) if data: try: path = f"{files_dir}/file" with open(path, "wb") as f: f.write(data) session_assistant.ingest(files_dir) pathlib.Path(path).unlink() except Exception as e: logging.error(traceback.format_exc()) message = "Files inserted successfully." generator = re.split(r'(\s)', message) return StreamingResponse(astreamer(generator), media_type='text/event-stream') @app.get("/clear") def ping(): session_assistant.clear() message = "All files have been cleared. The first query may take a little longer." generator = re.split(r'(\s)', message) return StreamingResponse(astreamer(generator), media_type='text/event-stream') @app.get("/") def ping(): return "Pong!"