whatsthispdf08 / app /main.py
mitulagr2's picture
minor fix
3e9ee7c
raw
history blame
2.44 kB
import os
import logging
import pathlib
import time
import re
from typing import List
from fastapi import FastAPI, Request, Depends
from fastapi.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from .rag import ChatPDF
middleware = [
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=['*'],
allow_headers=['*']
)
]
app = FastAPI(middleware=middleware)
files_dir = os.path.expanduser("~/wtp_be_files/")
os.makedirs(files_dir)
session_assistant = ChatPDF()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def astreamer(generator):
t0 = time.time()
for i in generator:
logger.info(f"Chunk being yielded (time {int((time.time()-t0)*1000)}ms)")
yield i
logger.info(f"Over (time {int((time.time()-t0)*1000)}ms)")
@app.get("/query")
async def process_input(text: str):
generator = None
if text and len(text.strip()) > 0:
if session_assistant.pdf_count > 0:
text = text.strip()
streaming_response = session_assistant.ask(text)
generator = streaming_response.response_gen
else:
message = "Please add a PDF document first."
generator = re.split(r'(\s)', message)
else:
message = "The provided query is empty."
generator = re.split(r'(\s)', message)
return StreamingResponse(astreamer(generator), media_type='text/event-stream')
async def parse_body(request: Request):
data: bytes = await request.body()
return data
@app.post("/upload")
def upload(data: bytes = Depends(parse_body)):
try:
path = f"{files_dir}/file"
with open(path, "wb") as f:
f.write(data)
session_assistant.ingest(files_dir)
pathlib.Path(path).unlink()
except Exception as e:
logging.error(traceback.format_exc())
message = "Files inserted successfully."
generator = re.split(r'(\s)', message)
return StreamingResponse(astreamer(generator), media_type='text/event-stream')
@app.get("/clear")
def ping():
session_assistant.clear()
message = "All files have been cleared. The first query may take a little longer."
generator = re.split(r'(\s)', message)
return StreamingResponse(astreamer(generator), media_type='text/event-stream')
@app.get("/")
def ping():
return "Pong!"