atrytone commited on
Commit
b02d896
·
1 Parent(s): 530e694

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -2
app.py CHANGED
@@ -11,12 +11,17 @@ def get_matches2(query):
11
  matches = vecdb2.similarity_search_with_score(query, k=60)
12
  return matches
13
 
 
 
 
14
 
15
  def inference(query,method=1):
16
  if method==1:
17
  matches = get_matches1(query)
18
- else:
19
  matches = get_matches2(query)
 
 
20
  auth_counts = {}
21
  j_bucket = {}
22
  n_table = []
@@ -88,8 +93,12 @@ def inference1(query):
88
  def inference2(query):
89
  return inference(query,2)
90
 
 
 
 
91
  model1_name = "biodatlab/MIReAD-Neuro-Large"
92
  model2_name = "biodatlab/MIReAD-Neuro-Contrastive"
 
93
  model_kwargs = {'device': 'cpu'}
94
  encode_kwargs = {'normalize_embeddings': False}
95
  faiss_embedder1 = HuggingFaceEmbeddings(
@@ -102,9 +111,15 @@ faiss_embedder2 = HuggingFaceEmbeddings(
102
  model_kwargs=model_kwargs,
103
  encode_kwargs=encode_kwargs
104
  )
 
 
 
 
 
105
 
106
  vecdb1 = FAISS.load_local("nbdt_index", faiss_embedder1)
107
  vecdb2 = FAISS.load_local("indexes", faiss_embedder2)
 
108
 
109
 
110
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
@@ -118,7 +133,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
118
  abst = gr.Textbox(label="Abstract", lines=10)
119
 
120
  action_btn = gr.Button(value="Find Matches with Normal Model")
121
- action2_btn = gr.Button(value="Find Matches with Contrastive Model")
 
122
 
123
  with gr.Tab("Authors"):
124
  n_output = gr.Dataframe(
@@ -159,5 +175,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
159
  ],
160
  outputs=[a_output, j_output, n_output],
161
  api_name="neurojane")
 
 
 
 
 
 
162
 
163
  demo.launch(debug=True)
 
11
  matches = vecdb2.similarity_search_with_score(query, k=60)
12
  return matches
13
 
14
+ def get_matches3(query):
15
+ matches = vecdb3.similarity_search_with_score(query, k=60)
16
+
17
 
18
  def inference(query,method=1):
19
  if method==1:
20
  matches = get_matches1(query)
21
+ elif method==2:
22
  matches = get_matches2(query)
23
+ else:
24
+ matches = get_matches3(query)
25
  auth_counts = {}
26
  j_bucket = {}
27
  n_table = []
 
93
  def inference2(query):
94
  return inference(query,2)
95
 
96
+ def inference3(query):
97
+ return inference(query,3)
98
+
99
  model1_name = "biodatlab/MIReAD-Neuro-Large"
100
  model2_name = "biodatlab/MIReAD-Neuro-Contrastive"
101
+ model3_name = "biodatlab/SciBERT-Neuro-Contrastive"
102
  model_kwargs = {'device': 'cpu'}
103
  encode_kwargs = {'normalize_embeddings': False}
104
  faiss_embedder1 = HuggingFaceEmbeddings(
 
111
  model_kwargs=model_kwargs,
112
  encode_kwargs=encode_kwargs
113
  )
114
+ faiss_embedder3 = HuggingFaceEmbeddings(
115
+ model_name=model3_name,
116
+ model_kwargs=model_kwargs,
117
+ encode_kwargs=encode_kwargs
118
+ )
119
 
120
  vecdb1 = FAISS.load_local("nbdt_index", faiss_embedder1)
121
  vecdb2 = FAISS.load_local("indexes", faiss_embedder2)
122
+ vecdb3 = FAISS.load_local("indexes/scibert_contr",faiss_embedder3)
123
 
124
 
125
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
133
  abst = gr.Textbox(label="Abstract", lines=10)
134
 
135
  action_btn = gr.Button(value="Find Matches with Normal Model")
136
+ action2_btn = gr.Button(value="Find Matches with MIReAD Contrastive Model")
137
+ action3_btn = gr.Button(value="Find Matches with SciBERT Contrastive Model")
138
 
139
  with gr.Tab("Authors"):
140
  n_output = gr.Dataframe(
 
175
  ],
176
  outputs=[a_output, j_output, n_output],
177
  api_name="neurojane")
178
+ action3_btn.click(fn=inference3,
179
+ inputs=[
180
+ abst,
181
+ ],
182
+ outputs=[a_output, j_output, n_output],
183
+ api_name="neurojane")
184
 
185
  demo.launch(debug=True)