invincible-jha's picture
Upload app.py
0e04908 verified
raw
history blame
4.23 kB
import gradio as gr
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModelForSequenceClassification, AutoTokenizer
import librosa
import numpy as np
import plotly.graph_objects as go
class ModelManager:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.models = {}
self.tokenizers = {}
self.processors = {}
self.load_models()
def load_models(self):
print("Loading Whisper model...")
self.processors['whisper'] = WhisperProcessor.from_pretrained("openai/whisper-base")
self.models['whisper'] = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to(self.device)
print("Loading emotion model...")
self.tokenizers['emotion'] = AutoTokenizer.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
self.models['emotion'] = AutoModelForSequenceClassification.from_pretrained("j-hartmann/emotion-english-distilroberta-base").to(self.device)
class AudioProcessor:
def __init__(self):
self.sample_rate = 16000
self.n_mfcc = 13
def process_audio(self, audio_path):
waveform, sr = librosa.load(audio_path, sr=self.sample_rate)
return waveform, self._extract_features(waveform)
def _extract_features(self, waveform):
return {
'mfcc': librosa.feature.mfcc(y=waveform, sr=self.sample_rate, n_mfcc=self.n_mfcc),
'energy': librosa.feature.rms(y=waveform)[0]
}
class Analyzer:
def __init__(self):
print("Initializing Analyzer...")
self.model_manager = ModelManager()
self.audio_processor = AudioProcessor()
print("Analyzer initialization complete")
def analyze(self, audio_path):
print(f"Processing audio file: {audio_path}")
# Process audio
waveform, features = self.audio_processor.process_audio(audio_path)
# Transcribe
print("Transcribing audio...")
inputs = self.model_manager.processors['whisper'](waveform, return_tensors="pt").input_features.to(self.model_manager.device)
predicted_ids = self.model_manager.models['whisper'].generate(inputs)
transcription = self.model_manager.processors['whisper'].batch_decode(predicted_ids, skip_special_tokens=True)[0]
# Analyze emotions
print("Analyzing emotions...")
inputs = self.model_manager.tokenizers['emotion'](transcription, return_tensors="pt", padding=True, truncation=True)
outputs = self.model_manager.models['emotion'](**inputs)
emotions = torch.nn.functional.softmax(outputs.logits, dim=-1)
emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
emotion_scores = {
label: float(score)
for label, score in zip(emotion_labels, emotions[0])
}
return {
'transcription': transcription,
'emotions': emotion_scores
}
def create_emotion_plot(emotions):
fig = go.Figure(data=[
go.Bar(x=list(emotions.keys()), y=list(emotions.values()))
])
fig.update_layout(
title='Emotion Analysis',
yaxis_range=[0, 1]
)
return fig.to_html()
print("Initializing application...")
analyzer = Analyzer()
def process_audio(audio_file):
try:
print(f"Processing audio file: {audio_file}")
results = analyzer.analyze(audio_file)
return (
results['transcription'],
create_emotion_plot(results['emotions'])
)
except Exception as e:
print(f"Error processing audio: {str(e)}")
return str(e), "Error in analysis"
print("Creating Gradio interface...")
interface = gr.Interface(
fn=process_audio,
inputs=gr.Audio(source="microphone", type="filepath"),
outputs=[
gr.Textbox(label="Transcription"),
gr.HTML(label="Emotion Analysis")
],
title="Vocal Biomarker Analysis",
description="Analyze voice for emotional indicators"
)
print("Launching application...")
if __name__ == "__main__":
interface.launch()