File size: 7,056 Bytes
e9b5d42
 
 
d4d97bb
e9b5d42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 RhapsodyAI
#
# @author: bokai xu <[email protected]>
# @date: 2024/07/13
#


import tqdm
from PIL import Image
import hashlib
import torch
import fitz
import threading
import gradio as gr


def get_image_md5(img: Image.Image):
    img_byte_array = img.tobytes()
    hash_md5 = hashlib.md5()
    hash_md5.update(img_byte_array)
    hex_digest = hash_md5.hexdigest()
    return hex_digest

def pdf_to_images(pdf_path, dpi=100):
    doc = fitz.open(pdf_path)
    images = []
    for page in tqdm.tqdm(doc):
        pix = page.get_pixmap(dpi=dpi)
        img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
        images.append(img)
    return images

def calculate_md5_from_binary(binary_data):
    hash_md5 = hashlib.md5()
    hash_md5.update(binary_data)
    return hash_md5.hexdigest()

class PDFVisualRetrieval:
    def __init__(self, model, tokenizer):
        self.tokenizer = tokenizer
        self.model = model
        self.reps = {}
        self.images = {}
        self.lock = threading.Lock()
    
    def retrieve(self, knowledge_base: str, query: str, topk: int):
        doc_reps = list(self.reps[knowledge_base].values())
        query_with_instruction = "Represent this query for retrieving relavant document: " + query
        with torch.no_grad():
            query_rep = self.model(text=[query_with_instruction], image=[None], tokenizer=self.tokenizer).reps.squeeze(0)
        doc_reps_cat = torch.stack(doc_reps, dim=0)
        similarities = torch.matmul(query_rep, doc_reps_cat.T)
        topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
        topk_values_np = topk_values.cpu().numpy()
        topk_doc_ids_np = topk_doc_ids.cpu().numpy()
        similarities_np = similarities.cpu().numpy()
        all_images_doc_list = list(self.images[knowledge_base].values())
        images_topk = [all_images_doc_list[idx] for idx in topk_doc_ids_np]
        return topk_doc_ids_np, topk_values_np, images_topk
    
    def add_pdf(self, knowledge_base_name: str, pdf_file_path: str, dpi: int = 100):
        if knowledge_base_name not in self.reps:
            self.reps[knowledge_base_name] = {}
        if knowledge_base_name not in self.images:
            self.images[knowledge_base_name] = {}
        doc = fitz.open(pdf_file_path)
        print("model encoding images..")
        for page in tqdm.tqdm(doc):
            pix = page.get_pixmap(dpi=dpi)
            image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
            image_md5 = get_image_md5(image)
            with torch.no_grad():
                reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
            self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
            self.images[knowledge_base_name][image_md5] = image
        return

    def add_pdf_gradio(self, pdf_file_binary, progress=gr.Progress()):
        knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
        if knowledge_base_name not in self.reps:
            self.reps[knowledge_base_name] = {}
        else:
            return knowledge_base_name
        if knowledge_base_name not in self.images:
            self.images[knowledge_base_name] = {}
        dpi = 100
        doc = fitz.open("pdf", pdf_file_binary)
        for page in progress.tqdm(doc):
            with self.lock: # because we hope one 16G gpu only process one image at the same time
                pix = page.get_pixmap(dpi=dpi)
                image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
                image_md5 = get_image_md5(image)
                with torch.no_grad():
                    reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
                self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
                self.images[knowledge_base_name][image_md5] = image
        return knowledge_base_name

    def retrieve_gradio(self, knowledge_base: str, query: str, topk: int):
        doc_reps = list(self.reps[knowledge_base].values())
        query_with_instruction = "Represent this query for retrieving relavant document: " + query
        with torch.no_grad():
            query_rep = self.model(text=[query_with_instruction], image=[None], tokenizer=self.tokenizer).reps.squeeze(0)
        doc_reps_cat = torch.stack(doc_reps, dim=0)
        similarities = torch.matmul(query_rep, doc_reps_cat.T)
        topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
        topk_values_np = topk_values.cpu().numpy()
        topk_doc_ids_np = topk_doc_ids.cpu().numpy()
        similarities_np = similarities.cpu().numpy()
        all_images_doc_list = list(self.images[knowledge_base].values())
        images_topk = [all_images_doc_list[idx] for idx in topk_doc_ids_np]
        return images_topk


if __name__ == "__main__":
    from transformers import AutoModel
    from transformers import AutoTokenizer
    from PIL import Image
    import torch
    
    device = 'cuda:0'
    
    # Load model, be sure to substitute `model_path` by your model path 
    model_path = '/home/jeeves/xubokai/minicpm-visual-embedding-v0' # replace with your local model path
    # pdf_path = "/home/jeeves/xubokai/minicpm-visual-embedding-v0/2406.07422v1.pdf"
    
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
    model.to(device)
    
    retriever = PDFVisualRetrieval(model=model, tokenizer=tokenizer)
    
    # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='what is the number of VQ of this kind of codec method?', topk=1)
    # # 2
    # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the training loss curve of this paper?', topk=1)
    # # 3
    # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the experiment table?', topk=1)
    # # 2
    
    with gr.Blocks() as app:
        gr.Markdown("# Memex: OCR-free Visual Document Retrieval @RhapsodyAI")
        
        with gr.Row():
            file_input = gr.File(type="binary", label="Upload PDF")
            file_result = gr.Text(label="Knowledge Base ID (remember this!)")
            process_button = gr.Button("Process PDF")
        
        process_button.click(retriever.add_pdf_gradio, inputs=[file_input], outputs=file_result)

        with gr.Row():
            kb_id_input = gr.Text(label="Your Knowledge Base ID")
            query_input = gr.Text(label="Your Queston")
            topk_input = inputs=gr.Number(value=1, minimum=1, maximum=5, step=1, label="Top K")
            retrieve_button = gr.Button("Retrieve")
        
        with gr.Row():
            images_output = gr.Gallery(label="Retrieved Pages")
        
        retrieve_button.click(retriever.retrieve_gradio, inputs=[kb_id_input, query_input, topk_input], outputs=images_output)

    app.launch()