Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -11,12 +11,17 @@ def get_matches2(query):
|
|
11 |
matches = vecdb2.similarity_search_with_score(query, k=60)
|
12 |
return matches
|
13 |
|
|
|
|
|
|
|
14 |
|
15 |
def inference(query,method=1):
|
16 |
if method==1:
|
17 |
matches = get_matches1(query)
|
18 |
-
|
19 |
matches = get_matches2(query)
|
|
|
|
|
20 |
auth_counts = {}
|
21 |
j_bucket = {}
|
22 |
n_table = []
|
@@ -88,8 +93,12 @@ def inference1(query):
|
|
88 |
def inference2(query):
|
89 |
return inference(query,2)
|
90 |
|
|
|
|
|
|
|
91 |
model1_name = "biodatlab/MIReAD-Neuro-Large"
|
92 |
model2_name = "biodatlab/MIReAD-Neuro-Contrastive"
|
|
|
93 |
model_kwargs = {'device': 'cpu'}
|
94 |
encode_kwargs = {'normalize_embeddings': False}
|
95 |
faiss_embedder1 = HuggingFaceEmbeddings(
|
@@ -102,9 +111,15 @@ faiss_embedder2 = HuggingFaceEmbeddings(
|
|
102 |
model_kwargs=model_kwargs,
|
103 |
encode_kwargs=encode_kwargs
|
104 |
)
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
vecdb1 = FAISS.load_local("nbdt_index", faiss_embedder1)
|
107 |
vecdb2 = FAISS.load_local("indexes", faiss_embedder2)
|
|
|
108 |
|
109 |
|
110 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
@@ -118,7 +133,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
118 |
abst = gr.Textbox(label="Abstract", lines=10)
|
119 |
|
120 |
action_btn = gr.Button(value="Find Matches with Normal Model")
|
121 |
-
action2_btn = gr.Button(value="Find Matches with Contrastive Model")
|
|
|
122 |
|
123 |
with gr.Tab("Authors"):
|
124 |
n_output = gr.Dataframe(
|
@@ -159,5 +175,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
159 |
],
|
160 |
outputs=[a_output, j_output, n_output],
|
161 |
api_name="neurojane")
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
demo.launch(debug=True)
|
|
|
11 |
matches = vecdb2.similarity_search_with_score(query, k=60)
|
12 |
return matches
|
13 |
|
14 |
+
def get_matches3(query):
|
15 |
+
matches = vecdb3.similarity_search_with_score(query, k=60)
|
16 |
+
|
17 |
|
18 |
def inference(query,method=1):
|
19 |
if method==1:
|
20 |
matches = get_matches1(query)
|
21 |
+
elif method==2:
|
22 |
matches = get_matches2(query)
|
23 |
+
else:
|
24 |
+
matches = get_matches3(query)
|
25 |
auth_counts = {}
|
26 |
j_bucket = {}
|
27 |
n_table = []
|
|
|
93 |
def inference2(query):
|
94 |
return inference(query,2)
|
95 |
|
96 |
+
def inference3(query):
|
97 |
+
return inference(query,3)
|
98 |
+
|
99 |
model1_name = "biodatlab/MIReAD-Neuro-Large"
|
100 |
model2_name = "biodatlab/MIReAD-Neuro-Contrastive"
|
101 |
+
model3_name = "biodatlab/SciBERT-Neuro-Contrastive"
|
102 |
model_kwargs = {'device': 'cpu'}
|
103 |
encode_kwargs = {'normalize_embeddings': False}
|
104 |
faiss_embedder1 = HuggingFaceEmbeddings(
|
|
|
111 |
model_kwargs=model_kwargs,
|
112 |
encode_kwargs=encode_kwargs
|
113 |
)
|
114 |
+
faiss_embedder3 = HuggingFaceEmbeddings(
|
115 |
+
model_name=model3_name,
|
116 |
+
model_kwargs=model_kwargs,
|
117 |
+
encode_kwargs=encode_kwargs
|
118 |
+
)
|
119 |
|
120 |
vecdb1 = FAISS.load_local("nbdt_index", faiss_embedder1)
|
121 |
vecdb2 = FAISS.load_local("indexes", faiss_embedder2)
|
122 |
+
vecdb3 = FAISS.load_local("indexes/scibert_contr",faiss_embedder3)
|
123 |
|
124 |
|
125 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
|
133 |
abst = gr.Textbox(label="Abstract", lines=10)
|
134 |
|
135 |
action_btn = gr.Button(value="Find Matches with Normal Model")
|
136 |
+
action2_btn = gr.Button(value="Find Matches with MIReAD Contrastive Model")
|
137 |
+
action3_btn = gr.Button(value="Find Matches with SciBERT Contrastive Model")
|
138 |
|
139 |
with gr.Tab("Authors"):
|
140 |
n_output = gr.Dataframe(
|
|
|
175 |
],
|
176 |
outputs=[a_output, j_output, n_output],
|
177 |
api_name="neurojane")
|
178 |
+
action3_btn.click(fn=inference3,
|
179 |
+
inputs=[
|
180 |
+
abst,
|
181 |
+
],
|
182 |
+
outputs=[a_output, j_output, n_output],
|
183 |
+
api_name="neurojane")
|
184 |
|
185 |
demo.launch(debug=True)
|