p3nguknight commited on
Commit
18fd83d
Β·
0 Parent(s):

Initial commit

Browse files
Files changed (6) hide show
  1. .gitattributes +35 -0
  2. README.md +20 -0
  3. app.py +242 -0
  4. packages.txt +1 -0
  5. plants_and_people.pdf +0 -0
  6. requirements.txt +7 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ColQwen & Pixtral
3
+ short_description: Document Question Answering with ColQwen & Pixtral
4
+ emoji: πŸ‘€
5
+ colorFrom: purple
6
+ colorTo: blue
7
+ sdk: gradio
8
+ sdk_version: 4.44.0
9
+ app_file: app.py
10
+ pinned: false
11
+ license: apache-2.0
12
+ models:
13
+ - vidore--colqwen2-base
14
+ - vidore/colqwen2-v0.1
15
+ - mistral-community/pixtral-12b-240910
16
+ preload_from_hub:
17
+ - vidore/colqwen2-base added_tokens.json,chat_template.json,config.json,generation_config.json,merges.txt,model-00001-of-00002.safetensors,model-00002-of-00002.safetensors,model.safetensors.index.json,preprocessor_config.json,special_tokens_map.json,tokenizer.json,tokenizer_config.json,vocab.json c722b912b50b14e404b91679db710fa2e1c6a762
18
+ - vidore/colqwen2-v0.1 adapter_config.json,adapter_model.safetensors,added_tokens.json,chat_template.json,preprocessor_config.json,special_tokens_map.json,tokenizer.json,tokenizer_config.json,vocab.json 6b9ef3c32c97c0bb3be99bc35a05d9f30e0cada5
19
+ - mistral-community/pixtral-12b-240910 params.json,tekken.json,consolidated.safetensors 95758896fcf4691ec9674f29ec90d1441d9d26d2
20
+ ---
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from typing import cast
3
+ import pathlib
4
+ import gradio as gr
5
+ import spaces
6
+ import torch
7
+ from ColQwen_engine.models import ColQwen2, ColQwen2Processor
8
+ from mistral_common.protocol.instruct.messages import (
9
+ ImageURLChunk,
10
+ TextChunk,
11
+ UserMessage,
12
+ )
13
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest
14
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
15
+ from mistral_inference.generate import generate
16
+ from mistral_inference.transformer import Transformer
17
+ from pdf2image import convert_from_path
18
+ from torch.utils.data import DataLoader
19
+ from tqdm import tqdm
20
+
21
+ PIXTAL_MODEL_ID = "mistral-community--pixtral-12b-240910"
22
+ PIXTRAL_MODEL_SNAPSHOT = "95758896fcf4691ec9674f29ec90d1441d9d26d2"
23
+ PIXTRAL_MODEL_PATH = (
24
+ pathlib.Path().home()
25
+ / f".cache/huggingface/hub/models--{PIXTAL_MODEL_ID}/snapshots/{PIXTRAL_MODEL_SNAPSHOT}"
26
+ )
27
+
28
+ COLQWEN_BASE_MODEL_ID = "vidore--colqwen2-base"
29
+ COLQWEN_BASE_MODEL_SNAPSHOT = "c722b912b50b14e404b91679db710fa2e1c6a762"
30
+ COLQWEN_BASE_MODEL_PATH = (
31
+ pathlib.Path().home()
32
+ / f".cache/huggingface/hub/models--{COLQWEN_BASE_MODEL_ID}/snapshots/{COLQWEN_BASE_MODEL_SNAPSHOT}"
33
+ )
34
+ COLQWEN_MODEL_ID = "vidore--colqwen2-v0.1"
35
+ COLQWEN_MODEL_SNAPSHOT = "6b9ef3c32c97c0bb3be99bc35a05d9f30e0cada5"
36
+ COLQWEN_MODEL_PATH = (
37
+ pathlib.Path().home()
38
+ / f".cache/huggingface/hub/models--{COLQWEN_MODEL_ID}/snapshots/{COLQWEN_MODEL_SNAPSHOT}"
39
+ )
40
+
41
+
42
+ def image_to_base64(image_path):
43
+ with open(image_path, "rb") as img:
44
+ encoded_string = base64.b64encode(img.read()).decode("utf-8")
45
+ return f"data:image/jpeg;base64,{encoded_string}"
46
+
47
+
48
+ @spaces.GPU(duration=60)
49
+ def pixtral_inference(
50
+ images,
51
+ text,
52
+ ):
53
+ if len(images) == 0:
54
+ raise gr.Error("No images for generation")
55
+ if text == "":
56
+ raise gr.Error("No query for generation")
57
+ tokenizer = MistralTokenizer.from_file(f"{PIXTRAL_MODEL_PATH}/tekken.json")
58
+ model = Transformer.from_folder(PIXTRAL_MODEL_PATH)
59
+
60
+ messages = [
61
+ UserMessage(
62
+ content=[ImageURLChunk(image_url=image_to_base64(i[0])) for i in images]
63
+ + [TextChunk(text=text)]
64
+ )
65
+ ]
66
+
67
+ completion_request = ChatCompletionRequest(messages=messages)
68
+
69
+ encoded = tokenizer.encode_chat_completion(completion_request)
70
+
71
+ images = encoded.images
72
+ tokens = encoded.tokens
73
+
74
+ out_tokens, _ = generate(
75
+ [tokens],
76
+ model,
77
+ images=[images],
78
+ max_tokens=512,
79
+ temperature=0.45,
80
+ eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id,
81
+ )
82
+ result = tokenizer.decode(out_tokens[0])
83
+ return result
84
+
85
+
86
+ @spaces.GPU(duration=60)
87
+ def retrieve(query: str, ds, images, k):
88
+ if len(images) == 0:
89
+ raise gr.Error("No docs/images for retrieval")
90
+ if query == "":
91
+ raise gr.Error("No query for retrieval")
92
+
93
+ model = ColQwen2.from_pretrained(
94
+ COLQWEN_BASE_MODEL_PATH,
95
+ torch_dtype=torch.bfloat16,
96
+ device_map="cuda",
97
+ ).eval()
98
+
99
+ model.load_adapter(COLQWEN_MODEL_PATH)
100
+ model = model.eval()
101
+ processor = cast(
102
+ ColQwen2Processor, ColQwen2Processor.from_pretrained(COLQWEN_MODEL_PATH)
103
+ )
104
+
105
+ qs = []
106
+ with torch.no_grad():
107
+ batch_query = processor.process_queries([query])
108
+ batch_query = {k: v.to("cuda") for k, v in batch_query.items()}
109
+ embeddings_query = model(**batch_query)
110
+ qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
111
+
112
+ scores = processor.score(qs, ds).numpy()
113
+ top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
114
+ results = []
115
+ for idx in top_k_indices:
116
+ results.append((images[idx], f"Score {scores[0][idx]:.2f}"))
117
+ del model
118
+ del processor
119
+ torch.cuda.empty_cache()
120
+ return results
121
+
122
+
123
+ def index(files, ds):
124
+ images = convert_files(files)
125
+ return index_gpu(images, ds)
126
+
127
+
128
+ def convert_files(files):
129
+ images = []
130
+ for f in files:
131
+ images.extend(convert_from_path(f, thread_count=4))
132
+
133
+ if len(images) >= 150:
134
+ raise gr.Error("The number of images in the dataset should be less than 150.")
135
+ return images
136
+
137
+
138
+ @spaces.GPU(duration=60)
139
+ def index_gpu(images, ds):
140
+ model = ColQwen2.from_pretrained(
141
+ COLQWEN_BASE_MODEL_PATH,
142
+ torch_dtype=torch.bfloat16,
143
+ device_map="cuda",
144
+ ).eval()
145
+
146
+ model.load_adapter(COLQWEN_MODEL_PATH)
147
+ model = model.eval()
148
+ processor = cast(
149
+ ColQwen2Processor, ColQwen2Processor.from_pretrained(COLQWEN_MODEL_PATH)
150
+ )
151
+
152
+ # run inference - docs
153
+ dataloader = DataLoader(
154
+ images,
155
+ batch_size=4,
156
+ shuffle=False,
157
+ collate_fn=lambda x: processor.process_images(x),
158
+ )
159
+
160
+ for batch_doc in tqdm(dataloader):
161
+ with torch.no_grad():
162
+ batch_doc = {k: v.to("cuda") for k, v in batch_doc.items()}
163
+ embeddings_doc = model(**batch_doc)
164
+ ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
165
+ del model
166
+ del processor
167
+ torch.cuda.empty_cache()
168
+ return f"Uploaded and converted {len(images)} pages", ds, images
169
+
170
+
171
+ def get_example():
172
+ return [
173
+ [["plants_and_people.pdf"], "What is the global population in 2050 ? "],
174
+ [["plants_and_people.pdf"], "Where was Teosinte domesticated ?"],
175
+ ]
176
+
177
+
178
+ css = """
179
+ #title-container {
180
+ margin: 0 auto;
181
+ max-width: 800px;
182
+ text-align: center;
183
+ }
184
+ #col-container {
185
+ margin: 0 auto;
186
+ max-width: 600px;
187
+ }
188
+ """
189
+ file = gr.File(file_types=["pdf"], file_count="multiple", label="PDFs")
190
+ query = gr.Textbox("", placeholder="Enter your query here", label="Query")
191
+
192
+ with gr.Blocks(
193
+ title="Document Question Answering with ColQwen & Pixtral",
194
+ theme=gr.themes.Soft(),
195
+ css=css,
196
+ ) as demo:
197
+ with gr.Row(elem_id="title-container"):
198
+ gr.Markdown("""# Document Question Answering with ColQwen & Pixtral""")
199
+ with gr.Column(elem_id="col-container"):
200
+ with gr.Row():
201
+ gr.Examples(
202
+ examples=get_example(),
203
+ inputs=[file, query],
204
+ )
205
+
206
+ with gr.Row():
207
+ with gr.Column(scale=2):
208
+ gr.Markdown("## Index PDFs")
209
+ file.render()
210
+ convert_button = gr.Button("πŸ”„ Run", variant="primary")
211
+ message = gr.Textbox("Files not yet uploaded", label="Status")
212
+ embeds = gr.State(value=[])
213
+ imgs = gr.State(value=[])
214
+ img_chunk = gr.State(value=[])
215
+
216
+ with gr.Column(scale=3):
217
+ gr.Markdown("## Retrieve with ColQwen and answer with Pixtral")
218
+ query.render()
219
+ k = gr.Slider(
220
+ minimum=1,
221
+ maximum=4,
222
+ step=1,
223
+ label="Number of docs to retrieve",
224
+ value=1,
225
+ )
226
+ answer_button = gr.Button("πŸƒ Run", variant="primary")
227
+
228
+ output_gallery = gr.Gallery(
229
+ label="Retrieved docs", height=400, show_label=True, interactive=False
230
+ )
231
+ output = gr.Textbox(label="Answer", lines=2, interactive=False)
232
+
233
+ convert_button.click(
234
+ index, inputs=[file, embeds], outputs=[message, embeds, imgs]
235
+ )
236
+ answer_button.click(
237
+ retrieve, inputs=[query, embeds, imgs, k], outputs=[output_gallery]
238
+ ).then(pixtral_inference, inputs=[output_gallery, query], outputs=[output])
239
+
240
+
241
+ if __name__ == "__main__":
242
+ demo.queue(max_size=10).launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ poppler-utils
plants_and_people.pdf ADDED
Binary file (487 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ transformers==4.45.1
3
+ huggingface_hub==0.25.0
4
+ pdf2image==1.17.0
5
+ spaces==0.30.2
6
+ colpali_engine==0.3.1
7
+ mistral_inference==1.4.0