eng-ses-demo / app.py
souvorinkg's picture
Update app.py
9ab314d verified
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, NllbTokenizer
import gradio as gr
#tokenizer_en_to_tsn = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="eng_Latn", tgt_lang="tsn_Latn")
# Load the tokenizer and model for English to Kinyarwanda
# tokenizer_en_to_kin = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="eng_Latn", tgt_lang="kin_Latn")
#tokenizer_ses_to_en = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="ses_Latn", tgt_lang="eng_Latn")
model = AutoModelForSeq2SeqLM.from_pretrained("souvorinkg/eng-ses-nllb", token=False).half()
tokenizer = NllbTokenizer.from_pretrained("souvorinkg/eng-ses-nllb")
def fix_tokenizer(tokenizer, new_lang='ses_Latn'):
"""
Add a new language token to the tokenizer vocabulary
(this should be done each time after its initialization)
"""
old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
tokenizer.lang_code_to_id[new_lang] = old_len-1
tokenizer.id_to_lang_code[old_len-1] = new_lang
# always move "mask" to the last position
tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset
tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
if new_lang not in tokenizer._additional_special_tokens:
tokenizer._additional_special_tokens.append(new_lang)
# clear the added token encoder; otherwise a new token may end up there by mistake
tokenizer.added_tokens_encoder = {}
tokenizer.added_tokens_decoder = {}
fix_tokenizer(tokenizer)
model.resize_token_embeddings(len(tokenizer))
def translate(
text, src_lang, tgt_lang,
a=32, b=3, max_input_length=1024, num_beams=4, **kwargs
):
"""Turn a text or a list of texts into a list of translations"""
tokenizer.src_lang = src_lang
tokenizer.tgt_lang = tgt_lang
inputs = tokenizer(
text, return_tensors='pt', padding=True, truncation=True,
max_length=max_input_length
)
model.eval() # turn off training mode
result = model.generate(
**inputs.to(model.device),
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
max_new_tokens=int(a + b * inputs.input_ids.shape[1]),
num_beams=num_beams, **kwargs
)
return tokenizer.batch_decode(result, skip_special_tokens=True)
# fixing the new/moved token embeddings in the model
added_token_id = tokenizer.convert_tokens_to_ids('ses_Latn')
similar_lang_id = tokenizer.convert_tokens_to_ids('tsn_Latn')
embeds = model.model.shared.weight.data
# moving the embedding for "mask" to its new position
embeds[added_token_id+1] =embeds[added_token_id]
# initializing new language token with a token of a similar language
embeds[added_token_id] = embeds[similar_lang_id]
# Load the tokenizer and model for Kinyarwanda to English
# tokenizer_kin_to_en = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="kin_Latn", tgt_lang="eng_Latn")
#tokenizer_en_to_ses = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="eng_Latn", tgt_lang="ses_Latn")
#tokenizer_tsn_to_eng = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="tsn_Latn", tgt_lang="eng_Latn")
# Define the translation function for English to Kinyarwanda
# def translate_en_to_kin(SourceText):
# inputs = tokenizer_en_to_kin(SourceText, return_tensors="pt")
# translated_tokens = model.generate(**inputs, max_length=30)
# return tokenizer_en_to_kin.batch_decode(translated_tokens, skip_special_tokens=True)[0]
# Define the translation function for Kinyarwanda to English
# def translate_kin_to_en(SourceText):
# inputs = tokenizer_kin_to_en(SourceText, return_tensors="pt")
# translated_tokens = model.generate(**inputs, max_length=30, no_repeat_ngram_size=2)
# return tokenizer_kin_to_en.batch_decode(translated_tokens, skip_special_tokens=True)[0]
# def translate_en_to_ses(SourceText):
# inputs = tokenizer_en_to_ses(SourceText, return_tensors="pt")
# translated_tokens = model.generate(**inputs, max_length=30)
# return tokenizer_tsn_to_eng.batch_decode(translated_tokens, skip_special_tokens=True)[0]
# def translate_ses_to_en(SourceText):
# inputs = inputs = tokenizer_tsn_to_eng(SourceText, return_tensors="pt")
# translated_tokens = model.generate(**inputs, max_length=30)
# return tokenizer_en_to_ses.batch_decode(translated_tokens, skip_special_tokens=True)[0]
# def translate_en_to_tsn(SourceText):
# inputs = tokenizer_en_to_tsn(SourceText, return_tensors="pt")
# translated_tokens = model.generate(**inputs, max_length=30)
# return tokenizer_en_to_tsn.batch_decode(translated_tokens, skip_special_tokens=True)[0]
# Function to handle dropdown selection and call the appropriate translation function
def translateIn(SourceText, direction):
# if direction == "English to Kinyarwanda":
# return translate_en_to_kin(SourceText)
# if direction == "Kinyarwanda to English":
# return translate_kin_to_en(SourceText)
if direction == "English to Sesotho":
text = translate(text=SourceText, src_lang='eng_Latn', tgt_lang='ses_Latn')
return text[0]
if direction == "Sesotho to English":
text = translate(text=SourceText, src_lang='tsn_Latn', tgt_lang='eng_Latn')
return text[0]
# if direction == "English to Tswana":
# return translate == translate_en_to_tsn(SourceText)
# Create the Gradio interface
iface = gr.Interface(
fn=translateIn,
inputs=[gr.Textbox(lines=2, label="Input Text"), gr.Dropdown(["English to Sesotho", "Sesotho to English"], label="Translation Direction")],
outputs="text",
title="Bilingual Translator",
description="Select translation direction and enter text to translate."
)
# Launch the app
iface.launch(share=True)