Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import spaces | |
import gradio as gr | |
import torch | |
from pdf2image import convert_from_path | |
from PIL import Image | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from colpali_engine.models import ColQwen2, ColQwen2Processor | |
def install_fa2(): | |
print("Install FA2") | |
os.system("pip install flash-attn --no-build-isolation") | |
# install_fa2() | |
model = ColQwen2.from_pretrained( | |
"vidore/colqwen2-v1.0", | |
torch_dtype=torch.bfloat16, | |
device_map="cuda:0", # or "mps" if on Apple Silicon | |
# attn_implementation="flash_attention_2", # should work on A100 | |
).eval() | |
processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0") | |
def encode_image_to_base64(image): | |
"""Encodes a PIL image to a base64 string.""" | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
def query_gpt4o_mini(query, images): | |
"""Calls OpenAI's GPT-4o-mini with the query and image data.""" | |
from openai import OpenAI | |
base64_images = [encode_image_to_base64(image) for image in images] | |
client = OpenAI(api_key=os.env.get("OPENAI_KEY")) | |
PROMPT = """ | |
You are a smart assistant designed to answer questions about a PDF document. | |
You are given relevant information in the form of PDF pages. Use them to construct a response to the question, and cite your sources. | |
If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents. | |
Give detailed and extensive answers, only containing info in the pages you are given. | |
Answer in the same language as the query. | |
Query: {query} | |
PDF pages: | |
""" | |
response = client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": PROMPT.format(query=query) | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{base64_images[0]}" | |
}, | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{base64_images[1]}" | |
}, | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{base64_images[2]}" | |
}, | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{base64_images[3]}" | |
}, | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{base64_images[4]}" | |
}, | |
}, | |
], | |
} | |
], | |
max_tokens=500, | |
) | |
return response.choices[0].message.content | |
def search(query: str, ds, images, k): | |
k = min(k, len(ds)) | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
if device != model.device: | |
model.to(device) | |
qs = [] | |
with torch.no_grad(): | |
batch_query = processor.process_queries([query]).to(model.device) | |
embeddings_query = model(**batch_query) | |
qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) | |
scores = processor.score(qs, ds, device=device) | |
top_k_indices = scores[0].topk(k).indices.tolist() | |
results = [] | |
for idx in top_k_indices: | |
results.append((images[idx], f"Page {idx}")) | |
# Generate response from GPT-4o-mini | |
ai_response = query_gpt4o_mini(query, results) | |
return results, ai_response | |
def index(files, ds): | |
print("Converting files") | |
images = convert_files(files) | |
print(f"Files converted with {len(images)} images.") | |
return index_gpu(images, ds) | |
def convert_files(files): | |
images = [] | |
for f in files: | |
images.extend(convert_from_path(f, thread_count=4)) | |
if len(images) >= 150: | |
raise gr.Error("The number of images in the dataset should be less than 150.") | |
return images | |
def index_gpu(images, ds): | |
"""Example script to run inference with ColPali (ColQwen2)""" | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
if device != model.device: | |
model.to(device) | |
# run inference - docs | |
dataloader = DataLoader( | |
images, | |
batch_size=4, | |
shuffle=False, | |
collate_fn=lambda x: processor.process_images(x).to(model.device), | |
) | |
for batch_doc in tqdm(dataloader): | |
with torch.no_grad(): | |
batch_doc = {k: v.to(device) for k, v in batch_doc.items()} | |
embeddings_doc = model(**batch_doc) | |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) | |
return f"Uploaded and converted {len(images)} pages", ds, images | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) π") | |
gr.Markdown("""Demo to test ColQwen2 (ColPali) on PDF documents. | |
ColPali is model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449). | |
This demo allows you to upload PDF files and search for the most relevant pages based on your query. | |
Refresh the page if you change documents ! | |
β οΈ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing english text. Performance is expected to drop for other page formats and languages. | |
Other models will be released with better robustness towards different languages and document formats ! | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gr.Markdown("## 1οΈβ£ Upload PDFs") | |
file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs") | |
convert_button = gr.Button("π Index documents") | |
message = gr.Textbox("Files not yet uploaded", label="Status") | |
embeds = gr.State(value=[]) | |
imgs = gr.State(value=[]) | |
with gr.Column(scale=3): | |
gr.Markdown("## 2οΈβ£ Search") | |
query = gr.Textbox(placeholder="Enter your query here", label="Query") | |
k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5) | |
# Define the actions | |
search_button = gr.Button("π Search", variant="primary") | |
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True) | |
output_text = gr.Textbox(label="AI Response", placeholder="Generated response based on retrieved documents") | |
convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs]) | |
search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery, output_text]) | |
if __name__ == "__main__": | |
demo.queue(max_size=10).launch(debug=True) |