Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -19,6 +19,7 @@ import json
|
|
19 |
cache_dir = '/data/KB'
|
20 |
os.makedirs(cache_dir, exist_ok=True)
|
21 |
|
|
|
22 |
def weighted_mean_pooling(hidden, attention_mask):
|
23 |
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
|
24 |
s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
|
@@ -26,6 +27,7 @@ def weighted_mean_pooling(hidden, attention_mask):
|
|
26 |
reps = s / d
|
27 |
return reps
|
28 |
|
|
|
29 |
@torch.no_grad()
|
30 |
def encode(text_or_image_list):
|
31 |
global model, tokenizer
|
@@ -106,7 +108,7 @@ def add_pdf_gradio(pdf_file_binary, progress=gr.Progress()):
|
|
106 |
|
107 |
return knowledge_base_name
|
108 |
|
109 |
-
|
110 |
def retrieve_gradio(knowledge_base: str, query: str, topk: int):
|
111 |
global model, tokenizer
|
112 |
|
|
|
19 |
cache_dir = '/data/KB'
|
20 |
os.makedirs(cache_dir, exist_ok=True)
|
21 |
|
22 |
+
@spaces.GPU
|
23 |
def weighted_mean_pooling(hidden, attention_mask):
|
24 |
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
|
25 |
s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
|
|
|
27 |
reps = s / d
|
28 |
return reps
|
29 |
|
30 |
+
@spaces.GPU
|
31 |
@torch.no_grad()
|
32 |
def encode(text_or_image_list):
|
33 |
global model, tokenizer
|
|
|
108 |
|
109 |
return knowledge_base_name
|
110 |
|
111 |
+
@spaces.GPU
|
112 |
def retrieve_gradio(knowledge_base: str, query: str, topk: int):
|
113 |
global model, tokenizer
|
114 |
|