Update README.md
Browse files
README.md
CHANGED
@@ -300,6 +300,7 @@ import torch
|
|
300 |
from transformers import pipeline
|
301 |
from datasets import load_dataset
|
302 |
from evaluate import load
|
|
|
303 |
|
304 |
# model config
|
305 |
model_id = "kotoba-tech/kotoba-whisper-v1.0"
|
@@ -307,6 +308,7 @@ torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
|
307 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
308 |
model_kwargs = {"attn_implementation": "sdpa"} if torch.cuda.is_available() else {}
|
309 |
generate_kwargs = {"language": "japanese", "task": "transcribe"}
|
|
|
310 |
|
311 |
# data config
|
312 |
dataset_name = "japanese-asr/ja_asr.reazonspeech_test"
|
@@ -326,8 +328,8 @@ pipe = pipeline(
|
|
326 |
# load the dataset and sample the audio with 16kHz
|
327 |
dataset = load_dataset(dataset_name, split="test")
|
328 |
transcriptions = pipe(dataset['audio'])
|
329 |
-
transcriptions = [i['text'].replace(" ", "") for i in transcriptions]
|
330 |
-
references = [i.replace(" ", "") for i in dataset['transcription']]
|
331 |
|
332 |
# compute the CER metric
|
333 |
cer_metric = load("cer")
|
|
|
300 |
from transformers import pipeline
|
301 |
from datasets import load_dataset
|
302 |
from evaluate import load
|
303 |
+
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
304 |
|
305 |
# model config
|
306 |
model_id = "kotoba-tech/kotoba-whisper-v1.0"
|
|
|
308 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
309 |
model_kwargs = {"attn_implementation": "sdpa"} if torch.cuda.is_available() else {}
|
310 |
generate_kwargs = {"language": "japanese", "task": "transcribe"}
|
311 |
+
normalizer = BasicTextNormalizer()
|
312 |
|
313 |
# data config
|
314 |
dataset_name = "japanese-asr/ja_asr.reazonspeech_test"
|
|
|
328 |
# load the dataset and sample the audio with 16kHz
|
329 |
dataset = load_dataset(dataset_name, split="test")
|
330 |
transcriptions = pipe(dataset['audio'])
|
331 |
+
transcriptions = [normalizer(i['text']).replace(" ", "") for i in transcriptions]
|
332 |
+
references = [normalizer(i).replace(" ", "") for i in dataset['transcription']]
|
333 |
|
334 |
# compute the CER metric
|
335 |
cer_metric = load("cer")
|