Spaces:
Running
on
Zero
Running
on
Zero
AdrienB134
commited on
Commit
·
79fd59c
1
Parent(s):
cc33a9b
feazrgf
Browse files
app.py
CHANGED
@@ -13,7 +13,83 @@ from pdf2image import convert_from_path
|
|
13 |
from PIL import Image
|
14 |
from torch.utils.data import DataLoader
|
15 |
from tqdm import tqdm
|
16 |
-
from transformers import AutoProcessor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# Load model
|
19 |
model_name = "vidore/colpali-v1.2"
|
@@ -96,7 +172,7 @@ def index_gpu(images, ds):
|
|
96 |
embeddings_doc = model(**batch_doc)
|
97 |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
98 |
return f"Uploaded and converted {len(images)} pages", ds, images
|
99 |
-
|
100 |
@spaces.GPU
|
101 |
def answer_gpu():
|
102 |
return 0
|
@@ -116,6 +192,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
116 |
message = gr.Textbox("Files not yet uploaded", label="Status")
|
117 |
embeds = gr.State(value=[])
|
118 |
imgs = gr.State(value=[])
|
|
|
119 |
|
120 |
with gr.Column(scale=3):
|
121 |
gr.Markdown("## 2️⃣ Search")
|
@@ -133,10 +210,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
133 |
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
|
134 |
|
135 |
convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
|
136 |
-
search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery])
|
137 |
|
138 |
answer_button = gr.Button("Answer", variant="primary")
|
139 |
-
|
|
|
140 |
|
141 |
if __name__ == "__main__":
|
142 |
demo.queue(max_size=10).launch(debug=True)
|
|
|
13 |
from PIL import Image
|
14 |
from torch.utils.data import DataLoader
|
15 |
from tqdm import tqdm
|
16 |
+
from transformers import AutoProcessor, Idefics3ForConditionalGeneration
|
17 |
+
import re
|
18 |
+
import time
|
19 |
+
from PIL import Image
|
20 |
+
import torch
|
21 |
+
import subprocess
|
22 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
23 |
+
|
24 |
+
|
25 |
+
## Load idefics
|
26 |
+
id_processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
|
27 |
+
|
28 |
+
id_model = Idefics3ForConditionalGeneration.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3",
|
29 |
+
torch_dtype=torch.bfloat16,
|
30 |
+
#_attn_implementation="flash_attention_2"
|
31 |
+
).to("cuda")
|
32 |
+
|
33 |
+
BAD_WORDS_IDS = processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
|
34 |
+
EOS_WORDS_IDS = [processor.tokenizer.eos_token_id]
|
35 |
+
|
36 |
+
@spaces.GPU
|
37 |
+
def model_inference(
|
38 |
+
images, text, assistant_prefix= None, decoding_strategy = "Greedy", temperature= 0.4, max_new_tokens=512,
|
39 |
+
repetition_penalty=1.2, top_p=0.8
|
40 |
+
):
|
41 |
+
if text == "" and not images:
|
42 |
+
gr.Error("Please input a query and optionally image(s).")
|
43 |
+
|
44 |
+
if text == "" and images:
|
45 |
+
gr.Error("Please input a text query along the image(s).")
|
46 |
+
|
47 |
+
if isinstance(images, Image.Image):
|
48 |
+
images = [images]
|
49 |
+
|
50 |
+
|
51 |
+
resulting_messages = [
|
52 |
+
{
|
53 |
+
"role": "user",
|
54 |
+
"content": [{"type": "image"}] + [
|
55 |
+
{"type": "text", "text": text}
|
56 |
+
]
|
57 |
+
}
|
58 |
+
]
|
59 |
+
|
60 |
+
if assistant_prefix:
|
61 |
+
text = f"{assistant_prefix} {text}"
|
62 |
+
|
63 |
+
|
64 |
+
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
|
65 |
+
inputs = processor(text=prompt, images=[images], return_tensors="pt")
|
66 |
+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
67 |
+
|
68 |
+
generation_args = {
|
69 |
+
"max_new_tokens": max_new_tokens,
|
70 |
+
"repetition_penalty": repetition_penalty,
|
71 |
+
|
72 |
+
}
|
73 |
+
|
74 |
+
assert decoding_strategy in [
|
75 |
+
"Greedy",
|
76 |
+
"Top P Sampling",
|
77 |
+
]
|
78 |
+
if decoding_strategy == "Greedy":
|
79 |
+
generation_args["do_sample"] = False
|
80 |
+
elif decoding_strategy == "Top P Sampling":
|
81 |
+
generation_args["temperature"] = temperature
|
82 |
+
generation_args["do_sample"] = True
|
83 |
+
generation_args["top_p"] = top_p
|
84 |
+
|
85 |
+
|
86 |
+
generation_args.update(inputs)
|
87 |
+
|
88 |
+
# Generate
|
89 |
+
generated_ids = model.generate(**generation_args)
|
90 |
+
|
91 |
+
generated_texts = processor.batch_decode(generated_ids[:, generation_args["input_ids"].size(1):], skip_special_tokens=True)
|
92 |
+
return generated_texts[0]
|
93 |
|
94 |
# Load model
|
95 |
model_name = "vidore/colpali-v1.2"
|
|
|
172 |
embeddings_doc = model(**batch_doc)
|
173 |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
174 |
return f"Uploaded and converted {len(images)} pages", ds, images
|
175 |
+
|
176 |
@spaces.GPU
|
177 |
def answer_gpu():
|
178 |
return 0
|
|
|
192 |
message = gr.Textbox("Files not yet uploaded", label="Status")
|
193 |
embeds = gr.State(value=[])
|
194 |
imgs = gr.State(value=[])
|
195 |
+
img_chunk = gr.State(value=[])
|
196 |
|
197 |
with gr.Column(scale=3):
|
198 |
gr.Markdown("## 2️⃣ Search")
|
|
|
210 |
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
|
211 |
|
212 |
convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
|
213 |
+
search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery, img_chunk])
|
214 |
|
215 |
answer_button = gr.Button("Answer", variant="primary")
|
216 |
+
output = gr.Textbox(label="Output")
|
217 |
+
answer_button.click(model_inference, inputs=[img_chunk, query], outputs=output)
|
218 |
|
219 |
if __name__ == "__main__":
|
220 |
demo.queue(max_size=10).launch(debug=True)
|