File size: 6,421 Bytes
822dda9
0c9630a
 
 
 
 
f7af1db
 
 
822dda9
f7af1db
 
 
 
 
0e04908
f7af1db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e04908
f7af1db
 
 
 
 
 
 
 
 
 
 
784383b
f7af1db
 
 
 
 
784383b
f7af1db
 
 
 
 
 
 
 
 
0e04908
 
f7af1db
784383b
 
 
f7af1db
784383b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e04908
f7af1db
 
0e04908
f7af1db
 
784383b
0e04908
f7af1db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e04908
f7af1db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e04908
f7af1db
784383b
 
0e04908
f7af1db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
784383b
f7af1db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import gradio as gr
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModelForSequenceClassification, AutoTokenizer
import librosa
import numpy as np
import plotly.graph_objects as go
import warnings
import os
warnings.filterwarnings('ignore')

# Global variables for models
processor = None
whisper_model = None
emotion_tokenizer = None
emotion_model = None

def load_models():
    """Initialize and load all required models"""
    global processor, whisper_model, emotion_tokenizer, emotion_model
    
    try:
        print("Loading Whisper model...")
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
        whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
        
        print("Loading emotion model...")
        emotion_tokenizer = AutoTokenizer.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
        emotion_model = AutoModelForSequenceClassification.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
        
        # Move models to CPU explicitly
        whisper_model.to("cpu")
        emotion_model.to("cpu")
        
        print("Models loaded successfully!")
        return True
    except Exception as e:
        print(f"Error loading models: {str(e)}")
        return False

def process_audio(audio_input):
    """Process audio file and extract waveform"""
    try:
        print(f"Audio input received: {type(audio_input)}")
        
        # Handle tuple input from Gradio
        if isinstance(audio_input, tuple):
            print(f"Audio input is tuple: {audio_input[0]}, {audio_input[1]}")
            audio_path = audio_input[0]  # Get the file path
        else:
            audio_path = audio_input
            
        print(f"Processing audio from path: {audio_path}")
        
        # Verify file exists
        if not os.path.exists(audio_path):
            raise FileNotFoundError(f"Audio file not found at {audio_path}")
            
        # Load and resample audio
        print("Loading audio file with librosa...")
        waveform, sr = librosa.load(audio_path, sr=16000)
        print(f"Audio loaded successfully. Shape: {waveform.shape}, SR: {sr}")
        
        return waveform
    except Exception as e:
        print(f"Error processing audio: {str(e)}")
        raise

def create_emotion_plot(emotions):
    """Create plotly visualization for emotion scores"""
    try:
        fig = go.Figure(data=[
            go.Bar(
                x=list(emotions.keys()),
                y=list(emotions.values()),
                marker_color='rgb(55, 83, 109)'
            )
        ])
        
        fig.update_layout(
            title='Emotion Analysis',
            xaxis_title='Emotion',
            yaxis_title='Score',
            yaxis_range=[0, 1],
            template='plotly_white',
            height=400
        )
        
        return fig.to_html(include_plotlyjs=True)
    except Exception as e:
        print(f"Error creating plot: {str(e)}")
        return "Error creating visualization"

def analyze_audio(audio_input):
    """Main function to analyze audio input"""
    try:
        if audio_input is None:
            print("No audio input provided")
            return "No audio file provided", "Please provide an audio file"
        
        print(f"Received audio input: {audio_input}")
        
        # Process audio
        waveform = process_audio(audio_input)
        
        if waveform is None or len(waveform) == 0:
            return "Error: Invalid audio file", "Please provide a valid audio file"
        
        # Transcribe audio
        print("Transcribing audio...")
        inputs = processor(waveform, sampling_rate=16000, return_tensors="pt").input_features
        
        with torch.no_grad():
            predicted_ids = whisper_model.generate(inputs)
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        
        print(f"Transcription completed: {transcription}")
        
        if not transcription or transcription.isspace():
            return "No speech detected in audio", "Unable to analyze emotions without speech"
        
        # Analyze emotions
        print("Analyzing emotions...")
        inputs = emotion_tokenizer(
            transcription,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )
        
        with torch.no_grad():
            outputs = emotion_model(**inputs)
        emotions = torch.nn.functional.softmax(outputs.logits, dim=-1)
        
        emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
        emotion_scores = {
            label: float(score) 
            for label, score in zip(emotion_labels, emotions[0].cpu().numpy())
        }
        
        print(f"Emotion analysis completed: {emotion_scores}")
        
        # Create visualization
        emotion_viz = create_emotion_plot(emotion_scores)
        
        return transcription, emotion_viz
        
    except FileNotFoundError as e:
        error_msg = f"Audio file not found: {str(e)}"
        print(error_msg)
        return error_msg, "Please provide a valid audio file"
    except Exception as e:
        error_msg = f"Error analyzing audio: {str(e)}"
        print(error_msg)
        return error_msg, "Error in analysis"

# Load models at startup
print("Initializing application...")
if not load_models():
    raise RuntimeError("Failed to load required models")

# Create Gradio interface
demo = gr.Interface(
    fn=analyze_audio,
    inputs=gr.Audio(
        source="microphone", 
        type="filepath",
        label="Audio Input"
    ),
    outputs=[
        gr.Textbox(label="Transcription"),
        gr.HTML(label="Emotion Analysis")
    ],
    title="Vocal Emotion Analysis",
    description="""
    This app analyzes voice recordings to:
    1. Transcribe speech to text
    2. Detect emotions in the speech
    
    Upload an audio file or record directly through your microphone.
    """,
    article="""
    Models used:
    - Speech recognition: Whisper (tiny)
    - Emotion detection: DistilRoBERTa
    
    Note: Processing may take a few moments depending on the length of the audio.
    """,
    examples=None,
    cache_examples=False
)

if __name__ == "__main__":
    demo.launch(debug=True)