add nllb translate codes
Browse files
app.py
CHANGED
@@ -24,7 +24,7 @@ from gradio.themes.utils import colors, fonts, sizes
|
|
24 |
import argparse
|
25 |
|
26 |
import langid
|
27 |
-
from transformers import pipeline
|
28 |
|
29 |
|
30 |
class myTheme(Base):
|
@@ -112,8 +112,6 @@ def opus_trans(article, target_language):
|
|
112 |
target_lang = "en"
|
113 |
elif target_language == "Chinese":
|
114 |
target_lang = "zh"
|
115 |
-
elif target_language == "Spanish":
|
116 |
-
target_lang = "es"
|
117 |
|
118 |
if result_lang != target_lang:
|
119 |
task_name = f"translation_{result_lang}_to_{target_lang}"
|
@@ -129,15 +127,31 @@ def opus_trans(article, target_language):
|
|
129 |
|
130 |
|
131 |
def nllb_trans(article, target_language):
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
|
135 |
def translate(article, toolkit, target_language):
|
136 |
if toolkit == "OPUS":
|
137 |
translated = opus_trans(article, target_language)
|
138 |
-
return translated
|
139 |
elif toolkit == "NLLB":
|
140 |
-
|
|
|
|
|
141 |
|
142 |
|
143 |
myTheme = myTheme()
|
|
|
24 |
import argparse
|
25 |
|
26 |
import langid
|
27 |
+
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
|
28 |
|
29 |
|
30 |
class myTheme(Base):
|
|
|
112 |
target_lang = "en"
|
113 |
elif target_language == "Chinese":
|
114 |
target_lang = "zh"
|
|
|
|
|
115 |
|
116 |
if result_lang != target_lang:
|
117 |
task_name = f"translation_{result_lang}_to_{target_lang}"
|
|
|
127 |
|
128 |
|
129 |
def nllb_trans(article, target_language):
|
130 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
|
131 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
|
132 |
+
inputs = tokenizer(article, return_tensors="pt")
|
133 |
+
|
134 |
+
if target_language == "English":
|
135 |
+
target_lang = "Eng_Latn"
|
136 |
+
elif target_language == "Chinese":
|
137 |
+
target_lang = "zho_Hans"
|
138 |
+
|
139 |
+
translated_tokens = model.generate(
|
140 |
+
**inputs,
|
141 |
+
forced_bos_token_id=tokenizer.lang_code_to_id[target_lang],
|
142 |
+
max_length=30,
|
143 |
+
)
|
144 |
+
|
145 |
+
return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
146 |
|
147 |
|
148 |
def translate(article, toolkit, target_language):
|
149 |
if toolkit == "OPUS":
|
150 |
translated = opus_trans(article, target_language)
|
|
|
151 |
elif toolkit == "NLLB":
|
152 |
+
translated = nllb_trans(article, target_language)
|
153 |
+
|
154 |
+
return translated
|
155 |
|
156 |
|
157 |
myTheme = myTheme()
|