ahmedJaafari commited on
Commit
9bb0768
·
1 Parent(s): 61a05a4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import streamlit as st
3
+ import numpy as np
4
+ from transformers.file_utils import cached_path, hf_bucket_url
5
+ import os
6
+ from transformers import Wav2Vec2ProcessorWithLM, AutoModelForCTC
7
+ from datasets import load_dataset
8
+ import torch
9
+ import kenlm
10
+ import torchaudio
11
+
12
+ cache_dir = './cache/'
13
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained("ahmedJaafari/Annarabic3.2", cache_dir=cache_dir, use_auth_token=st.secrets["AnnarabicToken"])
14
+ model = AutoModelForCTC.from_pretrained("ahmedJaafari/Annarabic3.2", cache_dir=cache_dir, use_auth_token=st.secrets["AnnarabicToken"])
15
+
16
+ # define function to read in sound file
17
+ def speech_file_to_array_fn(path, max_seconds=10):
18
+ batch = {"file": path}
19
+ speech_array, sampling_rate = torchaudio.load(batch["file"])
20
+ if sampling_rate != 16000:
21
+ transform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
22
+ new_freq=16000)
23
+ speech_array = transform(speech_array)
24
+ speech_array = speech_array[0]
25
+ if max_seconds > 0:
26
+ speech_array = speech_array[:max_seconds*16000]
27
+ batch["speech"] = speech_array.numpy()
28
+ batch["sampling_rate"] = 16000
29
+ return batch
30
+
31
+ # tokenize
32
+ def inference(audio):
33
+ # read in sound file
34
+ # load dummy dataset and read soundfiles
35
+ ds = speech_file_to_array_fn(audio.name)
36
+ # infer model
37
+ input_values = processor(
38
+ ds["speech"],
39
+ sampling_rate=ds["sampling_rate"],
40
+ return_tensors="pt"
41
+ ).input_values
42
+ # decode ctc output
43
+ with torch.no_grad():
44
+ logits = model(input_values).logits
45
+
46
+ #pred_ids = torch.argmax(logits, dim=-1)
47
+ h = logits.numpy()[0,:,:]
48
+ v = np.pad(h, [0, 2], mode='constant')
49
+
50
+ output = processor.decode(v).text
51
+
52
+ return output[:-4]
53
+
54
+ inputs = gr.inputs.Audio(label="Input Audio", type="file")
55
+ outputs = gr.outputs.Textbox(label="Output Text")
56
+ title = "Annarabic Speech Recognition System"
57
+ description = "Gradio demo for Annarabic ASR. To use it, simply upload your audio, or click one of the examples to load them. Read more at the links below."
58
+ examples=[['Aya.mp3'], ['Loubna.mp3']]
59
+ gr.Interface(inference, inputs, outputs, title=title, description=description, examples=examples).launch()