File size: 2,269 Bytes
5f72cc6
 
 
 
 
 
 
 
 
aa45081
 
5f72cc6
 
 
aa45081
 
5f72cc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import csv

import torch
import torchaudio
import numpy as np
import evaluate

from transformers import HubertForCTC, Wav2Vec2Processor

batch_size = 1
device = "cuda:0"  # cuda:0, or cpu
torch_dtype = torch.float16
sampling_rate = 16_000

model_name = "Yehor/mHuBERT-147-uk"
testset_file = "examples.csv"

# Load the test dataset
with open(testset_file) as f:
    samples = list(csv.DictReader(f))

# Load the model
asr_model = HubertForCTC.from_pretrained(
    model_name,
    device_map=device,
    torch_dtype=torch_dtype,
    # attn_implementation="flash_attention_2",
)
processor = Wav2Vec2Processor.from_pretrained(model_name)


# A util function to make batches
def make_batches(iterable, n=1):
    lx = len(iterable)
    for ndx in range(0, lx, n):
        yield iterable[ndx : min(ndx + n, lx)]


# Temporary variables
predictions_all = []
references_all = []

# Inference in the batched mode
for batch in make_batches(samples, batch_size):
    paths = [it["path"] for it in batch]
    references = [it["text"] for it in batch]

    # Extract audio
    audio_inputs = []
    for path in paths:
        audio_input, sampling_rate = torchaudio.load(path, backend="sox")
        audio_input = audio_input.squeeze(0).numpy()

        audio_inputs.append(audio_input)

    # Transcribe the audio
    inputs = processor(audio_inputs, sampling_rate=16_000, padding=True).input_values

    features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device)

    with torch.inference_mode():
        logits = asr_model(features).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    predictions = processor.batch_decode(predicted_ids)

    # Log outputs
    print("---")
    print("Predictions:")
    print(predictions)
    print("References:")
    print(references)
    print("---")

    # Add predictions and references
    predictions_all.extend(predictions)
    references_all.extend(references)

# Load evaluators
wer = evaluate.load("wer")
cer = evaluate.load("cer")

# Evaluate
wer_value = round(
    wer.compute(predictions=predictions_all, references=references_all), 4
)
cer_value = round(
    cer.compute(predictions=predictions_all, references=references_all), 4
)

# Print results
print("Final:")
print(f"WER: {wer_value} | CER: {cer_value}")