hf_translator / translate.py
Temuzin64's picture
commit
d59893d verified
mport streamlit as st
import torch
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
### Getting the Languages supported ####
LanguageCovered = "Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI"
LanguageCovered = LanguageCovered.split(",")
languages_list = [a.strip() for a in LanguageCovered]
languages_list = [a.split(" ") for a in languages_list]
languages = [a[0] for a in languages_list]
codes = [a[1] for a in languages_list]
codes = [a.replace('(', '') for a in codes]
codes = [a.replace(')', '') for a in codes]
lang_dict = dict(zip(languages, codes))
model_name = "facebook/mbart-large-50-many-to-many-mmt"
# tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
# model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
model = model.to(device)
def translate_text(text, source_lang, target_lang):
tokenizer.src_lang = source_lang
encoded_text = tokenizer(text, return_tensors="pt").to(device)
generated_tokens = model.generate(**encoded_text, forced_bos_token_id=tokenizer.lang_code_to_id[target_lang])
#Decode the output
translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
return translated_text
st.markdown("### Language Translator")
source_language = ''
target_language = ''
source = st.sidebar.selectbox('Source Language', languages)
if source:
source_language = lang_dict.get(source)
st.write(source_language)
target = st.sidebar.selectbox('Target Language', languages)
if target:
target_language = lang_dict.get(target)
st.write(target_language)
with st.form(key="myForm"):
text = st.text_area("Enter your text")
submit = st.form_submit_button("Submit", type='primary')
if submit and text and source_language and target_language:
with st.spinner(f"{source} to {target} translating"):
translation = translate_text(text, source_language, target_language)
st.write(translation)