Spaces:
Running
Running
import base64 | |
import chromadb | |
import gc | |
import gradio as gr | |
import io | |
import numpy as np | |
import ocrmypdf | |
import os | |
import pandas as pd | |
import pymupdf | |
import spaces | |
import torch | |
from PIL import Image | |
from chromadb.utils import embedding_functions | |
from chromadb.utils.data_loaders import ImageLoader | |
from langchain import PromptTemplate | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.llms import HuggingFaceEndpoint | |
from pdfminer.high_level import extract_text | |
from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor | |
from utils import * | |
if torch.cuda.is_available(): | |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
vision_model = LlavaNextForConditionalGeneration.from_pretrained( | |
"llava-hf/llava-v1.6-mistral-7b-hf", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
load_in_4bit=True, | |
) | |
def get_image_description(image): | |
torch.cuda.empty_cache() | |
gc.collect() | |
descriptions = [] | |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]" | |
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0") | |
output = vision_model.generate(**inputs, max_new_tokens=100) | |
descriptions.append(processor.decode(output[0], skip_special_tokens=True)) | |
return descriptions | |
CSS = """ | |
#table_col {background-color: rgb(33, 41, 54);} | |
""" | |
# def get_vectordb(text, images, tables): | |
def get_vectordb(text, images): | |
client = chromadb.EphemeralClient() | |
loader = ImageLoader() | |
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="multi-qa-mpnet-base-dot-v1" | |
) | |
if "text_db" in [i.name for i in client.list_collections()]: | |
client.delete_collection("text_db") | |
if "image_db" in [i.name for i in client.list_collections()]: | |
client.delete_collection("image_db") | |
text_collection = client.get_or_create_collection( | |
name="text_db", | |
embedding_function=sentence_transformer_ef, | |
data_loader=loader, | |
) | |
image_collection = client.get_or_create_collection( | |
name="image_db", | |
embedding_function=sentence_transformer_ef, | |
data_loader=loader, | |
metadata={"hnsw:space": "cosine"}, | |
) | |
descs = [] | |
print(descs) | |
for image in images: | |
try: | |
descs.append(get_image_description(image)[0]) | |
except: | |
descs.append("Could not generate image description due to some error") | |
# image_descriptions = get_image_descriptions(images) | |
image_dict = [{"image": image_to_bytes(img)} for img in images] | |
if len(images) > 0: | |
image_collection.add( | |
ids=[str(i) for i in range(len(images))], | |
documents=descs, | |
metadatas=image_dict, | |
) | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=10, | |
) | |
if len(text) > 0: | |
docs = splitter.create_documents([text]) | |
doc_texts = [i.page_content for i in docs] | |
text_collection.add( | |
ids=[str(i) for i in list(range(len(doc_texts)))], documents=doc_texts | |
) | |
return client | |
def extract_data_from_pdfs(docs, session, include_images, progress=gr.Progress()): | |
if len(docs) == 0: | |
raise gr.Error("No documents to process") | |
progress(0, "Extracting Images") | |
# images = extract_images(docs) | |
progress(0.25, "Extracting Text") | |
strategy = "hi_res" | |
model_name = "yolox" | |
all_elements = [] | |
all_text = "" | |
images = [] | |
for doc in docs: | |
ocrmypdf.ocr(doc, "ocr.pdf", deskew=True, skip_text=True) | |
text = extract_text("ocr.pdf") | |
all_text += clean_text(text) + "\n\n" | |
if include_images == "Include Images": | |
images.extend(extract_images(["ocr.pdf"])) | |
progress( | |
0.6, "Generating image descriptions and inserting everything into vectorDB" | |
) | |
vectordb = get_vectordb(all_text, images) | |
progress(1, "Completed") | |
session["processed"] = True | |
return ( | |
vectordb, | |
session, | |
gr.Row(visible=True), | |
all_text[:2000] + "...", | |
# display, | |
images[:2], | |
"<h1 style='text-align: center'>Completed<h1>", | |
# image_descriptions | |
) | |
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="multi-qa-mpnet-base-dot-v1" | |
) | |
def conversation( | |
vectordb_client, msg, num_context, img_context, history, hf_token, model_path | |
): | |
if hf_token.strip() != "" and model_path.strip() != "": | |
llm = HuggingFaceEndpoint( | |
repo_id=model_path, | |
temperature=0.4, | |
max_new_tokens=800, | |
huggingfacehub_api_token=hf_token, | |
) | |
else: | |
llm = HuggingFaceEndpoint( | |
repo_id="meta-llama/Meta-Llama-3-8B-Instruct", | |
temperature=0.4, | |
max_new_tokens=800, | |
huggingfacehub_api_token=os.getenv("P_HF_TOKEN", "None"), | |
) | |
text_collection = vectordb_client.get_collection( | |
"text_db", embedding_function=sentence_transformer_ef | |
) | |
image_collection = vectordb_client.get_collection( | |
"image_db", embedding_function=sentence_transformer_ef | |
) | |
results = text_collection.query( | |
query_texts=[msg], include=["documents"], n_results=num_context | |
)["documents"][0] | |
similar_images = image_collection.query( | |
query_texts=[msg], | |
include=["metadatas", "distances", "documents"], | |
n_results=img_context, | |
) | |
img_links = [i["image"] for i in similar_images["metadatas"][0]] | |
images_and_locs = [ | |
Image.open(io.BytesIO(base64.b64decode(i[1]))) | |
for i in zip(similar_images["distances"][0], img_links) | |
] | |
img_desc = "\n".join(similar_images["documents"][0]) | |
if len(img_links) == 0: | |
img_desc = "No Images Are Provided" | |
template = """ | |
Context: | |
{context} | |
Included Images: | |
{images} | |
Question: | |
{question} | |
Answer: | |
""" | |
prompt = PromptTemplate(template=template, input_variables=["context", "question"]) | |
context = "\n\n".join(results) | |
# references = [gr.Textbox(i, visible=True, interactive=False) for i in results] | |
response = llm(prompt.format(context=context, question=msg, images=img_desc)) | |
return history + [(msg, response)], results, images_and_locs | |
def check_validity_and_llm(session_states): | |
if session_states.get("processed", False) == True: | |
return gr.Tabs(selected=2) | |
raise gr.Error("Please extract data first") | |
def get_stats(vectordb): | |
eles = vectordb.get() | |
# words = | |
text_data = [f"Chunks: {len(eles)}", "HIII"] | |
return "\n".join(text_data), "", "" | |
with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo: | |
vectordb = gr.State() | |
doc_collection = gr.State(value=[]) | |
session_states = gr.State(value={}) | |
references = gr.State(value=[]) | |
gr.Markdown( | |
"""<h2><center>Multimodal PDF Chatbot</center></h2> | |
<h3><center><b>Interact With Your PDF Documents</b></center></h3>""" | |
) | |
gr.Markdown( | |
"""<center><h3><b>Note: </b> This application leverages advanced Retrieval-Augmented Generation (RAG) techniques to provide context-aware responses from your PDF documents</center><h3><br> | |
<center>Utilizing multimodal capabilities, this chatbot can interpret and answer queries based on both textual and visual information within your PDFs.</center>""" | |
) | |
gr.Markdown( | |
""" | |
<center><b>Warning: </b> Extracting text and images from your document and generating embeddings may take some time due to the use of OCR and multimodal LLMs for image description<center> | |
""" | |
) | |
with gr.Tabs() as tabs: | |
with gr.TabItem("Upload PDFs", id=0) as pdf_tab: | |
with gr.Row(): | |
with gr.Column(): | |
documents = gr.File( | |
file_count="multiple", | |
file_types=["pdf"], | |
interactive=True, | |
label="Upload your PDF file/s", | |
) | |
pdf_btn = gr.Button(value="Next", elem_id="button1") | |
with gr.TabItem("Extract Data", id=1) as preprocess: | |
with gr.Row(): | |
with gr.Column(): | |
back_p1 = gr.Button(value="Back") | |
with gr.Column(): | |
embed = gr.Button(value="Extract Data") | |
with gr.Column(): | |
next_p1 = gr.Button(value="Next") | |
with gr.Row(): | |
include_images = gr.Radio( | |
["Include Images", "Exclude Images"], | |
value="Include Images", | |
label="Include/ Exclude Images", | |
interactive=True, | |
) | |
with gr.Row(equal_height=True, variant="panel") as row: | |
selected = gr.Dataframe( | |
interactive=False, | |
col_count=(1, "fixed"), | |
headers=["Selected Files"], | |
) | |
prog = gr.HTML( | |
value="<h1 style='text-align: center'>Click the 'Extract' button to extract data from PDFs<h1>" | |
) | |
with gr.Accordion("See Parts of Extracted Data", open=False): | |
with gr.Column(visible=True) as sample_data: | |
with gr.Row(): | |
with gr.Column(): | |
ext_text = gr.Textbox( | |
label="Sample Extracted Text", lines=15 | |
) | |
with gr.Column(): | |
images = gr.Gallery( | |
label="Sample Extracted Images", columns=1, rows=2 | |
) | |
with gr.TabItem("Chat", id=2) as chat_tab: | |
with gr.Accordion("Config (Advanced) (Optional)", open=False): | |
with gr.Row(variant="panel", equal_height=True): | |
choice = gr.Radio( | |
["chromaDB"], | |
value="chromaDB", | |
label="Vector Database", | |
interactive=True, | |
) | |
with gr.Accordion("Use your own model (optional)", open=False): | |
hf_token = gr.Textbox( | |
label="HuggingFace Token", interactive=True | |
) | |
model_path = gr.Textbox(label="Model Path", interactive=True) | |
with gr.Row(variant="panel", equal_height=True): | |
num_context = gr.Slider( | |
label="Number of text context elements", | |
minimum=1, | |
maximum=20, | |
step=1, | |
interactive=True, | |
value=3, | |
) | |
img_context = gr.Slider( | |
label="Number of image context elements", | |
minimum=1, | |
maximum=10, | |
step=1, | |
interactive=True, | |
value=2, | |
) | |
with gr.Row(): | |
with gr.Column(): | |
ret_images = gr.Gallery("Similar Images", columns=1, rows=2) | |
with gr.Column(): | |
chatbot = gr.Chatbot(height=400) | |
with gr.Accordion("Text References", open=False): | |
# text_context = gr.Row() | |
def gen_refs(references): | |
# print(references) | |
n = len(references) | |
for i in range(n): | |
gr.Textbox( | |
label=f"Reference-{i+1}", value=references[i], lines=3 | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
placeholder="Type your question here (e.g. 'What is this document about?')", | |
interactive=True, | |
container=True, | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Submit message") | |
clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation") | |
pdf_btn.click( | |
fn=extract_pdfs, | |
inputs=[documents, doc_collection], | |
outputs=[doc_collection, tabs, selected], | |
) | |
embed.click( | |
extract_data_from_pdfs, | |
inputs=[doc_collection, session_states, include_images], | |
outputs=[ | |
vectordb, | |
session_states, | |
sample_data, | |
ext_text, | |
images, | |
prog, | |
], | |
) | |
submit_btn.click( | |
conversation, | |
[vectordb, msg, num_context, img_context, chatbot, hf_token, model_path], | |
[chatbot, references, ret_images], | |
) | |
msg.submit( | |
conversation, | |
[vectordb, msg, num_context, img_context, chatbot, hf_token, model_path], | |
[chatbot, references, ret_images], | |
) | |
back_p1.click(lambda: gr.Tabs(selected=0), None, tabs) | |
next_p1.click(check_validity_and_llm, session_states, tabs) | |
if __name__ == "__main__": | |
demo.launch() | |