|
|
|
from datasets import load_dataset |
|
from datasets import Audio |
|
import numpy as np |
|
from transformers import WhisperForConditionalGeneration, AutoProcessor, pipeline |
|
import torch |
|
from jiwer import wer |
|
import whisper |
|
|
|
PRECISION = torch.float16 |
|
PRECISION = torch.float32 |
|
DO_COND = True |
|
|
|
|
|
model_id = "openai/whisper-tiny.en" |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=PRECISION) |
|
model = model.to("cuda") |
|
|
|
model_orig = whisper.load_model(model_id.split("whisper-")[-1]) |
|
|
|
|
|
|
|
ds = load_dataset("distil-whisper/earnings21", "full")["test"] |
|
ds = ds.cast_column("audio", Audio(sampling_rate=16000)) |
|
|
|
num_samples = 3 |
|
start = 2 |
|
|
|
audios = [x['array'] for x in ds[start:num_samples]["audio"]] |
|
for name in ["text", "transcription"]: |
|
if name in ds.column_names: |
|
labels = ds[start:num_samples][name] |
|
break |
|
|
|
for audio, label in zip(audios, labels): |
|
inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000) |
|
inputs = inputs.to("cuda", PRECISION) |
|
|
|
if inputs["input_features"].shape[-1] < 3000: |
|
continue |
|
|
|
result = model_orig.transcribe(audio.astype(dtype=np.float32), condition_on_previous_text=DO_COND, temperature=0.0, logprob_threshold=None, compression_ratio_threshold=None, no_speech_threshold=None) |
|
|
|
gen_length = 448 |
|
result_hf = model.generate(**inputs, condition_on_prev_tokens=DO_COND, max_new_tokens=gen_length, return_timestamps=True) |
|
decoded = processor.batch_decode(result_hf, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result_text_norm = processor.tokenizer._normalize(result["text"]) |
|
decoded_norm = processor.tokenizer._normalize(decoded[0]) |
|
label_norm = processor.tokenizer._normalize(label) |
|
|
|
|
|
wer_orig = wer(label_norm, result_text_norm) |
|
wer_hf = wer(label_norm, decoded_norm) |
|
|
|
print("Cond:\n", decoded_norm) |
|
print(50 * "-") |
|
|
|
|
|
print("Orig Cond:\n", result_text_norm) |
|
print(50 * "-") |
|
|
|
|
|
print("Label:\n", label_norm) |
|
|
|
|
|
|
|
print("Result:") |
|
print("WER Orig", wer_orig) |
|
print("WER HF", wer_hf) |
|
|
|
print("Done") |
|
|