p3nguknight commited on
Commit
ce8881a
Β·
1 Parent(s): c5aa334
Files changed (1) hide show
  1. app.py +34 -26
app.py CHANGED
@@ -22,7 +22,7 @@ PIXTAL_MODEL_ID = "mistral-community--pixtral-12b-240910"
22
  PIXTRAL_MODEL_SNAPSHOT = "95758896fcf4691ec9674f29ec90d1441d9d26d2"
23
  PIXTRAL_MODEL_PATH = (
24
  pathlib.Path().home()
25
- / f".cache/huggingface/hub/models--{PIXTAL_MODEL_ID}/{PIXTRAL_MODEL_SNAPSHOT}"
26
  )
27
 
28
 
@@ -30,13 +30,13 @@ COLPALI_GEMMA_MODEL_ID = "vidore--colpaligemma-3b-pt-448-base"
30
  COLPALI_GEMMA_MODEL_SNAPSHOT = "12c59eb7e23bc4c26876f7be7c17760d5d3a1ffa"
31
  COLPALI_GEMMA_MODEL_PATH = (
32
  pathlib.Path().home()
33
- / f".cache/huggingface/hub/models--{COLPALI_GEMMA_MODEL_ID}/{COLPALI_GEMMA_MODEL_SNAPSHOT}"
34
  )
35
  COLPALI_MODEL_ID = "vidore--colpali-v1.2"
36
  COLPALI_MODEL_SNAPSHOT = "2d54d5d3684a4f5ceeefbef95df0c94159fd6a45"
37
  COLPALI_MODEL_PATH = (
38
  pathlib.Path().home()
39
- / f".cache/huggingface/hub/models--{COLPALI_MODEL_ID}/{COLPALI_MODEL_SNAPSHOT}"
40
  )
41
 
42
 
@@ -46,11 +46,15 @@ def image_to_base64(image_path):
46
  return f"data:image/jpeg;base64,{encoded_string}"
47
 
48
 
49
- @spaces.GPU
50
- def model_inference(
51
  images,
52
  text,
53
  ):
 
 
 
 
54
  tokenizer = MistralTokenizer.from_file(f"{PIXTRAL_MODEL_PATH}/tekken.json")
55
  model = Transformer.from_folder(PIXTRAL_MODEL_PATH)
56
 
