hahahahahahahah3 commited on
Commit
5e98383
·
verified ·
1 Parent(s): d797aab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +259 -62
app.py CHANGED
@@ -1,63 +1,260 @@
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
-
61
-
62
- if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ from PIL import Image
3
+ import hashlib
4
+ import torch
5
+ import fitz
6
+ import threading
7
  import gradio as gr
8
+ import spaces
9
+ import os
10
+ from transformers import AutoModel
11
+ from transformers import AutoTokenizer
12
+ from PIL import Image
13
+ import torch
14
+ import os
15
+ import numpy as np
16
+ import json
17
+
18
+ cache_dir = '/data/kb_cache'
19
+ os.makedirs(cache_dir, exist_ok=True)
20
+
21
+ def get_image_md5(img: Image.Image):
22
+ img_byte_array = img.tobytes()
23
+ hash_md5 = hashlib.md5()
24
+ hash_md5.update(img_byte_array)
25
+ hex_digest = hash_md5.hexdigest()
26
+ return hex_digest
27
+
28
+ def calculate_md5_from_binary(binary_data):
29
+ hash_md5 = hashlib.md5()
30
+ hash_md5.update(binary_data)
31
+ return hash_md5.hexdigest()
32
+
33
+ @spaces.GPU(duration=100)
34
+ def add_pdf_gradio(pdf_file_binary, progress=gr.Progress()):
35
+ global model, tokenizer
36
+ model.eval()
37
+
38
+ knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
39
+
40
+ this_cache_dir = os.path.join(cache_dir, knowledge_base_name)
41
+ os.makedirs(this_cache_dir, exist_ok=True)
42
+
43
+ with open(os.path.join(this_cache_dir, f"src.pdf"), 'wb') as file:
44
+ file.write(pdf_file_binary)
45
+
46
+ dpi = 200
47
+ doc = fitz.open("pdf", pdf_file_binary)
48
+
49
+ reps_list = []
50
+ images = []
51
+ image_md5s = []
52
+
53
+ for page in progress.tqdm(doc):
54
+ # with self.lock: # because we hope one 16G gpu only process one image at the same time
55
+ pix = page.get_pixmap(dpi=dpi)
56
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
57
+ image_md5 = get_image_md5(image)
58
+ image_md5s.append(image_md5)
59
+ with torch.no_grad():
60
+ reps = model(text=[''], image=[image], tokenizer=tokenizer).reps
61
+ reps_list.append(reps.squeeze(0).cpu().numpy())
62
+ images.append(image)
63
+
64
+ for idx in range(len(images)):
65
+ image = images[idx]
66
+ image_md5 = image_md5s[idx]
67
+ cache_image_path = os.path.join(this_cache_dir, f"{image_md5}.png")
68
+ image.save(cache_image_path)
69
+
70
+ np.save(os.path.join(this_cache_dir, f"reps.npy"), reps_list)
71
+
72
+ with open(os.path.join(this_cache_dir, f"md5s.txt"), 'w') as f:
73
+ for item in image_md5s:
74
+ f.write(item+'\n')
75
+
76
+ return knowledge_base_name
77
+
78
+ # @spaces.GPU
79
+ def retrieve_gradio(knowledge_base: str, query: str, topk: int):
80
+ global model, tokenizer
81
+
82
+ model.eval()
83
+
84
+ target_cache_dir = os.path.join(cache_dir, knowledge_base)
85
+
86
+ if not os.path.exists(target_cache_dir):
87
+ return None
88
+
89
+ md5s = []
90
+ with open(os.path.join(target_cache_dir, f"md5s.txt"), 'r') as f:
91
+ for line in f:
92
+ md5s.append(line.rstrip('\n'))
93
+
94
+ doc_reps = np.load(os.path.join(target_cache_dir, f"reps.npy"))
95
+
96
+ query_with_instruction = "Represent this query for retrieving relavant document: " + query
97
+ with torch.no_grad():
98
+ query_rep = model(text=[query_with_instruction], image=[None], tokenizer=tokenizer).reps.squeeze(0).cpu()
99
+
100
+ query_md5 = hashlib.md5(query.encode()).hexdigest()
101
+
102
+ doc_reps_cat = torch.stack([torch.Tensor(i) for i in doc_reps], dim=0)
103
+
104
+ similarities = torch.matmul(query_rep, doc_reps_cat.T)
105
+
106
+ topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
107
+
108
+ topk_values_np = topk_values.cpu().numpy()
109
+
110
+ topk_doc_ids_np = topk_doc_ids.cpu().numpy()
111
+
112
+ similarities_np = similarities.cpu().numpy()
113
+
114
+ images_topk = [Image.open(os.path.join(target_cache_dir, f"{md5s[idx]}.png")) for idx in topk_doc_ids_np]
115
+
116
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'w') as f:
117
+ f.write(json.dumps(
118
+ {
119
+ "knowledge_base": knowledge_base,
120
+ "query": query,
121
+ "retrived_docs": [os.path.join(target_cache_dir, f"{md5s[idx]}.png") for idx in topk_doc_ids_np]
122
+ }, indent=4, ensure_ascii=False
123
+ ))
124
+
125
+ return images_topk
126
+
127
+
128
+ def upvote(knowledge_base, query):
129
+ global model, tokenizer
130
+
131
+ target_cache_dir = os.path.join(cache_dir, knowledge_base)
132
+
133
+ query_md5 = hashlib.md5(query.encode()).hexdigest()
134
+
135
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'r') as f:
136
+ data = json.loads(f.read())
137
+
138
+ data["user_preference"] = "upvote"
139
+
140
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"), 'w') as f:
141
+ f.write(json.dumps(data, indent=4, ensure_ascii=False))
142
+
143
+ print("up", os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"))
144
+
145
+ gr.Info('Received, babe! Thank you!')
146
+
147
+ return
148
+
149
+
150
+ def downvote(knowledge_base, query):
151
+ global model, tokenizer
152
+
153
+ target_cache_dir = os.path.join(cache_dir, knowledge_base)
154
+
155
+ query_md5 = hashlib.md5(query.encode()).hexdigest()
156
+
157
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'r') as f:
158
+ data = json.loads(f.read())
159
+
160
+ data["user_preference"] = "downvote"
161
+
162
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"), 'w') as f:
163
+ f.write(json.dumps(data, indent=4, ensure_ascii=False))
164
+
165
+ print("down", os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"))
166
+
167
+ gr.Info('Received, babe! Thank you!')
168
+
169
+ return
170
+
171
+
172
+
173
+ device = 'cuda'
174
+
175
+ print("emb model load begin...")
176
+ model_path = 'RhapsodyAI/minicpm-visual-embedding-v0' # replace with your local model path
177
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
178
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
179
+ model.eval()
180
+ model.to(device)
181
+ print("emb model load success!")
182
+
183
+ print("gen model load begin...")
184
+ gen_model_path = 'openbmb/MiniCPM-V-2_6'
185
+ gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_path, trust_remote_code=True)
186
+ gen_model = AutoModel.from_pretrained(gen_model_path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16)
187
+ gen_model.eval()
188
+ gen_model.to(device)
189
+ print("gen model load success!")
190
+
191
+
192
+ @spaces.GPU(duration=50)
193
+ def answer_question(images, question):
194
+ global gen_model, gen_tokenizer
195
+ # here each element of images is a tuple of (image_path, None).
196
+ images_ = [Image.open(image[0]).convert('RGB') for image in images]
197
+ msgs = [{'role': 'user', 'content': [question, *images_]}]
198
+ answer = gen_model.chat(
199
+ image=None,
200
+ msgs=msgs,
201
+ tokenizer=gen_tokenizer
202
+ )
203
+ print(answer)
204
+ return answer
205
+
206
+
207
+ with gr.Blocks() as app:
208
+ gr.Markdown("# MiniCPMV-RAG-PDFQA: Two Vision Language Models Enable End-to-End RAG")
209
+
210
+ gr.Markdown("""
211
+ - A Vision Language Model Dense Retriever ([minicpm-visual-embedding-v0](https://huggingface.co/RhapsodyAI/minicpm-visual-embedding-v0)) **directly reads** your PDFs **without need of OCR**, produce **multimodal dense representations** and build your personal library.
212
+
213
+ - **Ask a question**, it retrieve most relavant pages, then [MiniCPM-V-2.6](https://huggingface.co/spaces/openbmb/MiniCPM-V-2_6) will answer your question based on pages recalled, with strong multi-image understanding capability.
214
+
215
+ - It helps you read a long **visually-intensive** or **text-oriented** PDF document and find the pages that answer your question.
216
+
217
+ - It helps you build a personal library and retireve book pages from a large collection of books.
218
+
219
+ - It works like a human: read, store, retrieve, and answer with full vision.
220
+ """)
221
+
222
+ gr.Markdown("- Currently online demo support PDF document with less than 50 pages due to GPU time limit. Deploy on your own machine for longer PDFs and books.")
223
+
224
+ with gr.Row():
225
+ file_input = gr.File(type="binary", label="Step 1: Upload PDF")
226
+ file_result = gr.Text(label="Knowledge Base ID (remember it, it is re-usable!)")
227
+ process_button = gr.Button("Process PDF (Don't click until PDF upload success)")
228
+
229
+ process_button.click(add_pdf_gradio, inputs=[file_input], outputs=file_result)
230
+
231
+ with gr.Row():
232
+ kb_id_input = gr.Text(label="Your Knowledge Base ID (paste your Knowledge Base ID here, it is re-usable:)")
233
+ query_input = gr.Text(label="Your Queston")
234
+ topk_input = inputs=gr.Number(value=5, minimum=1, maximum=10, step=1, label="Number of pages to retrieve")
235
+ retrieve_button = gr.Button("Step2: Retrieve Pages")
236
+
237
+ with gr.Row():
238
+ images_output = gr.Gallery(label="Retrieved Pages")
239
+
240
+ retrieve_button.click(retrieve_gradio, inputs=[kb_id_input, query_input, topk_input], outputs=images_output)
241
+
242
+ with gr.Row():
243
+ button = gr.Button("Step 3: Answer Question with Retrieved Pages")
244
+
245
+ gen_model_response = gr.Textbox(label="MiniCPM-V-2.6's Answer")
246
+
247
+ button.click(fn=answer_question, inputs=[images_output, query_input], outputs=gen_model_response)
248
+
249
+ with gr.Row():
250
+ downvote_button = gr.Button("🤣Downvote")
251
+ upvote_button = gr.Button("🤗Upvote")
252
+
253
+ upvote_button.click(upvote, inputs=[kb_id_input, query_input], outputs=None)
254
+ downvote_button.click(downvote, inputs=[kb_id_input, query_input], outputs=None)
255
+
256
+ gr.Markdown("By using this demo, you agree to share your use data with us for research purpose, to help improve user experience.")
257
+
258
+
259
+ app.launch()
260
+