File size: 2,483 Bytes
0479abb
 
 
 
 
 
 
 
f8cd84d
 
 
 
 
0479abb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8cd84d
0479abb
 
 
f8cd84d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import argparse
from typing import Dict

import nemo.collections.asr as nemo_asr
import torch
from omegaconf import open_dict


def evaluate_model(
        model_path: str = None,
        test_manifest: str = None,
        batch_size: int = 1,
) -> Dict:

    # Determine the device (CPU or GPU)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Restore the ASR model from the provided path
    model = nemo_asr.models.ASRModel.restore_from(restore_path=model_path)
    model.to(device)
    model.eval()

    # Update the model configuration for evaluation
    with open_dict(model.cfg):
        model.cfg.validation_ds.manifest_filepath = test_manifest
        model.cfg.validation_ds.batch_size = batch_size

    # Set up the test data using the updated configuration
    model.setup_test_data(model.cfg.validation_ds)

    wer_nums = []
    wer_denoms = []

    # Iterate through the test data
    for test_batch in model.test_dataloader():
        # Extract elements from the test batch
        test_batch = [x for x in test_batch]
        targets = test_batch[2].to(device)
        targets_lengths = test_batch[3].to(device)
        # Forward pass through the model
        log_probs, encoded_len, greedy_predictions = model(input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device))
        # Compute Word Error Rate (WER) and store results
        model._wer.update(greedy_predictions, targets, targets_lengths)
        _, wer_num, wer_denom = model._wer.compute()
        model._wer.reset()
        wer_nums.append(wer_num.detach().cpu().numpy())
        wer_denoms.append(wer_denom.detach().cpu().numpy())
        # Free up memory by deleting variables
        del test_batch, log_probs, targets, targets_lengths, encoded_len, greedy_predictions

    # Compute the WER score
    wer_score = sum(wer_nums) / sum(wer_denoms)
    print({"WER_score": wer_score})


if __name__ == "__main__":
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", default=None, help="Path to a model to evaluate.")
    parser.add_argument("--test_manifest", default=None, help="Path for train manifest JSON file.")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
    args = parser.parse_args()

    evaluate_model(
        model_path=args.model_path, 
        test_manifest=args.test_manifest, 
        batch_size=args.batch_size,
    )