vaw2tmp / metric_utils.py
nguyenvulebinh's picture
add random print sample when eval
5021159
raw
history blame contribute delete
945 Bytes
import numpy as np
from datasets import load_metric
from transformers import logging
import random
wer_metric = load_metric("./model-bin/metrics/wer")
# print(wer_metric)
def compute_metrics_fn(processor):
def compute(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
random_idx = random.randint(0, len(label_str))
logging.get_logger().info(
'\n\n\nRandom sample predict:\nTruth: {}\nPredict: {}'.format(label_str[random_idx], pred_str[random_idx]))
wer = wer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
return compute