Spaces:
Running
on
Zero
Running
on
Zero
Commit
Β·
ce8881a
1
Parent(s):
c5aa334
Fix bugs
Browse files
app.py
CHANGED
@@ -22,7 +22,7 @@ PIXTAL_MODEL_ID = "mistral-community--pixtral-12b-240910"
|
|
22 |
PIXTRAL_MODEL_SNAPSHOT = "95758896fcf4691ec9674f29ec90d1441d9d26d2"
|
23 |
PIXTRAL_MODEL_PATH = (
|
24 |
pathlib.Path().home()
|
25 |
-
/ f".cache/huggingface/hub/models--{PIXTAL_MODEL_ID}/{PIXTRAL_MODEL_SNAPSHOT}"
|
26 |
)
|
27 |
|
28 |
|
@@ -30,13 +30,13 @@ COLPALI_GEMMA_MODEL_ID = "vidore--colpaligemma-3b-pt-448-base"
|
|
30 |
COLPALI_GEMMA_MODEL_SNAPSHOT = "12c59eb7e23bc4c26876f7be7c17760d5d3a1ffa"
|
31 |
COLPALI_GEMMA_MODEL_PATH = (
|
32 |
pathlib.Path().home()
|
33 |
-
/ f".cache/huggingface/hub/models--{COLPALI_GEMMA_MODEL_ID}/{COLPALI_GEMMA_MODEL_SNAPSHOT}"
|
34 |
)
|
35 |
COLPALI_MODEL_ID = "vidore--colpali-v1.2"
|
36 |
COLPALI_MODEL_SNAPSHOT = "2d54d5d3684a4f5ceeefbef95df0c94159fd6a45"
|
37 |
COLPALI_MODEL_PATH = (
|
38 |
pathlib.Path().home()
|
39 |
-
/ f".cache/huggingface/hub/models--{COLPALI_MODEL_ID}/{COLPALI_MODEL_SNAPSHOT}"
|
40 |
)
|
41 |
|
42 |
|
@@ -46,11 +46,15 @@ def image_to_base64(image_path):
|
|
46 |
return f"data:image/jpeg;base64,{encoded_string}"
|
47 |
|
48 |
|
49 |
-
@spaces.GPU
|
50 |
-
def
|
51 |
images,
|
52 |
text,
|
53 |
):
|
|
|
|
|
|
|
|
|
54 |
tokenizer = MistralTokenizer.from_file(f"{PIXTRAL_MODEL_PATH}/tekken.json")
|
55 |
model = Transformer.from_folder(PIXTRAL_MODEL_PATH)
|
56 |
|
@@ -80,8 +84,13 @@ def model_inference(
|
|
80 |
return result
|
81 |
|
82 |
|
83 |
-
@spaces.GPU
|
84 |
-
def
|
|
|
|
|
|
|
|
|
|
|
85 |
model = ColPali.from_pretrained(
|
86 |
COLPALI_GEMMA_MODEL_PATH,
|
87 |
torch_dtype=torch.bfloat16,
|
@@ -101,11 +110,11 @@ def search(query: str, ds, images, k):
|
|
101 |
embeddings_query = model(**batch_query)
|
102 |
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
103 |
|
104 |
-
scores = processor.score(qs, ds)
|
105 |
-
top_k_indices = scores.argsort(axis=1)[0][-k:]
|
106 |
results = []
|
107 |
for idx in top_k_indices:
|
108 |
-
results.append((images[idx]
|
109 |
del model
|
110 |
del processor
|
111 |
torch.cuda.empty_cache()
|
@@ -127,7 +136,7 @@ def convert_files(files):
|
|
127 |
return images
|
128 |
|
129 |
|
130 |
-
@spaces.GPU
|
131 |
def index_gpu(images, ds):
|
132 |
model = ColPali.from_pretrained(
|
133 |
COLPALI_GEMMA_MODEL_PATH,
|
@@ -173,8 +182,8 @@ css = """
|
|
173 |
max-width: 600px;
|
174 |
}
|
175 |
"""
|
176 |
-
file = gr.File(file_types=["pdf"], file_count="multiple", label="
|
177 |
-
query = gr.Textbox(placeholder="Enter your query here", label="
|
178 |
|
179 |
with gr.Blocks(
|
180 |
title="Document Question Answering with ColPali & Pixtral",
|
@@ -201,32 +210,31 @@ with gr.Blocks(
|
|
201 |
img_chunk = gr.State(value=[])
|
202 |
|
203 |
with gr.Column(scale=3):
|
204 |
-
gr.Markdown("##
|
205 |
query.render()
|
206 |
k = gr.Slider(
|
207 |
-
minimum=1,
|
|
|
|
|
|
|
|
|
208 |
)
|
209 |
-
|
210 |
|
211 |
# Define the actions
|
212 |
|
213 |
output_gallery = gr.Gallery(
|
214 |
-
label="Retrieved
|
215 |
)
|
|
|
216 |
|
217 |
convert_button.click(
|
218 |
index, inputs=[file, embeds], outputs=[message, embeds, imgs]
|
219 |
)
|
220 |
-
search_button.click(
|
221 |
-
search, inputs=[query, embeds, imgs, k], outputs=[output_gallery]
|
222 |
-
)
|
223 |
-
|
224 |
-
gr.Markdown("## Get your answer with Pixtral")
|
225 |
-
answer_button = gr.Button("Run", variant="primary")
|
226 |
-
output = gr.Markdown(label="Output")
|
227 |
answer_button.click(
|
228 |
-
|
229 |
-
)
|
|
|
230 |
|
231 |
if __name__ == "__main__":
|
232 |
demo.queue(max_size=10).launch()
|
|
|
22 |
PIXTRAL_MODEL_SNAPSHOT = "95758896fcf4691ec9674f29ec90d1441d9d26d2"
|
23 |
PIXTRAL_MODEL_PATH = (
|
24 |
pathlib.Path().home()
|
25 |
+
/ f".cache/huggingface/hub/models--{PIXTAL_MODEL_ID}/snapshots/{PIXTRAL_MODEL_SNAPSHOT}"
|
26 |
)
|
27 |
|
28 |
|
|
|
30 |
COLPALI_GEMMA_MODEL_SNAPSHOT = "12c59eb7e23bc4c26876f7be7c17760d5d3a1ffa"
|
31 |
COLPALI_GEMMA_MODEL_PATH = (
|
32 |
pathlib.Path().home()
|
33 |
+
/ f".cache/huggingface/hub/models--{COLPALI_GEMMA_MODEL_ID}/snapshots/{COLPALI_GEMMA_MODEL_SNAPSHOT}"
|
34 |
)
|
35 |
COLPALI_MODEL_ID = "vidore--colpali-v1.2"
|
36 |
COLPALI_MODEL_SNAPSHOT = "2d54d5d3684a4f5ceeefbef95df0c94159fd6a45"
|
37 |
COLPALI_MODEL_PATH = (
|
38 |
pathlib.Path().home()
|
39 |
+
/ f".cache/huggingface/hub/models--{COLPALI_MODEL_ID}/snapshots/{COLPALI_MODEL_SNAPSHOT}"
|
40 |
)
|
41 |
|
42 |
|
|
|
46 |
return f"data:image/jpeg;base64,{encoded_string}"
|
47 |
|
48 |
|
49 |
+
@spaces.GPU(duration=30)
|
50 |
+
def pixtral_inference(
|
51 |
images,
|
52 |
text,
|
53 |
):
|
54 |
+
if len(images) == 0:
|
55 |
+
raise gr.Error("No images for generation")
|
56 |
+
if text == "":
|
57 |
+
raise gr.Error("No query for generation")
|
58 |
tokenizer = MistralTokenizer.from_file(f"{PIXTRAL_MODEL_PATH}/tekken.json")
|
59 |
model = Transformer.from_folder(PIXTRAL_MODEL_PATH)
|
60 |
|
|
|
84 |
return result
|
85 |
|
86 |
|
87 |
+
@spaces.GPU(duration=30)
|
88 |
+
def retrieve(query: str, ds, images, k):
|
89 |
+
if len(images) == 0:
|
90 |
+
raise gr.Error("No docs/images for retrieval")
|
91 |
+
if query == "":
|
92 |
+
raise gr.Error("No query for retrieval")
|
93 |
+
|
94 |
model = ColPali.from_pretrained(
|
95 |
COLPALI_GEMMA_MODEL_PATH,
|
96 |
torch_dtype=torch.bfloat16,
|
|
|
110 |
embeddings_query = model(**batch_query)
|
111 |
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
112 |
|
113 |
+
scores = processor.score(qs, ds).numpy()
|
114 |
+
top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
|
115 |
results = []
|
116 |
for idx in top_k_indices:
|
117 |
+
results.append((images[idx], f"Page {idx}, Score {scores[0][idx]:.2f}"))
|
118 |
del model
|
119 |
del processor
|
120 |
torch.cuda.empty_cache()
|
|
|
136 |
return images
|
137 |
|
138 |
|
139 |
+
@spaces.GPU(duration=30)
|
140 |
def index_gpu(images, ds):
|
141 |
model = ColPali.from_pretrained(
|
142 |
COLPALI_GEMMA_MODEL_PATH,
|
|
|
182 |
max-width: 600px;
|
183 |
}
|
184 |
"""
|
185 |
+
file = gr.File(file_types=["pdf"], file_count="multiple", label="Pdfs")
|
186 |
+
query = gr.Textbox("", placeholder="Enter your query here", label="Query")
|
187 |
|
188 |
with gr.Blocks(
|
189 |
title="Document Question Answering with ColPali & Pixtral",
|
|
|
210 |
img_chunk = gr.State(value=[])
|
211 |
|
212 |
with gr.Column(scale=3):
|
213 |
+
gr.Markdown("## Retrieve with ColPali and Answer with Pixtral")
|
214 |
query.render()
|
215 |
k = gr.Slider(
|
216 |
+
minimum=1,
|
217 |
+
maximum=4,
|
218 |
+
step=1,
|
219 |
+
label="Number of docs to retrieve",
|
220 |
+
value=1,
|
221 |
)
|
222 |
+
answer_button = gr.Button("π Run", variant="primary")
|
223 |
|
224 |
# Define the actions
|
225 |
|
226 |
output_gallery = gr.Gallery(
|
227 |
+
label="Retrieved docs", height=400, show_label=True, interactive=False
|
228 |
)
|
229 |
+
output = gr.Textbox(label="Answer", lines=2, interactive=False)
|
230 |
|
231 |
convert_button.click(
|
232 |
index, inputs=[file, embeds], outputs=[message, embeds, imgs]
|
233 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
answer_button.click(
|
235 |
+
retrieve, inputs=[query, embeds, imgs, k], outputs=[output_gallery]
|
236 |
+
).then(pixtral_inference, inputs=[output_gallery, query], outputs=[output])
|
237 |
+
|
238 |
|
239 |
if __name__ == "__main__":
|
240 |
demo.queue(max_size=10).launch()
|