tcy6 commited on
Commit
e82ff0e
1 Parent(s): d8a9d6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -19,6 +19,7 @@ import json
19
  cache_dir = '/data/KB'
20
  os.makedirs(cache_dir, exist_ok=True)
21
 
 
22
  def weighted_mean_pooling(hidden, attention_mask):
23
  attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
24
  s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
@@ -26,6 +27,7 @@ def weighted_mean_pooling(hidden, attention_mask):
26
  reps = s / d
27
  return reps
28
 
 
29
  @torch.no_grad()
30
  def encode(text_or_image_list):
31
  global model, tokenizer
@@ -106,7 +108,7 @@ def add_pdf_gradio(pdf_file_binary, progress=gr.Progress()):
106
 
107
  return knowledge_base_name
108
 
109
- # @spaces.GPU
110
  def retrieve_gradio(knowledge_base: str, query: str, topk: int):
111
  global model, tokenizer
112
 
 
19
  cache_dir = '/data/KB'
20
  os.makedirs(cache_dir, exist_ok=True)
21
 
22
+ @spaces.GPU
23
  def weighted_mean_pooling(hidden, attention_mask):
24
  attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
25
  s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
 
27
  reps = s / d
28
  return reps
29
 
30
+ @spaces.GPU
31
  @torch.no_grad()
32
  def encode(text_or_image_list):
33
  global model, tokenizer
 
108
 
109
  return knowledge_base_name
110
 
111
+ @spaces.GPU
112
  def retrieve_gradio(knowledge_base: str, query: str, topk: int):
113
  global model, tokenizer
114