johntsi's picture
Update README.md
61ebc87 verified
|
raw
history blame
12.1 kB
---
language:
- ace
- acm
- acq
- aeb
- af
- ajp
- ak
- als
- am
- apc
- ar
- ars
- ary
- arz
- as
- ast
- awa
- ayr
- azb
- azj
- ba
- bm
- ban
- be
- bem
- bn
- bho
- bjn
- bo
- bs
- bug
- bg
- ca
- ceb
- cs
- cjk
- ckb
- crh
- cy
- da
- de
- dik
- dyu
- dz
- el
- en
- eo
- et
- eu
- ee
- fo
- fj
- fi
- fon
- fr
- fur
- fuv
- gaz
- gd
- ga
- gl
- gn
- gu
- ht
- ha
- he
- hi
- hne
- hr
- hu
- hy
- ig
- ilo
- id
- is
- it
- jv
- ja
- kab
- kac
- kam
- kn
- ks
- ka
- kk
- kbp
- kea
- khk
- km
- ki
- rw
- ky
- kmb
- kmr
- knc
- kg
- ko
- lo
- lij
- li
- ln
- lt
- lmo
- ltg
- lb
- lua
- lg
- luo
- lus
- lvs
- mag
- mai
- ml
- mar
- min
- mk
- mt
- mni
- mos
- mi
- my
- nl
- nn
- nb
- npi
- nso
- nus
- ny
- oc
- ory
- pag
- pa
- pap
- pbt
- pes
- plt
- pl
- pt
- prs
- quy
- ro
- rn
- ru
- sg
- sa
- sat
- scn
- shn
- si
- sk
- sl
- sm
- sn
- sd
- so
- st
- es
- sc
- sr
- ss
- su
- sv
- swh
- szl
- ta
- taq
- tt
- te
- tg
- tl
- th
- ti
- tpi
- tn
- ts
- tk
- tum
- tr
- tw
- tzm
- ug
- uk
- umb
- ur
- uzn
- vec
- vi
- war
- wo
- xh
- ydd
- yo
- yue
- zh
- zsm
- zu
language_details: >-
ace_Arab, ace_Latn, acm_Arab, acq_Arab, aeb_Arab, afr_Latn, ajp_Arab,
aka_Latn, amh_Ethi, apc_Arab, arb_Arab, ars_Arab, ary_Arab, arz_Arab,
asm_Beng, ast_Latn, awa_Deva, ayr_Latn, azb_Arab, azj_Latn, bak_Cyrl,
bam_Latn, ban_Latn,bel_Cyrl, bem_Latn, ben_Beng, bho_Deva, bjn_Arab, bjn_Latn,
bod_Tibt, bos_Latn, bug_Latn, bul_Cyrl, cat_Latn, ceb_Latn, ces_Latn,
cjk_Latn, ckb_Arab, crh_Latn, cym_Latn, dan_Latn, deu_Latn, dik_Latn,
dyu_Latn, dzo_Tibt, ell_Grek, eng_Latn, epo_Latn, est_Latn, eus_Latn,
ewe_Latn, fao_Latn, pes_Arab, fij_Latn, fin_Latn, fon_Latn, fra_Latn,
fur_Latn, fuv_Latn, gla_Latn, gle_Latn, glg_Latn, grn_Latn, guj_Gujr,
hat_Latn, hau_Latn, heb_Hebr, hin_Deva, hne_Deva, hrv_Latn, hun_Latn,
hye_Armn, ibo_Latn, ilo_Latn, ind_Latn, isl_Latn, ita_Latn, jav_Latn,
jpn_Jpan, kab_Latn, kac_Latn, kam_Latn, kan_Knda, kas_Arab, kas_Deva,
kat_Geor, knc_Arab, knc_Latn, kaz_Cyrl, kbp_Latn, kea_Latn, khm_Khmr,
kik_Latn, kin_Latn, kir_Cyrl, kmb_Latn, kon_Latn, kor_Hang, kmr_Latn,
lao_Laoo, lvs_Latn, lij_Latn, lim_Latn, lin_Latn, lit_Latn, lmo_Latn,
ltg_Latn, ltz_Latn, lua_Latn, lug_Latn, luo_Latn, lus_Latn, mag_Deva,
mai_Deva, mal_Mlym, mar_Deva, min_Latn, mkd_Cyrl, plt_Latn, mlt_Latn,
mni_Beng, khk_Cyrl, mos_Latn, mri_Latn, zsm_Latn, mya_Mymr, nld_Latn,
nno_Latn, nob_Latn, npi_Deva, nso_Latn, nus_Latn, nya_Latn, oci_Latn,
gaz_Latn, ory_Orya, pag_Latn, pan_Guru, pap_Latn, pol_Latn, por_Latn,
prs_Arab, pbt_Arab, quy_Latn, ron_Latn, run_Latn, rus_Cyrl, sag_Latn,
san_Deva, sat_Beng, scn_Latn, shn_Mymr, sin_Sinh, slk_Latn, slv_Latn,
smo_Latn, sna_Latn, snd_Arab, som_Latn, sot_Latn, spa_Latn, als_Latn,
srd_Latn, srp_Cyrl, ssw_Latn, sun_Latn, swe_Latn, swh_Latn, szl_Latn,
tam_Taml, tat_Cyrl, tel_Telu, tgk_Cyrl, tgl_Latn, tha_Thai, tir_Ethi,
taq_Latn, taq_Tfng, tpi_Latn, tsn_Latn, tso_Latn, tuk_Latn, tum_Latn,
tur_Latn, twi_Latn, tzm_Tfng, uig_Arab, ukr_Cyrl, umb_Latn, urd_Arab,
uzn_Latn, vec_Latn, vie_Latn, war_Latn, wol_Latn, xho_Latn, ydd_Hebr,
yor_Latn, yue_Hant, zho_Hans, zho_Hant, zul_Latn
license: mit
metrics:
- bleu
pipeline_tag: automatic-speech-recognition
tags:
- zeroswot
- speech translation
- zero-shot
- end-to-end
- nllb
- wav2vec2
---
# ZeroSwot ✨🤖✨
<!-- <div style='display:flex; gap: 0.25rem; '>
<a href='https://arxiv.org/abs/2402.10422'><img src='https://img.shields.io/badge/paper-PDF-green'></a>
<a href='https://github.com/mt-upc/ZeroSwot/blob/main/LICENSE'><img src='https://img.shields.io/badge/License-MIT-blue.svg'></a>
<a href='https://github.com/mt-upc/ZeroSwot'><img src='https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white'></a>
</div> -->
ZeroSwot is a state-of-the-art zero-shot end-to-end Speech Translation system.
<div align=center><img src="resources/intro.png" height="65%" width="65%"/></div>
The model is created by adapting a wav2vec2.0-based encoder to the embedding space of NLLB, using a novel subword compression module and Optimal Transport, while only utilizing ASR data. It thus enables **Zero-shot E2E Speech Translation to all the 200 languages supported by NLLB**.
For more details please refer to our [paper](https://arxiv.org/abs/2402.10422) and the [original repo](https://github.com/mt-upc/ZeroSwot) build on fairseq.
## Architecture
The compression module is a light-weight transformer that takes as input the hidden state of wav2vec2.0 and the corresponding CTC predictions, and compresses them to subword-like embeddings similar to those expected from NLLB and aligns them using Optimal Transport. For inference we simply pass the output of the speech encoder to NLLB encoder.
<div align=center><img src="resources/methodology.png" height="120%" width="120%"/></div>
## Version
This version of ZeroSwot is trained with ASR data from MuST-C v1.0, and adapted [wav2vec2.0-large](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self) to the [nllb-200-distilled-600M](https://huggingface.co/facebook/nllb-200-distilled-600M) model.
We have more versions available:
| Models | ASR data | NLLB version |
|:------:|:--------:|:------------:|
| [ZeroSwot-Medium_asr-mustc](https://huggingface.co/johntsi/ZeroSwot-Medium_asr-mustc_en-to-200) | MuST-C v1.0 | [distilled-600M original](https://huggingface.co/facebook/nllb-200-distilled-600M)|
| [ZeroSwot-Medium_asr-mustc_mt-mustc](https://huggingface.co/johntsi/ZeroSwot-Medium_asr-mustc_mt-mustc_en-to-8) | MuST-C v1.0 | [distilled-600M finetuned w/ MuST-C](https://huggingface.co/johntsi/nllb-200-distilled-600M_mustc_en-to-8) |
| [ZeroSwot-Large_asr-mustc](https://huggingface.co/johntsi/ZeroSwot-Large_asr-mustc_en-to-200) | MuST-C v1.0 | [distilled-1.3B original](https://huggingface.co/facebook/nllb-200-distilled-1.3B) |
| [ZeroSwot-Large_asr-mustc_mt-mustc](https://huggingface.co/johntsi/ZeroSwot-Large_asr-mustc_mt-mustc_en-to-8) | MuST-C v1.0 | [distilled-1.3B finetuned w/ MuST-C](https://huggingface.co/johntsi/nllb-200-distilled-1.3B_mustc_en-to-8) |
| [ZeroSwot-Medium_asr-cv](https://huggingface.co/johntsi/ZeroSwot-Medium_asr-cv_en-to-200) | CommonVoice | [distilled-600M original](https://huggingface.co/facebook/nllb-200-distilled-600M)|
| [ZeroSwot-Medium_asr-cv_mt-covost2](https://huggingface.co/johntsi/ZeroSwot-Medium_asr-cv_mt-covost2_en-to-15) | CommonVoice | [distilled-600M finetuned w/ CoVoST2](https://huggingface.co/johntsi/nllb-200-distilled-600M_covost2_en-to-15) |
| [ZeroSwot-Large_asr-cv](https://huggingface.co/johntsi/ZeroSwot-Large_asr-cv_en-to-200) | CommonVoice | [distilled-1.3B original](https://huggingface.co/facebook/nllb-200-distilled-1.3B) |
| [ZeroSwot-Large_asr-cv_mt-covost2](https://huggingface.co/johntsi/ZeroSwot-Large_asr-cv_mt-covost2_en-to-15) | CommonVoice | [distilled-1.3B finetuned w/ CoVoST2](https://huggingface.co/johntsi/nllb-200-distilled-1.3B_covost2_en-to-15) |
## Usage
The model is tested with python 3.9.16 and Transformer v4.41.2. Install also torchaudio and sentencepiece for processing.
```bash
pip install transformers torchaudio sentencepiece
```
```python
from transformers import Wav2Vec2Processor, NllbTokenizer, AutoModel, AutoModelForSeq2SeqLM
import torchaudio
def load_and_resample_audio(audio_path, target_sr=16000):
audio, orig_freq = torchaudio.load(audio_path)
if orig_freq != target_sr:
audio = torchaudio.functional.resample(audio, orig_freq=orig_freq, new_freq=target_sr)
audio = audio.squeeze(0).numpy()
return audio
# Load processors and tokenizers
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
# Load ZeroSwot Encoder
commit_hash = "30d17145fd8e040430bbfcf74a011070fa83debd"
zeroswot_encoder = AutoModel.from_pretrained(
"johntsi/ZeroSwot-Medium_asr-mustc_en-to-200", trust_remote_code=True, revision=commit_hash,
)
zeroswot_encoder.eval()
zeroswot_encoder.to("cuda")
# Load NLLB Model
nllb_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
nllb_model.eval()
nllb_model.to("cuda")
# Load audio file
audio = load_and_resample_audio(path_to_audio_file) # you can use "resources/sample.wav" for testing
input_values = processor(audio, sampling_rate=16000, return_tensors="pt").to("cuda")
# translation to German
compressed_embeds, attention_mask = zeroswot_encoder(**input_values)
predicted_ids = nllb_model.generate(
inputs_embeds=compressed_embeds,
attention_mask=attention_mask,
forced_bos_token_id=tokenizer.lang_code_to_id["deu_Latn"],
num_beams=5,
)
translation = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
print(translation)
```
## Results
BLEU scores on MuST-C v1.0 tst-COMMON compared to _supervised_ SOTA models from the literature. You can refer to Table 4 of the Results section in the paper for more details.
| Models | ZS | Size (B) | De | Es | Fr | It | Nl | Pt | Ro | Ru | Average |
|:-----------------------:|:----:|:----------:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:-------:|
| Chimera (Han et al., 2021) | ✗ | 0.15 | 27.1 | 30.6 | 35.6 | 25.0 | 29.2 | 30.2 | 24.0 | 17.4 | 27.4 |
| STEMM (Fang et al., 2022) | ✗ | 0.15 | 28.7 | 31.0 | 37.4 | 25.8 | 30.5 | 31.7 | 24.5 | 17.8 | 28.4 |
| SpeechUT (Zhang et al., 2022) | ✗ | 0.15 | 30.1 | 33.6 | 41.4 | - | - | - | - | - | - |
| Siamese-PT (Le et al., 2023) | ✗ | 0.25 | 27.9 | 31.8 | 39.2 | 27.7 | 31.7 | 34.2 | 27.0 | 18.5 | 29.8 |
| CRESS (Fang and Feng, 2023) | ✗ | 0.15 | 29.4 | 33.2 | 40.1 | 27.6 | 32.2 | 33.6 | 26.4 | 19.7 | 30.3 |
| SimRegCR (Gao et al., 2023b) | ✗ | 0.15 | 29.2 | 33.0 | 40.0 | 28.2 | 32.7 | 34.2 | 26.7 | 20.1 | 30.5 |
| LST (LLaMA2-13B) (Zhang et al., 2023)| ✗ | 13 | 30.4 | 35.3 | **41.6** | - | - | - | - | - | - |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| [ZeroSwot-Medium_asr-cv](https://huggingface.co/johntsi/ZeroSwot-Medium_asr-cv_en-to-200) | ✓ | 0.35/0.95 | 24.8 | 30.0 | 32.6 | 24.1 | 28.6 | 28.8 | 22.9 | 16.4 | 26.0 |
| [ZeroSwot-Medium_asr-mustc](https://huggingface.co/johntsi/ZeroSwot-Medium_asr-mustc_en-to-200) | ✓ | 0.35/0.95 | 28.5 | 33.1 | 37.5 | 28.2 | 32.3 | 32.9 | 26.0 | 18.7 | 29.6 |
| [ZeroSwot-Medium_asr-mustc_mt-mustc](https://huggingface.co/johntsi/ZeroSwot-Medium_asr-mustc_mt-mustc_en-to-8) | ✓ | 0.35/0.95†| 30.5 | 34.9 | 39.4 | 30.6 | 35.0 | 37.1 | 27.8 | 20.3 | 31.9 |
| [ZeroSwot-Large_asr-cv](https://huggingface.co/johntsi/ZeroSwot-Large_asr-cv_en-to-200) | ✓ | 0.35/1.65 | 26.5 | 31.1 | 33.5 | 25.4 | 29.9 | 30.6 | 24.3 | 18.0 | 27.4 |
| [ZeroSwot-Large_asr-mustc](https://huggingface.co/johntsi/ZeroSwot-Large_asr-mustc_en-to-200)| ✓ | 0.35/1.65 | 30.1 | 34.8 | 38.9 | 29.8 | 34.4 | 35.3 | 27.6 | 20.4 | 31.4 |
| [ZeroSwot-Large_asr-mustc_mt-mustc](https://huggingface.co/johntsi/ZeroSwot-Large_asr-mustc_mt-mustc_en-to-8)| ✓ | 0.35/1.65†| **31.2** | **35.8** | 40.5 | **31.4** | **36.3** | **38.3** | **28.0** | **21.5** | **32.9** |
## Citation
If you find ZeroSwot useful for your research, please cite our paper :)
```
@inproceedings{tsiamas-etal-2024-pushing,
title = {{Pushing the Limits of Zero-shot End-to-End Speech Translation}},
author = "Tsiamas, Ioannis and
G{\'a}llego, Gerard and
Fonollosa, Jos{\'e} and
Costa-juss{\`a}, Marta",
editor = "Ku, Lun-Wei and
Martins, Andre and
Srikumar, Vivek",
booktitle = "Findings of the Association for Computational Linguistics ACL 2024",
month = aug,
year = "2024",
address = "Bangkok, Thailand and virtual meeting",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2024.findings-acl.847",
pages = "14245--14267",
}
```