import gradio as gr import torch from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModelForSequenceClassification, AutoTokenizer import librosa import numpy as np import plotly.graph_objects as go import warnings import os warnings.filterwarnings('ignore') # Global variables for models processor = None whisper_model = None emotion_tokenizer = None emotion_model = None def load_models(): """Initialize and load all required models""" global processor, whisper_model, emotion_tokenizer, emotion_model try: print("Loading Whisper model...") processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") print("Loading emotion model...") emotion_tokenizer = AutoTokenizer.from_pretrained("j-hartmann/emotion-english-distilroberta-base") emotion_model = AutoModelForSequenceClassification.from_pretrained("j-hartmann/emotion-english-distilroberta-base") # Move models to CPU explicitly whisper_model.to("cpu") emotion_model.to("cpu") print("Models loaded successfully!") return True except Exception as e: print(f"Error loading models: {str(e)}") return False def process_audio(audio_input): """Process audio file and extract waveform""" try: print(f"Audio input received: {type(audio_input)}") # Handle tuple input from Gradio if isinstance(audio_input, tuple): print(f"Audio input is tuple: {audio_input[0]}, {audio_input[1]}") audio_path = audio_input[0] # Get the file path else: audio_path = audio_input print(f"Processing audio from path: {audio_path}") # Verify file exists if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found at {audio_path}") # Load and resample audio print("Loading audio file with librosa...") waveform, sr = librosa.load(audio_path, sr=16000) print(f"Audio loaded successfully. Shape: {waveform.shape}, SR: {sr}") return waveform except Exception as e: print(f"Error processing audio: {str(e)}") raise def create_emotion_plot(emotions): """Create plotly visualization for emotion scores""" try: fig = go.Figure(data=[ go.Bar( x=list(emotions.keys()), y=list(emotions.values()), marker_color='rgb(55, 83, 109)' ) ]) fig.update_layout( title='Emotion Analysis', xaxis_title='Emotion', yaxis_title='Score', yaxis_range=[0, 1], template='plotly_white', height=400 ) return fig.to_html(include_plotlyjs=True) except Exception as e: print(f"Error creating plot: {str(e)}") return "Error creating visualization" def analyze_audio(audio_input): """Main function to analyze audio input""" try: if audio_input is None: print("No audio input provided") return "No audio file provided", "Please provide an audio file" print(f"Received audio input: {audio_input}") # Process audio waveform = process_audio(audio_input) if waveform is None or len(waveform) == 0: return "Error: Invalid audio file", "Please provide a valid audio file" # Transcribe audio print("Transcribing audio...") inputs = processor(waveform, sampling_rate=16000, return_tensors="pt").input_features with torch.no_grad(): predicted_ids = whisper_model.generate(inputs) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] print(f"Transcription completed: {transcription}") if not transcription or transcription.isspace(): return "No speech detected in audio", "Unable to analyze emotions without speech" # Analyze emotions print("Analyzing emotions...") inputs = emotion_tokenizer( transcription, return_tensors="pt", padding=True, truncation=True, max_length=512 ) with torch.no_grad(): outputs = emotion_model(**inputs) emotions = torch.nn.functional.softmax(outputs.logits, dim=-1) emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise'] emotion_scores = { label: float(score) for label, score in zip(emotion_labels, emotions[0].cpu().numpy()) } print(f"Emotion analysis completed: {emotion_scores}") # Create visualization emotion_viz = create_emotion_plot(emotion_scores) return transcription, emotion_viz except FileNotFoundError as e: error_msg = f"Audio file not found: {str(e)}" print(error_msg) return error_msg, "Please provide a valid audio file" except Exception as e: error_msg = f"Error analyzing audio: {str(e)}" print(error_msg) return error_msg, "Error in analysis" # Load models at startup print("Initializing application...") if not load_models(): raise RuntimeError("Failed to load required models") # Create Gradio interface demo = gr.Interface( fn=analyze_audio, inputs=gr.Audio( source="microphone", type="filepath", label="Audio Input" ), outputs=[ gr.Textbox(label="Transcription"), gr.HTML(label="Emotion Analysis") ], title="Vocal Emotion Analysis", description=""" This app analyzes voice recordings to: 1. Transcribe speech to text 2. Detect emotions in the speech Upload an audio file or record directly through your microphone. """, article=""" Models used: - Speech recognition: Whisper (tiny) - Emotion detection: DistilRoBERTa Note: Processing may take a few moments depending on the length of the audio. """, examples=None, cache_examples=False ) if __name__ == "__main__": demo.launch(debug=True)