yonigozlan HF staff commited on
Commit
52a6b6f
Β·
1 Parent(s): 4b4739f

initial commit (v1.2)

Browse files
Files changed (4) hide show
  1. README.md +2 -2
  2. app.py +141 -0
  3. packages.txt +1 -0
  4. requirements.txt +3 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: ColPali Transformers
3
- emoji: πŸŒ–
4
  colorFrom: yellow
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.9.1
8
  app_file: app.py
 
1
  ---
2
  title: ColPali Transformers
3
+ emoji: πŸ“š
4
  colorFrom: yellow
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.9.1
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import spaces
5
+ import torch
6
+ from pdf2image import convert_from_path
7
+ from torch.utils.data import DataLoader
8
+ from tqdm import tqdm
9
+
10
+ from transformers import ColPaliForRetrieval, ColPaliProcessor
11
+
12
+
13
+ @spaces.GPU
14
+ def install_fa2():
15
+ print("Install FA2")
16
+ os.system("pip install flash-attn --no-build-isolation")
17
+
18
+
19
+ # install_fa2()
20
+
21
+ model_name = "vidore/colpali-v1.2-hf"
22
+
23
+
24
+ model = ColPaliForRetrieval.from_pretrained(
25
+ model_name,
26
+ torch_dtype=torch.bfloat16,
27
+ device_map="cuda:0", # or "mps" if on Apple Silicon
28
+ # attn_implementation="flash_attention_2", # should work on A100
29
+ ).eval()
30
+ processor = ColPaliProcessor.from_pretrained(model_name)
31
+
32
+
33
+ @spaces.GPU
34
+ def search(query: str, ds, images, k):
35
+ k = min(k, len(ds))
36
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
37
+ if device != model.device:
38
+ model.to(device)
39
+
40
+ qs = []
41
+ with torch.no_grad():
42
+ batch_query = processor(text=[query]).to(model.device)
43
+ query_embeddings = model(**batch_query).embeddings
44
+ qs.extend(list(torch.unbind(query_embeddings.to("cpu"))))
45
+
46
+ scores = processor.score_retrieval(qs, ds)
47
+
48
+ top_k_indices = scores[0].topk(k).indices.tolist()
49
+
50
+ results = []
51
+ for idx in top_k_indices:
52
+ results.append((images[idx], f"Page {idx}"))
53
+
54
+ return results
55
+
56
+
57
+ def index(files, ds):
58
+ print("Converting files")
59
+ images = convert_files(files)
60
+ print(f"Files converted with {len(images)} images.")
61
+ return index_gpu(images, ds)
62
+
63
+
64
+ def convert_files(files):
65
+ images = []
66
+ for f in files:
67
+ images.extend(convert_from_path(f, thread_count=4))
68
+
69
+ if len(images) >= 150:
70
+ raise gr.Error("The number of images in the dataset should be less than 150.")
71
+ return images
72
+
73
+
74
+ @spaces.GPU
75
+ def index_gpu(images, ds):
76
+ """Example script to run inference with ColPali"""
77
+
78
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
79
+ if device != model.device:
80
+ model.to(device)
81
+
82
+ # run inference - docs
83
+ dataloader = DataLoader(
84
+ images,
85
+ batch_size=4,
86
+ shuffle=False,
87
+ collate_fn=lambda x: processor(images=x).to(model.device),
88
+ )
89
+
90
+ for batch_doc in tqdm(dataloader):
91
+ with torch.no_grad():
92
+ batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
93
+ embeddings_doc = model(**batch_doc).embeddings
94
+ ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
95
+ return f"Uploaded and converted {len(images)} pages", ds, images
96
+
97
+
98
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
99
+ gr.Markdown(
100
+ "# ColPali: Efficient Document Retrieval with Vision Language Models πŸ“š"
101
+ )
102
+ gr.Markdown("""Demo to test the Transformers πŸ€— implementation of ColPali on PDF documents.<br>
103
+ ColPali is the model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).<br>
104
+ This demo allows you to upload PDF files and search for the most relevant pages based on your query.
105
+ Refresh the page if you change documents!<br>
106
+ ⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing english text. Performance is expected to drop for other page formats and languages.<br>
107
+ Other models will be released with better robustness towards different languages and document formats!
108
+ Demo by [manu](https://huggingface.co/spaces/manu/ColPali-demo)
109
+ """)
110
+ with gr.Row():
111
+ with gr.Column(scale=2):
112
+ gr.Markdown("## 1️⃣ Upload PDFs")
113
+ file = gr.File(
114
+ file_types=["pdf"], file_count="multiple", label="Upload PDFs"
115
+ )
116
+
117
+ convert_button = gr.Button("πŸ”„ Index documents")
118
+ message = gr.Textbox("Files not yet uploaded", label="Status")
119
+ embeds = gr.State(value=[])
120
+ imgs = gr.State(value=[])
121
+
122
+ with gr.Column(scale=3):
123
+ gr.Markdown("## 2️⃣ Search")
124
+ query = gr.Textbox(placeholder="Enter your query here", label="Query")
125
+ k = gr.Slider(
126
+ minimum=1, maximum=10, step=1, label="Number of results", value=5
127
+ )
128
+
129
+ # Define the actions
130
+ search_button = gr.Button("πŸ” Search", variant="primary")
131
+ output_gallery = gr.Gallery(
132
+ label="Retrieved Documents", height=600, show_label=True
133
+ )
134
+
135
+ convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
136
+ search_button.click(
137
+ search, inputs=[query, embeds, imgs, k], outputs=[output_gallery]
138
+ )
139
+
140
+ if __name__ == "__main__":
141
+ demo.queue(max_size=10).launch(debug=True)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ poppler-utils
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ git+https://github.com/huggingface/transformers.git@main
3
+ pdf2image