SoumyaJ's picture
Update app.py
75f3f8a verified
raw
history blame
5.77 kB
from fastapi import FastAPI, UploadFile,File,HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_groq import ChatGroq
from langchain_pinecone import PineconeVectorStore
from langchain_core.runnables import RunnablePassthrough
from pathlib import Path
import uvicorn
import shutil
import os
import hashlib
from pinecone import Pinecone
import fitz
import pytesseract
from PIL import Image
from langchain.schema import Document
import io
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
UPLOAD_DIR = "uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)
index_name = "pinecone-chatbot"
load_dotenv()
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
os.environ["PINECONE_API_KEY"] = os.getenv("PINECONE_API_KEY")
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
llm = ChatGroq(model_name = "Llama3-8b-8192")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
prompt = '''You are given a context below. Use it to answer the question that follows.
Provide a concise and factual response. If the answer is not in the context, simply state "I don't know based on context provided."
<context>
{context}
</context>
Question: {question}
Answer:'''
parser = StrOutputParser()
pc = Pinecone(api_key=os.environ.get("PINECONE_API_KEY"))
index = pc.Index(name=index_name)
def generate_file_id(file_path):
hasher = hashlib.md5()
with open(file_path, "rb") as f:
hasher.update(f.read())
return hasher.hexdigest()
def delete_existing_embedding(file_id):
index_stats = index.describe_index_stats()
if index_stats["total_vector_count"] > 0:
index.delete(delete_all=True)
def tempUploadFile(filePath,file):
with open(filePath,'wb') as buffer:
shutil.copyfileobj(file.file, buffer)
def loadAndSplitDocuments(filePath):
loader = PyMuPDFLoader(filePath)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=500)
final_chunks = splitter.split_documents(docs)
return final_chunks
def loadAndSplitPdfFile(filePath):
doc = fitz.open(filePath)
documents = []
for i, page in enumerate(doc):
text = page.get_text("text") # Extract text from page
metadata = {"source": filePath, "page": i + 1}
if text.strip():
documents.append(Document(page_content=text, metadata=metadata))
# Extract and process images with OCR
images = page.get_images(full=True)
for img_index, img in enumerate(images):
xref = img[0]
base_image = doc.extract_image(xref)
image_bytes = base_image["image"]
img = Image.open(io.BytesIO(image_bytes))
# Perform OCR on the image
ocr_text = pytesseract.image_to_string(img)
if ocr_text.strip():
img_metadata = metadata.copy()
img_metadata["type"] = "image"
img_metadata["image_index"] = img_index
documents.append(Document(page_content=ocr_text, metadata=img_metadata))
splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=500)
final_chunks = splitter.split_documents(documents)
return final_chunks
def prepare_retriever(filePath = "", load_from_pinecone = False):
if load_from_pinecone:
vector_store = PineconeVectorStore.from_existing_index(index_name, embeddings)
return vector_store.as_retriever(search_kwargs={"k": 5})
elif filePath:
doc_chunks = loadAndSplitPdfFile(filePath)
vector_data = []
for i, doc in enumerate(doc_chunks):
embedding = embeddings.embed_query(doc.page_content)
if embedding:
metadata = {
"text": doc.page_content,
"source": doc.metadata.get("source", "unknown"),
"page": doc.metadata.get("page", i), # Add page info if available
}
vector_data.append((str(i), embedding, metadata))
print(f"Upserting {len(vector_data)} records into Pinecone...")
index.upsert(vectors=vector_data)
print("Upsert complete")
def get_retriever_chain(retriever):
chat_prompt = ChatPromptTemplate.from_template(prompt)
chain =({"context": retriever, "question": RunnablePassthrough()} | chat_prompt | llm | parser)
return chain
@app.post("/UploadFileInStore")
def UploadFileInStore(file: UploadFile = File(...)):
if not file.filename.endswith('.pdf'):
raise HTTPException(status_code=400, detail="File must be a pdf file")
filePath = Path(UPLOAD_DIR) / file.filename
tempUploadFile(filePath,file)
file_id = generate_file_id(filePath)
delete_existing_embedding(file_id)
prepare_retriever(filePath)
if os.path.exists(filePath):
os.remove(filePath)
return JSONResponse({"message": "File uploaded successfully"})
@app.get("/QnAFromPdf")
async def QnAFromPdf(query: str):
retriever = prepare_retriever(load_from_pinecone=True)
chain = get_retriever_chain(retriever)
response = chain.invoke(query)
return response
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)