hysts HF staff commited on
Commit
a71e647
·
1 Parent(s): c105fbb

Fix the language list

Browse files
Files changed (2) hide show
  1. app.py +31 -16
  2. lang_list.py +115 -197
app.py CHANGED
@@ -7,15 +7,15 @@ import torchaudio
7
  from seamless_communication.models.inference.translator import Translator
8
 
9
  from lang_list import (
10
- S2ST_TARGET_LANGUAGES,
11
- S2TT_TARGET_LANGUAGES,
12
- T2TT_TARGET_LANGUAGES,
13
- TEXT_SOURCE_LANGUAGES,
 
14
  )
15
 
16
  DESCRIPTION = "# SeamlessM4T"
17
 
18
-
19
  TASK_NAMES = [
20
  "S2ST (Speech to Speech translation)",
21
  "S2TT (Speech to Text translation)",
@@ -27,6 +27,8 @@ TASK_NAMES = [
27
  AUDIO_SAMPLE_RATE = 16000.0
28
  MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
29
 
 
 
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
  translator = Translator(
32
  model_name_or_card="multitask_unity_large",
@@ -46,6 +48,9 @@ def predict(
46
  target_language: str,
47
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
48
  task_name = task_name.split()[0]
 
 
 
49
  if task_name in ["S2ST", "S2TT", "ASR"]:
50
  if audio_source == "microphone":
51
  input_data = input_audio_mic
@@ -64,8 +69,8 @@ def predict(
64
  text_out, wav, sr = translator.predict(
65
  input=input_data,
66
  task_str=task_name,
67
- tgt_lang=target_language,
68
- src_lang=source_language,
69
  )
70
  if task_name in ["S2ST", "T2ST"]:
71
  return (sr, wav.cpu().detach().numpy()), text_out
@@ -88,35 +93,45 @@ def update_input_ui(task_name: str) -> tuple[dict, dict, dict, dict]:
88
  gr.update(visible=True), # audio_box
89
  gr.update(visible=False), # input_text
90
  gr.update(visible=False), # source_language
91
- gr.update(visible=True, choices=S2ST_TARGET_LANGUAGES, value="fra"), # target_language
 
 
92
  )
93
  elif task_name == "S2TT":
94
  return (
95
  gr.update(visible=True), # audio_box
96
  gr.update(visible=False), # input_text
97
  gr.update(visible=False), # source_language
98
- gr.update(visible=True, choices=S2TT_TARGET_LANGUAGES, value="fra"), # target_language
 
 
99
  )
100
  elif task_name == "T2ST":
101
  return (
102
  gr.update(visible=False), # audio_box
103
  gr.update(visible=True), # input_text
104
  gr.update(visible=True), # source_language
105
- gr.update(visible=True, choices=S2ST_TARGET_LANGUAGES, value="fra"), # target_language
 
 
106
  )
107
  elif task_name == "T2TT":
108
  return (
109
  gr.update(visible=False), # audio_box
110
  gr.update(visible=True), # input_text
111
  gr.update(visible=True), # source_language
112
- gr.update(visible=True, choices=T2TT_TARGET_LANGUAGES, value="fra"), # target_language
 
 
113
  )
114
  elif task_name == "ASR":
115
  return (
116
  gr.update(visible=True), # audio_box
117
  gr.update(visible=False), # input_text
118
  gr.update(visible=False), # source_language
119
- gr.update(visible=True, choices=S2TT_TARGET_LANGUAGES, value="fra"), # target_language
 
 
120
  )
121
  else:
122
  raise ValueError(f"Unknown task: {task_name}")
@@ -154,14 +169,14 @@ with gr.Blocks(css="style.css") as demo:
154
  with gr.Row():
155
  source_language = gr.Dropdown(
156
  label="Source language",
157
- choices=TEXT_SOURCE_LANGUAGES,
158
- value="eng",
159
  visible=False,
160
  )
161
  target_language = gr.Dropdown(
162
  label="Target language",
163
- choices=S2ST_TARGET_LANGUAGES,
164
- value="fra",
165
  )
166
  with gr.Row() as audio_box:
167
  audio_source = gr.Radio(
 
7
  from seamless_communication.models.inference.translator import Translator
8
 
9
  from lang_list import (
10
+ LANGUAGE_NAME_TO_CODE,
11
+ S2ST_TARGET_LANGUAGE_NAMES,
12
+ S2TT_TARGET_LANGUAGE_NAMES,
13
+ T2TT_TARGET_LANGUAGE_NAMES,
14
+ TEXT_SOURCE_LANGUAGE_NAMES,
15
  )
16
 
17
  DESCRIPTION = "# SeamlessM4T"
18
 
 
19
  TASK_NAMES = [
20
  "S2ST (Speech to Speech translation)",
21
  "S2TT (Speech to Text translation)",
 
27
  AUDIO_SAMPLE_RATE = 16000.0
28
  MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
29
 
30
+ DEFAULT_TARGET_LANGUAGE = "French"
31
+
32
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
  translator = Translator(
34
  model_name_or_card="multitask_unity_large",
 
48
  target_language: str,
49
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
50
  task_name = task_name.split()[0]
51
+ source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
52
+ target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
53
+
54
  if task_name in ["S2ST", "S2TT", "ASR"]:
55
  if audio_source == "microphone":
56
  input_data = input_audio_mic
 
69
  text_out, wav, sr = translator.predict(
70
  input=input_data,
71
  task_str=task_name,
72
+ tgt_lang=target_language_code,
73
+ src_lang=source_language_code,
74
  )
75
  if task_name in ["S2ST", "T2ST"]:
76
  return (sr, wav.cpu().detach().numpy()), text_out
 
93
  gr.update(visible=True), # audio_box
94
  gr.update(visible=False), # input_text
95
  gr.update(visible=False), # source_language
96
+ gr.update(
97
+ visible=True, choices=S2ST_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
98
+ ), # target_language
99
  )
100
  elif task_name == "S2TT":
101
  return (
102
  gr.update(visible=True), # audio_box
103
  gr.update(visible=False), # input_text
104
  gr.update(visible=False), # source_language
105
+ gr.update(
106
+ visible=True, choices=S2TT_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
107
+ ), # target_language
108
  )
109
  elif task_name == "T2ST":
110
  return (
111
  gr.update(visible=False), # audio_box
112
  gr.update(visible=True), # input_text
113
  gr.update(visible=True), # source_language
114
+ gr.update(
115
+ visible=True, choices=S2ST_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
116
+ ), # target_language
117
  )
118
  elif task_name == "T2TT":
119
  return (
120
  gr.update(visible=False), # audio_box
121
  gr.update(visible=True), # input_text
122
  gr.update(visible=True), # source_language
123
+ gr.update(
124
+ visible=True, choices=T2TT_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
125
+ ), # target_language
126
  )
127
  elif task_name == "ASR":
128
  return (
129
  gr.update(visible=True), # audio_box
130
  gr.update(visible=False), # input_text
131
  gr.update(visible=False), # source_language
132
+ gr.update(
133
+ visible=True, choices=S2TT_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
134
+ ), # target_language
135
  )
136
  else:
137
  raise ValueError(f"Unknown task: {task_name}")
 
169
  with gr.Row():
170
  source_language = gr.Dropdown(
171
  label="Source language",
172
+ choices=TEXT_SOURCE_LANGUAGE_NAMES,
173
+ value="English",
174
  visible=False,
175
  )
176
  target_language = gr.Dropdown(
177
  label="Target language",
178
+ choices=S2ST_TARGET_LANGUAGE_NAMES,
179
+ value=DEFAULT_TARGET_LANGUAGE,
180
  )
181
  with gr.Row() as audio_box:
182
  audio_source = gr.Radio(
lang_list.py CHANGED
@@ -1,6 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Source langs: S2ST / S2TT / ASR don't need source lang
2
  # T2TT / T2ST use this
3
- TEXT_SOURCE_LANGUAGES = [
4
  "afr",
5
  "amh",
6
  "arb",
@@ -17,7 +124,6 @@ TEXT_SOURCE_LANGUAGES = [
17
  "ces",
18
  "ckb",
19
  "cmn",
20
- "zho_Hant",
21
  "cym",
22
  "dan",
23
  "deu",
@@ -27,7 +133,6 @@ TEXT_SOURCE_LANGUAGES = [
27
  "eus",
28
  "fin",
29
  "fra",
30
- "fuv",
31
  "gaz",
32
  "gle",
33
  "glg",
@@ -75,7 +180,6 @@ TEXT_SOURCE_LANGUAGES = [
75
  "por",
76
  "ron",
77
  "rus",
78
- "sat",
79
  "slk",
80
  "slv",
81
  "sna",
@@ -100,10 +204,11 @@ TEXT_SOURCE_LANGUAGES = [
100
  "zsm",
101
  "zul",
102
  ]
 
103
 
104
  # Target langs:
105
  # S2ST / T2ST
106
- S2ST_TARGET_LANGUAGES = [
107
  "eng",
108
  "arb",
109
  "ben",
@@ -133,94 +238,6 @@ S2ST_TARGET_LANGUAGES = [
133
  "swe",
134
  "swh",
135
  "tel",
136
- "tgl/fil",
137
- "tha",
138
- "tur",
139
- "ukr",
140
- "urd",
141
- "uzn",
142
- "vie",
143
- ]
144
- # S2TT / ASR
145
- S2TT_TARGET_LANGUAGES = [
146
- "amh",
147
- "arb",
148
- "asm",
149
- "azj",
150
- "bel",
151
- "ben",
152
- "bos",
153
- "bul",
154
- "cat",
155
- "ceb",
156
- "ces",
157
- "ckb",
158
- "cmn",
159
- "cym",
160
- "dan",
161
- "deu",
162
- "ell",
163
- "eng",
164
- "est",
165
- "fin",
166
- "fra",
167
- "ful",
168
- "gaz",
169
- "gle",
170
- "glg",
171
- "guj",
172
- "heb",
173
- "hin",
174
- "hrv",
175
- "hun",
176
- "hye",
177
- "ibo",
178
- "ind",
179
- "isl",
180
- "ita",
181
- "jav",
182
- "jpn",
183
- "kan",
184
- "kat",
185
- "kaz",
186
- "khk",
187
- "khm",
188
- "kir",
189
- "kor",
190
- "lao",
191
- "lit",
192
- "lug",
193
- "luo",
194
- "lvs",
195
- "mal",
196
- "mar",
197
- "mkd",
198
- "mlt",
199
- "mya",
200
- "nld",
201
- "nob",
202
- "npi",
203
- "nya",
204
- "ory",
205
- "pan",
206
- "pbt",
207
- "pes",
208
- "pol",
209
- "por",
210
- "ron",
211
- "rus",
212
- "slk",
213
- "slv",
214
- "sna",
215
- "snd",
216
- "som",
217
- "spa",
218
- "srp",
219
- "swe",
220
- "swh",
221
- "tam",
222
- "tel",
223
- "tgk",
224
  "tgl",
225
  "tha",
226
  "tur",
@@ -228,109 +245,10 @@ S2TT_TARGET_LANGUAGES = [
228
  "urd",
229
  "uzn",
230
  "vie",
231
- "yor",
232
- "yue",
233
- "zlm",
234
- "zul",
235
  ]
 
 
 
 
236
  # T2TT
237
- T2TT_TARGET_LANGUAGES = [
238
- "afr",
239
- "amh",
240
- "arb",
241
- "ary",
242
- "arz",
243
- "asm",
244
- "azj",
245
- "bel",
246
- "ben",
247
- "bos",
248
- "bul",
249
- "cat",
250
- "ceb",
251
- "ces",
252
- "ckb",
253
- "cmn",
254
- "cym",
255
- "dan",
256
- "deu",
257
- "ell",
258
- "eng",
259
- "est",
260
- "eus",
261
- "fin",
262
- "fra",
263
- "fuv",
264
- "gaz",
265
- "gle",
266
- "glg",
267
- "guj",
268
- "heb",
269
- "hin",
270
- "hrv",
271
- "hun",
272
- "hye",
273
- "ibo",
274
- "ind",
275
- "isl",
276
- "ita",
277
- "jav",
278
- "jpn",
279
- "kan",
280
- "kat",
281
- "kaz",
282
- "khk",
283
- "khm",
284
- "kir",
285
- "kor",
286
- "lao",
287
- "lit",
288
- "lug",
289
- "luo",
290
- "lvs",
291
- "mai",
292
- "mal",
293
- "mar",
294
- "mkd",
295
- "mlt",
296
- "mni",
297
- "mya",
298
- "nld",
299
- "nno",
300
- "nob",
301
- "npi",
302
- "nya",
303
- "ory",
304
- "pan",
305
- "pbt",
306
- "pes",
307
- "pol",
308
- "por",
309
- "ron",
310
- "rus",
311
- "sat",
312
- "slk",
313
- "slv",
314
- "sna",
315
- "snd",
316
- "som",
317
- "spa",
318
- "srp",
319
- "swe",
320
- "swh",
321
- "tam",
322
- "tel",
323
- "tgk",
324
- "tgl",
325
- "tha",
326
- "tur",
327
- "ukr",
328
- "urd",
329
- "uzn",
330
- "vie",
331
- "yor",
332
- "yue",
333
- "zho_Hant",
334
- "zsm",
335
- "zul",
336
- ]
 
1
+ # Language dict
2
+ language_code_to_name = {
3
+ "afr": "Afrikaans",
4
+ "amh": "Amharic",
5
+ "arb": "Modern Standard Arabic",
6
+ "ary": "Moroccan Arabic",
7
+ "arz": "Egyptian Arabic",
8
+ "asm": "Assamese",
9
+ "ast": "Asturian",
10
+ "azj": "North Azerbaijani",
11
+ "bel": "Belarusian",
12
+ "ben": "Bengali",
13
+ "bos": "Bosnian",
14
+ "bul": "Bulgarian",
15
+ "cat": "Catalan",
16
+ "ceb": "Cebuano",
17
+ "ces": "Czech",
18
+ "ckb": "Central Kurdish",
19
+ "cmn": "Mandarin Chinese",
20
+ "cym": "Welsh",
21
+ "dan": "Danish",
22
+ "deu": "German",
23
+ "ell": "Greek",
24
+ "eng": "English",
25
+ "est": "Estonian",
26
+ "eus": "Basque",
27
+ "fin": "Finnish",
28
+ "fra": "French",
29
+ "gaz": "West Central Oromo",
30
+ "gle": "Irish",
31
+ "glg": "Galician",
32
+ "guj": "Gujarati",
33
+ "heb": "Hebrew",
34
+ "hin": "Hindi",
35
+ "hrv": "Croatian",
36
+ "hun": "Hungarian",
37
+ "hye": "Armenian",
38
+ "ibo": "Igbo",
39
+ "ind": "Indonesian",
40
+ "isl": "Icelandic",
41
+ "ita": "Italian",
42
+ "jav": "Javanese",
43
+ "jpn": "Japanese",
44
+ "kam": "Kamba",
45
+ "kan": "Kannada",
46
+ "kat": "Georgian",
47
+ "kaz": "Kazakh",
48
+ "kea": "Kabuverdianu",
49
+ "khk": "Halh Mongolian",
50
+ "khm": "Khmer",
51
+ "kir": "Kyrgyz",
52
+ "kor": "Korean",
53
+ "lao": "Lao",
54
+ "lit": "Lithuanian",
55
+ "ltz": "Luxembourgish",
56
+ "lug": "Ganda",
57
+ "luo": "Luo",
58
+ "lvs": "Standard Latvian",
59
+ "mai": "Maithili",
60
+ "mal": "Malayalam",
61
+ "mar": "Marathi",
62
+ "mkd": "Macedonian",
63
+ "mlt": "Maltese",
64
+ "mni": "Meitei",
65
+ "mya": "Burmese",
66
+ "nld": "Dutch",
67
+ "nno": "Norwegian Nynorsk",
68
+ "nob": "Norwegian Bokm\u00e5l",
69
+ "npi": "Nepali",
70
+ "nya": "Nyanja",
71
+ "oci": "Occitan",
72
+ "ory": "Odia",
73
+ "pan": "Punjabi",
74
+ "pbt": "Southern Pashto",
75
+ "pes": "Western Persian",
76
+ "pol": "Polish",
77
+ "por": "Portuguese",
78
+ "ron": "Romanian",
79
+ "rus": "Russian",
80
+ "slk": "Slovak",
81
+ "slv": "Slovenian",
82
+ "sna": "Shona",
83
+ "snd": "Sindhi",
84
+ "som": "Somali",
85
+ "spa": "Spanish",
86
+ "srp": "Serbian",
87
+ "swe": "Swedish",
88
+ "swh": "Swahili",
89
+ "tam": "Tamil",
90
+ "tel": "Telugu",
91
+ "tgk": "Tajik",
92
+ "tgl": "Tagalog",
93
+ "tha": "Thai",
94
+ "tur": "Turkish",
95
+ "ukr": "Ukrainian",
96
+ "urd": "Urdu",
97
+ "uzn": "Northern Uzbek",
98
+ "vie": "Vietnamese",
99
+ "xho": "Xhosa",
100
+ "yor": "Yoruba",
101
+ "yue": "Cantonese",
102
+ "zlm": "Colloquial Malay",
103
+ "zsm": "Standard Malay",
104
+ "zul": "Zulu",
105
+ }
106
+ LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
107
+
108
  # Source langs: S2ST / S2TT / ASR don't need source lang
109
  # T2TT / T2ST use this
110
+ text_source_language_codes = [
111
  "afr",
112
  "amh",
113
  "arb",
 
124
  "ces",
125
  "ckb",
126
  "cmn",
 
127
  "cym",
128
  "dan",
129
  "deu",
 
133
  "eus",
134
  "fin",
135
  "fra",
 
136
  "gaz",
137
  "gle",
138
  "glg",
 
180
  "por",
181
  "ron",
182
  "rus",
 
183
  "slk",
184
  "slv",
185
  "sna",
 
204
  "zsm",
205
  "zul",
206
  ]
207
+ TEXT_SOURCE_LANGUAGE_NAMES = sorted([language_code_to_name[code] for code in text_source_language_codes])
208
 
209
  # Target langs:
210
  # S2ST / T2ST
211
+ s2st_target_language_codes = [
212
  "eng",
213
  "arb",
214
  "ben",
 
238
  "swe",
239
  "swh",
240
  "tel",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  "tgl",
242
  "tha",
243
  "tur",
 
245
  "urd",
246
  "uzn",
247
  "vie",
 
 
 
 
248
  ]
249
+ S2ST_TARGET_LANGUAGE_NAMES = sorted([language_code_to_name[code] for code in s2st_target_language_codes])
250
+
251
+ # S2TT / ASR
252
+ S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
253
  # T2TT
254
+ T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES