invincible-jha's picture
Upload app.py
f7af1db verified
raw
history blame
6.42 kB
import gradio as gr
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModelForSequenceClassification, AutoTokenizer
import librosa
import numpy as np
import plotly.graph_objects as go
import warnings
import os
warnings.filterwarnings('ignore')
# Global variables for models
processor = None
whisper_model = None
emotion_tokenizer = None
emotion_model = None
def load_models():
"""Initialize and load all required models"""
global processor, whisper_model, emotion_tokenizer, emotion_model
try:
print("Loading Whisper model...")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
print("Loading emotion model...")
emotion_tokenizer = AutoTokenizer.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
emotion_model = AutoModelForSequenceClassification.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
# Move models to CPU explicitly
whisper_model.to("cpu")
emotion_model.to("cpu")
print("Models loaded successfully!")
return True
except Exception as e:
print(f"Error loading models: {str(e)}")
return False
def process_audio(audio_input):
"""Process audio file and extract waveform"""
try:
print(f"Audio input received: {type(audio_input)}")
# Handle tuple input from Gradio
if isinstance(audio_input, tuple):
print(f"Audio input is tuple: {audio_input[0]}, {audio_input[1]}")
audio_path = audio_input[0] # Get the file path
else:
audio_path = audio_input
print(f"Processing audio from path: {audio_path}")
# Verify file exists
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found at {audio_path}")
# Load and resample audio
print("Loading audio file with librosa...")
waveform, sr = librosa.load(audio_path, sr=16000)
print(f"Audio loaded successfully. Shape: {waveform.shape}, SR: {sr}")
return waveform
except Exception as e:
print(f"Error processing audio: {str(e)}")
raise
def create_emotion_plot(emotions):
"""Create plotly visualization for emotion scores"""
try:
fig = go.Figure(data=[
go.Bar(
x=list(emotions.keys()),
y=list(emotions.values()),
marker_color='rgb(55, 83, 109)'
)
])
fig.update_layout(
title='Emotion Analysis',
xaxis_title='Emotion',
yaxis_title='Score',
yaxis_range=[0, 1],
template='plotly_white',
height=400
)
return fig.to_html(include_plotlyjs=True)
except Exception as e:
print(f"Error creating plot: {str(e)}")
return "Error creating visualization"
def analyze_audio(audio_input):
"""Main function to analyze audio input"""
try:
if audio_input is None:
print("No audio input provided")
return "No audio file provided", "Please provide an audio file"
print(f"Received audio input: {audio_input}")
# Process audio
waveform = process_audio(audio_input)
if waveform is None or len(waveform) == 0:
return "Error: Invalid audio file", "Please provide a valid audio file"
# Transcribe audio
print("Transcribing audio...")
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt").input_features
with torch.no_grad():
predicted_ids = whisper_model.generate(inputs)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
print(f"Transcription completed: {transcription}")
if not transcription or transcription.isspace():
return "No speech detected in audio", "Unable to analyze emotions without speech"
# Analyze emotions
print("Analyzing emotions...")
inputs = emotion_tokenizer(
transcription,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = emotion_model(**inputs)
emotions = torch.nn.functional.softmax(outputs.logits, dim=-1)
emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
emotion_scores = {
label: float(score)
for label, score in zip(emotion_labels, emotions[0].cpu().numpy())
}
print(f"Emotion analysis completed: {emotion_scores}")
# Create visualization
emotion_viz = create_emotion_plot(emotion_scores)
return transcription, emotion_viz
except FileNotFoundError as e:
error_msg = f"Audio file not found: {str(e)}"
print(error_msg)
return error_msg, "Please provide a valid audio file"
except Exception as e:
error_msg = f"Error analyzing audio: {str(e)}"
print(error_msg)
return error_msg, "Error in analysis"
# Load models at startup
print("Initializing application...")
if not load_models():
raise RuntimeError("Failed to load required models")
# Create Gradio interface
demo = gr.Interface(
fn=analyze_audio,
inputs=gr.Audio(
source="microphone",
type="filepath",
label="Audio Input"
),
outputs=[
gr.Textbox(label="Transcription"),
gr.HTML(label="Emotion Analysis")
],
title="Vocal Emotion Analysis",
description="""
This app analyzes voice recordings to:
1. Transcribe speech to text
2. Detect emotions in the speech
Upload an audio file or record directly through your microphone.
""",
article="""
Models used:
- Speech recognition: Whisper (tiny)
- Emotion detection: DistilRoBERTa
Note: Processing may take a few moments depending on the length of the audio.
""",
examples=None,
cache_examples=False
)
if __name__ == "__main__":
demo.launch(debug=True)