File size: 6,526 Bytes
822dda9
0c9630a
 
 
 
 
822dda9
0c9630a
 
784383b
0c9630a
 
 
 
 
 
1cd7ce8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e04908
 
 
 
 
 
 
784383b
 
 
 
 
 
0e04908
 
784383b
 
 
 
 
 
 
 
0e04908
 
 
 
1cd7ce8
 
 
 
 
 
 
0e04908
 
784383b
 
 
 
 
 
 
 
1cd7ce8
784383b
1cd7ce8
 
784383b
 
 
 
 
 
 
 
 
 
 
 
 
1cd7ce8
784383b
1cd7ce8
 
784383b
 
 
 
 
1cd7ce8
784383b
 
 
 
 
 
 
 
 
0e04908
 
784383b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e04908
 
 
784383b
 
 
0e04908
 
 
 
 
 
 
 
784383b
 
 
0e04908
784383b
1cd7ce8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import gradio as gr
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModelForSequenceClassification, AutoTokenizer
import librosa
import numpy as np
import plotly.graph_objects as go

class ModelManager:
    def __init__(self):
        self.device = torch.device("cpu")
        self.models = {}
        self.tokenizers = {}
        self.processors = {}
        self.load_models()
        
    def load_models(self):
        try:
            print("Loading Whisper model...")
            self.processors['whisper'] = WhisperProcessor.from_pretrained(
                "openai/whisper-base"  # Removed device_map parameter
            )
            self.models['whisper'] = WhisperForConditionalGeneration.from_pretrained(
                "openai/whisper-base"  # Removed device_map parameter
            ).to(self.device)
            
            print("Loading emotion model...")
            self.tokenizers['emotion'] = AutoTokenizer.from_pretrained(
                "j-hartmann/emotion-english-distilroberta-base"
            )
            self.models['emotion'] = AutoModelForSequenceClassification.from_pretrained(
                "j-hartmann/emotion-english-distilroberta-base"  # Removed device_map parameter
            ).to(self.device)
            
            print("Models loaded successfully")
        except Exception as e:
            print(f"Error loading models: {str(e)}")
            raise

class AudioProcessor:
    def __init__(self):
        self.sample_rate = 16000
        self.n_mfcc = 13
        
    def process_audio(self, audio_path):
        try:
            waveform, sr = librosa.load(audio_path, sr=self.sample_rate)
            return waveform, self._extract_features(waveform)
        except Exception as e:
            print(f"Error processing audio: {str(e)}")
            raise
        
    def _extract_features(self, waveform):
        try:
            return {
                'mfcc': librosa.feature.mfcc(y=waveform, sr=self.sample_rate, n_mfcc=self.n_mfcc),
                'energy': librosa.feature.rms(y=waveform)[0]
            }
        except Exception as e:
            print(f"Error extracting features: {str(e)}")
            raise

class Analyzer:
    def __init__(self):
        print("Initializing Analyzer...")
        try:
            self.model_manager = ModelManager()
            self.audio_processor = AudioProcessor()
            print("Analyzer initialization complete")
        except Exception as e:
            print(f"Error initializing Analyzer: {str(e)}")
            raise
        
    def analyze(self, audio_path):
        try:
            print(f"Processing audio file: {audio_path}")
            waveform, features = self.audio_processor.process_audio(audio_path)
            
            print("Transcribing audio...")
            inputs = self.model_manager.processors['whisper'](
                waveform, 
                return_tensors="pt"
            ).input_features.to(self.model_manager.device)
            
            with torch.no_grad():
                predicted_ids = self.model_manager.models['whisper'].generate(inputs)
            transcription = self.model_manager.processors['whisper'].batch_decode(
                predicted_ids, 
                skip_special_tokens=True
            )[0]
            
            print("Analyzing emotions...")
            inputs = self.model_manager.tokenizers['emotion'](
                transcription,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            )
            inputs = {k: v.to(self.model_manager.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model_manager.models['emotion'](**inputs)
            emotions = torch.nn.functional.softmax(outputs.logits, dim=-1)
            
            emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
            emotion_scores = {
                label: float(score) 
                for label, score in zip(emotion_labels, emotions[0].cpu())
            }
            
            return {
                'transcription': transcription,
                'emotions': emotion_scores
            }
        except Exception as e:
            print(f"Error in analysis: {str(e)}")
            raise

def create_emotion_plot(emotions):
    try:
        fig = go.Figure(data=[
            go.Bar(
                x=list(emotions.keys()), 
                y=list(emotions.values()),
                marker_color='rgb(55, 83, 109)'
            )
        ])
        
        fig.update_layout(
            title='Emotion Analysis',
            xaxis_title='Emotion',
            yaxis_title='Score',
            yaxis_range=[0, 1],
            template='plotly_white',
            height=400
        )
        
        return fig.to_html(include_plotlyjs=True)
    except Exception as e:
        print(f"Error creating plot: {str(e)}")
        return "Error creating visualization"

def process_audio(audio_file):
    try:
        if audio_file is None:
            return "No audio file provided", "Please provide an audio file"
            
        print(f"Processing audio file: {audio_file}")
        results = analyzer.analyze(audio_file)
        
        return (
            results['transcription'],
            create_emotion_plot(results['emotions'])
        )
    except Exception as e:
        error_msg = f"Error processing audio: {str(e)}"
        print(error_msg)
        return error_msg, "Error in analysis"

if __name__ == "__main__":
    print("Initializing application...")
    try:
        analyzer = Analyzer()
        
        print("Creating Gradio interface...")
        interface = gr.Interface(
            fn=process_audio,
            inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
            outputs=[
                gr.Textbox(label="Transcription"),
                gr.HTML(label="Emotion Analysis")
            ],
            title="Vocal Biomarker Analysis",
            description="Analyze voice for emotional indicators",
            examples=[],
            cache_examples=False
        )
        
        print("Launching application...")
        interface.launch(
            server_name="0.0.0.0",
            server_port=7860,
            share=False
        )
    except Exception as e:
        print(f"Fatal error during application startup: {str(e)}")
        raise