mostafaashahin commited on
Commit
e7ae2d2
·
1 Parent(s): 1a25737

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -0
app.py CHANGED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import join
2
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer
3
+ import torch
4
+ import pandas as pd
5
+ import librosa
6
+ import gradio as gr
7
+ from gradio.components import Audio, Dropdown, Textbox
8
+
9
+
10
+ Attributes = {'Dental':2,
11
+ 'Labial':4,
12
+ 'Consonant':15,
13
+ 'Vowel':19,
14
+ 'Fricative':21,
15
+ 'Nasal':22,
16
+ 'Stop':23,
17
+ 'Affricate':25,
18
+ 'Voiced':31,
19
+ 'Bilabial':32,
20
+ }
21
+
22
+ #define groups
23
+ #make sure that all phonemes covered in each group
24
+ g1 = ['p_alveolar','n_alveolar']
25
+ g2 = ['p_palatal','n_palatal']
26
+ g3 = ['p_dental','n_dental']
27
+ g4 = ['p_glottal','n_glottal']
28
+ g5 = ['p_labial','n_labial']
29
+ g6 = ['p_velar','n_velar']
30
+ g7 = ['p_anterior','n_anterior']
31
+ g8 = ['p_posterior','n_posterior']
32
+ g9 = ['p_retroflex','n_retroflex']
33
+ g10 = ['p_mid','n_mid']
34
+ g11 = ['p_high_v','n_high_v']
35
+ g12 = ['p_low','n_low']
36
+ g13 = ['p_front','n_front']
37
+ g14 = ['p_back','n_back']
38
+ g15 = ['p_central','n_central']
39
+ g16 = ['p_consonant','n_consonant']
40
+ g17 = ['p_sonorant','n_sonorant']
41
+ g18 = ['p_long','n_long']
42
+ g19 = ['p_short','n_short']
43
+ g20 = ['p_vowel','n_vowel']
44
+ g21 = ['p_semivowel','n_semivowel']
45
+ g22 = ['p_fricative','n_fricative']
46
+ g23 = ['p_nasal','n_nasal']
47
+ g24 = ['p_stop','n_stop']
48
+ g25 = ['p_approximant','n_approximant']
49
+ g26 = ['p_affricate','n_affricate']
50
+ g27 = ['p_liquid','n_liquid']
51
+ g28 = ['p_continuant','n_continuant']
52
+ g29 = ['p_monophthong','n_monophthong']
53
+ g30 = ['p_diphthong','n_diphthong']
54
+ g31 = ['p_round','n_round']
55
+ g32 = ['p_voiced','n_voiced']
56
+ g33 = ['p_bilabial','n_bilabial']
57
+ g34 = ['p_coronal','n_coronal']
58
+ g35 = ['p_dorsal','n_dorsal']
59
+ groups = [g1,g2,g3,g4,g5,g6,g7,g8,g9,g10,g11,g12,g13,g14,g15,g16,g17,g18,g19,g20,g21,g22,g23,g24,g25,g26,g27,g28,g29,g30,g31,g32,g33,g34,g35]
60
+
61
+
62
+ model_dir = 'model/'
63
+ processor = Wav2Vec2Processor.from_pretrained(model_dir)
64
+ model = Wav2Vec2ForCTC.from_pretrained(model_dir)
65
+ tokenizer_phoneme = Wav2Vec2CTCTokenizer(join(model_dir,"phoneme_vocab.json"), pad_token="<pad>", word_delimiter_token="")
66
+
67
+ phoneme_list = list(tokenizer_phoneme.get_vocab().keys())
68
+ p_att = pd.read_csv(join(model_dir,"phonological_attributes_v12.csv"),index_col=0)
69
+
70
+ mappers = []
71
+ for g in groups:
72
+ p2att = {}
73
+ for att in g:
74
+ att_phs = p_att[p_att[att]==1].index
75
+ for ph in att_phs:
76
+ p2att[ph] = att
77
+ mappers.append(p2att)
78
+
79
+ p2att = torch.zeros((tokenizer_phoneme.vocab_size, processor.tokenizer.vocab_size)).type(torch.FloatTensor)
80
+
81
+ for p in phoneme_list:
82
+ for mapper in mappers:
83
+ if p == processor.tokenizer.pad_token:
84
+ p2att[tokenizer_phoneme.convert_tokens_to_ids(p),processor.tokenizer.pad_token_id] = 1
85
+ else:
86
+ p2att[tokenizer_phoneme.convert_tokens_to_ids(p), processor.tokenizer.convert_tokens_to_ids(mapper[p])] = 1
87
+
88
+
89
+ group_ids = [sorted(processor.tokenizer.convert_tokens_to_ids(group)) for group in groups]
90
+ group_ids = [dict([(x[0]+1,x[1]) for x in list(enumerate(g))]) for g in group_ids] #This is the inversion of the one used in training as here we need to map prediction back to original tokens
91
+ def masked_log_softmax(vector: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
92
+ if mask is not None:
93
+ mask = mask.float()
94
+ while mask.dim() < vector.dim():
95
+ mask = mask.unsqueeze(1)
96
+ # vector + mask.log() is an easy way to zero out masked elements in logspace, but it
97
+ # results in nans when the whole vector is masked. We need a very small value instead of a
98
+ # zero in the mask for these cases. log(1 + 1e-45) is still basically 0, so we can safely
99
+ # just add 1e-45 before calling mask.log(). We use 1e-45 because 1e-46 is so small it
100
+ # becomes 0 - this is just the smallest value we can actually use.
101
+ vector = vector + (mask + 1e-45).log()
102
+ return torch.nn.functional.log_softmax(vector, dim=dim)
103
+
104
+ def getPhonemes(logits):
105
+ ngroups = len(group_ids)
106
+ log_props_all_masked = []
107
+ for i in range(ngroups):
108
+ mask = torch.zeros(logits.size()[2], dtype = torch.bool)
109
+ mask[0] = True
110
+ mask[list(group_ids[i].values())] = True
111
+ mask.unsqueeze_(0).unsqueeze_(0)
112
+ log_probs = masked_log_softmax(vector=logits, mask=mask, dim=-1).masked_fill(~mask,0)
113
+ log_props_all_masked.append(log_probs)
114
+ log_probs_cat = torch.stack(log_props_all_masked, dim=0).sum(dim=0)
115
+ log_probs_phoneme = torch.matmul(p2att,log_probs_cat.transpose(1,2)).transpose(1,2).type(torch.FloatTensor)
116
+ pred_ids = torch.argmax(log_probs_phoneme,dim=-1)
117
+ pred = tokenizer_phoneme.batch_decode(pred_ids,spaces_between_special_tokens=True)[0]
118
+ return pred
119
+
120
+ def getAtt(logits,i):
121
+ mask = torch.zeros(logits.size()[2], dtype = torch.bool)
122
+ mask[0] = True
123
+ mask[list(group_ids[i].values())] = True
124
+ logits_g = logits[:,:,mask]
125
+ pred_ids = torch.argmax(logits_g,dim=-1)
126
+ pred_ids = pred_ids.cpu().apply_(lambda x: group_ids[i].get(x,x))
127
+ pred = processor.batch_decode(pred_ids,spaces_between_special_tokens=True)[0]
128
+ return pred.replace('p_','+').replace('n_','-')
129
+
130
+ def recognizeAudio(audio, mic_audioFilePath, att):
131
+ i = Attributes[att]
132
+ audio = mic_audioFilePath if mic_audioFilePath else audio
133
+ y, sr = librosa.load(audio, sr=16000)
134
+ input_values = processor(audio=y, sampling_rate=sr, return_tensors="pt").input_values
135
+ with torch.no_grad():
136
+ logits = model(input_values).logits
137
+ return getPhonemes(logits), getAtt(logits,i)
138
+
139
+
140
+ gui = gr.Interface(fn=recognizeAudio, inputs=[Audio(label="Upload Audio File", type="filepath"),Audio(source="microphone", type="filepath", label="Record from microphone"),
141
+ Dropdown(choices=Attributes.keys(),type="value",label="Select Attribute")],
142
+ outputs=[Textbox(label="ARPA Phoneme"),Textbox(label="Attribute (+/-)")])
143
+
144
+ gui.launch()