ThapeloAndrewSindane commited on
Commit
9e87bd5
1 Parent(s): 730820d

Adding All

Browse files

Adding all models prediction

Files changed (1) hide show
  1. app.py +103 -33
app.py CHANGED
@@ -192,6 +192,41 @@ def plot(label, prob):
192
  ax.set_xlabel("Confidence", color=BLACK_COLOR)
193
  st.pyplot(fig)
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  def compute(sentences, version = 'v3'):
196
  """Computes the language probablities and labels for the given sentences.
197
 
@@ -213,8 +248,10 @@ def compute(sentences, version = 'v3'):
213
  model_choice = model_afroxlmr_base
214
  elif version=='afrolm':
215
  model_choice = model_afrolm
216
- else:
217
  model_choice = za_lid
 
 
218
 
219
  my_bar = st.progress(0, text=progress_text)
220
 
@@ -224,33 +261,63 @@ def compute(sentences, version = 'v3'):
224
  sentences = [preprocess_text(sent) for sent in sentences]
225
 
226
  for index, sent in enumerate(sentences):
227
-
228
- output = model_choice.predict(sent)
229
- output_label = output[index]['label']
230
- output_prob = output[index]['score']
231
- output_label_language = output[index]['label']
232
- # output_label = output[0][0].split('__')[-1].replace('_Hans', '_Hani').replace('_Hant', '_Hani')
233
- # output_prob = max(min(output[1][0], 1), 0)
234
- # output_label_language = output_label.split('_')[0]
235
-
236
- # # script control
237
- # if version in ['v3', 'v2', 'openlid-201', 'nllb-218'] and output_label_language!= 'zxx':
238
- # main_script, all_scripts = get_script(sent)
239
- # output_label_script = output_label.split('_')[1]
240
-
241
- # if output_label_script not in all_scripts:
242
- # output_label_script = main_script
243
- # output_label = f"und_{output_label_script}"
244
- # output_prob = 0
245
-
246
-
247
- labels = labels + [output_label]
248
- probs = probs + [output_prob]
249
-
250
- my_bar.progress(
251
- min((index) / len(sentences), 1),
252
- text=progress_text,
253
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  my_bar.empty()
255
  return probs, labels
256
 
@@ -276,8 +343,8 @@ with tab1:
276
 
277
  version = st.radio(
278
  "Choose model",
279
- ["xlmrlarge", "serengeti", "afriberta", "afroxlmrbase", "afrolm", "BERT"],
280
- captions=["za-XLMR-Large", "za-Serengeti", "za-AfriBERTa", "za-Afro-XLMR-BASE", "za-AfroLM", "za-BERT"],
281
  index = 4,
282
  key = 'version_tab1',
283
  horizontal = True
@@ -308,15 +375,18 @@ with tab1:
308
  f.write(f"{sent}, {label}: {prob}\n")
309
 
310
  # plot
311
- plot(label, prob)
 
 
 
312
 
313
 
314
  with tab2:
315
 
316
  version = st.radio(
317
  "Choose model",
318
- ["xlmrlarge", "serengeti", "afriberta", "afroxlmrbase", "afrolm", "BERT"],
319
- captions=["za-XLMR-Large", "za-Serengeti", "za-AfriBERTa", "za-Afro-XLMR-BASE", "za-AfroLM", "za-BERT"],
320
  index = 4,
321
  key = 'version_tab2',
322
  horizontal = True
 
192
  ax.set_xlabel("Confidence", color=BLACK_COLOR)
193
  st.pyplot(fig)
194
 
195
+ # @st.cache_resource
196
+ def plot_multiples(models, labels, probs):
197
+ ORANGE_COLOR = "#FF8000"
198
+ BLACK_COLOR = "#31333F"
199
+
200
+ fig, ax = plt.subplots(figsize=(8, len(models)))
201
+ fig.patch.set_facecolor("none")
202
+ ax.set_facecolor("none")
203
+
204
+ ax.spines["left"].set_color(BLACK_COLOR)
205
+ ax.spines["bottom"].set_color(BLACK_COLOR)
206
+ ax.tick_params(axis="x", colors=BLACK_COLOR)
207
+
208
+ ax.spines[["right", "top"]].set_visible(False)
209
+
210
+ # Plot bars for each model, label, and probability
211
+ y_positions = range(len(models)) # Y positions for each model
212
+
213
+ ax.barh(y=y_positions, width=probs, color=ORANGE_COLOR)
214
+
215
+ # Add labels next to each bar
216
+ for i, (prob, label) in enumerate(zip(probs, labels)):
217
+ ax.text(prob + 0.01, i, f"{label} ({prob:.2f})", va='center', color=BLACK_COLOR)
218
+
219
+ # Set y-ticks and labels
220
+ ax.set_yticks(y_positions)
221
+ ax.set_yticklabels(models, color=BLACK_COLOR)
222
+
223
+ ax.set_xlim(0, 1)
224
+ ax.set_xlabel("Confidence", color=BLACK_COLOR)
225
+ ax.set_title("Model Predictions", color=BLACK_COLOR)
226
+
227
+ st.pyplot(fig)
228
+
229
+
230
  def compute(sentences, version = 'v3'):
231
  """Computes the language probablities and labels for the given sentences.
