File size: 6,526 Bytes
822dda9 0c9630a 822dda9 0c9630a 784383b 0c9630a 1cd7ce8 0e04908 784383b 0e04908 784383b 0e04908 1cd7ce8 0e04908 784383b 1cd7ce8 784383b 1cd7ce8 784383b 1cd7ce8 784383b 1cd7ce8 784383b 1cd7ce8 784383b 0e04908 784383b 0e04908 784383b 0e04908 784383b 0e04908 784383b 1cd7ce8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import gradio as gr
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModelForSequenceClassification, AutoTokenizer
import librosa
import numpy as np
import plotly.graph_objects as go
class ModelManager:
def __init__(self):
self.device = torch.device("cpu")
self.models = {}
self.tokenizers = {}
self.processors = {}
self.load_models()
def load_models(self):
try:
print("Loading Whisper model...")
self.processors['whisper'] = WhisperProcessor.from_pretrained(
"openai/whisper-base" # Removed device_map parameter
)
self.models['whisper'] = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-base" # Removed device_map parameter
).to(self.device)
print("Loading emotion model...")
self.tokenizers['emotion'] = AutoTokenizer.from_pretrained(
"j-hartmann/emotion-english-distilroberta-base"
)
self.models['emotion'] = AutoModelForSequenceClassification.from_pretrained(
"j-hartmann/emotion-english-distilroberta-base" # Removed device_map parameter
).to(self.device)
print("Models loaded successfully")
except Exception as e:
print(f"Error loading models: {str(e)}")
raise
class AudioProcessor:
def __init__(self):
self.sample_rate = 16000
self.n_mfcc = 13
def process_audio(self, audio_path):
try:
waveform, sr = librosa.load(audio_path, sr=self.sample_rate)
return waveform, self._extract_features(waveform)
except Exception as e:
print(f"Error processing audio: {str(e)}")
raise
def _extract_features(self, waveform):
try:
return {
'mfcc': librosa.feature.mfcc(y=waveform, sr=self.sample_rate, n_mfcc=self.n_mfcc),
'energy': librosa.feature.rms(y=waveform)[0]
}
except Exception as e:
print(f"Error extracting features: {str(e)}")
raise
class Analyzer:
def __init__(self):
print("Initializing Analyzer...")
try:
self.model_manager = ModelManager()
self.audio_processor = AudioProcessor()
print("Analyzer initialization complete")
except Exception as e:
print(f"Error initializing Analyzer: {str(e)}")
raise
def analyze(self, audio_path):
try:
print(f"Processing audio file: {audio_path}")
waveform, features = self.audio_processor.process_audio(audio_path)
print("Transcribing audio...")
inputs = self.model_manager.processors['whisper'](
waveform,
return_tensors="pt"
).input_features.to(self.model_manager.device)
with torch.no_grad():
predicted_ids = self.model_manager.models['whisper'].generate(inputs)
transcription = self.model_manager.processors['whisper'].batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
print("Analyzing emotions...")
inputs = self.model_manager.tokenizers['emotion'](
transcription,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
inputs = {k: v.to(self.model_manager.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model_manager.models['emotion'](**inputs)
emotions = torch.nn.functional.softmax(outputs.logits, dim=-1)
emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
emotion_scores = {
label: float(score)
for label, score in zip(emotion_labels, emotions[0].cpu())
}
return {
'transcription': transcription,
'emotions': emotion_scores
}
except Exception as e:
print(f"Error in analysis: {str(e)}")
raise
def create_emotion_plot(emotions):
try:
fig = go.Figure(data=[
go.Bar(
x=list(emotions.keys()),
y=list(emotions.values()),
marker_color='rgb(55, 83, 109)'
)
])
fig.update_layout(
title='Emotion Analysis',
xaxis_title='Emotion',
yaxis_title='Score',
yaxis_range=[0, 1],
template='plotly_white',
height=400
)
return fig.to_html(include_plotlyjs=True)
except Exception as e:
print(f"Error creating plot: {str(e)}")
return "Error creating visualization"
def process_audio(audio_file):
try:
if audio_file is None:
return "No audio file provided", "Please provide an audio file"
print(f"Processing audio file: {audio_file}")
results = analyzer.analyze(audio_file)
return (
results['transcription'],
create_emotion_plot(results['emotions'])
)
except Exception as e:
error_msg = f"Error processing audio: {str(e)}"
print(error_msg)
return error_msg, "Error in analysis"
if __name__ == "__main__":
print("Initializing application...")
try:
analyzer = Analyzer()
print("Creating Gradio interface...")
interface = gr.Interface(
fn=process_audio,
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
outputs=[
gr.Textbox(label="Transcription"),
gr.HTML(label="Emotion Analysis")
],
title="Vocal Biomarker Analysis",
description="Analyze voice for emotional indicators",
examples=[],
cache_examples=False
)
print("Launching application...")
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
except Exception as e:
print(f"Fatal error during application startup: {str(e)}")
raise |