File size: 2,483 Bytes
0479abb f8cd84d 0479abb f8cd84d 0479abb f8cd84d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import argparse
from typing import Dict
import nemo.collections.asr as nemo_asr
import torch
from omegaconf import open_dict
def evaluate_model(
model_path: str = None,
test_manifest: str = None,
batch_size: int = 1,
) -> Dict:
# Determine the device (CPU or GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Restore the ASR model from the provided path
model = nemo_asr.models.ASRModel.restore_from(restore_path=model_path)
model.to(device)
model.eval()
# Update the model configuration for evaluation
with open_dict(model.cfg):
model.cfg.validation_ds.manifest_filepath = test_manifest
model.cfg.validation_ds.batch_size = batch_size
# Set up the test data using the updated configuration
model.setup_test_data(model.cfg.validation_ds)
wer_nums = []
wer_denoms = []
# Iterate through the test data
for test_batch in model.test_dataloader():
# Extract elements from the test batch
test_batch = [x for x in test_batch]
targets = test_batch[2].to(device)
targets_lengths = test_batch[3].to(device)
# Forward pass through the model
log_probs, encoded_len, greedy_predictions = model(input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device))
# Compute Word Error Rate (WER) and store results
model._wer.update(greedy_predictions, targets, targets_lengths)
_, wer_num, wer_denom = model._wer.compute()
model._wer.reset()
wer_nums.append(wer_num.detach().cpu().numpy())
wer_denoms.append(wer_denom.detach().cpu().numpy())
# Free up memory by deleting variables
del test_batch, log_probs, targets, targets_lengths, encoded_len, greedy_predictions
# Compute the WER score
wer_score = sum(wer_nums) / sum(wer_denoms)
print({"WER_score": wer_score})
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", default=None, help="Path to a model to evaluate.")
parser.add_argument("--test_manifest", default=None, help="Path for train manifest JSON file.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
args = parser.parse_args()
evaluate_model(
model_path=args.model_path,
test_manifest=args.test_manifest,
batch_size=args.batch_size,
) |