from os.path import join from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer import torch import pandas as pd import librosa import gradio as gr from gradio.components import Audio, Dropdown, Textbox Attributes = {'Dental':2, 'Labial':4, 'Consonant':15, 'Vowel':19, 'Fricative':21, 'Nasal':22, 'Stop':23, 'Affricate':25, 'Voiced':31, 'Bilabial':32, } #define groups #make sure that all phonemes covered in each group g1 = ['p_alveolar','n_alveolar'] g2 = ['p_palatal','n_palatal'] g3 = ['p_dental','n_dental'] g4 = ['p_glottal','n_glottal'] g5 = ['p_labial','n_labial'] g6 = ['p_velar','n_velar'] g7 = ['p_anterior','n_anterior'] g8 = ['p_posterior','n_posterior'] g9 = ['p_retroflex','n_retroflex'] g10 = ['p_mid','n_mid'] g11 = ['p_high_v','n_high_v'] g12 = ['p_low','n_low'] g13 = ['p_front','n_front'] g14 = ['p_back','n_back'] g15 = ['p_central','n_central'] g16 = ['p_consonant','n_consonant'] g17 = ['p_sonorant','n_sonorant'] g18 = ['p_long','n_long'] g19 = ['p_short','n_short'] g20 = ['p_vowel','n_vowel'] g21 = ['p_semivowel','n_semivowel'] g22 = ['p_fricative','n_fricative'] g23 = ['p_nasal','n_nasal'] g24 = ['p_stop','n_stop'] g25 = ['p_approximant','n_approximant'] g26 = ['p_affricate','n_affricate'] g27 = ['p_liquid','n_liquid'] g28 = ['p_continuant','n_continuant'] g29 = ['p_monophthong','n_monophthong'] g30 = ['p_diphthong','n_diphthong'] g31 = ['p_round','n_round'] g32 = ['p_voiced','n_voiced'] g33 = ['p_bilabial','n_bilabial'] g34 = ['p_coronal','n_coronal'] g35 = ['p_dorsal','n_dorsal'] groups = [g1,g2,g3,g4,g5,g6,g7,g8,g9,g10,g11,g12,g13,g14,g15,g16,g17,g18,g19,g20,g21,g22,g23,g24,g25,g26,g27,g28,g29,g30,g31,g32,g33,g34,g35] model_dir = 'model/' processor = Wav2Vec2Processor.from_pretrained(model_dir) model = Wav2Vec2ForCTC.from_pretrained(model_dir) tokenizer_phoneme = Wav2Vec2CTCTokenizer(join(model_dir,"phoneme_vocab.json"), pad_token="", word_delimiter_token="") phoneme_list = list(tokenizer_phoneme.get_vocab().keys()) p_att = pd.read_csv(join(model_dir,"phonological_attributes_v12.csv"),index_col=0) mappers = [] for g in groups: p2att = {} for att in g: att_phs = p_att[p_att[att]==1].index for ph in att_phs: p2att[ph] = att mappers.append(p2att) p2att = torch.zeros((tokenizer_phoneme.vocab_size, processor.tokenizer.vocab_size)).type(torch.FloatTensor) for p in phoneme_list: for mapper in mappers: if p == processor.tokenizer.pad_token: p2att[tokenizer_phoneme.convert_tokens_to_ids(p),processor.tokenizer.pad_token_id] = 1 else: p2att[tokenizer_phoneme.convert_tokens_to_ids(p), processor.tokenizer.convert_tokens_to_ids(mapper[p])] = 1 group_ids = [sorted(processor.tokenizer.convert_tokens_to_ids(group)) for group in groups] group_ids = [dict([(x[0]+1,x[1]) for x in list(enumerate(g))]) for g in group_ids] #This is the inversion of the one used in training as here we need to map prediction back to original tokens def masked_log_softmax(vector: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: if mask is not None: mask = mask.float() while mask.dim() < vector.dim(): mask = mask.unsqueeze(1) # vector + mask.log() is an easy way to zero out masked elements in logspace, but it # results in nans when the whole vector is masked. We need a very small value instead of a # zero in the mask for these cases. log(1 + 1e-45) is still basically 0, so we can safely # just add 1e-45 before calling mask.log(). We use 1e-45 because 1e-46 is so small it # becomes 0 - this is just the smallest value we can actually use. vector = vector + (mask + 1e-45).log() return torch.nn.functional.log_softmax(vector, dim=dim) def getPhonemes(logits): ngroups = len(group_ids) log_props_all_masked = [] for i in range(ngroups): mask = torch.zeros(logits.size()[2], dtype = torch.bool) mask[0] = True mask[list(group_ids[i].values())] = True mask.unsqueeze_(0).unsqueeze_(0) log_probs = masked_log_softmax(vector=logits, mask=mask, dim=-1).masked_fill(~mask,0) log_props_all_masked.append(log_probs) log_probs_cat = torch.stack(log_props_all_masked, dim=0).sum(dim=0) log_probs_phoneme = torch.matmul(p2att,log_probs_cat.transpose(1,2)).transpose(1,2).type(torch.FloatTensor) pred_ids = torch.argmax(log_probs_phoneme,dim=-1) pred = tokenizer_phoneme.batch_decode(pred_ids,spaces_between_special_tokens=True)[0] return pred def getAtt(logits,i): mask = torch.zeros(logits.size()[2], dtype = torch.bool) mask[0] = True mask[list(group_ids[i].values())] = True logits_g = logits[:,:,mask] pred_ids = torch.argmax(logits_g,dim=-1) pred_ids = pred_ids.cpu().apply_(lambda x: group_ids[i].get(x,x)) pred = processor.batch_decode(pred_ids,spaces_between_special_tokens=True)[0] return pred.replace('p_','+').replace('n_','-') def recognizeAudio(audio, mic_audioFilePath, att): i = Attributes[att] audio = mic_audioFilePath if mic_audioFilePath else audio y, sr = librosa.load(audio, sr=16000) input_values = processor(audio=y, sampling_rate=sr, return_tensors="pt").input_values with torch.no_grad(): logits = model(input_values).logits return getPhonemes(logits), getAtt(logits,i) gui = gr.Interface(fn=recognizeAudio, inputs=[Audio(label="Upload Audio File", type="filepath"),Audio(source="microphone", type="filepath", label="Record from microphone"), Dropdown(choices=Attributes.keys(),type="value",label="Select Attribute")], outputs=[Textbox(label="ARPA Phoneme"),Textbox(label="Attribute (+/-)")]) gui.launch()