ssilwal commited on
Commit
f18e17f
·
1 Parent(s): 6460284

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -181,7 +181,7 @@ if 'slider' not in st.session_state:
181
  st.session_state['slider'] = 0
182
 
183
  if 'radio' not in st.session_state:
184
- st.session_state['radio'] = 'CivileLaw-IR'
185
 
186
  if 'show' not in st.session_state:
187
  st.session_state['show'] = False
@@ -217,18 +217,18 @@ def run_inference(model_name, query):
217
 
218
  # Compute the similarity scores for these combinations
219
 
220
- if model_name=='CivileLaw-IR':
221
  similarity_scores = model_nli.predict(sentence_combinations)
222
  scores = [(score_max[0],idx) for idx,score_max in enumerate(similarity_scores)]
223
  sim_scores_argsort = sorted(scores, key=lambda x: x[0], reverse=True)
224
  results = [pred[idx] for _,idx in list(sim_scores_argsort)[:int(top_K)]]
225
 
226
- if model_name=='STSB-IR':
227
  similarity_scores = model_nli_stsb.predict(sentence_combinations)
228
  sim_scores_argsort = reversed(np.argsort(similarity_scores))
229
  results = [pred[idx] for idx in list(sim_scores_argsort)[:int(top_K)]]
230
 
231
- if model_name=='NoRank-Baseline':
232
  similarity_scores = model_baseline.predict(sentence_combinations)
233
  scores = [(score_max[0],idx) for idx,score_max in enumerate(similarity_scores)]
234
  sim_scores_argsort = sorted(scores, key=lambda x: x[0], reverse=True)
@@ -254,8 +254,8 @@ top_K = st.text_input('Choose Number of Result: ','10')
254
 
255
  model_name = st.radio(
256
  "Choose Model",
257
- ("CivileLaw-IR", "STSB-IR", "NoRank-Baseline"),
258
- key='radio', on_change=callback, args=('radio','CivileLaw-IR')
259
  )
260
 
261
 
@@ -266,7 +266,7 @@ if st.button('Run', key='run'):
266
  st.session_state['show'] = True
267
  st.session_state['results'] = results
268
  st.session_state['query'] = query
269
- model_dict = {'CivileLaw-IR': 'NLI-Syn', 'STSB-IR': 'NLI-stsb', 'NoRank-Baseline': 'NLI-baseline'}
270
  st.session_state['model'] = model_dict[model_name]
271
 
272
 
 
181
  st.session_state['slider'] = 0
182
 
183
  if 'radio' not in st.session_state:
184
+ st.session_state['radio'] = 'Civile-Law-IR'
185
 
186
  if 'show' not in st.session_state:
187
  st.session_state['show'] = False
 
217
 
218
  # Compute the similarity scores for these combinations
219
 
220
+ if model_name=='Civile-Law-IR':
221
  similarity_scores = model_nli.predict(sentence_combinations)
222
  scores = [(score_max[0],idx) for idx,score_max in enumerate(similarity_scores)]
223
  sim_scores_argsort = sorted(scores, key=lambda x: x[0], reverse=True)
224
  results = [pred[idx] for _,idx in list(sim_scores_argsort)[:int(top_K)]]
225
 
226
+ if model_name=='STSB':
227
  similarity_scores = model_nli_stsb.predict(sentence_combinations)
228
  sim_scores_argsort = reversed(np.argsort(similarity_scores))
229
  results = [pred[idx] for idx in list(sim_scores_argsort)[:int(top_K)]]
230
 
231
+ if model_name=='DR-Baseline':
232
  similarity_scores = model_baseline.predict(sentence_combinations)
233
  scores = [(score_max[0],idx) for idx,score_max in enumerate(similarity_scores)]
234
  sim_scores_argsort = sorted(scores, key=lambda x: x[0], reverse=True)
 
254
 
255
  model_name = st.radio(
256
  "Choose Model",
257
+ ("Civile-Law-IR", "STSB", "DR-Baseline"),
258
+ key='radio', on_change=callback, args=('radio','CivileLaw-IR'), help="Civile-Law-IR: trained on civile-NLI-dataset, STSB: trained on STSB french dataset, DR-Baseline: existing nli model trained on ms marco dataset"
259
  )
260
 
261
 
 
266
  st.session_state['show'] = True
267
  st.session_state['results'] = results
268
  st.session_state['query'] = query
269
+ model_dict = {'Civile-Law-IR': 'NLI-Syn', 'STSB': 'NLI-stsb', 'DR-Baseline': 'NLI-baseline'}
270
  st.session_state['model'] = model_dict[model_name]
271
 
272