File size: 5,589 Bytes
056c529
 
 
 
 
 
 
 
 
 
 
 
 
d320d92
056c529
 
 
 
6aa1052
056c529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3715573
056c529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3715573
056c529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d320d92
056c529
 
 
 
 
 
 
d320d92
056c529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d320d92
056c529
 
 
d320d92
056c529
 
 
 
 
 
 
 
d320d92
056c529
d320d92
056c529
 
 
71dc112
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import torchaudio
from datasets import load_dataset, load_metric
from transformers import (
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    AutoTokenizer, 
    AutoModelWithLMHead 
)
import torch
import re
import sys
import soundfile as sf




model_name = "voidful/wav2vec2-xlsr-multilingual-56"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor_name = "voidful/wav2vec2-xlsr-multilingual-56"

import pickle
with open("lang_ids.pk", 'rb') as output:
    lang_ids = pickle.load(output)
    
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
processor = Wav2Vec2Processor.from_pretrained(processor_name)

model.eval()

def load_file_to_data(file,sampling_rate=16_000):
    batch = {}
    speech, _ = torchaudio.load(file)
    if sampling_rate != '16_000' or sampling_rate != '16000':
        resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16_000)
        batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
        batch["sampling_rate"] = resampler.new_freq
    else:
        batch["speech"] = speech.squeeze(0).numpy()
        batch["sampling_rate"] = '16000'
    return batch


def predict(data):
    data=load_file_to_data(data,sampling_rate=16_000)
    features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
    input_values = features.input_values.to(device)
    attention_mask = features.attention_mask.to(device)
    with torch.no_grad():
        logits = model(input_values, attention_mask=attention_mask).logits
        decoded_results = []
        for logit in logits:
            pred_ids = torch.argmax(logit, dim=-1)
            mask = pred_ids.ge(1).unsqueeze(-1).expand(logit.size())
            vocab_size = logit.size()[-1]
            voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
            comb_pred_ids = torch.argmax(voice_prob, dim=-1)
            decoded_results.append(processor.decode(comb_pred_ids))

    return decoded_results

def predict_lang_specific(data,lang_code):
    data=load_file_to_data(data,sampling_rate=16_000)
    features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
    input_values = features.input_values.to(device)
    attention_mask = features.attention_mask.to(device)
    with torch.no_grad():
        logits = model(input_values, attention_mask=attention_mask).logits
        decoded_results = []
        for logit in logits:
            pred_ids = torch.argmax(logit, dim=-1)
            mask = ~pred_ids.eq(processor.tokenizer.pad_token_id).unsqueeze(-1).expand(logit.size())
            vocab_size = logit.size()[-1]
            voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
            filtered_input = pred_ids[pred_ids!=processor.tokenizer.pad_token_id].view(1,-1).to(device)
            if len(filtered_input[0]) == 0:
                decoded_results.append("")
            else:
                lang_mask = torch.empty(voice_prob.shape[-1]).fill_(0)
                lang_index = torch.tensor(sorted(lang_ids[lang_code]))
                lang_mask.index_fill_(0, lang_index, 1)
                lang_mask = lang_mask.to(device)
                comb_pred_ids = torch.argmax(lang_mask*voice_prob, dim=-1)
                decoded_results.append(processor.decode(comb_pred_ids))
                
    return decoded_results

'''def recognition(audio_file):
    print("audio_file", audio_file.name)
    speech, rate = sp.load_speech_with_file(audio_file.name)
    
    result = sp.predict_audio_file(speech)
    print(result)

    return result
'''
#predict(load_file_to_data('audio file path',sampling_rate=16_000)) # beware of the audio file sampling rate

#predict_lang_specific(load_file_to_data('audio file path',sampling_rate=16_000),'en') # beware of the audio file sampling rate
with gr.Blocks() as demo:
    gr.Markdown("multilingual Speech Recognition")
    with gr.Tab("Auto"):
        gr.Markdown("automatically detects your language")
        inputs_speech =gr.Audio(source="upload", type="filepath", optional=True)
        output_transcribe = gr.HTML(label="")
        transcribe_audio= gr.Button("Submit")
    with gr.Tab("manual"):
        gr.Markdown("set your speech language")
        inputs_speech1 =[ 
            gr.Audio(source="upload", type="filepath"),
            gr.Dropdown(choices=["ar","as","br","ca","cnh","cs","cv","cy","de","dv","el","en","eo","es","et","eu","fa","fi","fr","fy-NL","ga-IE","hi","hsb","hu","ia","id","it","ja","ka","ky","lg","lt","lv","mn","mt","nl","or","pa-IN","pl","pt","rm-sursilv","rm-vallader","ro","ru","sah","sl","sv-SE","ta","th","tr","tt","uk","vi","zh-CN","zh-HK","zh-TW"]
,value="fa",label="language code")
        ]
        output_transcribe1 = gr.Textbox(label="output")
        transcribe_audio1= gr.Button("Submit")
    '''with gr.Tab("Auto1"):
        gr.Markdown("automatically detects your language")
        inputs_speech2 = gr.Audio(label="Input Audio", type="file")
        output_transcribe2 = gr.Textbox()
        transcribe_audio2= gr.Button("Submit")'''
    transcribe_audio.click(fn=predict,
    inputs=inputs_speech,
    outputs=output_transcribe)
    
    transcribe_audio1.click(fn=predict_lang_specific,
    inputs=inputs_speech1 ,
    outputs=output_transcribe1 )
    
    '''transcribe_audio2.click(fn=recognition,
    inputs=inputs_speech2 ,
    outputs=output_transcribe2 )'''


if __name__ == "__main__":
    demo.launch()