LuisG07 commited on
Commit
7f2a1b8
1 Parent(s): 941714c

added torch no grad

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -32,7 +32,8 @@ def predict_and_ctc_lm_decode(input_file, model_name):
32
  speech = load_and_fix_data(input_file)
33
 
34
  input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
35
- logits = model(input_values).logits.cpu().detach().numpy()[0]
 
36
 
37
  pred = processor.decode(logits).text
38
 
@@ -45,7 +46,8 @@ def predict_and_greedy_decode(input_file, model_name):
45
  speech = load_and_fix_data(input_file)
46
 
47
  input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
48
- logits = model(input_values).logits
 
49
 
50
  predicted_ids = torch.argmax(logits, dim=-1)
51
  pred = processor.batch_decode(predicted_ids)
@@ -59,11 +61,11 @@ def return_all_predictions(input_file, model_name):
59
 
60
 
61
  gr.Interface(return_all_predictions,
62
- inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["jonatasgrosman/wav2vec2-large-xlsr-53-spanish", "jonatasgrosman/wav2vec2-xls-r-1b-spanish"], label="Model Name")],
63
  outputs = [gr.outputs.Textbox(label="Beam CTC decoding w/ LM"), gr.outputs.Textbox(label="Greedy decoding")],
64
  title="ASR using Wav2Vec2 & pyctcdecode in spanish",
65
  description = "Comparing greedy decoder with beam search CTC decoder, record/ drop your audio!",
66
  layout = "horizontal",
67
- examples = [["test1.wav", "jonatasgrosman/wav2vec2-large-xlsr-53-spanish"], ["test2.wav", "jonatasgrosman/wav2vec2-xls-r-1b-spanish"]],
68
  theme="huggingface",
69
  enable_queue=True).launch()
 
32
  speech = load_and_fix_data(input_file)
33
 
34
  input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
35
+ with torch.no_grad():
36
+ logits = model(input_values).logits.cpu().detach().numpy()[0]
37
 
38
  pred = processor.decode(logits).text
39
 
 
46
  speech = load_and_fix_data(input_file)
47
 
48
  input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
49
+ with torch.no_grad():
50
+ logits = model(input_values).logits
51
 
52
  predicted_ids = torch.argmax(logits, dim=-1)
53
  pred = processor.batch_decode(predicted_ids)
 
61
 
62
 
63
  gr.Interface(return_all_predictions,
64
+ inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["LuisG07/wav2vec2-large-xlsr-53-spanish", "jonatasgrosman/wav2vec2-xls-r-1b-spanish"], label="Model Name")],
65
  outputs = [gr.outputs.Textbox(label="Beam CTC decoding w/ LM"), gr.outputs.Textbox(label="Greedy decoding")],
66
  title="ASR using Wav2Vec2 & pyctcdecode in spanish",
67
  description = "Comparing greedy decoder with beam search CTC decoder, record/ drop your audio!",
68
  layout = "horizontal",
69
+ examples = [["test1.wav", "LuisG07/wav2vec2-large-xlsr-53-spanish"], ["test2.wav", "jonatasgrosman/wav2vec2-xls-r-1b-spanish"]],
70
  theme="huggingface",
71
  enable_queue=True).launch()