mzboito commited on
Commit
763071e
·
1 Parent(s): 54c1aff

pushing code to the hub

Browse files
README.md CHANGED
@@ -30,7 +30,6 @@ It is based on the [mHuBERT-147](https://huggingface.co/utter-project/mHuBERT-14
30
 
31
  ## Training Parameters
32
  The training parameters are available in config.yaml.
33
- We downsample the commonvoice dataset to 70,000 utterances.
34
 
35
  ## ASR Model class
36
 
@@ -41,4 +40,15 @@ The code is available in [CTC_model.py](https://huggingface.co/naver/mHuBERT-147
41
  ## Running inference
42
 
43
  The run_asr.py file illustrates how to load the model for inference (**load_asr_model**), and how to produce transcription for a file (**run_asr_inference**).
44
- Please follow the [requirements file](https://huggingface.co/naver/mHuBERT-147-ASR-fr/blob/main/requirements.txt) to avoid incorrect model loading.
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  ## Training Parameters
32
  The training parameters are available in config.yaml.
 
33
 
34
  ## ASR Model class
35
 
 
40
  ## Running inference
41
 
42
  The run_asr.py file illustrates how to load the model for inference (**load_asr_model**), and how to produce transcription for a file (**run_asr_inference**).
43
+ Please follow the [requirements file](https://huggingface.co/naver/mHuBERT-147-ASR-fr/blob/main/requirements.txt) to avoid incorrect model loading.
44
+
45
+ Here is a simple example of the inference loop. Please notice that the sampling rate must be 16,000Hz.
46
+
47
+ ```
48
+ from inference_code.run_inference import load_asr_model, run_asr_inference
49
+
50
+ model, processor = load_asr_model()
51
+
52
+ prediction = run_inference(model, processor, your_audio_file)
53
+
54
+ ```
config.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ group_by_length: True
2
+ evaluation_strategy: "steps"
3
+ num_train_epochs: 100
4
+ fp16: False
5
+ gradient_checkpointing: True
6
+ eval_steps: 10000
7
+ save_steps: 10000
8
+ logging_steps: 10000
9
+ learning_rate: 1e-4
10
+ adam_beta1: 0.9
11
+ adam_beta2: 0.98
12
+ adam_epsilon: 1e-08
13
+ warmup_ratio: 0.2
14
+ save_total_limit: 4
15
+ load_best_model_at_end: True
16
+ per_device_train_batch_size: 8
17
+ per_device_eval_batch_size: 2
18
+ metric_for_best_model: "cer"
19
+ greater_is_better: False
20
+ gradient_accumulation_steps: 8
21
+ final_dropout: 0.3
22
+ seed: 3452
23
+ add_interface_layer: True
24
+ num_interface_layers: 3
CTC_model.py → inference_code/CTC_model.py RENAMED
File without changes
inference_code/__pycache__/CTC_model.cpython-39.pyc ADDED
Binary file (3.47 kB). View file
 
inference_code/__pycache__/run_inference.cpython-39.pyc ADDED
Binary file (2.21 kB). View file
 
inference_code/run_inference.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference main class.
3
+
4
+ Author: Marcely Zanon Boito, 2024
5
+ """
6
+
7
+ from .CTC_model import mHubertForCTC
8
+
9
+ import torch
10
+ from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
11
+ from transformers import HubertConfig
12
+
13
+ from datasets import load_dataset
14
+
15
+ fbk_test_id = 'FBK-MT/Speech-MASSIVE-test'
16
+ mhubert_id = 'utter-project/mHuBERT-147'
17
+
18
+ def load_asr_model():
19
+ def init_config():
20
+ config = HubertConfig.from_pretrained(mhubert_id)
21
+ config.pad_token_id = processor.tokenizer.pad_token_id
22
+ config.ctc_token_id = processor.tokenizer.convert_tokens_to_ids('[CTC]')
23
+ config.vocab_size = len(processor.tokenizer)
24
+
25
+ config.output_hidden_states = False
26
+ config.add_interface = True
27
+ config.num_interface_layers = 3
28
+ return config
29
+
30
+ # Load the ASR model
31
+ tokenizer = Wav2Vec2CTCTokenizer('vocab.json', unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
32
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(mhubert_id)
33
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
34
+
35
+ config = init_config()
36
+ model = mHubertForCTC.from_pretrained("naver/mHuBERT-147-ASR-fr", config=config)
37
+ model.eval()
38
+ return model, processor
39
+
40
+ def run_asr_inference(model, processor, example):
41
+ audio = processor(example["array"], sampling_rate=example["sampling_rate"]).input_values[0]
42
+ input_values = torch.tensor(audio).unsqueeze(0)
43
+
44
+ with torch.no_grad():
45
+ logits = model(input_values).logits
46
+
47
+ pred_ids = torch.argmax(logits, dim=-1)
48
+
49
+ prediction = processor.batch_decode(pred_ids)[0].replace('[CTC]', "")
50
+ return prediction
51
+
52
+ if __name__ == '__main__':
53
+
54
+ # Load the dataset in streaming mode
55
+ dataset = load_dataset(fbk_test_id, 'fr-FR', streaming=True)
56
+ dataset = dataset['test']
57
+ generator = iter(dataset)
58
+
59
+ # load model
60
+ model, processor = load_asr_model()
61
+ print(model)
62
+
63
+ # decode 10 examples from speech-MASSIVE
64
+ num_examples= 10
65
+ while num_examples >= 0:
66
+ example = next(generator)
67
+
68
+ prediction = run_inference(model, processor, example['audio'])
69
+
70
+ gold_standard = example['utt']
71
+
72
+ print("Gold standard:", gold_standard)
73
+ print("Prediction:", prediction)
74
+ print()
75
+ num_examples-=1
76
+
77
+
78
+
79
+
80
+