bel32123 commited on
Commit
6d3dc99
·
1 Parent(s): e41278c

Add Wav2Vec ASR Model Files

Browse files
wav2vecasr/MispronounciationDetector.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pandas.core.construction import T
2
+ import torch
3
+ import jiwer
4
+
5
+ class MispronounciationDetector:
6
+ def __init__(self, l2_phoneme_recogniser, l2_phoneme_recogniser_processor, g2p, device):
7
+ self.l2_phoneme_recogniser = l2_phoneme_recogniser
8
+ self.l2_phoneme_recogniser_processor = l2_phoneme_recogniser_processor
9
+ self.g2p = g2p
10
+ self.device = device
11
+
12
+ def detect(self, audio, text):
13
+ l2_phones = self.get_l2_phoneme_sequence(audio)
14
+ native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
15
+ raw_info = self.get_mispronounciation_output(text, l2_phones, native_speaker_phones)
16
+ return raw_info
17
+
18
+ def get_l2_phoneme_sequence(self, audio):
19
+ input_dict = self.l2_phoneme_recogniser_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
20
+ logits = self.l2_phoneme_recogniser(input_dict.input_values.to(self.device)).logits
21
+ pred_ids = torch.argmax(logits, dim=-1)[0]
22
+ pred_phones = [phoneme for phoneme in self.l2_phoneme_recogniser_processor.batch_decode(pred_ids) if phoneme != ""]
23
+ return pred_phones
24
+
25
+ def get_native_speaker_phoneme_sequence(self, text):
26
+ phonemes = self.g2p(text)
27
+ return phonemes
28
+
29
+ def get_mispronounciation_output(self, text, pred_phones, org_label_phones):
30
+ # get per
31
+ label_phones = [phone for phone in org_label_phones if phone != " "]
32
+ reference = " ".join(label_phones) # dummy phones
33
+ hypothesis = " ".join(pred_phones) # dummy l2 speaker phones
34
+ res = jiwer.process_words(reference, hypothesis)
35
+ per = res.wer
36
+ # print(jiwer.visualize_alignment(res))
37
+
38
+ # get phoneme alignments
39
+ alignments = res.alignments
40
+ error_bool = []
41
+ ref, hyp = [],[]
42
+ for alignment_chunk in alignments[0]:
43
+ alignment_type = alignment_chunk.type
44
+ ref_start_idx = alignment_chunk.ref_start_idx
45
+ ref_end_idx = alignment_chunk.ref_end_idx
46
+ hyp_start_idx = alignment_chunk.hyp_start_idx
47
+ hyp_end_idx = alignment_chunk.hyp_end_idx
48
+ if alignment_type != "equal":
49
+ if alignment_type == "insert":
50
+ for i in range(hyp_start_idx, hyp_end_idx):
51
+ ref.append("*" * len(pred_phones[i]))
52
+ space_padding = " " * (len(pred_phones[i])-1)
53
+ error_bool.append(space_padding + "a")
54
+ hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx])
55
+ elif alignment_type == "delete":
56
+ ref.extend(label_phones[ref_start_idx:ref_end_idx])
57
+ for i in range(ref_start_idx, ref_end_idx):
58
+ hyp.append("*" * len(label_phones[i]))
59
+ space_padding = " " * (len(label_phones[i])-1)
60
+ error_bool.append(space_padding + alignment_type[0])
61
+ else:
62
+ for i in range(ref_end_idx - ref_start_idx):
63
+ correct_phone = label_phones[ref_start_idx+i]
64
+ pred_phone = pred_phones[hyp_start_idx+i]
65
+ if len(correct_phone) > len(pred_phone):
66
+ space_padding = " " * (len(correct_phone) - len(pred_phone))
67
+ ref.append(correct_phone)
68
+ hyp.append(space_padding + pred_phone)
69
+ error_bool.append(" " * (len(correct_phone)-1) + alignment_type[0])
70
+ else:
71
+ space_padding = " " * (len(pred_phone) - len(correct_phone))
72
+ ref.append(space_padding + correct_phone)
73
+ hyp.append(pred_phone)
74
+ error_bool.append(" " * (len(pred_phone)-1) + alignment_type[0])
75
+ else:
76
+ ref.extend(label_phones[ref_start_idx:ref_end_idx])
77
+ hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx])
78
+ # ref or hyp does not matter
79
+ for i in range(ref_start_idx, ref_end_idx):
80
+ space_padding = "-" * (len(label_phones[i]))
81
+ error_bool.append(space_padding)
82
+
83
+ delimiter_idx = 0
84
+ for phone in org_label_phones:
85
+ if phone == " ":
86
+ hyp.insert(delimiter_idx+1, "|")
87
+ ref.insert(delimiter_idx+1, "|")
88
+ error_bool.insert(delimiter_idx+1, "|")
89
+ continue
90
+ while delimiter_idx < len(ref) and ref[delimiter_idx].strip() != phone:
91
+ delimiter_idx += 1
92
+ # word ends
93
+ ref.append("|")
94
+ hyp.append("|")
95
+
96
+ # get mispronounced words
97
+ aligned_word_error_output = ""
98
+ words = text.split(" ")
99
+ word_error_bool = self.get_mispronounced_words(error_bool)
100
+ wer = sum(word_error_bool) / len(words)
101
+
102
+ raw_info = {"ref":ref, "hyp": hyp, "per":per, "phoneme_errors": error_bool, "wer": wer, "words": words, "word_errors":word_error_bool}
103
+
104
+ return raw_info
105
+
106
+
107
+ def get_mispronounced_words(self, phoneme_error_bool):
108
+ # map mispronounced phones back to words that were mispronounce to get WER
109
+ word_error_bool = []
110
+ phoneme_error_bool.append("|")
111
+ word_phones = self.split_lst_by_delim(phoneme_error_bool, "|")
112
+ for phones in word_phones:
113
+ if "s" in phones or "d" in phones or "a" in phones:
114
+ word_error_bool.append(True)
115
+ else:
116
+ word_error_bool.append(False)
117
+ return word_error_bool
118
+
119
+
120
+ def split_lst_by_delim(self, lst, delimiter):
121
+ temp = []
122
+ res = []
123
+ for item in lst:
124
+ if item != delimiter:
125
+ temp.append(item.strip())
126
+ else:
127
+ res.append(temp);
128
+ temp = []
129
+ return res
wav2vecasr/data/arctic_a0003.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ For the twentieth time that evening the two men shook hands
wav2vecasr/data/arctic_a0003.wav ADDED
Binary file (283 kB). View file
 
