richylyq commited on
Commit
93baa69
·
1 Parent(s): 76ef102

add mbart for translation

Browse files
Files changed (1) hide show
  1. app.py +19 -1
app.py CHANGED
@@ -25,10 +25,16 @@ import argparse
25
 
26
  import langid
27
  from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
 
28
 
 
29
  tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
30
  model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
31
 
 
 
 
 
32
 
33
  class myTheme(Base):
34
  def __init__(
@@ -155,11 +161,23 @@ def nllb_trans(article, target_language):
155
  return translated
156
 
157
 
 
 
 
 
 
 
 
 
 
 
158
  def translate(article, toolkit, target_language):
159
  if toolkit == "OPUS":
160
  translated = opus_trans(article, target_language)
161
  elif toolkit == "NLLB":
162
  translated = nllb_trans(article, target_language)
 
 
163
 
164
  return translated
165
 
@@ -169,7 +187,7 @@ myTheme = myTheme()
169
  with gr.Blocks(theme=myTheme) as demo:
170
  article = gr.Textbox(label="Article")
171
  toolkit_select = gr.Radio(
172
- ["OPUS", "NLLB"], label="Select Translation Model", value="OPUS"
173
  )
174
  lang_select = gr.Radio(["English", "Chinese"], label="Select Desired Language")
175
  result = gr.Textbox(label="Translated Result")
 
25
 
26
  import langid
27
  from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
28
+ from easynmt import EasyNMT
29
 
30
+ # Initialize nllb-200 models
31
  tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
32
  model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
33
 
34
+ # Initialize mbart50 models
35
+ mbart_m2en_model = EasyNMT("mbart50_m2en")
36
+ mbart_en2m_model = EasyNMT("mbart50_en2m")
37
+
38
 
39
  class myTheme(Base):
40
  def __init__(
 
161
  return translated
162
 
163
 
164
+ def mbart_trans(article, target_language):
165
+ result_lang = detect_lang(article)
166
+
167
+ if result_lang != target_language:
168
+ if target_language == "English":
169
+ return mbart_m2en_model.translate(article)
170
+ else:
171
+ return mbart_en2m_model.translate(article, target_lang="zh")
172
+
173
+
174
  def translate(article, toolkit, target_language):
175
  if toolkit == "OPUS":
176
  translated = opus_trans(article, target_language)
177
  elif toolkit == "NLLB":
178
  translated = nllb_trans(article, target_language)
179
+ elif toolkit == "MBART":
180
+ translated = mbart_trans(article, target_language)
181
 
182
  return translated
183
 
 
187
  with gr.Blocks(theme=myTheme) as demo:
188
  article = gr.Textbox(label="Article")
189
  toolkit_select = gr.Radio(
190
+ ["OPUS", "NLLB", "MBART"], label="Select Translation Model", value="OPUS"
191
  )
192
  lang_select = gr.Radio(["English", "Chinese"], label="Select Desired Language")
193
  result = gr.Textbox(label="Translated Result")