tcy6 commited on
Commit
56e4893
1 Parent(s): 4b231a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -3
app.py CHANGED
@@ -2,6 +2,7 @@ import tqdm
2
  from PIL import Image
3
  import hashlib
4
  import torch
 
5
  import fitz
6
  import threading
7
  import gradio as gr
@@ -18,6 +19,36 @@ import json
18
  cache_dir = '/data/KB'
19
  os.makedirs(cache_dir, exist_ok=True)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def get_image_md5(img: Image.Image):
22
  img_byte_array = img.tobytes()
23
  hash_md5 = hashlib.md5()
@@ -57,8 +88,8 @@ def add_pdf_gradio(pdf_file_binary, progress=gr.Progress()):
57
  image_md5 = get_image_md5(image)
58
  image_md5s.append(image_md5)
59
  with torch.no_grad():
60
- reps = model(text=[''], image=[image], tokenizer=tokenizer).reps
61
- reps_list.append(reps.squeeze(0).cpu().numpy())
62
  images.append(image)
63
 
64
  for idx in range(len(images)):
@@ -95,7 +126,7 @@ def retrieve_gradio(knowledge_base: str, query: str, topk: int):
95
 
96
  query_with_instruction = "Represent this query for retrieving relavant document: " + query
97
  with torch.no_grad():
98
- query_rep = model(text=[query_with_instruction], image=[None], tokenizer=tokenizer).reps.squeeze(0).cpu()
99
 
100
  query_md5 = hashlib.md5(query.encode()).hexdigest()
101
 
 
2
  from PIL import Image
3
  import hashlib
4
  import torch
5
+ import torch.nn.functional as F
6
  import fitz
7
  import threading
8
  import gradio as gr
 
19
  cache_dir = '/data/KB'
20
  os.makedirs(cache_dir, exist_ok=True)
21
 
22
+ def weighted_mean_pooling(hidden, attention_mask):
23
+ attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
24
+ s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
25
+ d = attention_mask_.sum(dim=1, keepdim=True).float()
26
+ reps = s / d
27
+ return reps
28
+
29
+ @torch.no_grad()
30
+ def encode(text_or_image_list):
31
+ global model, tokenizer
32
+ if (isinstance(text_or_image_list[0], str)):
33
+ inputs = {
34
+ "text": text_or_image_list,
35
+ 'image': [None] * len(text_or_image_list),
36
+ 'tokenizer': tokenizer
37
+ }
38
+ else:
39
+ inputs = {
40
+ "text": [''] * len(text_or_image_list),
41
+ 'image': text_or_image_list,
42
+ 'tokenizer': tokenizer
43
+ }
44
+ outputs = model(**inputs)
45
+ attention_mask = outputs.attention_mask
46
+ hidden = outputs.last_hidden_state
47
+
48
+ reps = weighted_mean_pooling(hidden, attention_mask)
49
+ embeddings = F.normalize(reps, p=2, dim=1).detach().cpu().numpy()
50
+ return embeddings
51
+
52
  def get_image_md5(img: Image.Image):
53
  img_byte_array = img.tobytes()
54
  hash_md5 = hashlib.md5()
 
88
  image_md5 = get_image_md5(image)
89
  image_md5s.append(image_md5)
90
  with torch.no_grad():
91
+ reps = encode([image])
92
+ reps_list.append(reps)
93
  images.append(image)
94
 
95
  for idx in range(len(images)):
 
126
 
127
  query_with_instruction = "Represent this query for retrieving relavant document: " + query
128
  with torch.no_grad():
129
+ query_rep = encode([query_with_instruction])
130
 
131
  query_md5 = hashlib.md5(query.encode()).hexdigest()
132