bokesyo commited on
Commit
be78911
1 Parent(s): f288393

Delete pipeline_gradio.py

Browse files
Files changed (1) hide show
  1. pipeline_gradio.py +0 -165
pipeline_gradio.py DELETED
@@ -1,165 +0,0 @@
1
- #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
- #
4
- # Copyright @2023 RhapsodyAI
5
- #
6
- # @author: bokai xu <[email protected]>
7
- # @date: 2024/07/13
8
- #
9
-
10
-
11
- import tqdm
12
- from PIL import Image
13
- import hashlib
14
- import torch
15
- import fitz
16
- import threading
17
- import gradio as gr
18
-
19
-
20
- def get_image_md5(img: Image.Image):
21
- img_byte_array = img.tobytes()
22
- hash_md5 = hashlib.md5()
23
- hash_md5.update(img_byte_array)
24
- hex_digest = hash_md5.hexdigest()
25
- return hex_digest
26
-
27
- def pdf_to_images(pdf_path, dpi=100):
28
- doc = fitz.open(pdf_path)
29
- images = []
30
- for page in tqdm.tqdm(doc):
31
- pix = page.get_pixmap(dpi=dpi)
32
- img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
33
- images.append(img)
34
- return images
35
-
36
- def calculate_md5_from_binary(binary_data):
37
- hash_md5 = hashlib.md5()
38
- hash_md5.update(binary_data)
39
- return hash_md5.hexdigest()
40
-
41
- class PDFVisualRetrieval:
42
- def __init__(self, model, tokenizer):
43
- self.tokenizer = tokenizer
44
- self.model = model
45
- self.reps = {}
46
- self.images = {}
47
- self.lock = threading.Lock()
48
-
49
- def retrieve(self, knowledge_base: str, query: str, topk: int):
50
- doc_reps = list(self.reps[knowledge_base].values())
51
- query_with_instruction = "Represent this query for retrieving relavant document: " + query
52
- with torch.no_grad():
53
- query_rep = self.model(text=[query_with_instruction], image=[None], tokenizer=self.tokenizer).reps.squeeze(0)
54
- doc_reps_cat = torch.stack(doc_reps, dim=0)
55
- similarities = torch.matmul(query_rep, doc_reps_cat.T)
56
- topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
57
- topk_values_np = topk_values.cpu().numpy()
58
- topk_doc_ids_np = topk_doc_ids.cpu().numpy()
59
- similarities_np = similarities.cpu().numpy()
60
- all_images_doc_list = list(self.images[knowledge_base].values())
61
- images_topk = [all_images_doc_list[idx] for idx in topk_doc_ids_np]
62
- return topk_doc_ids_np, topk_values_np, images_topk
63
-
64
- def add_pdf(self, knowledge_base_name: str, pdf_file_path: str, dpi: int = 100):
65
- if knowledge_base_name not in self.reps:
66
- self.reps[knowledge_base_name] = {}
67
- if knowledge_base_name not in self.images:
68
- self.images[knowledge_base_name] = {}
69
- doc = fitz.open(pdf_file_path)
70
- print("model encoding images..")
71
- for page in tqdm.tqdm(doc):
72
- pix = page.get_pixmap(dpi=dpi)
73
- image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
74
- image_md5 = get_image_md5(image)
75
- with torch.no_grad():
76
- reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
77
- self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
78
- self.images[knowledge_base_name][image_md5] = image
79
- return
80
-
81
- def add_pdf_gradio(self, pdf_file_binary, progress=gr.Progress()):
82
- knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
83
- if knowledge_base_name not in self.reps:
84
- self.reps[knowledge_base_name] = {}
85
- else:
86
- return knowledge_base_name
87
- if knowledge_base_name not in self.images:
88
- self.images[knowledge_base_name] = {}
89
- dpi = 100
90
- doc = fitz.open("pdf", pdf_file_binary)
91
- for page in progress.tqdm(doc):
92
- with self.lock: # because we hope one 16G gpu only process one image at the same time
93
- pix = page.get_pixmap(dpi=dpi)
94
- image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
95
- image_md5 = get_image_md5(image)
96
- with torch.no_grad():
97
- reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
98
- self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
99
- self.images[knowledge_base_name][image_md5] = image
100
- return knowledge_base_name
101
-
102
- def retrieve_gradio(self, knowledge_base: str, query: str, topk: int):
103
- doc_reps = list(self.reps[knowledge_base].values())
104
- query_with_instruction = "Represent this query for retrieving relavant document: " + query
105
- with torch.no_grad():
106
- query_rep = self.model(text=[query_with_instruction], image=[None], tokenizer=self.tokenizer).reps.squeeze(0)
107
- doc_reps_cat = torch.stack(doc_reps, dim=0)
108
- similarities = torch.matmul(query_rep, doc_reps_cat.T)
109
- topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
110
- topk_values_np = topk_values.cpu().numpy()
111
- topk_doc_ids_np = topk_doc_ids.cpu().numpy()
112
- similarities_np = similarities.cpu().numpy()
113
- all_images_doc_list = list(self.images[knowledge_base].values())
114
- images_topk = [all_images_doc_list[idx] for idx in topk_doc_ids_np]
115
- return images_topk
116
-
117
-
118
- if __name__ == "__main__":
119
- from transformers import AutoModel
120
- from transformers import AutoTokenizer
121
- from PIL import Image
122
- import torch
123
-
124
- device = 'cuda:0'
125
-
126
- # Load model, be sure to substitute `model_path` by your model path
127
- model_path = '/home/jeeves/xubokai/minicpm-visual-embedding-v0' # replace with your local model path
128
- # pdf_path = "/home/jeeves/xubokai/minicpm-visual-embedding-v0/2406.07422v1.pdf"
129
-
130
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
131
- model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
132
- model.to(device)
133
-
134
- retriever = PDFVisualRetrieval(model=model, tokenizer=tokenizer)
135
-
136
- # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='what is the number of VQ of this kind of codec method?', topk=1)
137
- # # 2
138
- # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the training loss curve of this paper?', topk=1)
139
- # # 3
140
- # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the experiment table?', topk=1)
141
- # # 2
142
-
143
- with gr.Blocks() as app:
144
- gr.Markdown("# Memex: OCR-free Visual Document Retrieval @RhapsodyAI")
145
-
146
- with gr.Row():
147
- file_input = gr.File(type="binary", label="Upload PDF")
148
- file_result = gr.Text(label="Knowledge Base ID (remember this!)")
149
- process_button = gr.Button("Process PDF")
150
-
151
- process_button.click(retriever.add_pdf_gradio, inputs=[file_input], outputs=file_result)
152
-
153
- with gr.Row():
154
- kb_id_input = gr.Text(label="Your Knowledge Base ID")
155
- query_input = gr.Text(label="Your Queston")
156
- topk_input = inputs=gr.Number(value=1, minimum=1, maximum=5, step=1, label="Top K")
157
- retrieve_button = gr.Button("Retrieve")
158
-
159
- with gr.Row():
160
- images_output = gr.Gallery(label="Retrieved Pages")
161
-
162
- retrieve_button.click(retriever.retrieve_gradio, inputs=[kb_id_input, query_input, topk_input], outputs=images_output)
163
-
164
- app.launch()
165
-