import gradio as gr import torch import librosa import numpy as np import torch.nn.functional as F import os from encoders.transformer import Wav2Vec2EmotionClassifier # Define the emotions emotions = ["happy", "sad", "angry", "neutral", "fear", "disgust", "surprise"] label_mapping = {str(idx): emotion for idx, emotion in enumerate(emotions)} # Load the trained model model_path = "lora_only_model.pth" cfg = { "model": { "encoder": "Wav2Vec2Classifier", "optimizer": { "name": "Adam", "lr": 0.0003, "weight_decay": 3e-4 }, "l1_lambda": 0.0 } } model = Wav2Vec2EmotionClassifier(num_classes=len(emotions), optimizer_cfg=cfg["model"]["optimizer"]) state_dict = torch.load(model_path, map_location=torch.device("cpu")) model.load_state_dict(state_dict, strict=False) model.eval() for name, param in model.named_parameters(): if param.requires_grad: print(f"{name}: {}") # Optional: we define a minimum number of samples to avoid Wav2Vec2 conv errors MIN_SAMPLES = 10 # or 16000 if you want at least 1 second # Preprocessing function def preprocess_audio(file_path, sample_rate=16000): """ Safely loads the file at file_path and returns a (1, samples) torch tensor. Returns None if the file is invalid or too short. """ if not file_path or (not os.path.exists(file_path)): # file_path could be None or an empty string if user didn't record properly return None # Load with librosa (which merges to mono by default if multi-channel) waveform, sr = librosa.load(file_path, sr=sample_rate) # Check length if len(waveform) < MIN_SAMPLES: return None # Convert to torch tensor, shape (1, samples) waveform_tensor = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0) return waveform_tensor # Prediction function def predict_emotion(audio_file): """ audio_file is a file path from Gradio (type='filepath'). """ # Preprocess waveform = preprocess_audio(audio_file, sample_rate=16000) # If invalid or too short, return an error-like message if waveform is None: return ( "Audio is too short or invalid. Please record/upload a longer clip.", "" ) # Perform inference with torch.no_grad(): logits = model(waveform) probabilities = F.softmax(logits, dim=-1).cpu().numpy()[0] # Get the predicted class predicted_class = np.argmax(probabilities) predicted_emotion = label_mapping[str(predicted_class)] # Format probabilities for visualization probabilities_output = [ f"""
""" for i in range(len(emotions)) ] return predicted_emotion, "\n".join(probabilities_output) # Create Gradio interface def gradio_interface(audio): detected_emotion, probabilities_html = predict_emotion(audio) return detected_emotion, gr.HTML(probabilities_html) # Define Gradio UI with gr.Blocks(css=""" body { background-color: #121212; color: white; font-family: Arial, sans-serif; } h1 { color: #FFA500; font-size: 48px; text-align: center; margin-bottom: 10px; } p { text-align: center; font-size: 18px; } .gradio-row { justify-content: center; align-items: center; } #submit_button { background-color: #FFA500 !important; color: black !important; font-size: 18px; padding: 10px 20px; margin-top: 20px; } #detected_emotion { font-size: 24px; font-weight: bold; text-align: center; } .probabilities-container { margin-top: 20px; padding: 10px; background-color: #1F2937; border-radius: 8px; } """) as demo: gr.Markdown( """

Speech Emotion Recognition

🎵 Upload or record an audio file (max 1 minute) to detect emotions.

Supported Emotions: 😊 Happy | 😭 Sad | 😡 Angry | 😐 Neutral | 😨 Fear | 🤢 Disgust | 😮 Surprise

""" ) with gr.Row(): with gr.Column(scale=1, elem_id="audio-block"): # type="filepath" means we get a temporary file path from Gradio audio_input = gr.Audio(label="🎤 Record or Upload Audio", type="filepath") submit_button = gr.Button("Submit", elem_id="submit_button") with gr.Column(scale=1): detected_emotion_label = gr.Label(label="Detected Emotion", elem_id="detected_emotion") probabilities_html = gr.HTML(label="Probabilities", elem_id="probabilities") fn=gradio_interface, inputs=audio_input, outputs=[detected_emotion_label, probabilities_html] ) # Launch the app if __name__ == "__main__": demo.launch(share=True)