ThapeloAndrewSindane commited on
Commit
edd31b5
1 Parent(s): 9da6542

GlotLID and OpenLID

Browse files

Adding GlotLID and OpenLID

Files changed (1) hide show
  1. app.py +75 -48
app.py CHANGED
@@ -168,6 +168,9 @@ model_afriberta = load_model_pipeline('dsfsi/za-afriberta-lid', "model.bin")
168
  model_afroxlmr_base = load_model_pipeline('dsfsi/za-afro-xlmr-base-lid', "model.bin")
169
  model_afrolm = load_model_pipeline('dsfsi/za-afrolm-lid', "model.bin")
170
  za_lid = load_model_pipeline('dsfsi/za-lid-bert', "model.bin")
 
 
 
171
 
172
  # @st.cache_resource
173
  def plot(label, prob):
@@ -250,8 +253,12 @@ def compute(sentences, version = 'v3'):
250
  model_choice = model_afrolm
251
  elif version == 'BERT':
252
  model_choice = za_lid
 
 
 
 
253
  else:
254
- model_choice = [model_xlmr_large,model_serengeti, model_afriberta, model_afroxlmr_base, model_afrolm, za_lid]
255
 
256
  my_bar = st.progress(0, text=progress_text)
257
 
@@ -265,22 +272,70 @@ def compute(sentences, version = 'v3'):
265
  all_models_pred = []
266
  for model in model_choice:
267
  output = model.predict(sent)
268
- output_label = output[index]['label']
269
- output_prob = output[index]['score']
270
- output_label_language = output[index]['label']
271
- # output_label = output[0][0].split('__')[-1].replace('_Hans', '_Hani').replace('_Hant', '_Hani')
272
- # output_prob = max(min(output[1][0], 1), 0)
273
- # output_label_language = output_label.split('_')[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
- # # script control
276
- # if version in ['v3', 'v2', 'openlid-201', 'nllb-218'] and output_label_language!= 'zxx':
277
- # main_script, all_scripts = get_script(sent)
278
- # output_label_script = output_label.split('_')[1]
279
 
280
- # if output_label_script not in all_scripts:
281
- # output_label_script = main_script
282
- # output_label = f"und_{output_label_script}"
283
- # output_prob = 0
284
 
285
 
286
  labels = labels + [output_label]
@@ -289,34 +344,6 @@ def compute(sentences, version = 'v3'):
289
  my_bar.progress(
290
  min((index) / len(sentences), 1),
291
  text=progress_text,
292
- )
293
-
294
- else:
295
- output = model_choice.predict(sent)
296
- output_label = output[index]['label']
297
- output_prob = output[index]['score']
298
- output_label_language = output[index]['label']
299
- # output_label = output[0][0].split('__')[-1].replace('_Hans', '_Hani').replace('_Hant', '_Hani')
300
- # output_prob = max(min(output[1][0], 1), 0)
301
- # output_label_language = output_label.split('_')[0]
302
-
303
- # # script control
304
- # if version in ['v3', 'v2', 'openlid-201', 'nllb-218'] and output_label_language!= 'zxx':
305
- # main_script, all_scripts = get_script(sent)
306
- # output_label_script = output_label.split('_')[1]
307
-
308
- # if output_label_script not in all_scripts:
309
- # output_label_script = main_script
310
- # output_label = f"und_{output_label_script}"
311
- # output_prob = 0
312
-
313
-
314
- labels = labels + [output_label]
315
- probs = probs + [output_prob]
316
-
317
- my_bar.progress(
318
- min((index) / len(sentences), 1),
319
- text=progress_text,
320
  )
321
  my_bar.empty()
322
  return probs, labels
@@ -343,8 +370,8 @@ with tab1:
343
 
