add mbart for translation
Browse files
app.py
CHANGED
@@ -25,10 +25,16 @@ import argparse
|
|
25 |
|
26 |
import langid
|
27 |
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
28 |
|
|
|
29 |
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
|
30 |
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
|
31 |
|
|
|
|
|
|
|
|
|
32 |
|
33 |
class myTheme(Base):
|
34 |
def __init__(
|
@@ -155,11 +161,23 @@ def nllb_trans(article, target_language):
|
|
155 |
return translated
|
156 |
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
def translate(article, toolkit, target_language):
|
159 |
if toolkit == "OPUS":
|
160 |
translated = opus_trans(article, target_language)
|
161 |
elif toolkit == "NLLB":
|
162 |
translated = nllb_trans(article, target_language)
|
|
|
|
|
163 |
|
164 |
return translated
|
165 |
|
@@ -169,7 +187,7 @@ myTheme = myTheme()
|
|
169 |
with gr.Blocks(theme=myTheme) as demo:
|
170 |
article = gr.Textbox(label="Article")
|
171 |
toolkit_select = gr.Radio(
|
172 |
-
["OPUS", "NLLB"], label="Select Translation Model", value="OPUS"
|
173 |
)
|
174 |
lang_select = gr.Radio(["English", "Chinese"], label="Select Desired Language")
|
175 |
result = gr.Textbox(label="Translated Result")
|
|
|
25 |
|
26 |
import langid
|
27 |
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
|
28 |
+
from easynmt import EasyNMT
|
29 |
|
30 |
+
# Initialize nllb-200 models
|
31 |
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
|
32 |
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
|
33 |
|
34 |
+
# Initialize mbart50 models
|
35 |
+
mbart_m2en_model = EasyNMT("mbart50_m2en")
|
36 |
+
mbart_en2m_model = EasyNMT("mbart50_en2m")
|
37 |
+
|
38 |
|
39 |
class myTheme(Base):
|
40 |
def __init__(
|
|
|
161 |
return translated
|
162 |
|
163 |
|
164 |
+
def mbart_trans(article, target_language):
|
165 |
+
result_lang = detect_lang(article)
|
166 |
+
|
167 |
+
if result_lang != target_language:
|
168 |
+
if target_language == "English":
|
169 |
+
return mbart_m2en_model.translate(article)
|
170 |
+
else:
|
171 |
+
return mbart_en2m_model.translate(article, target_lang="zh")
|
172 |
+
|
173 |
+
|
174 |
def translate(article, toolkit, target_language):
|
175 |
if toolkit == "OPUS":
|
176 |
translated = opus_trans(article, target_language)
|
177 |
elif toolkit == "NLLB":
|
178 |
translated = nllb_trans(article, target_language)
|
179 |
+
elif toolkit == "MBART":
|
180 |
+
translated = mbart_trans(article, target_language)
|
181 |
|
182 |
return translated
|
183 |
|
|
|
187 |
with gr.Blocks(theme=myTheme) as demo:
|
188 |
article = gr.Textbox(label="Article")
|
189 |
toolkit_select = gr.Radio(
|
190 |
+
["OPUS", "NLLB", "MBART"], label="Select Translation Model", value="OPUS"
|
191 |
)
|
192 |
lang_select = gr.Radio(["English", "Chinese"], label="Select Desired Language")
|
193 |
result = gr.Textbox(label="Translated Result")
|