invincible-jha commited on
Commit
1cd7ce8
1 Parent(s): 784383b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -48
app.py CHANGED
@@ -14,24 +14,27 @@ class ModelManager:
14
  self.load_models()
15
 
16
  def load_models(self):
17
- print("Loading Whisper model...")
18
- self.processors['whisper'] = WhisperProcessor.from_pretrained(
19
- "openai/whisper-base",
20
- device_map="cpu"
21
- )
22
- self.models['whisper'] = WhisperForConditionalGeneration.from_pretrained(
23
- "openai/whisper-base",
24
- device_map="cpu"
25
- )
26
-
27
- print("Loading emotion model...")
28
- self.tokenizers['emotion'] = AutoTokenizer.from_pretrained(
29
- "j-hartmann/emotion-english-distilroberta-base"
30
- )
31
- self.models['emotion'] = AutoModelForSequenceClassification.from_pretrained(
32
- "j-hartmann/emotion-english-distilroberta-base",
33
- device_map="cpu"
34
- )
 
 
 
35
 
36
  class AudioProcessor:
37
  def __init__(self):
@@ -59,9 +62,13 @@ class AudioProcessor:
59
  class Analyzer:
60
  def __init__(self):
61
  print("Initializing Analyzer...")
62
- self.model_manager = ModelManager()
63
- self.audio_processor = AudioProcessor()
64
- print("Analyzer initialization complete")
 
 
 
 
65
 
66
  def analyze(self, audio_path):
67
  try:
@@ -72,9 +79,10 @@ class Analyzer:
72
  inputs = self.model_manager.processors['whisper'](
73
  waveform,
74
  return_tensors="pt"
75
- ).input_features
76
 
77
- predicted_ids = self.model_manager.models['whisper'].generate(inputs)
 
78
  transcription = self.model_manager.processors['whisper'].batch_decode(
79
  predicted_ids,
80
  skip_special_tokens=True
@@ -88,14 +96,16 @@ class Analyzer:
88
  truncation=True,
89
  max_length=512
90
  )
 
91
 
92
- outputs = self.model_manager.models['emotion'](**inputs)
 
93
  emotions = torch.nn.functional.softmax(outputs.logits, dim=-1)
94
 
95
  emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
96
  emotion_scores = {
97
  label: float(score)
98
- for label, score in zip(emotion_labels, emotions[0])
99
  }
100
 
101
  return {
@@ -130,9 +140,6 @@ def create_emotion_plot(emotions):
130
  print(f"Error creating plot: {str(e)}")
131
  return "Error creating visualization"
132
 
133
- print("Initializing application...")
134
- analyzer = Analyzer()
135
-
136
  def process_audio(audio_file):
137
  try:
138
  if audio_file is None:
@@ -150,24 +157,31 @@ def process_audio(audio_file):
150
  print(error_msg)
151
  return error_msg, "Error in analysis"
152
 
153
- print("Creating Gradio interface...")
154
- interface = gr.Interface(
155
- fn=process_audio,
156
- inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
157
- outputs=[
158
- gr.Textbox(label="Transcription"),
159
- gr.HTML(label="Emotion Analysis")
160
- ],
161
- title="Vocal Biomarker Analysis",
162
- description="Analyze voice for emotional indicators",
163
- examples=[],
164
- cache_examples=False
165
- )
166
-
167
  if __name__ == "__main__":
168
- print("Launching application...")
169
- interface.launch(
170
- server_name="0.0.0.0",
171
- server_port=7860,
172
- share=False
173
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  self.load_models()
15
 
16
  def load_models(self):
17
+ try:
18
+ print("Loading Whisper model...")
19
+ self.processors['whisper'] = WhisperProcessor.from_pretrained(
20
+ "openai/whisper-base" # Removed device_map parameter
21
+ )
22
+ self.models['whisper'] = WhisperForConditionalGeneration.from_pretrained(
23
+ "openai/whisper-base" # Removed device_map parameter
24
+ ).to(self.device)
25
+
26
+ print("Loading emotion model...")
27
+ self.tokenizers['emotion'] = AutoTokenizer.from_pretrained(
28
+ "j-hartmann/emotion-english-distilroberta-base"
29
+ )
30
+ self.models['emotion'] = AutoModelForSequenceClassification.from_pretrained(
31
+ "j-hartmann/emotion-english-distilroberta-base" # Removed device_map parameter
32
+ ).to(self.device)
33
+
34
+ print("Models loaded successfully")
35
+ except Exception as e:
36
+ print(f"Error loading models: {str(e)}")
37
+ raise
38
 
39
  class AudioProcessor:
40
  def __init__(self):
 
62
  class Analyzer:
63
  def __init__(self):
64
  print("Initializing Analyzer...")
65
+ try:
66
+ self.model_manager = ModelManager()
67
+ self.audio_processor = AudioProcessor()
68
+ print("Analyzer initialization complete")
69
+ except Exception as e:
70
+ print(f"Error initializing Analyzer: {str(e)}")
71
+ raise
72
 
73
  def analyze(self, audio_path):
74
  try:
 
79
  inputs = self.model_manager.processors['whisper'](
80
  waveform,
81
  return_tensors="pt"
82
+ ).input_features.to(self.model_manager.device)
83
 
84
+ with torch.no_grad():
85
+ predicted_ids = self.model_manager.models['whisper'].generate(inputs)
86
  transcription = self.model_manager.processors['whisper'].batch_decode(
87
  predicted_ids,
88
  skip_special_tokens=True
 
96
  truncation=True,
97
  max_length=512
98
  )
99
+ inputs = {k: v.to(self.model_manager.device) for k, v in inputs.items()}
100
 
101
+ with torch.no_grad():
102
+ outputs = self.model_manager.models['emotion'](**inputs)
103
  emotions = torch.nn.functional.softmax(outputs.logits, dim=-1)
104
 
105
  emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
106
  emotion_scores = {
107
  label: float(score)
108
+ for label, score in zip(emotion_labels, emotions[0].cpu())
109
  }
110
 
111
  return {
 
140
  print(f"Error creating plot: {str(e)}")
141
  return "Error creating visualization"
142
 
 
 
 
143
  def process_audio(audio_file):
144
  try:
145
  if audio_file is None:
 
157
  print(error_msg)
158
  return error_msg, "Error in analysis"
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  if __name__ == "__main__":
161
+ print("Initializing application...")
162
+ try:
163
+ analyzer = Analyzer()
164
+
165
+ print("Creating Gradio interface...")
166
+ interface = gr.Interface(
167
+ fn=process_audio,
168
+ inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
169
+ outputs=[
170
+ gr.Textbox(label="Transcription"),
171
+ gr.HTML(label="Emotion Analysis")
172
+ ],
173
+ title="Vocal Biomarker Analysis",
174
+ description="Analyze voice for emotional indicators",
175
+ examples=[],
176
+ cache_examples=False
177
+ )
178
+
179
+ print("Launching application...")
180
+ interface.launch(
181
+ server_name="0.0.0.0",
182
+ server_port=7860,
183
+ share=False
184
+ )
185
+ except Exception as e:
186
+ print(f"Fatal error during application startup: {str(e)}")
187
+ raise