bokesyo commited on
Commit
e9b5d42
1 Parent(s): 63057ca

Create pipeline_gradio.py

Browse files
Files changed (1) hide show
  1. pipeline_gradio.py +165 -0
pipeline_gradio.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ #
4
+ # Copyright @2023 RhapsodyAI, ModelBest Inc. (modelbest.cn)
5
+ #
6
+ # @author: bokai xu <[email protected]>
7
+ # @date: 2024/07/13
8
+ #
9
+
10
+
11
+ import tqdm
12
+ from PIL import Image
13
+ import hashlib
14
+ import torch
15
+ import fitz
16
+ import threading
17
+ import gradio as gr
18
+
19
+
20
+ def get_image_md5(img: Image.Image):
21
+ img_byte_array = img.tobytes()
22
+ hash_md5 = hashlib.md5()
23
+ hash_md5.update(img_byte_array)
24
+ hex_digest = hash_md5.hexdigest()
25
+ return hex_digest
26
+
27
+ def pdf_to_images(pdf_path, dpi=100):
28
+ doc = fitz.open(pdf_path)
29
+ images = []
30
+ for page in tqdm.tqdm(doc):
31
+ pix = page.get_pixmap(dpi=dpi)
32
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
33
+ images.append(img)
34
+ return images
35
+
36
+ def calculate_md5_from_binary(binary_data):
37
+ hash_md5 = hashlib.md5()
38
+ hash_md5.update(binary_data)
39
+ return hash_md5.hexdigest()
40
+
41
+ class PDFVisualRetrieval:
42
+ def __init__(self, model, tokenizer):
43
+ self.tokenizer = tokenizer
44
+ self.model = model
45
+ self.reps = {}
46
+ self.images = {}
47
+ self.lock = threading.Lock()
48
+
49
+ def retrieve(self, knowledge_base: str, query: str, topk: int):
50
+ doc_reps = list(self.reps[knowledge_base].values())
51
+ query_with_instruction = "Represent this query for retrieving relavant document: " + query
52
+ with torch.no_grad():
53
+ query_rep = self.model(text=[query_with_instruction], image=[None], tokenizer=self.tokenizer).reps.squeeze(0)
54
+ doc_reps_cat = torch.stack(doc_reps, dim=0)
55
+ similarities = torch.matmul(query_rep, doc_reps_cat.T)
56
+ topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
57
+ topk_values_np = topk_values.cpu().numpy()
58
+ topk_doc_ids_np = topk_doc_ids.cpu().numpy()
59
+ similarities_np = similarities.cpu().numpy()
60
+ all_images_doc_list = list(self.images[knowledge_base].values())
61
+ images_topk = [all_images_doc_list[idx] for idx in topk_doc_ids_np]
62
+ return topk_doc_ids_np, topk_values_np, images_topk
63
+
64
+ def add_pdf(self, knowledge_base_name: str, pdf_file_path: str, dpi: int = 100):
65
+ if knowledge_base_name not in self.reps:
66
+ self.reps[knowledge_base_name] = {}
67
+ if knowledge_base_name not in self.images:
68
+ self.images[knowledge_base_name] = {}
69
+ doc = fitz.open(pdf_file_path)
70
+ print("model encoding images..")
71
+ for page in tqdm.tqdm(doc):
72
+ pix = page.get_pixmap(dpi=dpi)
73
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
74
+ image_md5 = get_image_md5(image)
75
+ with torch.no_grad():
76
+ reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
77
+ self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
78
+ self.images[knowledge_base_name][image_md5] = image
79
+ return
80
+
81
+ def add_pdf_gradio(self, pdf_file_binary, progress=gr.Progress()):
82
+ knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
83
+ if knowledge_base_name not in self.reps:
84
+ self.reps[knowledge_base_name] = {}
85
+ else:
86
+ return knowledge_base_name
87
+ if knowledge_base_name not in self.images:
88
+ self.images[knowledge_base_name] = {}
89
+ dpi = 100
90
+ doc = fitz.open("pdf", pdf_file_binary)
91
+ for page in progress.tqdm(doc):
92
+ with self.lock: # because we hope one 16G gpu only process one image at the same time
93
+ pix = page.get_pixmap(dpi=dpi)
94
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
95
+ image_md5 = get_image_md5(image)
96
+ with torch.no_grad():
97
+ reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
98
+ self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
99
+ self.images[knowledge_base_name][image_md5] = image
100
+ return knowledge_base_name
101
+
102
+ def retrieve_gradio(self, knowledge_base: str, query: str, topk: int):
103
+ doc_reps = list(self.reps[knowledge_base].values())
104
+ query_with_instruction = "Represent this query for retrieving relavant document: " + query
105
+ with torch.no_grad():
106
+ query_rep = self.model(text=[query_with_instruction], image=[None], tokenizer=self.tokenizer).reps.squeeze(0)
107
+ doc_reps_cat = torch.stack(doc_reps, dim=0)
108
+ similarities = torch.matmul(query_rep, doc_reps_cat.T)
109
+ topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
110
+ topk_values_np = topk_values.cpu().numpy()
111
+ topk_doc_ids_np = topk_doc_ids.cpu().numpy()
112
+ similarities_np = similarities.cpu().numpy()
113
+ all_images_doc_list = list(self.images[knowledge_base].values())
114
+ images_topk = [all_images_doc_list[idx] for idx in topk_doc_ids_np]
115
+ return images_topk
116
+
117
+
118
+ if __name__ == "__main__":
119
+ from transformers import AutoModel
120
+ from transformers import AutoTokenizer
121
+ from PIL import Image
122
+ import torch
123
+
124
+ device = 'cuda:0'
125
+
126
+ # Load model, be sure to substitute `model_path` by your model path
127
+ model_path = '/home/jeeves/xubokai/minicpm-visual-embedding-v0' # replace with your local model path
128
+ # pdf_path = "/home/jeeves/xubokai/minicpm-visual-embedding-v0/2406.07422v1.pdf"
129
+
130
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
131
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
132
+ model.to(device)
133
+
134
+ retriever = PDFVisualRetrieval(model=model, tokenizer=tokenizer)
135
+
136
+ # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='what is the number of VQ of this kind of codec method?', topk=1)
137
+ # # 2
138
+ # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the training loss curve of this paper?', topk=1)
139
+ # # 3
140
+ # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the experiment table?', topk=1)
141
+ # # 2
142
+
143
+ with gr.Blocks() as app:
144
+ gr.Markdown("# Memex: OCR-free Visual Document Retrieval @RhapsodyAI")
145
+
146
+ with gr.Row():
147
+ file_input = gr.File(type="binary", label="Upload PDF")
148
+ file_result = gr.Text(label="Knowledge Base ID (remember this!)")
149
+ process_button = gr.Button("Process PDF")
150
+
151
+ process_button.click(retriever.add_pdf_gradio, inputs=[file_input], outputs=file_result)
152
+
153
+ with gr.Row():
154
+ kb_id_input = gr.Text(label="Your Knowledge Base ID")
155
+ query_input = gr.Text(label="Your Queston")
156
+ topk_input = inputs=gr.Number(value=1, minimum=1, maximum=5, step=1, label="Top K")
157
+ retrieve_button = gr.Button("Retrieve")
158
+
159
+ with gr.Row():
160
+ images_output = gr.Gallery(label="Retrieved Pages")
161
+
162
+ retrieve_button.click(retriever.retrieve_gradio, inputs=[kb_id_input, query_input, topk_input], outputs=images_output)
163
+
164
+ app.launch()
165
+