Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile,File,HTTPException | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from dotenv import load_dotenv | |
from langchain_community.document_loaders import PyMuPDFLoader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_groq import ChatGroq | |
from langchain_pinecone import PineconeVectorStore | |
from langchain_core.runnables import RunnablePassthrough | |
from pathlib import Path | |
import uvicorn | |
import shutil | |
import os | |
import hashlib | |
from pinecone import Pinecone | |
import fitz | |
import pytesseract | |
from PIL import Image | |
from langchain.schema import Document | |
import io | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
UPLOAD_DIR = "uploads" | |
os.makedirs(UPLOAD_DIR, exist_ok=True) | |
index_name = "pinecone-chatbot" | |
load_dotenv() | |
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN") | |
os.environ["PINECONE_API_KEY"] = os.getenv("PINECONE_API_KEY") | |
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY") | |
llm = ChatGroq(model_name = "Llama3-8b-8192") | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
prompt = '''You are given a context below. Use it to answer the question that follows. | |
Provide a concise and factual response. If the answer is not in the context, simply state "I don't know based on context provided." | |
<context> | |
{context} | |
</context> | |
Question: {question} | |
Answer:''' | |
parser = StrOutputParser() | |
pc = Pinecone(api_key=os.environ.get("PINECONE_API_KEY")) | |
index = pc.Index(name=index_name) | |
def generate_file_id(file_path): | |
hasher = hashlib.md5() | |
with open(file_path, "rb") as f: | |
hasher.update(f.read()) | |
return hasher.hexdigest() | |
def delete_existing_embedding(file_id): | |
index_stats = index.describe_index_stats() | |
if index_stats["total_vector_count"] > 0: | |
index.delete(delete_all=True) | |
def tempUploadFile(filePath,file): | |
with open(filePath,'wb') as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
def loadAndSplitDocuments(filePath): | |
loader = PyMuPDFLoader(filePath) | |
docs = loader.load() | |
splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=500) | |
final_chunks = splitter.split_documents(docs) | |
return final_chunks | |
def loadAndSplitPdfFile(filePath): | |
doc = fitz.open(filePath) | |
documents = [] | |
for i, page in enumerate(doc): | |
text = page.get_text("text") # Extract text from page | |
metadata = {"source": filePath, "page": i + 1} | |
if text.strip(): | |
documents.append(Document(page_content=text, metadata=metadata)) | |
# Extract and process images with OCR | |
images = page.get_images(full=True) | |
for img_index, img in enumerate(images): | |
xref = img[0] | |
base_image = doc.extract_image(xref) | |
image_bytes = base_image["image"] | |
img = Image.open(io.BytesIO(image_bytes)) | |
# Perform OCR on the image | |
ocr_text = pytesseract.image_to_string(img) | |
if ocr_text.strip(): | |
img_metadata = metadata.copy() | |
img_metadata["type"] = "image" | |
img_metadata["image_index"] = img_index | |
documents.append(Document(page_content=ocr_text, metadata=img_metadata)) | |
splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=500) | |
final_chunks = splitter.split_documents(documents) | |
return final_chunks | |
def prepare_retriever(filePath = "", load_from_pinecone = False): | |
if load_from_pinecone: | |
vector_store = PineconeVectorStore.from_existing_index(index_name, embeddings) | |
return vector_store.as_retriever(search_kwargs={"k": 5}) | |
elif filePath: | |
doc_chunks = loadAndSplitPdfFile(filePath) | |
vector_data = [] | |
for i, doc in enumerate(doc_chunks): | |
embedding = embeddings.embed_query(doc.page_content) | |
if embedding: | |
metadata = { | |
"text": doc.page_content, | |
"source": doc.metadata.get("source", "unknown"), | |
"page": doc.metadata.get("page", i), # Add page info if available | |
} | |
vector_data.append((str(i), embedding, metadata)) | |
print(f"Upserting {len(vector_data)} records into Pinecone...") | |
index.upsert(vectors=vector_data) | |
print("Upsert complete") | |
def get_retriever_chain(retriever): | |
chat_prompt = ChatPromptTemplate.from_template(prompt) | |
chain =({"context": retriever, "question": RunnablePassthrough()} | chat_prompt | llm | parser) | |
return chain | |
def UploadFileInStore(file: UploadFile = File(...)): | |
if not file.filename.endswith('.pdf'): | |
raise HTTPException(status_code=400, detail="File must be a pdf file") | |
filePath = Path(UPLOAD_DIR) / file.filename | |
tempUploadFile(filePath,file) | |
file_id = generate_file_id(filePath) | |
delete_existing_embedding(file_id) | |
prepare_retriever(filePath) | |
if os.path.exists(filePath): | |
os.remove(filePath) | |
return JSONResponse({"message": "File uploaded successfully"}) | |
async def QnAFromPdf(query: str): | |
retriever = prepare_retriever(load_from_pinecone=True) | |
chain = get_retriever_chain(retriever) | |
response = chain.invoke(query) | |
return response | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |