Commit
·
12c969c
1
Parent(s):
89d22e9
Fix examples: input_ids -> input_features (#2)
Browse files- Fix examples: input_ids -> input_features (bb4003f2f80a17a9f6b736fae47f65aa63ba3eb2)
- Update README.md (2e3b22f123401d6a3dfef7756384ca5cfc352807)
Co-authored-by: Sanchit Gandhi <[email protected]>
README.md
CHANGED
@@ -101,7 +101,7 @@ input_features = processor(
|
|
101 |
sampling_rate=16_000,
|
102 |
return_tensors="pt"
|
103 |
).input_features # Batch size 1
|
104 |
-
generated_ids = model.generate(
|
105 |
|
106 |
transcription = processor.batch_decode(generated_ids)
|
107 |
```
|
@@ -112,29 +112,28 @@ The following script shows how to evaluate this model on the [LibriSpeech](https
|
|
112 |
*"clean"* and *"other"* test dataset.
|
113 |
|
114 |
```python
|
115 |
-
from datasets import load_dataset
|
|
|
116 |
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
|
117 |
|
118 |
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test") # change to "other" for other test dataset
|
119 |
-
wer =
|
120 |
|
121 |
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr").to("cuda")
|
122 |
processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr", do_upper_case=True)
|
123 |
|
124 |
-
librispeech_eval = librispeech_eval.map(map_to_array)
|
125 |
-
|
126 |
def map_to_pred(batch):
|
127 |
features = processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt")
|
128 |
input_features = features.input_features.to("cuda")
|
129 |
attention_mask = features.attention_mask.to("cuda")
|
130 |
|
131 |
-
gen_tokens = model.generate(
|
132 |
-
batch["transcription"] = processor.batch_decode(gen_tokens, skip_special_tokens=True)
|
133 |
return batch
|
134 |
|
135 |
-
result = librispeech_eval.map(map_to_pred,
|
136 |
|
137 |
-
print("WER:", wer(predictions=result["transcription"], references=result["text"]))
|
138 |
```
|
139 |
|
140 |
*Result (WER)*:
|
|
|
101 |
sampling_rate=16_000,
|
102 |
return_tensors="pt"
|
103 |
).input_features # Batch size 1
|
104 |
+
generated_ids = model.generate(input_features=input_features)
|
105 |
|
106 |
transcription = processor.batch_decode(generated_ids)
|
107 |
```
|
|
|
112 |
*"clean"* and *"other"* test dataset.
|
113 |
|
114 |
```python
|
115 |
+
from datasets import load_dataset
|
116 |
+
from evaluate import load
|
117 |
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
|
118 |
|
119 |
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test") # change to "other" for other test dataset
|
120 |
+
wer = load("wer")
|
121 |
|
122 |
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr").to("cuda")
|
123 |
processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr", do_upper_case=True)
|
124 |
|
|
|
|
|
125 |
def map_to_pred(batch):
|
126 |
features = processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt")
|
127 |
input_features = features.input_features.to("cuda")
|
128 |
attention_mask = features.attention_mask.to("cuda")
|
129 |
|
130 |
+
gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask)
|
131 |
+
batch["transcription"] = processor.batch_decode(gen_tokens, skip_special_tokens=True)[0]
|
132 |
return batch
|
133 |
|
134 |
+
result = librispeech_eval.map(map_to_pred, remove_columns=["audio"])
|
135 |
|
136 |
+
print("WER:", wer.compute(predictions=result["transcription"], references=result["text"]))
|
137 |
```
|
138 |
|
139 |
*Result (WER)*:
|