Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile,File,HTTPException | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from dotenv import load_dotenv | |
from langchain_community.document_loaders import PyMuPDFLoader, UnstructuredPDFLoader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_groq import ChatGroq | |
from langchain_core.runnables import RunnablePassthrough | |
from pathlib import Path | |
import uvicorn | |
import shutil | |
import os | |
import hashlib | |
import fitz | |
import pytesseract | |
from PIL import Image | |
from langchain.schema import Document | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.vectorstores.utils import filter_complex_metadata | |
import io | |
import chromadb | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
UPLOAD_DIR = "uploads" | |
os.makedirs(UPLOAD_DIR, exist_ok=True) | |
persist_directory = "/home/user/.cache/chroma_db" | |
load_dotenv() | |
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN") | |
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY") | |
llm = ChatGroq(model_name = "qwen-2.5-32b") | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
prompt = '''You are an AI assistant tasked with answering questions based on the given context. | |
Follow these guidelines: | |
- If the answer is **explicitly stated in the context**, provide a **concise and factual response**. | |
- If the answer is **not in the context**, simply state: *"I don't know based on the provided context."* | |
- If the question requires **logical reasoning** based on the context, summarize the necessary details before answering. | |
- If the question is about **duration or summary**, calculate or extract the total duration and provide a brief overview. | |
- If the question contains **end date** and is not found in the context, consider it as **Completion date**. | |
<context> | |
{context} | |
</context> | |
Question: {question} | |
Answer:''' | |
parser = StrOutputParser() | |
def generate_file_id(file_path): | |
hasher = hashlib.md5() | |
with open(file_path, "rb") as f: | |
hasher.update(f.read()) | |
return hasher.hexdigest() | |
def delete_existing_embedding(file_id): | |
if os.path.exists(persist_directory): | |
client_settings = chromadb.config.Settings(allow_reset=True) | |
vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings,client_settings=client_settings) | |
vector_store.delete_collection() # Drop all stored vectors | |
def tempUploadFile(filePath,file): | |
with open(filePath,'wb') as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
def loadAndSplitDocuments(filePath): | |
loader = UnstructuredPDFLoader(filePath) | |
docs = loader.load() | |
splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=500) | |
final_chunks = splitter.split_documents(docs) | |
return final_chunks | |
def loadAndSplitPdfFile(filePath): | |
doc = fitz.open(filePath) | |
documents = [] | |
for i, page in enumerate(doc): | |
text = page.get_text("text") # Extract text from page | |
metadata = {"source": filePath, "page": i + 1} | |
if text.strip(): | |
documents.append(Document(page_content=text, metadata=metadata)) | |
# Extract and process images with OCR | |
images = page.get_images(full=True) | |
for img_index, img in enumerate(images): | |
xref = img[0] | |
base_image = doc.extract_image(xref) | |
image_bytes = base_image["image"] | |
img = Image.open(io.BytesIO(image_bytes)) | |
# Perform OCR on the image | |
ocr_text = pytesseract.image_to_string(img) | |
if ocr_text.strip(): | |
img_metadata = metadata.copy() | |
img_metadata["type"] = "image" | |
img_metadata["image_index"] = img_index | |
documents.append(Document(page_content=ocr_text, metadata=img_metadata)) | |
splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=500) | |
final_chunks = splitter.split_documents(documents) | |
return final_chunks | |
def prepare_retriever(filePath = "", load_from_chromadb = False): | |
if load_from_chromadb: | |
vector_store = Chroma(persist_directory=persist_directory, embedding_function = embeddings) | |
print("Total documents stored:", vector_store._collection.count()) | |
return vector_store.as_retriever(search_kwargs={"k": 10}) | |
elif filePath: | |
doc_chunks = loadAndSplitPdfFile(filePath) | |
print(f"Loaded {len(doc_chunks)} documents from {filePath}") | |
for doc in doc_chunks: | |
if hasattr(doc, "metadata") and isinstance(doc.metadata, dict): | |
# Convert Path objects to strings | |
doc.metadata = { | |
key: str(value) if isinstance(value, Path) else value | |
for key, value in doc.metadata.items() | |
if isinstance(value, (str, int, float, bool, Path)) | |
} | |
client_settings = chromadb.config.Settings( | |
allow_reset=True | |
) | |
vector_store = Chroma.from_documents(documents= doc_chunks, persist_directory=persist_directory, embedding= embeddings) | |
vector_store.persist() | |
def get_retriever_chain(retriever): | |
chat_prompt = ChatPromptTemplate.from_template(prompt) | |
chain =({"context": retriever, "question": RunnablePassthrough()} | chat_prompt | llm | parser) | |
return chain | |
def UploadFileInStore(file: UploadFile = File(...)): | |
if not file.filename.endswith('.pdf'): | |
raise HTTPException(status_code=400, detail="File must be a pdf file") | |
filePath = Path(UPLOAD_DIR) / file.filename | |
tempUploadFile(filePath,file) | |
file_id = generate_file_id(filePath) | |
delete_existing_embedding(file_id) | |
prepare_retriever(filePath) | |
if os.path.exists(filePath): | |
os.remove(filePath) | |
return JSONResponse({"message": "File uploaded successfully"}) | |
async def QnAFromPdf(query: str): | |
retriever = prepare_retriever(load_from_chromadb=True) | |
chain = get_retriever_chain(retriever) | |
response = chain.invoke(query) | |
return response | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |