File size: 4,009 Bytes
ee22b11
 
 
e5c4b60
ee22b11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 RhapsodyAI
#
# @author: bokai xu <[email protected]>
# @date: 2024/07/13
#


import tqdm
from PIL import Image
import hashlib
import torch
import fitz


def get_image_md5(img: Image.Image):
    img_byte_array = img.tobytes()
    hash_md5 = hashlib.md5()
    hash_md5.update(img_byte_array)
    hex_digest = hash_md5.hexdigest()
    return hex_digest

def pdf_to_images(pdf_path, dpi=100):
    doc = fitz.open(pdf_path)
    images = []
    for page in tqdm.tqdm(doc):
        pix = page.get_pixmap(dpi=dpi)
        img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
        images.append(img)
    return images


class PDFVisualRetrieval:
    def __init__(self, model, tokenizer):
        self.tokenizer = tokenizer
        self.model = model
        self.reps = {}
        self.images = {}
    
    def add_visual_documents(self, knowledge_base_name: str, images: Image.Image):
        if knowledge_base_name not in self.reps:
            self.reps[knowledge_base_name] = {}
        if knowledge_base_name not in self.images:
            self.images[knowledge_base_name] = {}
        for image in tqdm.tqdm(images):
            image_md5 = get_image_md5(image)
            with torch.no_grad():
                reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
            self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
            self.images[knowledge_base_name][image_md5] = image
        return
    
    def retrieve(self, knowledge_base: str, query: str, topk: int):
        doc_reps = list(self.reps[knowledge_base].values())
        query_with_instruction = "Represent this query for retrieving relavant document: " + query
        with torch.no_grad():
            query_rep = self.model(text=[query_with_instruction], image=[None], tokenizer=self.tokenizer).reps.squeeze(0)
        doc_reps_cat = torch.stack(doc_reps, dim=0)
        similarities = torch.matmul(query_rep, doc_reps_cat.T)
        topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
        topk_values_np = topk_values.cpu().numpy()
        topk_doc_ids_np = topk_doc_ids.cpu().numpy()
        similarities_np = similarities.cpu().numpy()
        all_images_doc_list = list(self.images[knowledge_base].values())
        images_topk = [all_images_doc_list[idx] for idx in topk_doc_ids_np]
        return topk_doc_ids_np, topk_values_np, images_topk
    
    def add_pdf(self, knowledge_base_name: str, pdf_file_path: str, dpi: int = 100):
        print("[1/2] rendering pdf to images..")
        images = pdf_to_images(pdf_file_path, dpi=dpi)
        print("[2/2] model encoding images..")
        self.add_visual_documents(knowledge_base_name=knowledge_base_name, images=images)
        print("add pdf ok.")
        return


if __name__ == "__main__":
    from transformers import AutoModel
    from transformers import AutoTokenizer
    from PIL import Image
    import torch
    
    device = 'cuda:0'
    
    # Load model, be sure to substitute `model_path` by your model path 
    model_path = '/home/jeeves/xubokai/minicpm-visual-embedding-v0'
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
    model.to(device)
    
    pdf_path = "/home/jeeves/xubokai/minicpm-visual-embedding-v0/2406.07422v1.pdf"
    retriever = PDFVisualRetrieval(model=model, tokenizer=tokenizer)
    retriever.add_pdf('test', pdf_path)
    
    topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='what is the number of VQ of this kind of codec method?', topk=1)
    # 2
    topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the training loss curve of this paper?', topk=1)
    # 3
    topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the experiment table?', topk=1)
    # 2