import numpy as np | |
from datasets import load_metric | |
from transformers import logging | |
import random | |
wer_metric = load_metric("./model-bin/metrics/wer") | |
# print(wer_metric) | |
def compute_metrics_fn(processor): | |
def compute(pred): | |
pred_logits = pred.predictions | |
pred_ids = np.argmax(pred_logits, axis=-1) | |
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id | |
pred_str = processor.batch_decode(pred_ids) | |
# we do not want to group tokens when computing the metrics | |
label_str = processor.batch_decode(pred.label_ids, group_tokens=False) | |
random_idx = random.randint(0, len(label_str)) | |
logging.get_logger().info( | |
'\n\n\nRandom sample predict:\nTruth: {}\nPredict: {}'.format(label_str[random_idx], pred_str[random_idx])) | |
wer = wer_metric.compute(predictions=pred_str, references=label_str) | |
return {"wer": wer} | |
return compute | |