bokesyo commited on
Commit
ee22b11
1 Parent(s): c37926b

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +103 -0
pipeline.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ #
4
+ # Copyright @2023 RhapsodyAI, ModelBest Inc. (modelbest.cn)
5
+ #
6
+ # @author: bokai xu <[email protected]>
7
+ # @date: 2024/07/13
8
+ #
9
+
10
+
11
+ import tqdm
12
+ from PIL import Image
13
+ import hashlib
14
+ import torch
15
+ import fitz
16
+
17
+
18
+ def get_image_md5(img: Image.Image):
19
+ img_byte_array = img.tobytes()
20
+ hash_md5 = hashlib.md5()
21
+ hash_md5.update(img_byte_array)
22
+ hex_digest = hash_md5.hexdigest()
23
+ return hex_digest
24
+
25
+ def pdf_to_images(pdf_path, dpi=100):
26
+ doc = fitz.open(pdf_path)
27
+ images = []
28
+ for page in tqdm.tqdm(doc):
29
+ pix = page.get_pixmap(dpi=dpi)
30
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
31
+ images.append(img)
32
+ return images
33
+
34
+
35
+ class PDFVisualRetrieval:
36
+ def __init__(self, model, tokenizer):
37
+ self.tokenizer = tokenizer
38
+ self.model = model
39
+ self.reps = {}
40
+ self.images = {}
41
+
42
+ def add_visual_documents(self, knowledge_base_name: str, images: Image.Image):
43
+ if knowledge_base_name not in self.reps:
44
+ self.reps[knowledge_base_name] = {}
45
+ if knowledge_base_name not in self.images:
46
+ self.images[knowledge_base_name] = {}
47
+ for image in tqdm.tqdm(images):
48
+ image_md5 = get_image_md5(image)
49
+ with torch.no_grad():
50
+ reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
51
+ self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
52
+ self.images[knowledge_base_name][image_md5] = image
53
+ return
54
+
55
+ def retrieve(self, knowledge_base: str, query: str, topk: int):
56
+ doc_reps = list(self.reps[knowledge_base].values())
57
+ query_with_instruction = "Represent this query for retrieving relavant document: " + query
58
+ with torch.no_grad():
59
+ query_rep = self.model(text=[query_with_instruction], image=[None], tokenizer=self.tokenizer).reps.squeeze(0)
60
+ doc_reps_cat = torch.stack(doc_reps, dim=0)
61
+ similarities = torch.matmul(query_rep, doc_reps_cat.T)
62
+ topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
63
+ topk_values_np = topk_values.cpu().numpy()
64
+ topk_doc_ids_np = topk_doc_ids.cpu().numpy()
65
+ similarities_np = similarities.cpu().numpy()
66
+ all_images_doc_list = list(self.images[knowledge_base].values())
67
+ images_topk = [all_images_doc_list[idx] for idx in topk_doc_ids_np]
68
+ return topk_doc_ids_np, topk_values_np, images_topk
69
+
70
+ def add_pdf(self, knowledge_base_name: str, pdf_file_path: str, dpi: int = 100):
71
+ print("[1/2] rendering pdf to images..")
72
+ images = pdf_to_images(pdf_file_path, dpi=dpi)
73
+ print("[2/2] model encoding images..")
74
+ self.add_visual_documents(knowledge_base_name=knowledge_base_name, images=images)
75
+ print("add pdf ok.")
76
+ return
77
+
78
+
79
+ if __name__ == "__main__":
80
+ from transformers import AutoModel
81
+ from transformers import AutoTokenizer
82
+ from PIL import Image
83
+ import torch
84
+
85
+ device = 'cuda:0'
86
+
87
+ # Load model, be sure to substitute `model_path` by your model path
88
+ model_path = '/home/jeeves/xubokai/minicpm-visual-embedding-v0'
89
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
90
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
91
+ model.to(device)
92
+
93
+ pdf_path = "/home/jeeves/xubokai/minicpm-visual-embedding-v0/2406.07422v1.pdf"
94
+ retriever = PDFVisualRetrieval(model=model, tokenizer=tokenizer)
95
+ retriever.add_pdf('test', pdf_path)
96
+
97
+ topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='what is the number of VQ of this kind of codec method?', topk=1)
98
+ # 2
99
+ topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the training loss curve of this paper?', topk=1)
100
+ # 3
101
+ topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the experiment table?', topk=1)
102
+ # 2
103
+