import gradio as gr import torch import librosa import numpy as np import torch.nn.functional as F import os from encoders.transformer import Wav2Vec2EmotionClassifier # Define the emotions emotions = ["happy", "sad", "angry", "neutral", "fear", "disgust", "surprise"] label_mapping = {str(idx): emotion for idx, emotion in enumerate(emotions)} # Load the trained model model_path = "lora_only_model.pth" cfg = { "model": { "encoder": "Wav2Vec2Classifier", "optimizer": { "name": "Adam", "lr": 0.0003, "weight_decay": 3e-4 }, "l1_lambda": 0.0 } } model = Wav2Vec2EmotionClassifier(num_classes=len(emotions), optimizer_cfg=cfg["model"]["optimizer"]) state_dict = torch.load(model_path, map_location=torch.device("cpu")) model.load_state_dict(state_dict, strict=False) model.eval() for name, param in model.named_parameters(): if param.requires_grad: print(f"{name}: {param.data}") # Optional: we define a minimum number of samples to avoid Wav2Vec2 conv errors MIN_SAMPLES = 10 # or 16000 if you want at least 1 second # Preprocessing function def preprocess_audio(file_path, sample_rate=16000): """ Safely loads the file at file_path and returns a (1, samples) torch tensor. Returns None if the file is invalid or too short. """ if not file_path or (not os.path.exists(file_path)): # file_path could be None or an empty string if user didn't record properly return None # Load with librosa (which merges to mono by default if multi-channel) waveform, sr = librosa.load(file_path, sr=sample_rate) # Check length if len(waveform) < MIN_SAMPLES: return None # Convert to torch tensor, shape (1, samples) waveform_tensor = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0) return waveform_tensor # Prediction function def predict_emotion(audio_file): """ audio_file is a file path from Gradio (type='filepath'). """ # Preprocess waveform = preprocess_audio(audio_file, sample_rate=16000) # If invalid or too short, return an error-like message if waveform is None: return ( "Audio is too short or invalid. Please record/upload a longer clip.", "" ) # Perform inference with torch.no_grad(): logits = model(waveform) probabilities = F.softmax(logits, dim=-1).cpu().numpy()[0] # Get the predicted class predicted_class = np.argmax(probabilities) predicted_emotion = label_mapping[str(predicted_class)] # Format probabilities for visualization probabilities_output = [ f"""
🎵 Upload or record an audio file (max 1 minute) to detect emotions.
Supported Emotions: 😊 Happy | 😭 Sad | 😡 Angry | 😐 Neutral | 😨 Fear | 🤢 Disgust | 😮 Surprise