Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
import whisper | |
import numpy as np | |
from transformers import ( | |
AutoModelForSequenceClassification, | |
AutoTokenizer, | |
AutoModelForCausalLM | |
) | |
from datetime import datetime | |
from gtts import gTTS | |
import tempfile | |
import os | |
class EnhancedCaregiverSystem: | |
def __init__(self): | |
# Device configuration | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {self.device}") | |
# Initialize Whisper for speech recognition | |
self.whisper_model = whisper.load_model("base").to(self.device) | |
# Initialize emotion detection | |
self.emotion_model_name = "bhadresh-savani/bert-base-uncased-emotion" | |
self.emotion_tokenizer = AutoTokenizer.from_pretrained(self.emotion_model_name) | |
self.emotion_model = AutoModelForSequenceClassification.from_pretrained( | |
self.emotion_model_name | |
).to(self.device) | |
# Initialize conversation model | |
self.conv_model_name = "microsoft/DialoGPT-medium" | |
self.conv_tokenizer = AutoTokenizer.from_pretrained(self.conv_model_name) | |
self.conv_model = AutoModelForCausalLM.from_pretrained( | |
self.conv_model_name | |
).to(self.device) | |
# Initialize chat history | |
self.chat_history = [] | |
def transcribe_audio(self, audio): | |
"""Convert speech to text using Whisper""" | |
try: | |
if isinstance(audio, np.ndarray): | |
audio = np.mean(audio, axis=1) # Convert stereo to mono if necessary | |
result = self.whisper_model.transcribe(audio) | |
return result["text"].strip() | |
except Exception as e: | |
print(f"Error in transcription: {e}") | |
return "Could not transcribe audio. Please try again." | |
def detect_emotion(self, text): | |
"""Detect emotion from text""" | |
try: | |
inputs = self.emotion_tokenizer( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512 | |
).to(self.device) | |
with torch.no_grad(): | |
outputs = self.emotion_model(**inputs) | |
emotion_id = outputs.logits.argmax().item() | |
emotion_map = { | |
0: "sadness", | |
1: "joy", | |
2: "love", | |
3: "anger", | |
4: "fear", | |
5: "surprise" | |
} | |
return emotion_map.get(emotion_id, "neutral") | |
except Exception as e: | |
print(f"Error in emotion detection: {e}") | |
return "neutral" | |
def generate_response(self, text, emotion): | |
"""Generate contextual response based on input and emotion""" | |
try: | |
emotion_context = f"[Emotion: {emotion}] " | |
full_prompt = emotion_context + text + self.conv_tokenizer.eos_token | |
inputs = self.conv_tokenizer.encode( | |
full_prompt, | |
return_tensors="pt" | |
).to(self.device) | |
with torch.no_grad(): | |
outputs = self.conv_model.generate( | |
inputs, | |
max_length=1000, | |
pad_token_id=self.conv_tokenizer.eos_token_id, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
no_repeat_ngram_size=3 | |
) | |
response = self.conv_tokenizer.decode( | |
outputs[:, inputs.shape[-1]:][0], | |
skip_special_tokens=True | |
) | |
if emotion in ["sadness", "anger", "fear"]: | |
response += f"\n\nI notice you might be feeling {emotion}. I'm here to support you." | |
elif emotion == "joy": | |
response += "\n\nIt's wonderful to hear you're feeling positive!" | |
return response | |
except Exception as e: | |
print(f"Error in response generation: {e}") | |
return "I apologize, but I'm having trouble generating a response. Could you rephrase that?" | |
def text_to_speech(self, text): | |
"""Convert response to speech using gTTS""" | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as fp: | |
tts = gTTS(text=text, lang='en') | |
tts.save(fp.name) | |
return fp.name | |
except Exception as e: | |
print(f"Error in text-to-speech: {e}") | |
return None | |
def process_input(self, input_type, input_data): | |
"""Process either text or audio input""" | |
if input_type == "audio": | |
text = self.transcribe_audio(input_data) | |
else: | |
text = input_data.strip() | |
if not text: | |
return { | |
"text": "Please provide some input.", | |
"emotion": "neutral", | |
"response": "I couldn't understand that. Could you try again?", | |
"audio_response": None | |
} | |
emotion = self.detect_emotion(text) | |
response = self.generate_response(text, emotion) | |
audio_file = self.text_to_speech(response) | |
timestamp = datetime.now().strftime("%H:%M:%S") | |
self.chat_history.append({ | |
"timestamp": timestamp, | |
"user_input": text, | |
"emotion": emotion, | |
"response": response | |
}) | |
return { | |
"text": text, | |
"emotion": emotion, | |
"response": response, | |
"audio_response": audio_file | |
} | |
# Initialize the system | |
caregiver = EnhancedCaregiverSystem() | |
# Define Gradio interface functions | |
def process_text(message): | |
result = caregiver.process_input("text", message) | |
return result["text"], result["emotion"], result["response"], result["audio_response"] | |
def process_audio(audio): | |
result = caregiver.process_input("audio", audio) | |
return result["text"], result["emotion"], result["response"], result["audio_response"] | |
# Create Gradio interface | |
with gr.Blocks() as iface: | |
gr.Markdown("# AI Caregiver System") | |
with gr.Tab("Text Input"): | |
text_input = gr.Textbox(label="Your message") | |
text_button = gr.Button("Send Message") | |
text_output = [gr.Textbox(label="Transcription"), gr.Textbox(label="Emotion"), gr.Textbox(label="Response"), gr.Audio(label="Audio Response")] | |
text_button.click(process_text, inputs=text_input, outputs=text_output) | |
with gr.Tab("Voice Input"): | |
# Removing the source argument and using the default value "upload" | |
audio_input = gr.Audio(type="numpy", label="Your voice message") | |
audio_button = gr.Button("Process Voice") | |
audio_output = [gr.Textbox(label="Transcription"), gr.Textbox(label="Emotion"), gr.Textbox(label="Response"), gr.Audio(label="Audio Response")] | |
audio_button.click(process_audio, inputs=audio_input, outputs=audio_output) | |
iface.launch(share=True) |