Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, NllbTokenizer | |
import gradio as gr | |
#tokenizer_en_to_tsn = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="eng_Latn", tgt_lang="tsn_Latn") | |
# Load the tokenizer and model for English to Kinyarwanda | |
# tokenizer_en_to_kin = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="eng_Latn", tgt_lang="kin_Latn") | |
#tokenizer_ses_to_en = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="ses_Latn", tgt_lang="eng_Latn") | |
model = AutoModelForSeq2SeqLM.from_pretrained("souvorinkg/eng-ses-nllb", token=False).half() | |
tokenizer = NllbTokenizer.from_pretrained("souvorinkg/eng-ses-nllb") | |
def fix_tokenizer(tokenizer, new_lang='ses_Latn'): | |
""" | |
Add a new language token to the tokenizer vocabulary | |
(this should be done each time after its initialization) | |
""" | |
old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder) | |
tokenizer.lang_code_to_id[new_lang] = old_len-1 | |
tokenizer.id_to_lang_code[old_len-1] = new_lang | |
# always move "mask" to the last position | |
tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset | |
tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id) | |
tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()} | |
if new_lang not in tokenizer._additional_special_tokens: | |
tokenizer._additional_special_tokens.append(new_lang) | |
# clear the added token encoder; otherwise a new token may end up there by mistake | |
tokenizer.added_tokens_encoder = {} | |
tokenizer.added_tokens_decoder = {} | |
fix_tokenizer(tokenizer) | |
model.resize_token_embeddings(len(tokenizer)) | |
def translate( | |
text, src_lang, tgt_lang, | |
a=32, b=3, max_input_length=1024, num_beams=4, **kwargs | |
): | |
"""Turn a text or a list of texts into a list of translations""" | |
tokenizer.src_lang = src_lang | |
tokenizer.tgt_lang = tgt_lang | |
inputs = tokenizer( | |
text, return_tensors='pt', padding=True, truncation=True, | |
max_length=max_input_length | |
) | |
model.eval() # turn off training mode | |
result = model.generate( | |
**inputs.to(model.device), | |
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang), | |
max_new_tokens=int(a + b * inputs.input_ids.shape[1]), | |
num_beams=num_beams, **kwargs | |
) | |
return tokenizer.batch_decode(result, skip_special_tokens=True) | |
# fixing the new/moved token embeddings in the model | |
added_token_id = tokenizer.convert_tokens_to_ids('ses_Latn') | |
similar_lang_id = tokenizer.convert_tokens_to_ids('tsn_Latn') | |
embeds = model.model.shared.weight.data | |
# moving the embedding for "mask" to its new position | |
embeds[added_token_id+1] =embeds[added_token_id] | |
# initializing new language token with a token of a similar language | |
embeds[added_token_id] = embeds[similar_lang_id] | |
# Load the tokenizer and model for Kinyarwanda to English | |
# tokenizer_kin_to_en = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="kin_Latn", tgt_lang="eng_Latn") | |
#tokenizer_en_to_ses = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="eng_Latn", tgt_lang="ses_Latn") | |
#tokenizer_tsn_to_eng = AutoTokenizer.from_pretrained("souvorinkg/eng-ses-nllb", token=False, src_lang="tsn_Latn", tgt_lang="eng_Latn") | |
# Define the translation function for English to Kinyarwanda | |
# def translate_en_to_kin(SourceText): | |
# inputs = tokenizer_en_to_kin(SourceText, return_tensors="pt") | |
# translated_tokens = model.generate(**inputs, max_length=30) | |
# return tokenizer_en_to_kin.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
# Define the translation function for Kinyarwanda to English | |
# def translate_kin_to_en(SourceText): | |
# inputs = tokenizer_kin_to_en(SourceText, return_tensors="pt") | |
# translated_tokens = model.generate(**inputs, max_length=30, no_repeat_ngram_size=2) | |
# return tokenizer_kin_to_en.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
# def translate_en_to_ses(SourceText): | |
# inputs = tokenizer_en_to_ses(SourceText, return_tensors="pt") | |
# translated_tokens = model.generate(**inputs, max_length=30) | |
# return tokenizer_tsn_to_eng.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
# def translate_ses_to_en(SourceText): | |
# inputs = inputs = tokenizer_tsn_to_eng(SourceText, return_tensors="pt") | |
# translated_tokens = model.generate(**inputs, max_length=30) | |
# return tokenizer_en_to_ses.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
# def translate_en_to_tsn(SourceText): | |
# inputs = tokenizer_en_to_tsn(SourceText, return_tensors="pt") | |
# translated_tokens = model.generate(**inputs, max_length=30) | |
# return tokenizer_en_to_tsn.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
# Function to handle dropdown selection and call the appropriate translation function | |
def translateIn(SourceText, direction): | |
# if direction == "English to Kinyarwanda": | |
# return translate_en_to_kin(SourceText) | |
# if direction == "Kinyarwanda to English": | |
# return translate_kin_to_en(SourceText) | |
if direction == "English to Sesotho": | |
text = translate(text=SourceText, src_lang='eng_Latn', tgt_lang='ses_Latn') | |
return text[0] | |
if direction == "Sesotho to English": | |
text = translate(text=SourceText, src_lang='tsn_Latn', tgt_lang='eng_Latn') | |
return text[0] | |
# if direction == "English to Tswana": | |
# return translate == translate_en_to_tsn(SourceText) | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=translateIn, | |
inputs=[gr.Textbox(lines=2, label="Input Text"), gr.Dropdown(["English to Sesotho", "Sesotho to English"], label="Translation Direction")], | |
outputs="text", | |
title="Bilingual Translator", | |
description="Select translation direction and enter text to translate." | |
) | |
# Launch the app | |
iface.launch(share=True) |