Commit
·
5021159
1
Parent(s):
1e275bf
add random print sample when eval
Browse files- metric_utils.py +6 -0
metric_utils.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import numpy as np
|
2 |
from datasets import load_metric
|
|
|
|
|
3 |
|
4 |
wer_metric = load_metric("./model-bin/metrics/wer")
|
5 |
|
@@ -18,6 +20,10 @@ def compute_metrics_fn(processor):
|
|
18 |
# we do not want to group tokens when computing the metrics
|
19 |
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
20 |
|
|
|
|
|
|
|
|
|
21 |
wer = wer_metric.compute(predictions=pred_str, references=label_str)
|
22 |
|
23 |
return {"wer": wer}
|
|
|
1 |
import numpy as np
|
2 |
from datasets import load_metric
|
3 |
+
from transformers import logging
|
4 |
+
import random
|
5 |
|
6 |
wer_metric = load_metric("./model-bin/metrics/wer")
|
7 |
|
|
|
20 |
# we do not want to group tokens when computing the metrics
|
21 |
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
22 |
|
23 |
+
random_idx = random.randint(0, len(label_str))
|
24 |
+
logging.get_logger().info(
|
25 |
+
'\n\n\nRandom sample predict:\nTruth: {}\nPredict: {}'.format(label_str[random_idx], pred_str[random_idx]))
|
26 |
+
|
27 |
wer = wer_metric.compute(predictions=pred_str, references=label_str)
|
28 |
|
29 |
return {"wer": wer}
|