File size: 945 Bytes
778e524
 
5021159
 
778e524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5021159
 
 
 
778e524
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
from datasets import load_metric
from transformers import logging
import random

wer_metric = load_metric("./model-bin/metrics/wer")


# print(wer_metric)


def compute_metrics_fn(processor):
    def compute(pred):
        pred_logits = pred.predictions
        pred_ids = np.argmax(pred_logits, axis=-1)

        pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

        pred_str = processor.batch_decode(pred_ids)
        # we do not want to group tokens when computing the metrics
        label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

        random_idx = random.randint(0, len(label_str))
        logging.get_logger().info(
            '\n\n\nRandom sample predict:\nTruth: {}\nPredict: {}'.format(label_str[random_idx], pred_str[random_idx]))

        wer = wer_metric.compute(predictions=pred_str, references=label_str)

        return {"wer": wer}

    return compute