wav2vecasr/demo.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
2
+ from speechbrain.pretrained import GraphemeToPhoneme
3
+ import datasets
4
+ import os
5
+ import torchaudio
6
+ from MispronounciationDetector import MispronounciationDetector
7
+
8
+ # Load sample data
9
+ audio_path, transcript_path = os.path.join(os.getcwd(), "data", "arctic_a0003.wav"), os.path.join(os.getcwd(), "data", "arctic_a0003.txt")
10
+ audio, org_sr = torchaudio.load(audio_path)
11
+ audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000)
12
+ audio = audio.view(audio.shape[1])
13
+ with open(transcript_path) as f:
14
+ text = f.read()
15
+ f.close()
16
+ print("Done loading sample data")
17
+
18
+ # Load processors and models
19
+ device = "cpu"
20
+ path = os.path.join(os.getcwd(), "model", "checkpoint-1200")
21
+ model = Wav2Vec2ForCTC.from_pretrained(path).to(device)
22
+ processor = Wav2Vec2Processor.from_pretrained(path)
23
+ g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
24
+ mispronounciation_detector = MispronounciationDetector(model, processor, g2p, "cpu")
25
+ print("Done loading models and processors")
26
+
27
+ # Predict
28
+ raw_info = mispronounciation_detector.detect(audio, text)
29
+ aligned_phoneme_output_delimited_by_words = " ".join(raw_info['ref']) + "\n" + " ".join(raw_info['hyp']) + "\n" +\
30
+ " ".join(raw_info['phoneme_errors'])
31
+ print(f"PER: {raw_info['per']}\n")
32
+ print(f"Phoneme level errors:\n{raw_info['phoneme_output']}\n")
wav2vecasr/model/checkpoint-1200/config.json ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/content/drive/MyDrive/NUS/Y4S1/Sound and Music Computing/CS4347 Project/Experiments/Wav2Vec Baselines/L2 Artic 3 Speakers: Baseline 2/wav2vec-baseline2-model-checkpoints/checkpoint-600",
3
+ "activation_dropout": 0.0,
4
+ "adapter_kernel_size": 3,
5
+ "adapter_stride": 2,
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ "Wav2Vec2ForCTC"
10
+ ],
11
+ "attention_dropout": 0.1,
12
+ "bos_token_id": 1,
13
+ "classifier_proj_size": 256,
14
+ "codevector_dim": 768,
15
+ "contrastive_logits_temperature": 0.1,
16
+ "conv_bias": true,
17
+ "conv_dim": [
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512
25
+ ],
26
+ "conv_kernel": [
27
+ 10,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 2,
33
+ 2
34
+ ],
35
+ "conv_stride": [
36
+ 5,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2
43
+ ],
44
+ "ctc_loss_reduction": "mean",
45
+ "ctc_zero_infinity": false,
46
+ "diversity_loss_weight": 0.1,
47
+ "do_stable_layer_norm": true,
48
+ "eos_token_id": 2,
49
+ "feat_extract_activation": "gelu",
50
+ "feat_extract_dropout": 0.0,
51
+ "feat_extract_norm": "layer",
52
+ "feat_proj_dropout": 0.0,
53
+ "feat_quantizer_dropout": 0.0,
54
+ "final_dropout": 0.0,
55
+ "gradient_checkpointing": false,
56
+ "hidden_act": "gelu",
57
+ "hidden_dropout": 0.1,
58
+ "hidden_size": 1024,
59
+ "initializer_range": 0.02,
60
+ "intermediate_size": 4096,
61
+ "layer_norm_eps": 1e-05,
62
+ "layerdrop": 0.0,
63
+ "mask_feature_length": 64,
64
+ "mask_feature_min_masks": 0,
65
+ "mask_feature_prob": 0.25,
66
+ "mask_time_length": 10,
67
+ "mask_time_min_masks": 2,
68
+ "mask_time_prob": 0.75,
69
+ "model_type": "wav2vec2",
70
+ "num_adapter_layers": 3,
71
+ "num_attention_heads": 16,
72
+ "num_codevector_groups": 2,
73
+ "num_codevectors_per_group": 320,
74
+ "num_conv_pos_embedding_groups": 16,
75
+ "num_conv_pos_embeddings": 128,
76
+ "num_feat_extract_layers": 7,
77
+ "num_hidden_layers": 24,
78
+ "num_negatives": 100,
79
+ "output_hidden_size": 1024,
80
+ "pad_token_id": 82,
81
+ "proj_codevector_dim": 768,
82
+ "tdnn_dilation": [
83
+ 1,
84
+ 2,
85
+ 3,
86
+ 1,
87
+ 1
88
+ ],
89
+ "tdnn_dim": [
90
+ 512,
91
+ 512,
92
+ 512,
93
+ 512,
94
+ 1500
95
+ ],
96
+ "tdnn_kernel": [
97
+ 5,
98
+ 3,
99
+ 3,
100
+ 1,
101
+ 1
102
+ ],
103
+ "torch_dtype": "float32",
104
+ "transformers_version": "4.17.0",
105
+ "use_weighted_layer_sum": false,
106
+ "vocab_size": 83,
107
+ "xvector_output_dim": 512
108
+ }
wav2vecasr/model/checkpoint-1200/preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "return_attention_mask": false,
8
+ "sampling_rate": 16000
9
+ }
wav2vecasr/model/checkpoint-1200/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44bd3813b64d85faa8f88091f160cc107e340ee71372b470dd6c4b09cb00906d
3
+ size 1262269741
wav2vecasr/model/checkpoint-1200/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6686b1782523e688cd46835b2db33ae51a6ffd852401967b311db1a20efad2ee
3
+ size 14639
wav2vecasr/model/checkpoint-1200/scaler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:662b2a6102fe369b78bf169eb2bcea08b4dc636d31dfd2652b32a63eda7e03e8
3
+ size 557
wav2vecasr/model/checkpoint-1200/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dcbd8882ac5e67f1b9d59f4eaa2583483d429dfecec5ce45fd99da4d06e47847
3
+ size 627
wav2vecasr/model/checkpoint-1200/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]"}
wav2vecasr/model/checkpoint-1200/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "[PAD]", "do_lower_case": false, "word_delimiter_token": "|", "replace_word_delimiter_char": " ", "tokenizer_class": "Wav2Vec2CTCTokenizer"}
wav2vecasr/model/checkpoint-1200/trainer_state.json ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 17.541647491946616,
3
+ "best_model_checkpoint": "/content/drive/MyDrive/NUS/Y4S1/Sound and Music Computing/CS4347 Project/Experiments/Wav2Vec Baselines/L2 Artic 3 Speakers: Baseline 2/wav2vec-baseline2-model-checkpoints/checkpoint-200",
4
+ "epoch": 4.411764705882353,
5
+ "global_step": 1200,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 0.74,
12
+ "learning_rate": 0.0001194,
13
+ "loss": 4.1351,
14
+ "step": 200
15
+ },
16
+ {
17
+ "epoch": 0.74,
18
+ "eval_loss": 1.9543204307556152,
19
+ "eval_per": 17.541647491946616,
20
+ "eval_runtime": 143.4439,
21
+ "eval_samples_per_second": 15.149,
22
+ "eval_steps_per_second": 1.896,
23
+ "step": 200
24
+ },
25
+ {
26
+ "epoch": 1.47,
27
+ "learning_rate": 0.0002394,
28
+ "loss": 1.7915,
29
+ "step": 400
30
+ },
31
+ {
32
+ "epoch": 1.47,
33
+ "eval_loss": 1.692239761352539,
34
+ "eval_per": 7.960756135308424,
35
+ "eval_runtime": 137.113,
36
+ "eval_samples_per_second": 15.848,
37
+ "eval_steps_per_second": 1.984,
38
+ "step": 400
39
+ },
40
+ {
41
+ "epoch": 2.21,
42
+ "learning_rate": 0.0002971387283236994,
43
+ "loss": 1.2246,
44
+ "step": 600
45
+ },
46
+ {
47
+ "epoch": 2.21,
48
+ "eval_loss": 0.5273078083992004,
49
+ "eval_per": 0.31805393535991217,
50
+ "eval_runtime": 136.1021,
51
+ "eval_samples_per_second": 15.966,
52
+ "eval_steps_per_second": 1.999,
53
+ "step": 600
54
+ },
55
+ {
56
+ "epoch": 2.94,
57
+ "learning_rate": 0.0002913872832369942,
58
+ "loss": 0.9433,
59
+ "step": 800
60
+ },
61
+ {
62
+ "epoch": 2.94,
63
+ "eval_loss": 0.41386935114860535,
64
+ "eval_per": 0.2565853269749339,
65
+ "eval_runtime": 136.2091,
66
+ "eval_samples_per_second": 15.953,
67
+ "eval_steps_per_second": 1.997,
68
+ "step": 800
69
+ },
70
+ {
71
+ "epoch": 3.68,
72
+ "learning_rate": 0.000285606936416185,
73
+ "loss": 0.8842,
74
+ "step": 1000
75
+ },
76
+ {
77
+ "epoch": 3.68,
78
+ "eval_loss": 0.3962230980396271,
79
+ "eval_per": 0.24980343554684897,
80
+ "eval_runtime": 139.9847,
81
+ "eval_samples_per_second": 15.523,
82
+ "eval_steps_per_second": 1.943,
83
+ "step": 1000
84
+ },
85
+ {
86
+ "epoch": 4.41,
87
+ "learning_rate": 0.00027982658959537567,
88
+ "loss": 0.8542,
89
+ "step": 1200
90
+ },
91
+ {
92
+ "epoch": 4.41,
93
+ "eval_loss": 0.3784765601158142,
94
+ "eval_per": 0.24003603985584057,
95
+ "eval_runtime": 136.6045,
96
+ "eval_samples_per_second": 15.907,
97
+ "eval_steps_per_second": 1.991,
98
+ "step": 1200
99
+ }
100
+ ],
101
+ "max_steps": 10880,
102
+ "num_train_epochs": 40,
103
+ "total_flos": 4.4309288969819863e+18,
104
+ "trial_name": null,
105
+ "trial_params": null
106
+ }
wav2vecasr/model/checkpoint-1200/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d30a9efd4c6d95a24828b0e52d71ed7ad4f3c83075c158d1488ebe5f50b6719
3
+ size 3323
wav2vecasr/model/checkpoint-1200/vocab.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"K*": 0, "Z*": 1, "AE*": 2, "B": 3, "UH*": 4, "W": 5, "SIL": 6, "CH": 7, "EH*": 8, "T": 9, "D_": 10, "W*": 11, "K": 12, "spn": 13, "AH": 14, "AH*": 15, "NG": 16, "P*": 17, "B*": 18, "G": 19, "OY": 20, "D": 21, "ZH": 22, "sp": 23, "V": 24, "EY": 25, "V``": 26, "UW": 27, "s": 28, "P": 29, "UW*": 30, "ER*": 31, "sil": 32, "R*": 33, "IH": 34, "OW": 35, "HH*": 36, "Y": 37, "AO": 38, "AW*": 39, "ER": 40, "OW*": 41, "AY": 42, "M": 43, "T*": 44, "DH": 45, "AA*": 46, "L": 47, "AX": 48, "N*": 49, "EH": 50, "DH*": 51, "t": 52, "ERR": 53, "AO*": 54, "Z": 55, "S": 56, "ZH*": 57, "EY*": 58, "JH*": 59, "F": 60, "L*": 61, "Y*": 62, "R": 63, "G*": 64, "JH": 65, "W`": 66, "D*": 67, "AA": 68, "IY": 69, "AE": 70, "Ah": 71, "AW": 72, "SH": 73, "TH": 74, "N": 75, "V*": 76, "HH": 77, "UH": 78, "err": 79, "|": 80, "[UNK]": 81, "[PAD]": 82}