eremeev-d commited on
Commit
b09ffa7
·
1 Parent(s): d103a97

Moved from sentence-transformers to pure transformers

Browse files
Files changed (2) hide show
  1. core.py +30 -7
  2. requirements.txt +2 -1
core.py CHANGED
@@ -1,10 +1,13 @@
1
  import streamlit as st
2
- from sentence_transformers import SentenceTransformer
3
  from huggingface_hub import HfApi, HfFolder
4
  import datasets
5
  import logging
6
  import os
7
 
 
 
 
 
8
 
9
  @st.cache_data
10
  def login():
@@ -12,7 +15,7 @@ def login():
12
  logging.info("Trying to log in to HF")
13
  st.session_state['logged'] = True
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
15
- api=HfApi()
16
  api.set_access_token(HF_TOKEN)
17
  folder = HfFolder()
18
  folder.save_token(HF_TOKEN)
@@ -26,16 +29,18 @@ def login():
26
  @st.cache_resource
27
  def load_model():
28
  logging.info("Trying to load model")
29
- model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2')
 
 
30
  logging.info("Model loaded")
31
- return model
32
 
33
 
34
  @st.cache_resource
35
  def load_index():
36
  logging.info("Trying to load index")
37
  index = datasets.load_dataset(
38
- "eremeev-d/arxiv-abstracts-small",
39
  use_auth_token=True,
40
  split="train"
41
  )
@@ -44,11 +49,29 @@ def load_index():
44
  return index
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def get_answers(query):
48
  logging.info("Getting answers for {}".format(query))
49
- model = load_model()
50
  index = load_index()
51
- query_embedding = model.encode(query)
52
  scores, answers = index.get_nearest_examples('embedding', query_embedding)
53
  logging.info("Succesfully got answers for {}".format(query))
54
  return answers
 
1
  import streamlit as st
 
2
  from huggingface_hub import HfApi, HfFolder
3
  import datasets
4
  import logging
5
  import os
6
 
7
+ from transformers import AutoTokenizer, AutoModel
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
 
12
  @st.cache_data
13
  def login():
 
15
  logging.info("Trying to log in to HF")
16
  st.session_state['logged'] = True
17
  HF_TOKEN = os.environ.get("HF_TOKEN")
18
+ api = HfApi()
19
  api.set_access_token(HF_TOKEN)
20
  folder = HfFolder()
21
  folder.save_token(HF_TOKEN)
 
29
  @st.cache_resource
30
  def load_model():
31
  logging.info("Trying to load model")
32
+ tokenizer = AutoTokenizer.from_pretrained(
33
+ 'sentence-transformers/all-MiniLM-L6-v2')
34
+ model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
35
  logging.info("Model loaded")
36
+ return model, tokenizer
37
 
38
 
39
  @st.cache_resource
40
  def load_index():
41
  logging.info("Trying to load index")
42
  index = datasets.load_dataset(
43
+ "eremeev-d/arxiv-abstracts-small",
44
  use_auth_token=True,
45
  split="train"
46
  )
 
49
  return index
50
 
51
 
52
+ def mean_pooling(model_output, attention_mask):
53
+ token_embeddings = model_output[0]
54
+ input_mask_expanded = attention_mask.unsqueeze(-1) \
55
+ .expand(token_embeddings.size()).float()
56
+ return torch.sum(token_embeddings * input_mask_expanded, 1) \
57
+ / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
58
+
59
+
60
+ def get_embedding(query, model, tokenizer):
61
+ encoded_input = tokenizer(
62
+ query, padding=True, truncation=True, return_tensors='pt')
63
+ with torch.no_grad():
64
+ embeds = model(**encoded_input)
65
+ embeds = mean_pooling(embeds, encoded_input['attention_mask'])
66
+ embeds = F.normalize(embeds, p=2, dim=1)
67
+ return embeds
68
+
69
+
70
  def get_answers(query):
71
  logging.info("Getting answers for {}".format(query))
72
+ model, tokenizer = load_model()
73
  index = load_index()
74
+ query_embedding = get_embedding(query, model, tokenizer).reshape(-1).numpy()
75
  scores, answers = index.get_nearest_examples('embedding', query_embedding)
76
  logging.info("Succesfully got answers for {}".format(query))
77
  return answers
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  faiss-cpu~=1.7.2
2
  sentence-transformers~=2.2.2
3
  datasets~=2.10.1
4
- huggingface_hub~=0.10.1
 
 
1
  faiss-cpu~=1.7.2
2
  sentence-transformers~=2.2.2
3
  datasets~=2.10.1
4
+ huggingface_hub~=0.10.1
5
+ torch