richylyq commited on
Commit
33e3967
·
1 Parent(s): e059be8

add nllb translate codes

Browse files
Files changed (1) hide show
  1. app.py +20 -6
app.py CHANGED
@@ -24,7 +24,7 @@ from gradio.themes.utils import colors, fonts, sizes
24
  import argparse
25
 
26
  import langid
27
- from transformers import pipeline
28
 
29
 
30
  class myTheme(Base):
@@ -112,8 +112,6 @@ def opus_trans(article, target_language):
112
  target_lang = "en"
113
  elif target_language == "Chinese":
114
  target_lang = "zh"
115
- elif target_language == "Spanish":
116
- target_lang = "es"
117
 
118
  if result_lang != target_lang:
119
  task_name = f"translation_{result_lang}_to_{target_lang}"
@@ -129,15 +127,31 @@ def opus_trans(article, target_language):
129
 
130
 
131
  def nllb_trans(article, target_language):
132
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
 
135
  def translate(article, toolkit, target_language):
136
  if toolkit == "OPUS":
137
  translated = opus_trans(article, target_language)
138
- return translated
139
  elif toolkit == "NLLB":
140
- pass
 
 
141
 
142
 
143
  myTheme = myTheme()
 
24
  import argparse
25
 
26
  import langid
27
+ from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
28
 
29
 
30
  class myTheme(Base):
 
112
  target_lang = "en"
113
  elif target_language == "Chinese":
114
  target_lang = "zh"
 
 
115
 
116
  if result_lang != target_lang:
117
  task_name = f"translation_{result_lang}_to_{target_lang}"
 
127
 
128
 
129
  def nllb_trans(article, target_language):
130
+ tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
131
+ model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
132
+ inputs = tokenizer(article, return_tensors="pt")
133
+
134
+ if target_language == "English":
135
+ target_lang = "Eng_Latn"
136
+ elif target_language == "Chinese":
137
+ target_lang = "zho_Hans"
138
+
139
+ translated_tokens = model.generate(
140
+ **inputs,
141
+ forced_bos_token_id=tokenizer.lang_code_to_id[target_lang],
142
+ max_length=30,
143
+ )
144
+
145
+ return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
146
 
147
 
148
  def translate(article, toolkit, target_language):
149
  if toolkit == "OPUS":
150
  translated = opus_trans(article, target_language)
 
151
  elif toolkit == "NLLB":
152
+ translated = nllb_trans(article, target_language)
153
+
154
+ return translated
155
 
156
 
157
  myTheme = myTheme()