Spaces:
Running
Running
import os | |
import logging | |
import pathlib | |
import time | |
import re | |
from typing import List | |
from fastapi import FastAPI, Request, Depends | |
from fastapi.middleware import Middleware | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse | |
from .rag import ChatPDF | |
middleware = [ | |
Middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=['*'], | |
allow_headers=['*'] | |
) | |
] | |
app = FastAPI(middleware=middleware) | |
files_dir = os.path.expanduser("~/wtp_be_files/") | |
os.makedirs(files_dir) | |
session_assistant = ChatPDF() | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def astreamer(generator): | |
t0 = time.time() | |
for i in generator: | |
logger.info(f"Chunk being yielded (time {int((time.time()-t0)*1000)}ms)") | |
yield i | |
logger.info(f"Over (time {int((time.time()-t0)*1000)}ms)") | |
async def process_input(text: str): | |
generator = None | |
if text and len(text.strip()) > 0: | |
if session_assistant.pdf_count > 0: | |
text = text.strip() | |
streaming_response = session_assistant.ask(text) | |
generator = streaming_response.response_gen | |
else: | |
message = "Please add a PDF document first." | |
generator = re.split(r'(\s)', message) | |
else: | |
message = "The provided query is empty." | |
generator = re.split(r'(\s)', message) | |
return StreamingResponse(astreamer(generator), media_type='text/event-stream') | |
async def parse_body(request: Request): | |
data: bytes = await request.body() | |
return data | |
def upload(filename: str, data: bytes = Depends(parse_body)): | |
if data: | |
try: | |
print("Filename: " + filename) | |
path = f"{files_dir}/{filename}" | |
with open(path, "wb") as f: | |
f.write(data) | |
session_assistant.ingest(files_dir) | |
pathlib.Path(path).unlink() | |
except Exception as e: | |
logging.error(traceback.format_exc()) | |
message = "Files inserted successfully." | |
generator = re.split(r'(\s)', message) | |
return StreamingResponse(astreamer(generator), media_type='text/event-stream') | |
def ping(): | |
session_assistant.clear() | |
message = "All files have been cleared. The first query may take a little longer." | |
generator = re.split(r'(\s)', message) | |
return StreamingResponse(astreamer(generator), media_type='text/event-stream') | |
def ping(): | |
return "Pong!" | |