344
  version = st.radio(
345
  "Choose model",
346
- ["xlmrlarge", "serengeti", "afriberta", "afroxlmrbase", "afrolm", "BERT", "All-Models"],
347
- captions=["za-XLMR-Large", "za-Serengeti", "za-AfriBERTa", "za-Afro-XLMR-BASE", "za-AfroLM", "za-BERT", 'All-Models'],
348
  index = 4,
349
  key = 'version_tab1',
350
  horizontal = True
@@ -376,7 +403,7 @@ with tab1:
376
 
377
  # plot
378
  if version == "All-Models":
379
- plot_multiples(["xlmrlarge", "serengeti", "afriberta", "afroxlmrbase", "afrolm", "BERT"], labels, probs)
380
  else:
381
  plot(label, prob)
382
 
@@ -385,8 +412,8 @@ with tab2:
385
 
386
  version = st.radio(
387
  "Choose model",
388
- ["xlmrlarge", "serengeti", "afriberta", "afroxlmrbase", "afrolm", "BERT", "All-Models"],
389
- captions=["za-XLMR-Large", "za-Serengeti", "za-AfriBERTa", "za-Afro-XLMR-BASE", "za-AfroLM", "za-BERT", "All-Models"],
390
  index = 4,
391
  key = 'version_tab2',
392
  horizontal = True
 
168
  model_afroxlmr_base = load_model_pipeline('dsfsi/za-afro-xlmr-base-lid', "model.bin")
169
  model_afrolm = load_model_pipeline('dsfsi/za-afrolm-lid', "model.bin")
170
  za_lid = load_model_pipeline('dsfsi/za-lid-bert', "model.bin")
171
+ openlid = load_model('laurievb/OpenLID', "model.bin")
172
+ glotlid_3 = load_model(constants.MODEL_NAME, "model_v3.bin")
173
+
174
 
175
  # @st.cache_resource
176
  def plot(label, prob):
 
253
  model_choice = model_afrolm
254
  elif version == 'BERT':
255
  model_choice = za_lid
256
+ elif version == 'OpenLID':
257
+ model_choice = openlid
258
+ elif version == 'GlotLID v3':
259
+ model_choice = glotlid_3
260
  else:
261
+ model_choice = [model_xlmr_large,model_serengeti, model_afriberta, model_afroxlmr_base, model_afrolm, za_lid, openlid, glotlid_3]
262
 
263
  my_bar = st.progress(0, text=progress_text)
264
 
 
272
  all_models_pred = []
273
  for model in model_choice:
274
  output = model.predict(sent)
275
+ if version in ["openlid-201", "GlotLID v3"]:
276
+
277
+ output_label = output[index]['label']
278
+ output_prob = output[index]['score']
279
+ output_label_language = output[index]['label']
280
+ labels = labels + [output_label]
281
+ probs = probs + [output_prob]
282
+
283
+ my_bar.progress(
284
+ min((index) / len(sentences), 1),
285
+ text=progress_text,
286
+ )
287
+ else:
288
+
289
+ output_label = output[0][0].split('__')[-1].replace('_Hans', '_Hani').replace('_Hant', '_Hani')
290
+ output_prob = max(min(output[1][0], 1), 0)
291
+ output_label_language = output_label.split('_')[0]
292
+
293
+ # script control
294
+ if version in ['GlotLID v3', 'openlid-201', 'nllb-218'] and output_label_language!= 'zxx':
295
+ main_script, all_scripts = get_script(sent)
296
+ output_label_script = output_label.split('_')[1]
297
+
298
+ if output_label_script not in all_scripts:
299
+ output_label_script = main_script
300
+ output_label = f"und_{output_label_script}"
301
+ output_prob = 0
302
+
303
+
304
+ labels = labels + [output_label]
305
+ probs = probs + [output_prob]
306
+
307
+ my_bar.progress(
308
+ min((index) / len(sentences), 1),
309
+ text=progress_text,
310
+ )
311
+
312
+ else:
313
+ output = model_choice.predict(sent)
314
+ if version not in ["openlid-201", "GlotLID v3"]
315
+ output_label = output[index]['label']
316
+ output_prob = output[index]['score']
317
+ output_label_language = output[index]['label']
318
+ labels = labels + [output_label]
319
+ probs = probs + [output_prob]
320
+
321
+ my_bar.progress(
322
+ min((index) / len(sentences), 1),
323
+ text=progress_text,
324
+ )
325
+ else:
326
+ output_label = output[0][0].split('__')[-1].replace('_Hans', '_Hani').replace('_Hant', '_Hani')
327
+ output_prob = max(min(output[1][0], 1), 0)
328
+ output_label_language = output_label.split('_')[0]
329
 
330
+ # script control
331
+ if version in ['GlotLID v3', 'openlid-201', 'nllb-218'] and output_label_language!= 'zxx':
332
+ main_script, all_scripts = get_script(sent)
333
+ output_label_script = output_label.split('_')[1]
334
 
335
+ if output_label_script not in all_scripts:
336
+ output_label_script = main_script
337
+ output_label = f"und_{output_label_script}"
338
+ output_prob = 0
339
 
340
 
341
  labels = labels + [output_label]
 
344
  my_bar.progress(
345
  min((index) / len(sentences), 1),
346
  text=progress_text,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  )
348
  my_bar.empty()
349
  return probs, labels
 
370
 
371
  version = st.radio(
372
  "Choose model",
373
+ ["xlmrlarge", "serengeti", "afriberta", "afroxlmrbase", "afrolm", "BERT", "openlid-201", "GlotLID v3", "All-Models"],
374
+ captions=["za-XLMR-Large", "za-Serengeti", "za-AfriBERTa", "za-Afro-XLMR-BASE", "za-AfroLM", "za-BERT", "OpenLID", "GlotLID v3",'All-Models'],
375
  index = 4,
376
  key = 'version_tab1',
377
  horizontal = True
 
403
 
404
  # plot
405
  if version == "All-Models":
406
+ plot_multiples(["xlmrlarge", "serengeti", "afriberta", "afroxlmrbase", "afrolm", "BERT", "OpenLID", "GlotLID v3"], labels, probs)
407
  else:
408
  plot(label, prob)
409
 
 
412
 
413
  version = st.radio(
414
  "Choose model",
415
+ ["xlmrlarge", "serengeti", "afriberta", "afroxlmrbase", "afrolm", "BERT","openlid-201", "GlotLID v3", "All-Models"],
416
+ captions=["za-XLMR-Large", "za-Serengeti", "za-AfriBERTa", "za-Afro-XLMR-BASE", "za-AfroLM", "za-BERT", "OpenLID", "GlotLID v3", "All-Models"],
417
  index = 4,
418
  key = 'version_tab2',
419
  horizontal = True