File size: 7,056 Bytes
e9b5d42 d4d97bb e9b5d42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 RhapsodyAI
#
# @author: bokai xu <[email protected]>
# @date: 2024/07/13
#
import tqdm
from PIL import Image
import hashlib
import torch
import fitz
import threading
import gradio as gr
def get_image_md5(img: Image.Image):
img_byte_array = img.tobytes()
hash_md5 = hashlib.md5()
hash_md5.update(img_byte_array)
hex_digest = hash_md5.hexdigest()
return hex_digest
def pdf_to_images(pdf_path, dpi=100):
doc = fitz.open(pdf_path)
images = []
for page in tqdm.tqdm(doc):
pix = page.get_pixmap(dpi=dpi)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
images.append(img)
return images
def calculate_md5_from_binary(binary_data):
hash_md5 = hashlib.md5()
hash_md5.update(binary_data)
return hash_md5.hexdigest()
class PDFVisualRetrieval:
def __init__(self, model, tokenizer):
self.tokenizer = tokenizer
self.model = model
self.reps = {}
self.images = {}
self.lock = threading.Lock()
def retrieve(self, knowledge_base: str, query: str, topk: int):
doc_reps = list(self.reps[knowledge_base].values())
query_with_instruction = "Represent this query for retrieving relavant document: " + query
with torch.no_grad():
query_rep = self.model(text=[query_with_instruction], image=[None], tokenizer=self.tokenizer).reps.squeeze(0)
doc_reps_cat = torch.stack(doc_reps, dim=0)
similarities = torch.matmul(query_rep, doc_reps_cat.T)
topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
topk_values_np = topk_values.cpu().numpy()
topk_doc_ids_np = topk_doc_ids.cpu().numpy()
similarities_np = similarities.cpu().numpy()
all_images_doc_list = list(self.images[knowledge_base].values())
images_topk = [all_images_doc_list[idx] for idx in topk_doc_ids_np]
return topk_doc_ids_np, topk_values_np, images_topk
def add_pdf(self, knowledge_base_name: str, pdf_file_path: str, dpi: int = 100):
if knowledge_base_name not in self.reps:
self.reps[knowledge_base_name] = {}
if knowledge_base_name not in self.images:
self.images[knowledge_base_name] = {}
doc = fitz.open(pdf_file_path)
print("model encoding images..")
for page in tqdm.tqdm(doc):
pix = page.get_pixmap(dpi=dpi)
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
image_md5 = get_image_md5(image)
with torch.no_grad():
reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
self.images[knowledge_base_name][image_md5] = image
return
def add_pdf_gradio(self, pdf_file_binary, progress=gr.Progress()):
knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
if knowledge_base_name not in self.reps:
self.reps[knowledge_base_name] = {}
else:
return knowledge_base_name
if knowledge_base_name not in self.images:
self.images[knowledge_base_name] = {}
dpi = 100
doc = fitz.open("pdf", pdf_file_binary)
for page in progress.tqdm(doc):
with self.lock: # because we hope one 16G gpu only process one image at the same time
pix = page.get_pixmap(dpi=dpi)
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
image_md5 = get_image_md5(image)
with torch.no_grad():
reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
self.images[knowledge_base_name][image_md5] = image
return knowledge_base_name
def retrieve_gradio(self, knowledge_base: str, query: str, topk: int):
doc_reps = list(self.reps[knowledge_base].values())
query_with_instruction = "Represent this query for retrieving relavant document: " + query
with torch.no_grad():
query_rep = self.model(text=[query_with_instruction], image=[None], tokenizer=self.tokenizer).reps.squeeze(0)
doc_reps_cat = torch.stack(doc_reps, dim=0)
similarities = torch.matmul(query_rep, doc_reps_cat.T)
topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
topk_values_np = topk_values.cpu().numpy()
topk_doc_ids_np = topk_doc_ids.cpu().numpy()
similarities_np = similarities.cpu().numpy()
all_images_doc_list = list(self.images[knowledge_base].values())
images_topk = [all_images_doc_list[idx] for idx in topk_doc_ids_np]
return images_topk
if __name__ == "__main__":
from transformers import AutoModel
from transformers import AutoTokenizer
from PIL import Image
import torch
device = 'cuda:0'
# Load model, be sure to substitute `model_path` by your model path
model_path = '/home/jeeves/xubokai/minicpm-visual-embedding-v0' # replace with your local model path
# pdf_path = "/home/jeeves/xubokai/minicpm-visual-embedding-v0/2406.07422v1.pdf"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
model.to(device)
retriever = PDFVisualRetrieval(model=model, tokenizer=tokenizer)
# topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='what is the number of VQ of this kind of codec method?', topk=1)
# # 2
# topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the training loss curve of this paper?', topk=1)
# # 3
# topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the experiment table?', topk=1)
# # 2
with gr.Blocks() as app:
gr.Markdown("# Memex: OCR-free Visual Document Retrieval @RhapsodyAI")
with gr.Row():
file_input = gr.File(type="binary", label="Upload PDF")
file_result = gr.Text(label="Knowledge Base ID (remember this!)")
process_button = gr.Button("Process PDF")
process_button.click(retriever.add_pdf_gradio, inputs=[file_input], outputs=file_result)
with gr.Row():
kb_id_input = gr.Text(label="Your Knowledge Base ID")
query_input = gr.Text(label="Your Queston")
topk_input = inputs=gr.Number(value=1, minimum=1, maximum=5, step=1, label="Top K")
retrieve_button = gr.Button("Retrieve")
with gr.Row():
images_output = gr.Gallery(label="Retrieved Pages")
retrieve_button.click(retriever.retrieve_gradio, inputs=[kb_id_input, query_input, topk_input], outputs=images_output)
app.launch()
|