langtrans / app.py
richylyq's picture
tweak m2m model code
ea13813
raw
history blame
7.19 kB
"""
translation program for simple text
1. detect language from langdetect
2. translate to target language given by user
Example from
https://www.thepythoncode.com/article/machine-translation-using-huggingface-transformers-in-python
user_input:
string: string to be translated
target_lang: language to be translated to
Returns:
string: translated string of text
try this : https://pypi.org/project/EasyNMT/
and this : https://huggingface.co/IDEA-CCNL/Randeng-Deltalm-362M-En-Zh
"""
from __future__ import annotations
from typing import Iterable
import gradio as gr
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes
import argparse
import langid
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
from easynmt import EasyNMT
# # Initialize nllb-200 models
# tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
# model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
# # Initialize mbart50 models
# mbart_m2en_model = EasyNMT("mbart50_m2en")
# mbart_en2m_model = EasyNMT("mbart50_en2m")
# Initialize m2m_100 models
m2m_model = EasyNMT("m2m_100_1.2B")
class myTheme(Base):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.red,
secondary_hue: colors.Color | str = colors.blue,
neutral_hue: colors.Color | str = colors.orange,
spacing_size: sizes.Size | str = sizes.spacing_md,
radius_size: sizes.Size | str = sizes.radius_md,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("handjet"),
"cursive",
# "sans-serif",
),
font_mono: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"),
"ui-monospace",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
body_background_fill="repeating-linear-gradient(135deg, *primary_800, *primary_800 10px, *primary_900 10px, *primary_900 20px)",
button_primary_background_fill="linear-gradient(90deg, *primary_600, *secondary_800)",
button_primary_background_fill_hover="linear-gradient(45deg, *primary_200, *secondary_300)",
button_primary_text_color="white",
slider_color="*secondary_300",
slider_color_dark="*secondary_600",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_shadow="*shadow_drop_lg",
button_large_padding="24px",
)
def detect_lang(article):
"""
Language Detection using library langid
Args:
article (string): article that user wish to translate
target_lang (string): language user want to translate article into
Returns:
string: detected language short form
"""
result_lang = langid.classify(article)
return result_lang[0]
def opus_trans(article, target_language):
"""
Translation by Helsinki-NLP model
Args:
article (string): article that user wishes to translate
target_language (string): language that user wishes to translate article into
Returns:
string: translated piece of article based off target_language
"""
result_lang = detect_lang(article)
if target_language == "English":
target_lang = "en"
elif target_language == "Chinese":
target_lang = "zh"
if result_lang != target_lang:
task_name = f"translation_{result_lang}_to_{target_lang}"
model_name = f"Helsinki-NLP/opus-mt-{result_lang}-{target_lang}"
try:
translator = pipeline(task_name, model=model_name, tokenizer=model_name)
translated = translator(article)[0]["translation_text"]
except:
translated = "Error: Model doesn't exist"
else:
translated = "Error: You chose the same language as the article detected language. Please reselect language and try again."
return translated
def nllb_trans(article, target_language):
result_lang = detect_lang(article)
inputs = tokenizer(article, return_tensors="pt")
if target_language == "English":
target_lang = "eng_Latn"
target_language = "en"
elif target_language == "Chinese":
target_lang = "zho_Hans"
target_language = "zh"
if result_lang != target_language:
translated_tokens = model.generate(
**inputs,
forced_bos_token_id=tokenizer.lang_code_to_id[target_lang],
max_length=30,
)
translated = tokenizer.batch_decode(
translated_tokens, skip_special_tokens=True
)[0]
else:
translated = "Error: You chose the same language as the article detected language. Please reselect language and try again."
return translated
def mbart_trans(article, target_language):
result_lang = detect_lang(article)
if result_lang != target_language:
if target_language == "English":
return mbart_m2en_model.translate(article, target_lang="en")
else:
return mbart_en2m_model.translate(article, target_lang="zh")
else:
return "Error: You chose the same language as the article detected language. Please reselect language and try again."
def m2m_trans(article, target_language):
result_lang = detect_lang(article)
if result_lang != target_language:
if target_language == "English":
return m2m_model.translate(article, target_lang="en")
elif target_language == "Chinese":
return m2m_model.translate(article, target_lang="zh")
else:
return "Error: You chose the same language as the article detected language. Please reselect language and try again."
def translate(article, toolkit, target_language):
if toolkit == "OPUS":
translated = opus_trans(article, target_language)
elif toolkit == "NLLB":
translated = nllb_trans(article, target_language)
elif toolkit == "MBART":
translated = mbart_trans(article, target_language)
elif toolkit == "M2M":
translated = m2m_trans(article, target_language)
return translated
myTheme = myTheme()
with gr.Blocks(theme=myTheme) as demo:
article = gr.Textbox(label="Article")
toolkit_select = gr.Radio(
["OPUS", "NLLB", "MBART", "M2M"], label="Select Translation Model", value="OPUS"
)
lang_select = gr.Radio(["English", "Chinese"], label="Select Desired Language")
result = gr.Textbox(label="Translated Result")
trans_btn = gr.Button("Translate")
trans_btn.click(
fn=translate, inputs=[article, toolkit_select, lang_select], outputs=result
)
demo.launch()