File size: 945 Bytes
778e524 5021159 778e524 5021159 778e524 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import numpy as np
from datasets import load_metric
from transformers import logging
import random
wer_metric = load_metric("./model-bin/metrics/wer")
# print(wer_metric)
def compute_metrics_fn(processor):
def compute(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
random_idx = random.randint(0, len(label_str))
logging.get_logger().info(
'\n\n\nRandom sample predict:\nTruth: {}\nPredict: {}'.format(label_str[random_idx], pred_str[random_idx]))
wer = wer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
return compute
|