minjibi commited on
Commit
0945e74
·
1 Parent(s): 3172a4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -51
app.py CHANGED
@@ -1,58 +1,51 @@
1
- #Importing all the necessary packages
2
- import nltk
3
- import librosa
4
- import torch
5
  import gradio as gr
6
- from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
7
- nltk.download("punkt")
 
 
 
 
8
 
9
- #Loading the pre-trained model and the tokenizer
10
- model_name = "shizukanabasho/north2"
11
- tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
12
- model = Wav2Vec2ForCTC.from_pretrained(model_name)
 
 
 
13
 
14
- def load_data(input_file):
15
- #reading the file
16
- speech, sample_rate = librosa.load(input_file)
17
- #make it 1-D
18
- if len(speech.shape) > 1:
19
- speech = speech[:,0] + speech[:,1]
20
- #Resampling the audio at 16KHz
21
- if sample_rate !=16000:
22
- speech = librosa.resample(speech, sample_rate,16000)
23
- return speech
24
-
25
- def correct_casing(input_sentence):
26
 
27
- sentences = nltk.sent_tokenize(input_sentence)
28
- return (''.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
29
-
30
- def asr_transcript(input_file):
31
 
32
- speech = load_data(input_file)
33
- #Tokenize
34
- input_values = tokenizer(speech, return_tensors="pt").input_values
35
- #Take logits
36
- logits = model(input_values).logits
37
- #Take argmax
38
- predicted_ids = torch.argmax(logits, dim=-1)
39
- #Get the words from predicted word ids
40
- transcription = tokenizer.decode(predicted_ids[0])
41
- #Correcting the letter casing
42
- # transcription = correct_casing(transcription.lower())
43
- return transcription
44
 
45
- gr.Interface(asr_transcript,
46
- inputs = gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Upload"),
47
- outputs = gr.outputs.Textbox(label="Output Text"),
48
- title="ASR using Wav2Vec2.0",
49
- description = "This application displays transcribed text for given audio input",
50
- theme="grass").launch()
 
 
 
51
 
52
- # gr.Interface(asr_transcript,
53
- # inputs = [gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Speaker"),
54
- # gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Speaker")],
55
- # outputs = gr.outputs.Textbox(label="Output Text"),
56
- # title="ASR using Wav2Vec2.0",
57
- # description = "This application displays transcribed text for given audio input",
58
- # theme="grass").launch()
 
 
 
 
 
1
  import gradio as gr
2
+ import pytorch_lightning as pl
3
+ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
4
+ from transformers import (
5
+ MT5ForConditionalGeneration,
6
+ MT5TokenizerFast,
7
+ )
8
 
9
+ model = MT5ForConditionalGeneration.from_pretrained(
10
+ "minjibi/qa",
11
+ return_dict=True,
12
+ )
13
+ tokenizer = MT5TokenizerFast.from_pretrained(
14
+ "minjibi/qa"
15
+ )
16
 
17
+ model.cuda()
18
+ def predict(text):
19
+ with torch.no_grad():
20
+ input_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
 
 
 
 
 
 
 
 
21
 
22
+ input_ids = input_ids.cuda()
 
 
 
23
 
24
+ generated_ids = model.generate(
25
+ input_ids=input_ids,
26
+ num_beams=5,
27
+ max_length=1000,
28
+ repetition_penalty=3.0, #default = 2.5
29
+ length_penalty=1.0,
30
+ early_stopping=True,
31
+ top_p=50, #default 50
32
+ top_k=20, #default 20
33
+ num_return_sequences=3,
34
+ )
 
35
 
36
+ preds = [
37
+ tokenizer.decode(
38
+ g,
39
+ skip_special_tokens=True,
40
+ clean_up_tokenization_spaces=True,
41
+ )
42
+ for g in generated_ids
43
+ ]
44
+ return ['Q: ' + text for text in preds]
45
 
46
+ # text_to_predict = predict(text)
47
+ # predicted = ['Q: ' + text for text in predict(text_to_predict)]
48
+ # predicted
49
+
50
+ iface = gr.Interface(fn=predict, inputs="text", outputs="text")
51
+ iface.launch()