ColFlor-Demo / app.py
ahmed-masry's picture
Update app.py
8ee10ab verified
import os
import spaces
import gradio as gr
import torch
from modeling_colflor import ColFlor
from processing_colflor import ColFlorProcessor
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import (
process_images,
process_queries,
)
from pdf2image import convert_from_path
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
# Load model
model_name = "ahmed-masry/ColFlor"
token = os.environ.get("HF_TOKEN")
model = ColFlor.from_pretrained(
model_name, device_map="cuda", token = token).eval()
processor = ColFlorProcessor.from_pretrained(model_name, token = token)
mock_image = Image.new("RGB", (768, 768), (255, 255, 255))
@spaces.GPU
def search(query: str, ds, images, k):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
qs = []
with torch.no_grad():
batch_query = processor.process_queries([query])
batch_query = {k: v.to(device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
scores = retriever_evaluator.evaluate(qs, ds)
top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
results = []
for idx in top_k_indices:
results.append((images[idx], f"Page {idx}"))
return results
def index(files, ds):
print("Converting files")
images = convert_files(files)
print(f"Files converted with {len(images)} images.")
return index_gpu(images, ds)
def convert_files(files):
images = []
for f in files:
images.extend(convert_from_path(f, thread_count=4))
if len(images) >= 150:
raise gr.Error("The number of images in the dataset should be less than 150.")
return images
@spaces.GPU
def index_gpu(images, ds):
"""Example script to run inference with ColPali"""
# run inference - docs
dataloader = DataLoader(
images,
batch_size=4,
shuffle=False,
collate_fn=processor.process_images,
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return f"Uploaded and converted {len(images)} pages", ds, images
def get_example():
return [[["climate_youth_magazine.pdf"], "How much tropical forest is cut annually ?"]]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# ColFlor: Towards BERT-Size Vision-Language Document Retrieval Models")
gr.Markdown("""Demo to test ColFlor on PDF documents. This space is adapted from [ColPali Demo Space](https://huggingface.co/spaces/manu/ColPali-demo)
For more details about ColFlor, please refer to our blogpost (https://huggingface.co/blog/ahmed-masry/colflor).
This demo allows you to upload PDF files and search for the most relevant pages based on your query.
Refresh the page if you change documents !
⚠️ This model performs best on English documents, and does not generalize well to other languages.
""")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("## 1️⃣ Upload PDFs")
file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs")
convert_button = gr.Button("πŸ”„ Index documents")
message = gr.Textbox("Files not yet uploaded", label="Status")
embeds = gr.State(value=[])
imgs = gr.State(value=[])
with gr.Column(scale=3):
gr.Markdown("## 2️⃣ Search")
query = gr.Textbox(placeholder="Enter your query here", label="Query")
k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
# Define the actions
search_button = gr.Button("πŸ” Search", variant="primary")
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery])
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True)