Update README.md
Browse files
README.md
CHANGED
@@ -239,7 +239,6 @@ pipe = pipeline(
|
|
239 |
model_kwargs=model_kwargs
|
240 |
)
|
241 |
|
242 |
-
|
243 |
# load sample audio & downsample to 16kHz
|
244 |
dataset = load_dataset("japanese-asr/ja_asr.reazonspeech_test", split="test")
|
245 |
|
@@ -296,60 +295,45 @@ pip install --upgrade transformers datasets[audio] evaluate jiwer
|
|
296 |
Evaluation can then be run end-to-end with the following example:
|
297 |
|
298 |
```python
|
299 |
-
from
|
|
|
|
|
|
|
300 |
from datasets import load_dataset, Audio
|
301 |
from evaluate import load
|
302 |
-
import torch
|
303 |
-
from tqdm import tqdm
|
304 |
|
305 |
-
# config
|
306 |
model_id = "kotoba-tech/kotoba-whisper-v1.0"
|
307 |
-
dataset_name = "japanese-asr/ja_asr.reazonspeech_test"
|
308 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
309 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
audio_column = 'audio'
|
311 |
text_column = 'transcription'
|
312 |
-
batch_size = 16
|
313 |
|
314 |
# load model
|
315 |
-
|
316 |
-
|
317 |
-
|
|
|
|
|
|
|
|
|
|
|
318 |
|
319 |
# load the dataset and sample the audio with 16kHz
|
320 |
dataset = load_dataset(dataset_name, split="test")
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
def inference(batch):
|
326 |
-
# 1. Pre-process the audio data to log-mel spectrogram inputs
|
327 |
-
audio = [sample["array"] for sample in batch["audio"]]
|
328 |
-
input_features = processor(audio, sampling_rate=batch["audio"][0]["sampling_rate"], return_tensors="pt").input_features
|
329 |
-
input_features = input_features.to(device, dtype=torch_dtype)
|
330 |
-
# 2. Auto-regressively generate the predicted token ids
|
331 |
-
pred_ids = model.generate(input_features, language="ja", max_new_tokens=128)
|
332 |
-
# 3. Decode the token ids to the final transcription
|
333 |
-
batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
|
334 |
-
batch["reference"] = batch[text_column]
|
335 |
-
return batch
|
336 |
-
|
337 |
-
dataset = dataset.map(function=inference, batched=True, batch_size=batch_size)
|
338 |
-
|
339 |
-
# iterate over the dataset and run inference
|
340 |
-
all_transcriptions = []
|
341 |
-
all_references = []
|
342 |
-
for result in tqdm(dataset, desc="Evaluating..."):
|
343 |
-
all_transcriptions.append(result["transcription"])
|
344 |
-
all_references.append(result["reference"])
|
345 |
-
|
346 |
-
# normalize predictions and references
|
347 |
-
all_transcriptions = [transcription.replace(" ", "") for transcription in all_transcriptions]
|
348 |
-
all_references = [reference.replace(" ", "") for reference in all_references]
|
349 |
|
350 |
# compute the CER metric
|
351 |
cer_metric = load("cer")
|
352 |
-
cer = 100 * cer_metric.compute(predictions=
|
353 |
print(cer)
|
354 |
```
|
355 |
|
|
|
239 |
model_kwargs=model_kwargs
|
240 |
)
|
241 |
|
|
|
242 |
# load sample audio & downsample to 16kHz
|
243 |
dataset = load_dataset("japanese-asr/ja_asr.reazonspeech_test", split="test")
|
244 |
|
|
|
295 |
Evaluation can then be run end-to-end with the following example:
|
296 |
|
297 |
```python
|
298 |
+
from tqdm import tqdm
|
299 |
+
|
300 |
+
import torch
|
301 |
+
from transformers import pipeline
|
302 |
from datasets import load_dataset, Audio
|
303 |
from evaluate import load
|
|
|
|
|
304 |
|
305 |
+
# model config
|
306 |
model_id = "kotoba-tech/kotoba-whisper-v1.0"
|
|
|
307 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
308 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
309 |
+
model_kwargs = {"attn_implementation": "sdpa"} if torch.cuda.is_available() else {}
|
310 |
+
generate_kwargs = {"language": "japanese", "task": "transcribe"}
|
311 |
+
|
312 |
+
# data config
|
313 |
+
generate_kwargs = {"language": "japanese", "task": "transcribe"}
|
314 |
+
dataset_name = "japanese-asr/ja_asr.reazonspeech_test"
|
315 |
audio_column = 'audio'
|
316 |
text_column = 'transcription'
|
|
|
317 |
|
318 |
# load model
|
319 |
+
pipe = pipeline(
|
320 |
+
"automatic-speech-recognition",
|
321 |
+
model=model_id,
|
322 |
+
torch_dtype=torch_dtype,
|
323 |
+
device=device,
|
324 |
+
model_kwargs=model_kwargs,
|
325 |
+
batch_size=16
|
326 |
+
)
|
327 |
|
328 |
# load the dataset and sample the audio with 16kHz
|
329 |
dataset = load_dataset(dataset_name, split="test")
|
330 |
+
transcriptions = pipe(dataset['audio'])
|
331 |
+
transcriptions = [i['text'].replace(" ", "") for i in transcriptions]
|
332 |
+
references = [i.replace(" ", "") for i in dataset['transcription']]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
|
334 |
# compute the CER metric
|
335 |
cer_metric = load("cer")
|
336 |
+
cer = 100 * cer_metric.compute(predictions=transcriptions, references=references)
|
337 |
print(cer)
|
338 |
```
|
339 |
|