invincible-jha's picture
Upload app.py
1cd7ce8 verified
raw
history blame
6.53 kB
import gradio as gr
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModelForSequenceClassification, AutoTokenizer
import librosa
import numpy as np
import plotly.graph_objects as go
class ModelManager:
def __init__(self):
self.device = torch.device("cpu")
self.models = {}
self.tokenizers = {}
self.processors = {}
self.load_models()
def load_models(self):
try:
print("Loading Whisper model...")
self.processors['whisper'] = WhisperProcessor.from_pretrained(
"openai/whisper-base" # Removed device_map parameter
)
self.models['whisper'] = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-base" # Removed device_map parameter
).to(self.device)
print("Loading emotion model...")
self.tokenizers['emotion'] = AutoTokenizer.from_pretrained(
"j-hartmann/emotion-english-distilroberta-base"
)
self.models['emotion'] = AutoModelForSequenceClassification.from_pretrained(
"j-hartmann/emotion-english-distilroberta-base" # Removed device_map parameter
).to(self.device)
print("Models loaded successfully")
except Exception as e:
print(f"Error loading models: {str(e)}")
raise
class AudioProcessor:
def __init__(self):
self.sample_rate = 16000
self.n_mfcc = 13
def process_audio(self, audio_path):
try:
waveform, sr = librosa.load(audio_path, sr=self.sample_rate)
return waveform, self._extract_features(waveform)
except Exception as e:
print(f"Error processing audio: {str(e)}")
raise
def _extract_features(self, waveform):
try:
return {
'mfcc': librosa.feature.mfcc(y=waveform, sr=self.sample_rate, n_mfcc=self.n_mfcc),
'energy': librosa.feature.rms(y=waveform)[0]
}
except Exception as e:
print(f"Error extracting features: {str(e)}")
raise
class Analyzer:
def __init__(self):
print("Initializing Analyzer...")
try:
self.model_manager = ModelManager()
self.audio_processor = AudioProcessor()
print("Analyzer initialization complete")
except Exception as e:
print(f"Error initializing Analyzer: {str(e)}")
raise
def analyze(self, audio_path):
try:
print(f"Processing audio file: {audio_path}")
waveform, features = self.audio_processor.process_audio(audio_path)
print("Transcribing audio...")
inputs = self.model_manager.processors['whisper'](
waveform,
return_tensors="pt"
).input_features.to(self.model_manager.device)
with torch.no_grad():
predicted_ids = self.model_manager.models['whisper'].generate(inputs)
transcription = self.model_manager.processors['whisper'].batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
print("Analyzing emotions...")
inputs = self.model_manager.tokenizers['emotion'](
transcription,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
inputs = {k: v.to(self.model_manager.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model_manager.models['emotion'](**inputs)
emotions = torch.nn.functional.softmax(outputs.logits, dim=-1)
emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
emotion_scores = {
label: float(score)
for label, score in zip(emotion_labels, emotions[0].cpu())
}
return {
'transcription': transcription,
'emotions': emotion_scores
}
except Exception as e:
print(f"Error in analysis: {str(e)}")
raise
def create_emotion_plot(emotions):
try:
fig = go.Figure(data=[
go.Bar(
x=list(emotions.keys()),
y=list(emotions.values()),
marker_color='rgb(55, 83, 109)'
)
])
fig.update_layout(
title='Emotion Analysis',
xaxis_title='Emotion',
yaxis_title='Score',
yaxis_range=[0, 1],
template='plotly_white',
height=400
)
return fig.to_html(include_plotlyjs=True)
except Exception as e:
print(f"Error creating plot: {str(e)}")
return "Error creating visualization"
def process_audio(audio_file):
try:
if audio_file is None:
return "No audio file provided", "Please provide an audio file"
print(f"Processing audio file: {audio_file}")
results = analyzer.analyze(audio_file)
return (
results['transcription'],
create_emotion_plot(results['emotions'])
)
except Exception as e:
error_msg = f"Error processing audio: {str(e)}"
print(error_msg)
return error_msg, "Error in analysis"
if __name__ == "__main__":
print("Initializing application...")
try:
analyzer = Analyzer()
print("Creating Gradio interface...")
interface = gr.Interface(
fn=process_audio,
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
outputs=[
gr.Textbox(label="Transcription"),
gr.HTML(label="Emotion Analysis")
],
title="Vocal Biomarker Analysis",
description="Analyze voice for emotional indicators",
examples=[],
cache_examples=False
)
print("Launching application...")
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
except Exception as e:
print(f"Fatal error during application startup: {str(e)}")
raise