Spaces:
Sleeping
Sleeping
mostafaashahin
commited on
Commit
·
e7ae2d2
1
Parent(s):
1a25737
Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os.path import join
|
2 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer
|
3 |
+
import torch
|
4 |
+
import pandas as pd
|
5 |
+
import librosa
|
6 |
+
import gradio as gr
|
7 |
+
from gradio.components import Audio, Dropdown, Textbox
|
8 |
+
|
9 |
+
|
10 |
+
Attributes = {'Dental':2,
|
11 |
+
'Labial':4,
|
12 |
+
'Consonant':15,
|
13 |
+
'Vowel':19,
|
14 |
+
'Fricative':21,
|
15 |
+
'Nasal':22,
|
16 |
+
'Stop':23,
|
17 |
+
'Affricate':25,
|
18 |
+
'Voiced':31,
|
19 |
+
'Bilabial':32,
|
20 |
+
}
|
21 |
+
|
22 |
+
#define groups
|
23 |
+
#make sure that all phonemes covered in each group
|
24 |
+
g1 = ['p_alveolar','n_alveolar']
|
25 |
+
g2 = ['p_palatal','n_palatal']
|
26 |
+
g3 = ['p_dental','n_dental']
|
27 |
+
g4 = ['p_glottal','n_glottal']
|
28 |
+
g5 = ['p_labial','n_labial']
|
29 |
+
g6 = ['p_velar','n_velar']
|
30 |
+
g7 = ['p_anterior','n_anterior']
|
31 |
+
g8 = ['p_posterior','n_posterior']
|
32 |
+
g9 = ['p_retroflex','n_retroflex']
|
33 |
+
g10 = ['p_mid','n_mid']
|
34 |
+
g11 = ['p_high_v','n_high_v']
|
35 |
+
g12 = ['p_low','n_low']
|
36 |
+
g13 = ['p_front','n_front']
|
37 |
+
g14 = ['p_back','n_back']
|
38 |
+
g15 = ['p_central','n_central']
|
39 |
+
g16 = ['p_consonant','n_consonant']
|
40 |
+
g17 = ['p_sonorant','n_sonorant']
|
41 |
+
g18 = ['p_long','n_long']
|
42 |
+
g19 = ['p_short','n_short']
|
43 |
+
g20 = ['p_vowel','n_vowel']
|
44 |
+
g21 = ['p_semivowel','n_semivowel']
|
45 |
+
g22 = ['p_fricative','n_fricative']
|
46 |
+
g23 = ['p_nasal','n_nasal']
|
47 |
+
g24 = ['p_stop','n_stop']
|
48 |
+
g25 = ['p_approximant','n_approximant']
|
49 |
+
g26 = ['p_affricate','n_affricate']
|
50 |
+
g27 = ['p_liquid','n_liquid']
|
51 |
+
g28 = ['p_continuant','n_continuant']
|
52 |
+
g29 = ['p_monophthong','n_monophthong']
|
53 |
+
g30 = ['p_diphthong','n_diphthong']
|
54 |
+
g31 = ['p_round','n_round']
|
55 |
+
g32 = ['p_voiced','n_voiced']
|
56 |
+
g33 = ['p_bilabial','n_bilabial']
|
57 |
+
g34 = ['p_coronal','n_coronal']
|
58 |
+
g35 = ['p_dorsal','n_dorsal']
|
59 |
+
groups = [g1,g2,g3,g4,g5,g6,g7,g8,g9,g10,g11,g12,g13,g14,g15,g16,g17,g18,g19,g20,g21,g22,g23,g24,g25,g26,g27,g28,g29,g30,g31,g32,g33,g34,g35]
|
60 |
+
|
61 |
+
|
62 |
+
model_dir = 'model/'
|
63 |
+
processor = Wav2Vec2Processor.from_pretrained(model_dir)
|
64 |
+
model = Wav2Vec2ForCTC.from_pretrained(model_dir)
|
65 |
+
tokenizer_phoneme = Wav2Vec2CTCTokenizer(join(model_dir,"phoneme_vocab.json"), pad_token="<pad>", word_delimiter_token="")
|
66 |
+
|
67 |
+
phoneme_list = list(tokenizer_phoneme.get_vocab().keys())
|
68 |
+
p_att = pd.read_csv(join(model_dir,"phonological_attributes_v12.csv"),index_col=0)
|
69 |
+
|
70 |
+
mappers = []
|
71 |
+
for g in groups:
|
72 |
+
p2att = {}
|
73 |
+
for att in g:
|
74 |
+
att_phs = p_att[p_att[att]==1].index
|
75 |
+
for ph in att_phs:
|
76 |
+
p2att[ph] = att
|
77 |
+
mappers.append(p2att)
|
78 |
+
|
79 |
+
p2att = torch.zeros((tokenizer_phoneme.vocab_size, processor.tokenizer.vocab_size)).type(torch.FloatTensor)
|
80 |
+
|
81 |
+
for p in phoneme_list:
|
82 |
+
for mapper in mappers:
|
83 |
+
if p == processor.tokenizer.pad_token:
|
84 |
+
p2att[tokenizer_phoneme.convert_tokens_to_ids(p),processor.tokenizer.pad_token_id] = 1
|
85 |
+
else:
|
86 |
+
p2att[tokenizer_phoneme.convert_tokens_to_ids(p), processor.tokenizer.convert_tokens_to_ids(mapper[p])] = 1
|
87 |
+
|
88 |
+
|
89 |
+
group_ids = [sorted(processor.tokenizer.convert_tokens_to_ids(group)) for group in groups]
|
90 |
+
group_ids = [dict([(x[0]+1,x[1]) for x in list(enumerate(g))]) for g in group_ids] #This is the inversion of the one used in training as here we need to map prediction back to original tokens
|
91 |
+
def masked_log_softmax(vector: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
92 |
+
if mask is not None:
|
93 |
+
mask = mask.float()
|
94 |
+
while mask.dim() < vector.dim():
|
95 |
+
mask = mask.unsqueeze(1)
|
96 |
+
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it
|
97 |
+
# results in nans when the whole vector is masked. We need a very small value instead of a
|
98 |
+
# zero in the mask for these cases. log(1 + 1e-45) is still basically 0, so we can safely
|
99 |
+
# just add 1e-45 before calling mask.log(). We use 1e-45 because 1e-46 is so small it
|
100 |
+
# becomes 0 - this is just the smallest value we can actually use.
|
101 |
+
vector = vector + (mask + 1e-45).log()
|
102 |
+
return torch.nn.functional.log_softmax(vector, dim=dim)
|
103 |
+
|
104 |
+
def getPhonemes(logits):
|
105 |
+
ngroups = len(group_ids)
|
106 |
+
log_props_all_masked = []
|
107 |
+
for i in range(ngroups):
|
108 |
+
mask = torch.zeros(logits.size()[2], dtype = torch.bool)
|
109 |
+
mask[0] = True
|
110 |
+
mask[list(group_ids[i].values())] = True
|
111 |
+
mask.unsqueeze_(0).unsqueeze_(0)
|
112 |
+
log_probs = masked_log_softmax(vector=logits, mask=mask, dim=-1).masked_fill(~mask,0)
|
113 |
+
log_props_all_masked.append(log_probs)
|
114 |
+
log_probs_cat = torch.stack(log_props_all_masked, dim=0).sum(dim=0)
|
115 |
+
log_probs_phoneme = torch.matmul(p2att,log_probs_cat.transpose(1,2)).transpose(1,2).type(torch.FloatTensor)
|
116 |
+
pred_ids = torch.argmax(log_probs_phoneme,dim=-1)
|
117 |
+
pred = tokenizer_phoneme.batch_decode(pred_ids,spaces_between_special_tokens=True)[0]
|
118 |
+
return pred
|
119 |
+
|
120 |
+
def getAtt(logits,i):
|
121 |
+
mask = torch.zeros(logits.size()[2], dtype = torch.bool)
|
122 |
+
mask[0] = True
|
123 |
+
mask[list(group_ids[i].values())] = True
|
124 |
+
logits_g = logits[:,:,mask]
|
125 |
+
pred_ids = torch.argmax(logits_g,dim=-1)
|
126 |
+
pred_ids = pred_ids.cpu().apply_(lambda x: group_ids[i].get(x,x))
|
127 |
+
pred = processor.batch_decode(pred_ids,spaces_between_special_tokens=True)[0]
|
128 |
+
return pred.replace('p_','+').replace('n_','-')
|
129 |
+
|
130 |
+
def recognizeAudio(audio, mic_audioFilePath, att):
|
131 |
+
i = Attributes[att]
|
132 |
+
audio = mic_audioFilePath if mic_audioFilePath else audio
|
133 |
+
y, sr = librosa.load(audio, sr=16000)
|
134 |
+
input_values = processor(audio=y, sampling_rate=sr, return_tensors="pt").input_values
|
135 |
+
with torch.no_grad():
|
136 |
+
logits = model(input_values).logits
|
137 |
+
return getPhonemes(logits), getAtt(logits,i)
|
138 |
+
|
139 |
+
|
140 |
+
gui = gr.Interface(fn=recognizeAudio, inputs=[Audio(label="Upload Audio File", type="filepath"),Audio(source="microphone", type="filepath", label="Record from microphone"),
|
141 |
+
Dropdown(choices=Attributes.keys(),type="value",label="Select Attribute")],
|
142 |
+
outputs=[Textbox(label="ARPA Phoneme"),Textbox(label="Attribute (+/-)")])
|
143 |
+
|
144 |
+
gui.launch()
|