Spaces:
Runtime error
Runtime error
Moved from sentence-transformers to pure transformers
Browse files- core.py +30 -7
- requirements.txt +2 -1
core.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1 |
import streamlit as st
|
2 |
-
from sentence_transformers import SentenceTransformer
|
3 |
from huggingface_hub import HfApi, HfFolder
|
4 |
import datasets
|
5 |
import logging
|
6 |
import os
|
7 |
|
|
|
|
|
|
|
|
|
8 |
|
9 |
@st.cache_data
|
10 |
def login():
|
@@ -12,7 +15,7 @@ def login():
|
|
12 |
logging.info("Trying to log in to HF")
|
13 |
st.session_state['logged'] = True
|
14 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
15 |
-
api=HfApi()
|
16 |
api.set_access_token(HF_TOKEN)
|
17 |
folder = HfFolder()
|
18 |
folder.save_token(HF_TOKEN)
|
@@ -26,16 +29,18 @@ def login():
|
|
26 |
@st.cache_resource
|
27 |
def load_model():
|
28 |
logging.info("Trying to load model")
|
29 |
-
|
|
|
|
|
30 |
logging.info("Model loaded")
|
31 |
-
return model
|
32 |
|
33 |
|
34 |
@st.cache_resource
|
35 |
def load_index():
|
36 |
logging.info("Trying to load index")
|
37 |
index = datasets.load_dataset(
|
38 |
-
"eremeev-d/arxiv-abstracts-small",
|
39 |
use_auth_token=True,
|
40 |
split="train"
|
41 |
)
|
@@ -44,11 +49,29 @@ def load_index():
|
|
44 |
return index
|
45 |
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def get_answers(query):
|
48 |
logging.info("Getting answers for {}".format(query))
|
49 |
-
model = load_model()
|
50 |
index = load_index()
|
51 |
-
query_embedding = model.
|
52 |
scores, answers = index.get_nearest_examples('embedding', query_embedding)
|
53 |
logging.info("Succesfully got answers for {}".format(query))
|
54 |
return answers
|
|
|
1 |
import streamlit as st
|
|
|
2 |
from huggingface_hub import HfApi, HfFolder
|
3 |
import datasets
|
4 |
import logging
|
5 |
import os
|
6 |
|
7 |
+
from transformers import AutoTokenizer, AutoModel
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
|
12 |
@st.cache_data
|
13 |
def login():
|
|
|
15 |
logging.info("Trying to log in to HF")
|
16 |
st.session_state['logged'] = True
|
17 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
18 |
+
api = HfApi()
|
19 |
api.set_access_token(HF_TOKEN)
|
20 |
folder = HfFolder()
|
21 |
folder.save_token(HF_TOKEN)
|
|
|
29 |
@st.cache_resource
|
30 |
def load_model():
|
31 |
logging.info("Trying to load model")
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
33 |
+
'sentence-transformers/all-MiniLM-L6-v2')
|
34 |
+
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
35 |
logging.info("Model loaded")
|
36 |
+
return model, tokenizer
|
37 |
|
38 |
|
39 |
@st.cache_resource
|
40 |
def load_index():
|
41 |
logging.info("Trying to load index")
|
42 |
index = datasets.load_dataset(
|
43 |
+
"eremeev-d/arxiv-abstracts-small",
|
44 |
use_auth_token=True,
|
45 |
split="train"
|
46 |
)
|
|
|
49 |
return index
|
50 |
|
51 |
|
52 |
+
def mean_pooling(model_output, attention_mask):
|
53 |
+
token_embeddings = model_output[0]
|
54 |
+
input_mask_expanded = attention_mask.unsqueeze(-1) \
|
55 |
+
.expand(token_embeddings.size()).float()
|
56 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) \
|
57 |
+
/ torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
58 |
+
|
59 |
+
|
60 |
+
def get_embedding(query, model, tokenizer):
|
61 |
+
encoded_input = tokenizer(
|
62 |
+
query, padding=True, truncation=True, return_tensors='pt')
|
63 |
+
with torch.no_grad():
|
64 |
+
embeds = model(**encoded_input)
|
65 |
+
embeds = mean_pooling(embeds, encoded_input['attention_mask'])
|
66 |
+
embeds = F.normalize(embeds, p=2, dim=1)
|
67 |
+
return embeds
|
68 |
+
|
69 |
+
|
70 |
def get_answers(query):
|
71 |
logging.info("Getting answers for {}".format(query))
|
72 |
+
model, tokenizer = load_model()
|
73 |
index = load_index()
|
74 |
+
query_embedding = get_embedding(query, model, tokenizer).reshape(-1).numpy()
|
75 |
scores, answers = index.get_nearest_examples('embedding', query_embedding)
|
76 |
logging.info("Succesfully got answers for {}".format(query))
|
77 |
return answers
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
faiss-cpu~=1.7.2
|
2 |
sentence-transformers~=2.2.2
|
3 |
datasets~=2.10.1
|
4 |
-
huggingface_hub~=0.10.1
|
|
|
|
1 |
faiss-cpu~=1.7.2
|
2 |
sentence-transformers~=2.2.2
|
3 |
datasets~=2.10.1
|
4 |
+
huggingface_hub~=0.10.1
|
5 |
+
torch
|