tcy6 commited on
Commit
4599dc2
1 Parent(s): 619d8a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -4
app.py CHANGED
@@ -1,7 +1,218 @@
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ from PIL import Image
3
+ import hashlib
4
+ import torch
5
+ import fitz
6
+ import threading
7
  import gradio as gr
8
+ import spaces
9
+ import os
10
+ from transformers import AutoModel
11
+ from transformers import AutoTokenizer
12
+ from PIL import Image
13
+ import torch
14
+ import os
15
+ import numpy as np
16
+ import json
17
 
18
+ cache_dir = '/home/user/data'
19
+ os.makedirs(cache_dir, exist_ok=True)
20
 
21
+ def get_image_md5(img: Image.Image):
22
+ img_byte_array = img.tobytes()
23
+ hash_md5 = hashlib.md5()
24
+ hash_md5.update(img_byte_array)
25
+ hex_digest = hash_md5.hexdigest()
26
+ return hex_digest
27
+
28
+ def calculate_md5_from_binary(binary_data):
29
+ hash_md5 = hashlib.md5()
30
+ hash_md5.update(binary_data)
31
+ return hash_md5.hexdigest()
32
+
33
+ @spaces.GPU(duration=100)
34
+ def add_pdf_gradio(pdf_file_binary, progress=gr.Progress()):
35
+ global model, tokenizer
36
+
37
+ knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
38
+
39
+ this_cache_dir = os.path.join(cache_dir, knowledge_base_name)
40
+ os.makedirs(this_cache_dir, exist_ok=True)
41
+
42
+ with open(os.path.join(this_cache_dir, f"src.pdf"), 'wb') as file:
43
+ file.write(pdf_file_binary)
44
+
45
+ dpi = 200
46
+ doc = fitz.open("pdf", pdf_file_binary)
47
+
48
+ reps_list = []
49
+ images = []
50
+ image_md5s = []
51
+
52
+ for page in progress.tqdm(doc):
53
+ # with self.lock: # because we hope one 16G gpu only process one image at the same time
54
+ pix = page.get_pixmap(dpi=dpi)
55
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
56
+ image_md5 = get_image_md5(image)
57
+ image_md5s.append(image_md5)
58
+ with torch.no_grad():
59
+ reps = model(text=[''], image=[image], tokenizer=tokenizer).reps
60
+ reps_list.append(reps.squeeze(0).cpu().numpy())
61
+ images.append(image)
62
+
63
+ for idx in range(len(images)):
64
+ image = images[idx]
65
+ image_md5 = image_md5s[idx]
66
+ cache_image_path = os.path.join(this_cache_dir, f"{image_md5}.png")
67
+ image.save(cache_image_path)
68
+
69
+ np.save(os.path.join(this_cache_dir, f"reps.npy"), reps_list)
70
+
71
+ with open(os.path.join(this_cache_dir, f"md5s.txt"), 'w') as f:
72
+ for item in image_md5s:
73
+ f.write(item+'\n')
74
+
75
+ return knowledge_base_name
76
+
77
+ # @spaces.GPU
78
+ def retrieve_gradio(knowledge_base: str, query: str, topk: int):
79
+ global model, tokenizer
80
+
81
+ target_cache_dir = os.path.join(cache_dir, knowledge_base)
82
+
83
+ if not os.path.exists(target_cache_dir):
84
+ return None
85
+
86
+ md5s = []
87
+ with open(os.path.join(target_cache_dir, f"md5s.txt"), 'r') as f:
88
+ for line in f:
89
+ md5s.append(line.rstrip('\n'))
90
+
91
+ doc_reps = np.load(os.path.join(target_cache_dir, f"reps.npy"))
92
+
93
+ query_with_instruction = "Represent this query for retrieving relevant document: " + query
94
+ with torch.no_grad():
95
+ query_rep = model(text=[query_with_instruction], image=[None], tokenizer=tokenizer).reps.squeeze(0).cpu()
96
+
97
+ query_md5 = hashlib.md5(query.encode()).hexdigest()
98
+
99
+ doc_reps_cat = torch.stack([torch.Tensor(i) for i in doc_reps], dim=0)
100
+
101
+ similarities = torch.matmul(query_rep, doc_reps_cat.T)
102
+
103
+ topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
104
+
105
+ topk_values_np = topk_values.cpu().numpy()
106
+
107
+ topk_doc_ids_np = topk_doc_ids.cpu().numpy()
108
+
109
+ similarities_np = similarities.cpu().numpy()
110
+
111
+ images_topk = [Image.open(os.path.join(target_cache_dir, f"{md5s[idx]}.png")) for idx in topk_doc_ids_np]
112
+
113
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'w') as f:
114
+ f.write(json.dumps(
115
+ {
116
+ "knowledge_base": knowledge_base,
117
+ "query": query,
118
+ "retrived_docs": [os.path.join(target_cache_dir, f"{md5s[idx]}.png") for idx in topk_doc_ids_np]
119
+ }, indent=4, ensure_ascii=False
120
+ ))
121
+
122
+ return images_topk
123
+
124
+
125
+ def upvote(knowledge_base, query):
126
+ global model, tokenizer
127
+
128
+ target_cache_dir = os.path.join(cache_dir, knowledge_base)
129
+
130
+ query_md5 = hashlib.md5(query.encode()).hexdigest()
131
+
132
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'r') as f:
133
+ data = json.loads(f.read())
134
+
135
+ data["user_preference"] = "upvote"
136
+
137
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"), 'w') as f:
138
+ f.write(json.dumps(data, indent=4, ensure_ascii=False))
139
+
140
+ print("up", os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"))
141
+
142
+ gr.Info('Received, babe! Thank you!')
143
+
144
+ return
145
+
146
+
147
+ def downvote(knowledge_base, query):
148
+ global model, tokenizer
149
+
150
+ target_cache_dir = os.path.join(cache_dir, knowledge_base)
151
+
152
+ query_md5 = hashlib.md5(query.encode()).hexdigest()
153
+
154
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'r') as f:
155
+ data = json.loads(f.read())
156
+
157
+ data["user_preference"] = "downvote"
158
+
159
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"), 'w') as f:
160
+ f.write(json.dumps(data, indent=4, ensure_ascii=False))
161
+
162
+ print("down", os.path.join(target_cache_dir, f"q-{query_md5}-withpref.json"))
163
+
164
+ gr.Info('Received, babe! Thank you!')
165
+
166
+ return
167
+
168
+
169
+ device = 'cuda'
170
+ model_path = 'RhapsodyAI/minicpm-visual-embedding-v0' # replace with your local model path
171
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
172
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
173
+ model.to(device)
174
+
175
+
176
+ with gr.Blocks() as app:
177
+ gr.Markdown("# Memex: OCR-free Visual Document Embedding Model as Your Personal Librarian")
178
+ gr.Markdown("""The model only takes images as document-side inputs and produce vectors representing document pages. Memex is trained with over 200k query-visual document pairs, including textual document, visual document, arxiv figures, plots, charts, industry documents, textbooks, ebooks, and openly-available PDFs, etc. Its performance is on a par with our ablation text embedding model on text-oriented documents, and an advantages on visually-intensive documents.
179
+ Our model is capable of:
180
+ - Help you read a long visually-intensive or text-oriented PDF document and find the pages that answer your question.
181
+ - Help you build a personal library and retireve book pages from a large collection of books.
182
+ - It works like human: read and comprehend with vision and remember multimodal information in hippocampus.""")
183
+
184
+ gr.Markdown("- Our model is proudly based on MiniCPM-V series [MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6) [MiniCPM-V-2](https://huggingface.co/openbmb/MiniCPM-V-2).")
185
+
186
+ gr.Markdown("- We open-sourced our model at [RhapsodyAI/minicpm-visual-embedding-v0](https://huggingface.co/RhapsodyAI/minicpm-visual-embedding-v0)")
187
+
188
+ gr.Markdown("- Currently we support PDF document with less than 50 pages, PDF over 50 pages will reach GPU time limit.")
189
+
190
+ with gr.Row():
191
+ file_input = gr.File(type="binary", label="Upload PDF")
192
+ file_result = gr.Text(label="Knowledge Base ID (remember this!)")
193
+ process_button = gr.Button("Process PDF (Don't click until PDF upload success)")
194
+
195
+ process_button.click(add_pdf_gradio, inputs=[file_input], outputs=file_result)
196
+
197
+ with gr.Row():
198
+ kb_id_input = gr.Text(label="Your Knowledge Base ID (paste your Knowledge Base ID here:)")
199
+ query_input = gr.Text(label="Your Queston")
200
+ topk_input = inputs=gr.Number(value=5, minimum=1, maximum=10, step=1, label="Number of pages to retrieve")
201
+ retrieve_button = gr.Button("Step 1: Retrieve")
202
+
203
+ with gr.Row():
204
+ downvote_button = gr.Button("🤣Downvote")
205
+ upvote_button = gr.Button("🤗Upvote")
206
+
207
+ with gr.Row():
208
+ images_output = gr.Gallery(label="Step 2: Retrieved Pages")
209
+
210
+ retrieve_button.click(retrieve_gradio, inputs=[kb_id_input, query_input, topk_input], outputs=images_output)
211
+
212
+ upvote_button.click(upvote, inputs=[kb_id_input, query_input], outputs=None)
213
+ downvote_button.click(downvote, inputs=[kb_id_input, query_input], outputs=None)
214
+
215
+ gr.Markdown("By using this demo, you agree to share your use data with us for research purpose, to help improve user experience.")
216
+
217
+
218
+ app.launch()