invincible-jha
commited on
Commit
•
1cd7ce8
1
Parent(s):
784383b
Upload app.py
Browse files
app.py
CHANGED
@@ -14,24 +14,27 @@ class ModelManager:
|
|
14 |
self.load_models()
|
15 |
|
16 |
def load_models(self):
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
35 |
|
36 |
class AudioProcessor:
|
37 |
def __init__(self):
|
@@ -59,9 +62,13 @@ class AudioProcessor:
|
|
59 |
class Analyzer:
|
60 |
def __init__(self):
|
61 |
print("Initializing Analyzer...")
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def analyze(self, audio_path):
|
67 |
try:
|
@@ -72,9 +79,10 @@ class Analyzer:
|
|
72 |
inputs = self.model_manager.processors['whisper'](
|
73 |
waveform,
|
74 |
return_tensors="pt"
|
75 |
-
).input_features
|
76 |
|
77 |
-
|
|
|
78 |
transcription = self.model_manager.processors['whisper'].batch_decode(
|
79 |
predicted_ids,
|
80 |
skip_special_tokens=True
|
@@ -88,14 +96,16 @@ class Analyzer:
|
|
88 |
truncation=True,
|
89 |
max_length=512
|
90 |
)
|
|
|
91 |
|
92 |
-
|
|
|
93 |
emotions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
94 |
|
95 |
emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
|
96 |
emotion_scores = {
|
97 |
label: float(score)
|
98 |
-
for label, score in zip(emotion_labels, emotions[0])
|
99 |
}
|
100 |
|
101 |
return {
|
@@ -130,9 +140,6 @@ def create_emotion_plot(emotions):
|
|
130 |
print(f"Error creating plot: {str(e)}")
|
131 |
return "Error creating visualization"
|
132 |
|
133 |
-
print("Initializing application...")
|
134 |
-
analyzer = Analyzer()
|
135 |
-
|
136 |
def process_audio(audio_file):
|
137 |
try:
|
138 |
if audio_file is None:
|
@@ -150,24 +157,31 @@ def process_audio(audio_file):
|
|
150 |
print(error_msg)
|
151 |
return error_msg, "Error in analysis"
|
152 |
|
153 |
-
print("Creating Gradio interface...")
|
154 |
-
interface = gr.Interface(
|
155 |
-
fn=process_audio,
|
156 |
-
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
|
157 |
-
outputs=[
|
158 |
-
gr.Textbox(label="Transcription"),
|
159 |
-
gr.HTML(label="Emotion Analysis")
|
160 |
-
],
|
161 |
-
title="Vocal Biomarker Analysis",
|
162 |
-
description="Analyze voice for emotional indicators",
|
163 |
-
examples=[],
|
164 |
-
cache_examples=False
|
165 |
-
)
|
166 |
-
|
167 |
if __name__ == "__main__":
|
168 |
-
print("
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
self.load_models()
|
15 |
|
16 |
def load_models(self):
|
17 |
+
try:
|
18 |
+
print("Loading Whisper model...")
|
19 |
+
self.processors['whisper'] = WhisperProcessor.from_pretrained(
|
20 |
+
"openai/whisper-base" # Removed device_map parameter
|
21 |
+
)
|
22 |
+
self.models['whisper'] = WhisperForConditionalGeneration.from_pretrained(
|
23 |
+
"openai/whisper-base" # Removed device_map parameter
|
24 |
+
).to(self.device)
|
25 |
+
|
26 |
+
print("Loading emotion model...")
|
27 |
+
self.tokenizers['emotion'] = AutoTokenizer.from_pretrained(
|
28 |
+
"j-hartmann/emotion-english-distilroberta-base"
|
29 |
+
)
|
30 |
+
self.models['emotion'] = AutoModelForSequenceClassification.from_pretrained(
|
31 |
+
"j-hartmann/emotion-english-distilroberta-base" # Removed device_map parameter
|
32 |
+
).to(self.device)
|
33 |
+
|
34 |
+
print("Models loaded successfully")
|
35 |
+
except Exception as e:
|
36 |
+
print(f"Error loading models: {str(e)}")
|
37 |
+
raise
|
38 |
|
39 |
class AudioProcessor:
|
40 |
def __init__(self):
|
|
|
62 |
class Analyzer:
|
63 |
def __init__(self):
|
64 |
print("Initializing Analyzer...")
|
65 |
+
try:
|
66 |
+
self.model_manager = ModelManager()
|
67 |
+
self.audio_processor = AudioProcessor()
|
68 |
+
print("Analyzer initialization complete")
|
69 |
+
except Exception as e:
|
70 |
+
print(f"Error initializing Analyzer: {str(e)}")
|
71 |
+
raise
|
72 |
|
73 |
def analyze(self, audio_path):
|
74 |
try:
|
|
|
79 |
inputs = self.model_manager.processors['whisper'](
|
80 |
waveform,
|
81 |
return_tensors="pt"
|
82 |
+
).input_features.to(self.model_manager.device)
|
83 |
|
84 |
+
with torch.no_grad():
|
85 |
+
predicted_ids = self.model_manager.models['whisper'].generate(inputs)
|
86 |
transcription = self.model_manager.processors['whisper'].batch_decode(
|
87 |
predicted_ids,
|
88 |
skip_special_tokens=True
|
|
|
96 |
truncation=True,
|
97 |
max_length=512
|
98 |
)
|
99 |
+
inputs = {k: v.to(self.model_manager.device) for k, v in inputs.items()}
|
100 |
|
101 |
+
with torch.no_grad():
|
102 |
+
outputs = self.model_manager.models['emotion'](**inputs)
|
103 |
emotions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
104 |
|
105 |
emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
|
106 |
emotion_scores = {
|
107 |
label: float(score)
|
108 |
+
for label, score in zip(emotion_labels, emotions[0].cpu())
|
109 |
}
|
110 |
|
111 |
return {
|
|
|
140 |
print(f"Error creating plot: {str(e)}")
|
141 |
return "Error creating visualization"
|
142 |
|
|
|
|
|
|
|
143 |
def process_audio(audio_file):
|
144 |
try:
|
145 |
if audio_file is None:
|
|
|
157 |
print(error_msg)
|
158 |
return error_msg, "Error in analysis"
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
if __name__ == "__main__":
|
161 |
+
print("Initializing application...")
|
162 |
+
try:
|
163 |
+
analyzer = Analyzer()
|
164 |
+
|
165 |
+
print("Creating Gradio interface...")
|
166 |
+
interface = gr.Interface(
|
167 |
+
fn=process_audio,
|
168 |
+
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
|
169 |
+
outputs=[
|
170 |
+
gr.Textbox(label="Transcription"),
|
171 |
+
gr.HTML(label="Emotion Analysis")
|
172 |
+
],
|
173 |
+
title="Vocal Biomarker Analysis",
|
174 |
+
description="Analyze voice for emotional indicators",
|
175 |
+
examples=[],
|
176 |
+
cache_examples=False
|
177 |
+
)
|
178 |
+
|
179 |
+
print("Launching application...")
|
180 |
+
interface.launch(
|
181 |
+
server_name="0.0.0.0",
|
182 |
+
server_port=7860,
|
183 |
+
share=False
|
184 |
+
)
|
185 |
+
except Exception as e:
|
186 |
+
print(f"Fatal error during application startup: {str(e)}")
|
187 |
+
raise
|