@@ -80,8 +84,13 @@ def model_inference(
80
  return result
81
 
82
 
83
- @spaces.GPU
84
- def search(query: str, ds, images, k):
 
 
 
 
 
85
  model = ColPali.from_pretrained(
86
  COLPALI_GEMMA_MODEL_PATH,
87
  torch_dtype=torch.bfloat16,
@@ -101,11 +110,11 @@ def search(query: str, ds, images, k):
101
  embeddings_query = model(**batch_query)
102
  qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
103
 
104
- scores = processor.score(qs, ds)
105
- top_k_indices = scores.argsort(axis=1)[0][-k:]
106
  results = []
107
  for idx in top_k_indices:
108
- results.append((images[idx]), f"Page {idx}")
109
  del model
110
  del processor
111
  torch.cuda.empty_cache()
@@ -127,7 +136,7 @@ def convert_files(files):
127
  return images
128
 
129
 
130
- @spaces.GPU
131
  def index_gpu(images, ds):
132
  model = ColPali.from_pretrained(
133
  COLPALI_GEMMA_MODEL_PATH,
@@ -173,8 +182,8 @@ css = """
173
  max-width: 600px;
174
  }
175
  """
176
- file = gr.File(file_types=["pdf"], file_count="multiple", label="pdfs")
177
- query = gr.Textbox(placeholder="Enter your query here", label="query")
178
 
179
  with gr.Blocks(
180
  title="Document Question Answering with ColPali & Pixtral",
@@ -201,32 +210,31 @@ with gr.Blocks(
201
  img_chunk = gr.State(value=[])
202
 
203
  with gr.Column(scale=3):
204
- gr.Markdown("## Search with ColPali")
205
  query.render()
206
  k = gr.Slider(
207
- minimum=1, maximum=4, step=1, label="Number of results", value=1
 
 
 
 
208
  )
209
- search_button = gr.Button("πŸ” Run", variant="primary")
210
 
211
  # Define the actions
212
 
213
  output_gallery = gr.Gallery(
214
- label="Retrieved Documents", height=600, show_label=True
215
  )
 
216
 
217
  convert_button.click(
218
  index, inputs=[file, embeds], outputs=[message, embeds, imgs]
219
  )
220
- search_button.click(
221
- search, inputs=[query, embeds, imgs, k], outputs=[output_gallery]
222
- )
223
-
224
- gr.Markdown("## Get your answer with Pixtral")
225
- answer_button = gr.Button("Run", variant="primary")
226
- output = gr.Markdown(label="Output")
227
  answer_button.click(
228
- model_inference, inputs=[output_gallery, query], outputs=output
229
- )
 
230
 
231
  if __name__ == "__main__":
232
  demo.queue(max_size=10).launch()
 
22
  PIXTRAL_MODEL_SNAPSHOT = "95758896fcf4691ec9674f29ec90d1441d9d26d2"
23
  PIXTRAL_MODEL_PATH = (
24
  pathlib.Path().home()
25
+ / f".cache/huggingface/hub/models--{PIXTAL_MODEL_ID}/snapshots/{PIXTRAL_MODEL_SNAPSHOT}"
26
  )
27
 
28
 
 
30
  COLPALI_GEMMA_MODEL_SNAPSHOT = "12c59eb7e23bc4c26876f7be7c17760d5d3a1ffa"
31
  COLPALI_GEMMA_MODEL_PATH = (
32
  pathlib.Path().home()
33
+ / f".cache/huggingface/hub/models--{COLPALI_GEMMA_MODEL_ID}/snapshots/{COLPALI_GEMMA_MODEL_SNAPSHOT}"
34
  )
35
  COLPALI_MODEL_ID = "vidore--colpali-v1.2"
36
  COLPALI_MODEL_SNAPSHOT = "2d54d5d3684a4f5ceeefbef95df0c94159fd6a45"
37
  COLPALI_MODEL_PATH = (
38
  pathlib.Path().home()
39
+ / f".cache/huggingface/hub/models--{COLPALI_MODEL_ID}/snapshots/{COLPALI_MODEL_SNAPSHOT}"
40
  )
41
 
42
 
 
46
  return f"data:image/jpeg;base64,{encoded_string}"
47
 
48
 
49
+ @spaces.GPU(duration=30)
50
+ def pixtral_inference(
51
  images,
52
  text,
53
  ):
54
+ if len(images) == 0:
55
+ raise gr.Error("No images for generation")
56
+ if text == "":
57
+ raise gr.Error("No query for generation")
58
  tokenizer = MistralTokenizer.from_file(f"{PIXTRAL_MODEL_PATH}/tekken.json")
59
  model = Transformer.from_folder(PIXTRAL_MODEL_PATH)
60
 
 
84
  return result
85
 
86
 
87
+ @spaces.GPU(duration=30)
88
+ def retrieve(query: str, ds, images, k):
89
+ if len(images) == 0:
90
+ raise gr.Error("No docs/images for retrieval")
91
+ if query == "":
92
+ raise gr.Error("No query for retrieval")
93
+
94
  model = ColPali.from_pretrained(
95
  COLPALI_GEMMA_MODEL_PATH,
96
  torch_dtype=torch.bfloat16,
 
110
  embeddings_query = model(**batch_query)
111
  qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
112
 
113
+ scores = processor.score(qs, ds).numpy()
114
+ top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
115
  results = []
116
  for idx in top_k_indices:
117
+ results.append((images[idx], f"Page {idx}, Score {scores[0][idx]:.2f}"))
118
  del model
119
  del processor
120
  torch.cuda.empty_cache()
 
136
  return images
137
 
138
 
139
+ @spaces.GPU(duration=30)
140
  def index_gpu(images, ds):
141
  model = ColPali.from_pretrained(
142
  COLPALI_GEMMA_MODEL_PATH,
 
182
  max-width: 600px;
183
  }
184
  """
185
+ file = gr.File(file_types=["pdf"], file_count="multiple", label="Pdfs")
186
+ query = gr.Textbox("", placeholder="Enter your query here", label="Query")
187
 
188
  with gr.Blocks(
189
  title="Document Question Answering with ColPali & Pixtral",
 
210
  img_chunk = gr.State(value=[])
211
 
212
  with gr.Column(scale=3):
213
+ gr.Markdown("## Retrieve with ColPali and Answer with Pixtral")
214
  query.render()
215
  k = gr.Slider(
216
+ minimum=1,
217
+ maximum=4,
218
+ step=1,
219
+ label="Number of docs to retrieve",
220
+ value=1,
221
  )
222
+ answer_button = gr.Button("πŸƒ Run", variant="primary")
223
 
224
  # Define the actions
225
 
226
  output_gallery = gr.Gallery(
227
+ label="Retrieved docs", height=400, show_label=True, interactive=False
228
  )
229
+ output = gr.Textbox(label="Answer", lines=2, interactive=False)
230
 
231
  convert_button.click(
232
  index, inputs=[file, embeds], outputs=[message, embeds, imgs]
233
  )
 
 
 
 
 
 
 
234
  answer_button.click(
235
+ retrieve, inputs=[query, embeds, imgs, k], outputs=[output_gallery]
236
+ ).then(pixtral_inference, inputs=[output_gallery, query], outputs=[output])
237
+
238
 
239
  if __name__ == "__main__":
240
  demo.queue(max_size=10).launch()