""" Inference main class. Author: Marcely Zanon Boito, 2024 """ from .CTC_model import mHubertForCTC import torch from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor from transformers import HubertConfig from datasets import load_dataset fbk_test_id = 'FBK-MT/Speech-MASSIVE-test' mhubert_id = 'utter-project/mHuBERT-147' def load_asr_model(): # Load the ASR model tokenizer = Wav2Vec2CTCTokenizer("asr/vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|") feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(mhubert_id) processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) config = HubertConfig.from_pretrained("naver/mHuBERT-147-ASR-fr") model = mHubertForCTC.from_pretrained("naver/mHuBERT-147-ASR-fr", config=config) model.eval() return model, processor def run_asr_inference(model, processor, example): audio = processor(example["array"], sampling_rate=example["sampling_rate"]).input_values[0] input_values = torch.tensor(audio).unsqueeze(0) with torch.no_grad(): logits = model(input_values).logits pred_ids = torch.argmax(logits, dim=-1) prediction = processor.batch_decode(pred_ids)[0].replace('[CTC]', "") return prediction