232
 
 
248
  model_choice = model_afroxlmr_base
249
  elif version=='afrolm':
250
  model_choice = model_afrolm
251
+ elif version == 'BERT':
252
  model_choice = za_lid
253
+ else:
254
+ model_choice = [model_xlmr_large,model_serengeti, model_afriberta, model_afroxlmr_base, model_afrolm, za_lid]
255
 
256
  my_bar = st.progress(0, text=progress_text)
257
 
 
261
  sentences = [preprocess_text(sent) for sent in sentences]
262
 
263
  for index, sent in enumerate(sentences):
264
+ if type(model_choice) == list:
265
+ all_models_pred = []
266
+ for model in model_choise:
267
+ output = model.predict(sent)
268
+ output_label = output[index]['label']
269
+ output_prob = output[index]['score']
270
+ output_label_language = output[index]['label']
271
+ # output_label = output[0][0].split('__')[-1].replace('_Hans', '_Hani').replace('_Hant', '_Hani')
272
+ # output_prob = max(min(output[1][0], 1), 0)
273
+ # output_label_language = output_label.split('_')[0]
274
+
275
+ # # script control
276
+ # if version in ['v3', 'v2', 'openlid-201', 'nllb-218'] and output_label_language!= 'zxx':
277
+ # main_script, all_scripts = get_script(sent)
278
+ # output_label_script = output_label.split('_')[1]
279
+
280
+ # if output_label_script not in all_scripts:
281
+ # output_label_script = main_script
282
+ # output_label = f"und_{output_label_script}"
283
+ # output_prob = 0
284
+
285
+
286
+ labels = labels + [output_label]
287
+ probs = probs + [output_prob]
288
+
289
+ my_bar.progress(
290
+ min((index) / len(sentences), 1),
291
+ text=progress_text,
292
+ )
293
+
294
+ else:
295
+ output = model_choice.predict(sent)
296
+ output_label = output[index]['label']
297
+ output_prob = output[index]['score']
298
+ output_label_language = output[index]['label']
299
+ # output_label = output[0][0].split('__')[-1].replace('_Hans', '_Hani').replace('_Hant', '_Hani')
300
+ # output_prob = max(min(output[1][0], 1), 0)
301
+ # output_label_language = output_label.split('_')[0]
302
+
303
+ # # script control
304
+ # if version in ['v3', 'v2', 'openlid-201', 'nllb-218'] and output_label_language!= 'zxx':
305
+ # main_script, all_scripts = get_script(sent)
306
+ # output_label_script = output_label.split('_')[1]
307
+
308
+ # if output_label_script not in all_scripts:
309
+ # output_label_script = main_script
310
+ # output_label = f"und_{output_label_script}"
311
+ # output_prob = 0
312
+
313
+
314
+ labels = labels + [output_label]
315
+ probs = probs + [output_prob]
316
+
317
+ my_bar.progress(
318
+ min((index) / len(sentences), 1),
319
+ text=progress_text,
320
+ )
321
  my_bar.empty()
322
  return probs, labels
323
 
 
343
 
344
  version = st.radio(
345
  "Choose model",
346
+ ["xlmrlarge", "serengeti", "afriberta", "afroxlmrbase", "afrolm", "BERT", "All-Models"],
347
+ captions=["za-XLMR-Large", "za-Serengeti", "za-AfriBERTa", "za-Afro-XLMR-BASE", "za-AfroLM", "za-BERT", 'All-Models'],
348
  index = 4,
349
  key = 'version_tab1',
350
  horizontal = True
 
375
  f.write(f"{sent}, {label}: {prob}\n")
376
 
377
  # plot
378
+ if version == "All-Models":
379
+ plot_multiples(["xlmrlarge", "serengeti", "afriberta", "afroxlmrbase", "afrolm", "BERT"], labels, probs)
380
+ else:
381
+ plot(label, prob)
382
 
383
 
384
  with tab2:
385
 
386
  version = st.radio(
387
  "Choose model",
388
+ ["xlmrlarge", "serengeti", "afriberta", "afroxlmrbase", "afrolm", "BERT", "All-Models"],
389
+ captions=["za-XLMR-Large", "za-Serengeti", "za-AfriBERTa", "za-Afro-XLMR-BASE", "za-AfroLM", "za-BERT", "All-Models"],
390
  index = 4,
391
  key = 'version_tab2',
392
  horizontal = True