jcho02 commited on
Commit
f93d945
·
verified ·
1 Parent(s): 4485862

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -27
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # Import necessary libraries
2
  import gradio as gr
3
  import torch
4
  import torch.nn as nn
@@ -8,12 +7,13 @@ import datasets
8
  from datasets import load_dataset, DatasetDict, Audio
9
  from huggingface_hub import PyTorchModelHubMixin
10
  import numpy as np
11
- import tempfile
12
- import os
13
 
14
  # Ensure you have the device setup (cuda or cpu)
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
 
 
 
17
  # Define data class
18
  class SpeechInferenceDataset(Dataset):
19
  def __init__(self, audio_data, text_processor):
@@ -28,8 +28,7 @@ class SpeechInferenceDataset(Dataset):
28
  return_tensors="pt",
29
  sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
30
  input_features = inputs.input_features
31
- # Modify decoder_input_ids as per your model's requirements
32
- decoder_input_ids = torch.tensor([[1, 1]])
33
  return input_features, decoder_input_ids
34
 
35
  # Define model class
@@ -55,34 +54,40 @@ class SpeechClassifier(nn.Module, PyTorchModelHubMixin):
55
  logits = self.classifier(pooled_output)
56
  return logits
57
 
58
- # Prepare data function (may need to update for numpy input)
59
  def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
60
  feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
61
- # ... your logic for preparing data ...
62
- # Must return tensor that your model can process
63
- pass
 
64
 
65
- # Prediction function (may need to update for numpy input)
66
- def predict(audio_data, sampling_rate, config={"encoder": "openai/whisper-base", "num_labels": 2}):
67
- # Load the model from Hugging Face Hub (ensure correct loading mechanism)
 
68
  model = SpeechClassifier(config).to(device)
69
- # ... your logic for prediction using model ...
70
- # Must return a prediction
71
- pass
 
 
 
 
 
72
 
73
- # Gradio Interface function for uploaded files
74
  def gradio_file_interface(uploaded_file):
75
- # Gradio passes a file path as a string for uploaded files
76
- prediction = predict(uploaded_file, config)
 
 
77
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
78
  return label
79
 
80
- # Gradio Interface function for microphone input
81
  def gradio_mic_interface(mic_input):
82
- # Gradio passes mic input as a numpy array and sample rate
83
- audio_data = mic_input['data']
84
- sampling_rate = mic_input['sample_rate']
85
- prediction = predict(audio_data, sampling_rate, config)
86
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
87
  return label
88
 
@@ -93,18 +98,17 @@ demo = gr.Blocks()
93
  with demo:
94
  mic_transcribe = gr.Interface(
95
  fn=gradio_mic_interface,
96
- inputs=gr.Audio(type="numpy"), # Receives numpy array for mic input
97
  outputs=gr.Textbox(label="Prediction")
98
  )
99
 
100
  file_transcribe = gr.Interface(
101
  fn=gradio_file_interface,
102
- inputs=gr.Audio(type="filepath"), # Receives file path for file upload
103
  outputs=gr.Textbox(label="Prediction")
104
  )
105
 
106
- # Use a tabbed interface to switch between the microphone and file upload interfaces
107
  gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"])
108
 
109
- # Launch the demo with debugging enabled
110
  demo.launch(debug=True)
 
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
 
7
  from datasets import load_dataset, DatasetDict, Audio
8
  from huggingface_hub import PyTorchModelHubMixin
9
  import numpy as np
 
 
10
 
11
  # Ensure you have the device setup (cuda or cpu)
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
+ # Define the config for your model
15
+ config = {"encoder": "openai/whisper-base", "num_labels": 2}
16
+
17
  # Define data class
18
  class SpeechInferenceDataset(Dataset):
19
  def __init__(self, audio_data, text_processor):
 
28
  return_tensors="pt",
29
  sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
30
  input_features = inputs.input_features
31
+ decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
 
32
  return input_features, decoder_input_ids
33
 
34
  # Define model class
 
54
  logits = self.classifier(pooled_output)
55
  return logits
56
 
57
+ # Prepare data function
58
  def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
59
  feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
60
+ inputs = feature_extractor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
61
+ input_features = inputs.input_features
62
+ decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
63
+ return input_features.to(device), decoder_input_ids.to(device)
64
 
65
+ # Prediction function
66
+ def predict(audio_data, sampling_rate, config):
67
+ input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"])
68
+
69
  model = SpeechClassifier(config).to(device)
70
+ # Here we load the model from Hugging Face Hub
71
+ model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))
72
+
73
+ model.eval()
74
+ with torch.no_grad():
75
+ logits = model(input_features, decoder_input_ids)
76
+ predicted_ids = int(torch.argmax(logits, dim=-1))
77
+ return predicted_ids
78
 
79
+ # Gradio Interface functions
80
  def gradio_file_interface(uploaded_file):
81
+ # Assuming the uploaded_file is a filepath (str)
82
+ with open(uploaded_file, "rb") as f:
83
+ audio_data = np.frombuffer(f.read(), np.int16)
84
+ prediction = predict(audio_data, 16000, config) # Assume 16kHz sample rate
85
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
86
  return label
87
 
 
88
  def gradio_mic_interface(mic_input):
89
+ # mic_input is a dictionary with 'data' and 'sample_rate' keys
90
+ prediction = predict(mic_input['data'], mic_input['sample_rate'], config)
 
 
91
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
92
  return label
93
 
 
98
  with demo:
99
  mic_transcribe = gr.Interface(
100
  fn=gradio_mic_interface,
101
+ inputs=gr.Audio(source="microphone", type="numpy"), # Correct type for microphone
102
  outputs=gr.Textbox(label="Prediction")
103
  )
104
 
105
  file_transcribe = gr.Interface(
106
  fn=gradio_file_interface,
107
+ inputs=gr.Audio(source="upload", type="file"), # Correct type for file upload
108
  outputs=gr.Textbox(label="Prediction")
109
  )
110
 
 
111
  gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"])
112
 
113
+ # Launch the demo
114
  demo.launch(debug=True)