Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, UploadFile, File,Request,Depends,status | |
from fastapi.security import OAuth2PasswordBearer | |
from pydantic import BaseModel, Json | |
from uuid import uuid4, UUID | |
from typing import Optional | |
import pymupdf | |
from pinecone import Pinecone, ServerlessSpec | |
import os | |
from dotenv import load_dotenv | |
from rag import * | |
from fastapi.responses import StreamingResponse | |
import json | |
from prompts import * | |
from typing import Literal | |
from models import * | |
load_dotenv() | |
pinecone_api_key = os.environ.get("PINECONE_API_KEY") | |
common_namespace = os.environ.get("COMMON_NAMESPACE") | |
pc = Pinecone(api_key=pinecone_api_key) | |
import time | |
index_name = os.environ.get("INDEX_NAME") # change if desired | |
existing_indexes = [index_info["name"] for index_info in pc.list_indexes()] | |
if index_name not in existing_indexes: | |
pc.create_index( | |
name=index_name, | |
dimension=3072, | |
metric="cosine", | |
spec=ServerlessSpec(cloud="aws", region="us-east-1"), | |
) | |
while not pc.describe_index(index_name).status["ready"]: | |
time.sleep(1) | |
index = pc.Index(index_name) | |
api_keys = [os.environ.get("FASTAPI_API_KEY")] | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # use token authentication | |
def api_key_auth(api_key: str = Depends(oauth2_scheme)): | |
if api_key not in api_keys: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Forbidden" | |
) | |
app = FastAPI(dependencies=[Depends(api_key_auth)]) | |
# FASTAPI_KEY_NAME = os.environ.get("FASTAPI_KEY_NAME") | |
# FASTAPI_API_KEY = os.environ.get("FASTAPI_API_KEY") | |
# @app.middleware("http") | |
# async def api_key_middleware(request: Request, call_next): | |
# if request.url.path not in ["/","/docs","/openapi.json"]: | |
# api_key = request.headers.get(FASTAPI_KEY_NAME) | |
# if api_key != FASTAPI_API_KEY: | |
# raise HTTPException(status_code=403, detail="invalid API key :/") | |
# response = await call_next(request) | |
# return response | |
class StyleWriter(BaseModel): | |
style: Optional[str] = "neutral" | |
tonality: Optional[str] = "formal" | |
models = ["gpt-4o","gpt-4o-mini","mistral-large-latest"] | |
class UserInput(BaseModel): | |
prompt: str | |
enterprise_id: str | |
stream: Optional[bool] = False | |
messages: Optional[list[dict]] = [] | |
style_tonality: Optional[StyleWriter] = None | |
marque: Optional[str] = None | |
model: Literal["gpt-4o","gpt-4o-mini","mistral-large-latest","o1-preview"] = "gpt-4o" | |
class EnterpriseData(BaseModel): | |
name: str | |
id: Optional[str] = None | |
filename: Optional[str] = None | |
tasks = [] | |
def greet_json(): | |
return {"Hello": "World!"} | |
async def upload_file(file: UploadFile, enterprise_data: Json[EnterpriseData]): | |
try: | |
# Read the uploaded file | |
contents = await file.read() | |
enterprise_name = enterprise_data.name.replace(" ","_").replace("-","_").replace(".","_").replace("/","_").replace("\\","_").strip() | |
if enterprise_data.filename is not None: | |
filename = enterprise_data.filename | |
else: | |
filename = file.filename | |
# Assign a new UUID if id is not provided | |
if enterprise_data.id is None: | |
clean_name = remove_non_standard_ascii(enterprise_name) | |
enterprise_data.id = f"{clean_name}_{uuid4()}" | |
# Open the file with PyMuPDF | |
pdf_document = pymupdf.open(stream=contents, filetype="pdf") | |
# Extract all text from the document | |
text = "" | |
for page in pdf_document: | |
text += page.get_text() | |
# Split the text into chunks | |
text_chunks = get_text_chunks(text) | |
# Create a vector store | |
vector_store = get_vectorstore(text_chunks, filename=filename, file_type="pdf", namespace=enterprise_data.id, index=index,enterprise_name=enterprise_name) | |
if vector_store: | |
return { | |
"file_name":filename, | |
"enterprise_id": enterprise_data.id, | |
"number_of_chunks": len(text_chunks), | |
"filename_id":vector_store["filename_id"], | |
"enterprise_name":enterprise_name | |
} | |
else: | |
raise HTTPException(status_code=500, detail="Could not create vector store") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
finally: | |
await file.close() | |
def get_documents(enterprise_id: str): | |
try: | |
docs_names = [] | |
for ids in index.list(namespace=enterprise_id): | |
for id in ids: | |
name_doc = "_".join(id.split("_")[:-1]) | |
if name_doc not in docs_names: | |
docs_names.append(name_doc) | |
return docs_names | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
def delete_document(enterprise_id: str, filename_id: str): | |
try: | |
for ids in index.list(prefix=f"{filename_id}_", namespace=enterprise_id): | |
index.delete(ids=ids, namespace=enterprise_id) | |
return {"message": "Document deleted", "chunks_deleted": ids} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
def delete_all_documents(enterprise_id: str): | |
try: | |
index.delete(namespace=enterprise_id,delete_all=True) | |
return {"message": "All documents deleted"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
import async_timeout | |
import asyncio | |
GENERATION_TIMEOUT_SEC = 60 | |
async def stream_generator(response, prompt): | |
async with async_timeout.timeout(GENERATION_TIMEOUT_SEC): | |
try: | |
async for chunk in response: | |
if isinstance(chunk, bytes): | |
chunk = chunk.decode('utf-8') # Convert bytes to str if needed | |
yield json.dumps({"prompt": prompt, "content": chunk}) | |
except asyncio.TimeoutError: | |
raise HTTPException(status_code=504, detail="Stream timed out") | |
def generate_answer(user_input: UserInput): | |
try: | |
prompt = user_input.prompt | |
enterprise_id = user_input.enterprise_id | |
template_prompt = base_template | |
context = get_retreive_answer(enterprise_id, prompt, index, common_namespace) | |
#final_prompt_simplified = prompt_formatting(prompt,template,context) | |
if not context: | |
context = "" | |
if user_input.style_tonality is None: | |
prompt_formated = prompt_reformatting(template_prompt,context,prompt,enterprise_name=getattr(user_input,"marque","")) | |
answer = generate_response_via_langchain(prompt, | |
model=getattr(user_input,"model","gpt-4o"), | |
stream=user_input.stream,context = context , | |
messages=user_input.messages, | |
template=template_prompt, | |
enterprise_name=getattr(user_input,"marque",""), | |
enterprise_id=enterprise_id, | |
index=index) | |
else: | |
prompt_formated = prompt_reformatting(template_prompt, | |
context, | |
prompt, | |
style=getattr(user_input.style_tonality,"style","neutral"), | |
tonality=getattr(user_input.style_tonality,"tonality","formal"), | |
enterprise_name=getattr(user_input,"marque","")) | |
answer = generate_response_via_langchain(prompt,model=getattr(user_input,"model","gpt-4o"), | |
stream=user_input.stream,context = context , | |
messages=user_input.messages, | |
style=getattr(user_input.style_tonality,"style","neutral"), | |
tonality=getattr(user_input.style_tonality,"tonality","formal"), | |
template=template_prompt, | |
enterprise_name=getattr(user_input,"marque",""), | |
enterprise_id=enterprise_id, | |
index=index) | |
if user_input.stream: | |
return StreamingResponse(stream_generator(answer,prompt_formated), media_type="application/json") | |
return { | |
"prompt": prompt_formated, | |
"answer": answer, | |
"context": context, | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
def get_models(): | |
return {"models": models} | |