File size: 4,614 Bytes
0201468
 
 
 
 
3679015
 
0201468
 
 
 
 
 
 
 
 
 
 
 
82cf8b3
0201468
82cf8b3
851a649
0201468
82cf8b3
0201468
82cf8b3
0201468
 
 
 
 
 
 
 
 
 
 
21cc2de
0201468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21cc2de
0201468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ee10ab
 
0201468
8ee10ab
0201468
 
 
 
8ee10ab
0201468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import spaces

import gradio as gr
import torch
from modeling_colflor import ColFlor
from processing_colflor import ColFlorProcessor
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import (
    process_images,
    process_queries,
)
from pdf2image import convert_from_path
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor

# Load model
model_name = "ahmed-masry/ColFlor"
token = os.environ.get("HF_TOKEN")
model = ColFlor.from_pretrained(
    model_name, device_map="cuda", token = token).eval()

processor = ColFlorProcessor.from_pretrained(model_name, token = token)

mock_image = Image.new("RGB", (768, 768), (255, 255, 255))


@spaces.GPU
def search(query: str, ds, images, k):

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    if device != model.device:
        model.to(device)
        
    qs = []
    with torch.no_grad():
        batch_query = processor.process_queries([query])
        batch_query = {k: v.to(device) for k, v in batch_query.items()}
        embeddings_query = model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

    retriever_evaluator = CustomEvaluator(is_multi_vector=True)
    scores = retriever_evaluator.evaluate(qs, ds)

    top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]

    results = []
    for idx in top_k_indices:
        results.append((images[idx], f"Page {idx}"))

    return results


def index(files, ds):
    print("Converting files")
    images = convert_files(files)
    print(f"Files converted with {len(images)} images.")
    return index_gpu(images, ds)
    


def convert_files(files):
    images = []
    for f in files:
        images.extend(convert_from_path(f, thread_count=4))

    if len(images) >= 150:
        raise gr.Error("The number of images in the dataset should be less than 150.")
    return images


@spaces.GPU
def index_gpu(images, ds):
    """Example script to run inference with ColPali"""
    
    # run inference - docs
    dataloader = DataLoader(
        images,
        batch_size=4,
        shuffle=False,
        collate_fn=processor.process_images,
    )

    
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    if device != model.device:
        model.to(device)
        
          
    for batch_doc in tqdm(dataloader):
        with torch.no_grad():
            batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
            embeddings_doc = model(**batch_doc)
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
    return f"Uploaded and converted {len(images)} pages", ds, images


def get_example():
    return [[["climate_youth_magazine.pdf"], "How much tropical forest is cut annually ?"]]

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# ColFlor: Towards BERT-Size Vision-Language Document Retrieval Models")
    gr.Markdown("""Demo to test ColFlor on PDF documents. This space is adapted from [ColPali Demo Space](https://huggingface.co/spaces/manu/ColPali-demo)

    For more details about ColFlor, please refer to our blogpost (https://huggingface.co/blog/ahmed-masry/colflor). 

    This demo allows you to upload PDF files and search for the most relevant pages based on your query.
    Refresh the page if you change documents !

    ⚠️ This model performs best on English documents, and does not generalize well to other languages.  
    """)
    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("## 1️⃣ Upload PDFs")
            file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs")

            convert_button = gr.Button("🔄 Index documents")
            message = gr.Textbox("Files not yet uploaded", label="Status")
            embeds = gr.State(value=[])
            imgs = gr.State(value=[])

        with gr.Column(scale=3):
            gr.Markdown("## 2️⃣ Search")
            query = gr.Textbox(placeholder="Enter your query here", label="Query")
            k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)

    # Define the actions
    search_button = gr.Button("🔍 Search", variant="primary")
    output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)

    convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
    search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery])

if __name__ == "__main__":
    demo.queue(max_size=10).launch(debug=True)