atrytone commited on
Commit
c5239cd
·
1 Parent(s): 8e41ad4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -28
app.py CHANGED
@@ -3,32 +3,20 @@ from langchain.vectorstores import FAISS
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
  import torch
5
 
 
 
 
6
 
7
- def create_miread_embed(sents, bundle):
8
- tokenizer = bundle[0]
9
- model = bundle[1]
10
- model.cpu()
11
- tokens = tokenizer(sents,
12
- max_length=512,
13
- padding=True,
14
- truncation=True,
15
- return_tensors="pt"
16
- )
17
- device = torch.device('cpu')
18
- tokens = tokens.to(device)
19
- with torch.no_grad():
20
- out = model.bert(**tokens)
21
- feature = out.last_hidden_state[:, 0, :]
22
- return feature.cpu()
23
-
24
-
25
- def get_matches(query):
26
- matches = vecdb.similarity_search_with_score(query, k=60)
27
  return matches
28
 
29
 
30
- def inference(query):
31
- matches = get_matches(query)
 
 
 
32
  auth_counts = {}
33
  j_bucket = {}
34
  n_table = []
@@ -94,17 +82,29 @@ def inference(query):
94
 
95
  return [a_output, j_output, n_output]
96
 
 
 
 
 
 
97
 
98
- model_name = "biodatlab/MIReAD-Neuro-Large"
 
99
  model_kwargs = {'device': 'cpu'}
100
  encode_kwargs = {'normalize_embeddings': False}
101
- faiss_embedder = HuggingFaceEmbeddings(
102
- model_name=model_name,
 
 
 
 
 
103
  model_kwargs=model_kwargs,
104
  encode_kwargs=encode_kwargs
105
  )
106
 
107
- vecdb = FAISS.load_local("nbdt_index", faiss_embedder)
 
108
 
109
 
110
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
@@ -117,7 +117,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
117
 
118
  abst = gr.Textbox(label="Abstract", lines=10)
119
 
120
- action_btn = gr.Button(value="Find Matches")
 
121
 
122
  with gr.Tab("Authors"):
123
  n_output = gr.Dataframe(
@@ -146,7 +147,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
146
  visible=False
147
  )
148
 
149
- action_btn.click(fn=inference,
 
 
 
 
 
 
150
  inputs=[
151
  abst,
152
  ],
 
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
  import torch
5
 
6
+ def get_matches1(query):
7
+ matches = vecdb1.similarity_search_with_score(query, k=60)
8
+ return matches
9
 
10
+ def get_matches2(query):
11
+ matches = vecdb2.similarity_search_with_score(query, k=60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  return matches
13
 
14
 
15
+ def inference(query,method=1):
16
+ if method==1:
17
+ matches = get_matches1(query)
18
+ else:
19
+ matches = get_matches2(query)
20
  auth_counts = {}
21
  j_bucket = {}
22
  n_table = []
 
82
 
83
  return [a_output, j_output, n_output]
84
 
85
+ def inference1(query):
86
+ return inference(query,1)
87
+
88
+ def inference2(query):
89
+ return inference(query,2)
90
 
91
+ model1_name = "biodatlab/MIReAD-Neuro-Large"
92
+ model2_name = "biodatlab/MIReAD-Neuro-Contrastive"
93
  model_kwargs = {'device': 'cpu'}
94
  encode_kwargs = {'normalize_embeddings': False}
95
+ faiss_embedder1 = HuggingFaceEmbeddings(
96
+ model_name=model1_name,
97
+ model_kwargs=model_kwargs,
98
+ encode_kwargs=encode_kwargs
99
+ )
100
+ faiss_embedder2 = HuggingFaceEmbeddings(
101
+ model_name=model2_name,
102
  model_kwargs=model_kwargs,
103
  encode_kwargs=encode_kwargs
104
  )
105
 
106
+ vecdb1 = FAISS.load_local("nbdt_index", faiss_embedder1)
107
+ vecdb2 = FAISS.load_local("indexes", faiss_embedder2)
108
 
109
 
110
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
117
 
118
  abst = gr.Textbox(label="Abstract", lines=10)
119
 
120
+ action_btn = gr.Button(value="Find Matches with Normal Model")
121
+ action2_btn = gr.Button(value="Find Matches with Contrastive Model")
122
 
123
  with gr.Tab("Authors"):
124
  n_output = gr.Dataframe(
 
147
  visible=False
148
  )
149
 
150
+ action_btn.click(fn=inference1,
151
+ inputs=[
152
+ abst,
153
+ ],
154
+ outputs=[a_output, j_output, n_output],
155
+ api_name="neurojane")
156
+ action2_btn.click(fn=inference2,
157
  inputs=[
158
  abst,
159
  ],