Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -181,7 +181,7 @@ if 'slider' not in st.session_state:
|
|
181 |
st.session_state['slider'] = 0
|
182 |
|
183 |
if 'radio' not in st.session_state:
|
184 |
-
st.session_state['radio'] = '
|
185 |
|
186 |
if 'show' not in st.session_state:
|
187 |
st.session_state['show'] = False
|
@@ -217,18 +217,18 @@ def run_inference(model_name, query):
|
|
217 |
|
218 |
# Compute the similarity scores for these combinations
|
219 |
|
220 |
-
if model_name=='
|
221 |
similarity_scores = model_nli.predict(sentence_combinations)
|
222 |
scores = [(score_max[0],idx) for idx,score_max in enumerate(similarity_scores)]
|
223 |
sim_scores_argsort = sorted(scores, key=lambda x: x[0], reverse=True)
|
224 |
results = [pred[idx] for _,idx in list(sim_scores_argsort)[:int(top_K)]]
|
225 |
|
226 |
-
if model_name=='STSB
|
227 |
similarity_scores = model_nli_stsb.predict(sentence_combinations)
|
228 |
sim_scores_argsort = reversed(np.argsort(similarity_scores))
|
229 |
results = [pred[idx] for idx in list(sim_scores_argsort)[:int(top_K)]]
|
230 |
|
231 |
-
if model_name=='
|
232 |
similarity_scores = model_baseline.predict(sentence_combinations)
|
233 |
scores = [(score_max[0],idx) for idx,score_max in enumerate(similarity_scores)]
|
234 |
sim_scores_argsort = sorted(scores, key=lambda x: x[0], reverse=True)
|
@@ -254,8 +254,8 @@ top_K = st.text_input('Choose Number of Result: ','10')
|
|
254 |
|
255 |
model_name = st.radio(
|
256 |
"Choose Model",
|
257 |
-
("
|
258 |
-
key='radio', on_change=callback, args=('radio','CivileLaw-IR')
|
259 |
)
|
260 |
|
261 |
|
@@ -266,7 +266,7 @@ if st.button('Run', key='run'):
|
|
266 |
st.session_state['show'] = True
|
267 |
st.session_state['results'] = results
|
268 |
st.session_state['query'] = query
|
269 |
-
model_dict = {'
|
270 |
st.session_state['model'] = model_dict[model_name]
|
271 |
|
272 |
|
|
|
181 |
st.session_state['slider'] = 0
|
182 |
|
183 |
if 'radio' not in st.session_state:
|
184 |
+
st.session_state['radio'] = 'Civile-Law-IR'
|
185 |
|
186 |
if 'show' not in st.session_state:
|
187 |
st.session_state['show'] = False
|
|
|
217 |
|
218 |
# Compute the similarity scores for these combinations
|
219 |
|
220 |
+
if model_name=='Civile-Law-IR':
|
221 |
similarity_scores = model_nli.predict(sentence_combinations)
|
222 |
scores = [(score_max[0],idx) for idx,score_max in enumerate(similarity_scores)]
|
223 |
sim_scores_argsort = sorted(scores, key=lambda x: x[0], reverse=True)
|
224 |
results = [pred[idx] for _,idx in list(sim_scores_argsort)[:int(top_K)]]
|
225 |
|
226 |
+
if model_name=='STSB':
|
227 |
similarity_scores = model_nli_stsb.predict(sentence_combinations)
|
228 |
sim_scores_argsort = reversed(np.argsort(similarity_scores))
|
229 |
results = [pred[idx] for idx in list(sim_scores_argsort)[:int(top_K)]]
|
230 |
|
231 |
+
if model_name=='DR-Baseline':
|
232 |
similarity_scores = model_baseline.predict(sentence_combinations)
|
233 |
scores = [(score_max[0],idx) for idx,score_max in enumerate(similarity_scores)]
|
234 |
sim_scores_argsort = sorted(scores, key=lambda x: x[0], reverse=True)
|
|
|
254 |
|
255 |
model_name = st.radio(
|
256 |
"Choose Model",
|
257 |
+
("Civile-Law-IR", "STSB", "DR-Baseline"),
|
258 |
+
key='radio', on_change=callback, args=('radio','CivileLaw-IR'), help="Civile-Law-IR: trained on civile-NLI-dataset, STSB: trained on STSB french dataset, DR-Baseline: existing nli model trained on ms marco dataset"
|
259 |
)
|
260 |
|
261 |
|
|
|
266 |
st.session_state['show'] = True
|
267 |
st.session_state['results'] = results
|
268 |
st.session_state['query'] = query
|
269 |
+
model_dict = {'Civile-Law-IR': 'NLI-Syn', 'STSB': 'NLI-stsb', 'DR-Baseline': 'NLI-baseline'}
|
270 |
st.session_state['model'] = model_dict[model_name]
|
271 |
|
272 |
|