johntsi commited on
Commit
2d18c76
·
verified ·
1 Parent(s): 7656b04

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +319 -3
README.md CHANGED
@@ -1,3 +1,319 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - ace
4
+ - acm
5
+ - acq
6
+ - aeb
7
+ - af
8
+ - ajp
9
+ - ak
10
+ - als
11
+ - am
12
+ - apc
13
+ - ar
14
+ - ars
15
+ - ary
16
+ - arz
17
+ - as
18
+ - ast
19
+ - awa
20
+ - ayr
21
+ - azb
22
+ - azj
23
+ - ba
24
+ - bm
25
+ - ban
26
+ - be
27
+ - bem
28
+ - bn
29
+ - bho
30
+ - bjn
31
+ - bo
32
+ - bs
33
+ - bug
34
+ - bg
35
+ - ca
36
+ - ceb
37
+ - cs
38
+ - cjk
39
+ - ckb
40
+ - crh
41
+ - cy
42
+ - da
43
+ - de
44
+ - dik
45
+ - dyu
46
+ - dz
47
+ - el
48
+ - en
49
+ - eo
50
+ - et
51
+ - eu
52
+ - ee
53
+ - fo
54
+ - fj
55
+ - fi
56
+ - fon
57
+ - fr
58
+ - fur
59
+ - fuv
60
+ - gaz
61
+ - gd
62
+ - ga
63
+ - gl
64
+ - gn
65
+ - gu
66
+ - ht
67
+ - ha
68
+ - he
69
+ - hi
70
+ - hne
71
+ - hr
72
+ - hu
73
+ - hy
74
+ - ig
75
+ - ilo
76
+ - id
77
+ - is
78
+ - it
79
+ - jv
80
+ - ja
81
+ - kab
82
+ - kac
83
+ - kam
84
+ - kn
85
+ - ks
86
+ - ka
87
+ - kk
88
+ - kbp
89
+ - kea
90
+ - khk
91
+ - km
92
+ - ki
93
+ - rw
94
+ - ky
95
+ - kmb
96
+ - kmr
97
+ - knc
98
+ - kg
99
+ - ko
100
+ - lo
101
+ - lij
102
+ - li
103
+ - ln
104
+ - lt
105
+ - lmo
106
+ - ltg
107
+ - lb
108
+ - lua
109
+ - lg
110
+ - luo
111
+ - lus
112
+ - lvs
113
+ - mag
114
+ - mai
115
+ - ml
116
+ - mar
117
+ - min
118
+ - mk
119
+ - mt
120
+ - mni
121
+ - mos
122
+ - mi
123
+ - my
124
+ - nl
125
+ - nn
126
+ - nb
127
+ - npi
128
+ - nso
129
+ - nus
130
+ - ny
131
+ - oc
132
+ - ory
133
+ - pag
134
+ - pa
135
+ - pap
136
+ - pbt
137
+ - pes
138
+ - plt
139
+ - pl
140
+ - pt
141
+ - prs
142
+ - quy
143
+ - ro
144
+ - rn
145
+ - ru
146
+ - sg
147
+ - sa
148
+ - sat
149
+ - scn
150
+ - shn
151
+ - si
152
+ - sk
153
+ - sl
154
+ - sm
155
+ - sn
156
+ - sd
157
+ - so
158
+ - st
159
+ - es
160
+ - sc
161
+ - sr
162
+ - ss
163
+ - su
164
+ - sv
165
+ - swh
166
+ - szl
167
+ - ta
168
+ - taq
169
+ - tt
170
+ - te
171
+ - tg
172
+ - tl
173
+ - th
174
+ - ti
175
+ - tpi
176
+ - tn
177
+ - ts
178
+ - tk
179
+ - tum
180
+ - tr
181
+ - tw
182
+ - tzm
183
+ - ug
184
+ - uk
185
+ - umb
186
+ - ur
187
+ - uzn
188
+ - vec
189
+ - vi
190
+ - war
191
+ - wo
192
+ - xh
193
+ - ydd
194
+ - yo
195
+ - yue
196
+ - zh
197
+ - zsm
198
+ - zu
199
+ language_details: >-
200
+ ace_Arab, ace_Latn, acm_Arab, acq_Arab, aeb_Arab, afr_Latn, ajp_Arab,
201
+ aka_Latn, amh_Ethi, apc_Arab, arb_Arab, ars_Arab, ary_Arab, arz_Arab,
202
+ asm_Beng, ast_Latn, awa_Deva, ayr_Latn, azb_Arab, azj_Latn, bak_Cyrl,
203
+ bam_Latn, ban_Latn,bel_Cyrl, bem_Latn, ben_Beng, bho_Deva, bjn_Arab, bjn_Latn,
204
+ bod_Tibt, bos_Latn, bug_Latn, bul_Cyrl, cat_Latn, ceb_Latn, ces_Latn,
205
+ cjk_Latn, ckb_Arab, crh_Latn, cym_Latn, dan_Latn, deu_Latn, dik_Latn,
206
+ dyu_Latn, dzo_Tibt, ell_Grek, eng_Latn, epo_Latn, est_Latn, eus_Latn,
207
+ ewe_Latn, fao_Latn, pes_Arab, fij_Latn, fin_Latn, fon_Latn, fra_Latn,
208
+ fur_Latn, fuv_Latn, gla_Latn, gle_Latn, glg_Latn, grn_Latn, guj_Gujr,
209
+ hat_Latn, hau_Latn, heb_Hebr, hin_Deva, hne_Deva, hrv_Latn, hun_Latn,
210
+ hye_Armn, ibo_Latn, ilo_Latn, ind_Latn, isl_Latn, ita_Latn, jav_Latn,
211
+ jpn_Jpan, kab_Latn, kac_Latn, kam_Latn, kan_Knda, kas_Arab, kas_Deva,
212
+ kat_Geor, knc_Arab, knc_Latn, kaz_Cyrl, kbp_Latn, kea_Latn, khm_Khmr,
213
+ kik_Latn, kin_Latn, kir_Cyrl, kmb_Latn, kon_Latn, kor_Hang, kmr_Latn,
214
+ lao_Laoo, lvs_Latn, lij_Latn, lim_Latn, lin_Latn, lit_Latn, lmo_Latn,
215
+ ltg_Latn, ltz_Latn, lua_Latn, lug_Latn, luo_Latn, lus_Latn, mag_Deva,
216
+ mai_Deva, mal_Mlym, mar_Deva, min_Latn, mkd_Cyrl, plt_Latn, mlt_Latn,
217
+ mni_Beng, khk_Cyrl, mos_Latn, mri_Latn, zsm_Latn, mya_Mymr, nld_Latn,
218
+ nno_Latn, nob_Latn, npi_Deva, nso_Latn, nus_Latn, nya_Latn, oci_Latn,
219
+ gaz_Latn, ory_Orya, pag_Latn, pan_Guru, pap_Latn, pol_Latn, por_Latn,
220
+ prs_Arab, pbt_Arab, quy_Latn, ron_Latn, run_Latn, rus_Cyrl, sag_Latn,
221
+ san_Deva, sat_Beng, scn_Latn, shn_Mymr, sin_Sinh, slk_Latn, slv_Latn,
222
+ smo_Latn, sna_Latn, snd_Arab, som_Latn, sot_Latn, spa_Latn, als_Latn,
223
+ srd_Latn, srp_Cyrl, ssw_Latn, sun_Latn, swe_Latn, swh_Latn, szl_Latn,
224
+ tam_Taml, tat_Cyrl, tel_Telu, tgk_Cyrl, tgl_Latn, tha_Thai, tir_Ethi,
225
+ taq_Latn, taq_Tfng, tpi_Latn, tsn_Latn, tso_Latn, tuk_Latn, tum_Latn,
226
+ tur_Latn, twi_Latn, tzm_Tfng, uig_Arab, ukr_Cyrl, umb_Latn, urd_Arab,
227
+ uzn_Latn, vec_Latn, vie_Latn, war_Latn, wol_Latn, xho_Latn, ydd_Hebr,
228
+ yor_Latn, yue_Hant, zho_Hans, zho_Hant, zul_Latn
229
+ license: mit
230
+ metrics:
231
+ - bleu
232
+ datasets:
233
+ - mozilla-foundation/common_voice_8_0
234
+ pipeline_tag: automatic-speech-recognition
235
+ tags:
236
+ - zeroswot
237
+ - speech translation
238
+ - zero-shot
239
+ - end-to-end
240
+ - nllb
241
+ - wav2vec2
242
+ ---
243
+
244
+ # ZeroSwot ✨🤖✨
245
+
246
+ ZeroSwot is a state-of-the-art zero-shot end-to-end Speech Translation system.
247
+
248
+ The model is created by adapting a wav2vec2.0-based encoder to the embedding space of NLLB, using a novel subword compression module and Optimal Transport, while using only ASR data. It thus enables **Speech Translation to all the 200 languages supported by NLLB**. The compression module is a light-weight transformer that takes as input the hidden state of wav2vec2.0 and the corresponding CTC predictions, and compresses them to subword-like embeddings similar to those expected from NLLB and aligns them using Optimal Transport. For inference we simply pass the output of the speech encoder to NLLB encoder.
249
+
250
+ For more details please refer to our [paper](https://arxiv.org/abs/2402.10422) and the [original repo](https://github.com/mt-upc/ZeroSwot) build on fairseq.
251
+
252
+ This version of ZeroSwot is trained with ASR data from CommonVoice, and adapting [wav2vec2.0-large](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self) to the [nllb-200-distilled-600M](https://huggingface.co/facebook/nllb-200-distilled-600M) model.
253
+
254
+ <div align=center><img src="resource/structure.png" height="100%" width="75%"/></div>
255
+
256
+ ## Usage
257
+
258
+ ```python
259
+ from transformers import Wav2Vec2Processor, NllbTokenizer, AutoModel, AutoModelForSeq2SeqLM
260
+ import soundfile as sf
261
+
262
+ # Load processors and tokenizers
263
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
264
+ tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
265
+
266
+ # Load ZeroSwot Encoder
267
+ commit_hash = "1d38f5dbf4f89adefe06961e4ec344b21f74ebae"
268
+ zeroswot_encoder = AutoModel.from_pretrained(
269
+ "johntsi/ZeroSwot-Medium_asr-cv_en-to-200", trust_remote_code=True, revision=commit_hash,
270
+ )
271
+ model.eval()
272
+ model.to("cuda")
273
+
274
+ # Load NLLB Model
275
+ nllb_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
276
+ nllb_model.eval()
277
+ nllb_model.to("cuda")
278
+
279
+ # Load sample .wav
280
+ audio, sr = sf.read("sample.wav")
281
+ assert sr == 16000, "Input of wav2vec2.0 is expected to have sampling rate of 16,000"
282
+ input_values = processor(audio, sampling_rate=16000, return_tensors="pt").cuda()
283
+
284
+ # translation to German
285
+ emb, mask = zeroswot_encoder(**input_values)
286
+ predicted_ids = nllb_model.generate(
287
+ inputs_embeds=emb,
288
+ attention_mask=~mask,
289
+ forced_bos_token_id=tokenizer.lang_code_to_id["deu_Latn"],
290
+ num_beams=5,
291
+ )
292
+ translation = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
293
+ print(translation)
294
+ ```
295
+
296
+ ## Results
297
+
298
+ BLEU scores on CoVoST-2 test compared to supervised SOTA models [XLS-R-1B](https://huggingface.co/facebook/wav2vec2-xls-r-1b) and [SeamlessM4T-Medium](https://huggingface.co/facebook/seamless-m4t-medium). You can refer to Table 5 of the Results section in the paper for more details.
299
+
300
+ | Models | ZS | Size (B) | Ar | Ca | Cy | De | Et | Fa | Id | Ja | Lv | Mn | Sl | Sv | Ta | Tr | Zh | Average |
301
+ |:--------------:|:----:|:----------:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:-------:|
302
+ | XLS-R-1B | ✗ | 1.0 | 19.2 | 32.1 | **31.8** | 26.2 | 22.4 | 21.3 | 30.3 | 39.9 | 22.0 | 14.9 | 25.4 | 32.3 | 18.1 | 17.1 | 36.7 | 26.0 |
303
+ | SeamlessM4T-M | ✗ | 1.2 | 20.8 | 37.3 | 29.9 | **31.4** | 23.3 | 17.2 | 34.8 | 37.5 | 19.5 | 12.9 | 29.0 | 37.3 | 18.9 | **19.8** | 30.0 | 26.6 |
304
+ | ZeroSwot-M_asr-cv | ✓ | 0.35/0.95 | **24.4** | **38.7** | 28.8 | 31.2 | **26.2** | **26.0** | **36.0** | **46.0** | **24.8** | **19.0** | **31.6** | **37.8** | **24.4** | 18.6 | **39.0** | **30.2** |
305
+
306
+ ## Citation
307
+
308
+ If you find ZeroSwot useful for your research, please cite our paper :)
309
+
310
+ ```
311
+ @misc{tsiamas2024pushing,
312
+ title={{Pushing the Limits of Zero-shot End-to-End Speech Translation}},
313
+ author={Ioannis Tsiamas and Gerard I. Gállego and José A. R. Fonollosa and Marta R. Costa-jussà},
314
+ year={2024},
315
+ eprint={2402.10422},
316
+ archivePrefix={arXiv},
317
+ primaryClass={cs.CL}
318
+ }
319
+ ```