tima_chatbot / app2.py
Skier8402's picture
Update app2.py
ccf2da0 verified
"""Module for mental health chatbot with improved crisis handling.
This module provides a Gradio-based chat interface for mental health support
with proper crisis intervention protocols and streamlined code organization.
Example:
To run the chat interface:
$ python app2.py
Attributes:
TIMA_API_KEY (str): The API token for the AI service, loaded from environment variables.
Created by Nyabuti 2025-06-01
"""
import os
import time
import logging
from typing import Generator, List, Dict
import gradio as gr
import openai
import httpx
from tenacity import retry, stop_after_attempt, wait_exponential
from ratelimit import limits, sleep_and_retry
from prompts import load_system_prompt
# Constants for API interaction
MAX_RETRIES = 5
INITIAL_RETRY_DELAY = 1 # seconds
MAX_RETRY_DELAY = 60 # seconds
RATE_LIMIT_CALLS = 40 # Cerebras recommended rate limit
RATE_LIMIT_PERIOD = 60 # 1 minute period
# Configure logging with both file and console handlers for streaming logs
def setup_logging():
"""Set up logging with both file and console output for streaming visibility"""
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Clear any existing handlers
logger.handlers.clear()
# File handler for persistent logging
file_handler = logging.FileHandler("chat_interactions.log")
file_handler.setLevel(logging.INFO)
file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
file_handler.setFormatter(file_formatter)
# Console handler for streaming logs
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
console_handler.setFormatter(console_formatter)
# Add both handlers
logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger
logger = setup_logging()
# Load environment variables
TIMA_API_KEY = os.getenv("TIMA_API_KEY", os.getenv("CEREBRAS_API_KEY", None))
if not TIMA_API_KEY:
raise ValueError("CEREBRAS_API_KEY environment variable not found in environment")
# Rate limiting helps to avoid exceeding the API rate limits
CALLS_PER_MINUTE = 60
PERIOD = 60
class ChatError(Exception):
"""Custom exception for chat-related errors"""
pass
class APIError(Exception):
"""Base exception for API errors"""
pass
class RateLimitError(APIError):
"""Exception for rate limit errors"""
pass
class TokenLimitError(APIError):
"""Exception for token limit errors"""
pass
class InvalidRequestError(APIError):
"""Exception for invalid request errors"""
pass
class AuthenticationError(APIError):
"""Exception for authentication errors"""
pass
class ServerError(APIError):
"""Exception for server-side errors"""
pass
def detect_crisis_situation(messages: List[Dict[str, str]]) -> bool:
"""Detect if the conversation indicates a crisis situation.
Args:
messages: List of conversation messages
Returns:
bool: True if crisis situation detected, False otherwise
"""
if not messages:
return False
# Check for PHQ-9 suicide question with concerning response
last_message = messages[-1].get("content", "").lower()
prev_message = messages[-2].get("content", "").lower() if len(messages) >= 2 else ""
# Look for PHQ-9 suicide question and high score response
suicide_question_keywords = [
"thoughts that you would be better off dead",
"hurting yourself",
]
concerning_responses = ["3", "nearly every day", "more than half"]
has_suicide_question = any(
keyword in prev_message for keyword in suicide_question_keywords
)
has_concerning_response = any(
response in last_message for response in concerning_responses
)
crisis_detected = has_suicide_question and has_concerning_response
if crisis_detected:
logger.warning("Crisis situation detected in conversation")
return crisis_detected
def get_crisis_response() -> str:
"""Generate an appropriate crisis response with local resources
Returns:
str: A compassionate response with immediate and local resources
"""
logger.info("Generating crisis response with local resources")
response_parts = []
# Initial validation and support
response_parts.append(
"I hear you, and I want you to know that your life has immense value. "
"What you're going through sounds incredibly difficult, and it's so important "
"that you get the support you need right now."
)
# Immediate crisis resources
response_parts.append(
"\n\nThere are caring people available 24/7 to support you:"
"\nβ€’ National Suicide Prevention Lifeline: 1-800-273-TALK (8255)"
"\nβ€’ Crisis Text Line: Text HOME to 741741"
"\nβ€’ Emergency Services: 911 (if you're in immediate danger)"
)
# Local mental health professionals
response_parts.append(
"\n\nI also want to connect you with some local mental health professionals in Nairobi "
"who can provide ongoing support:"
"\n\n1. Dr Nancy Nyagah at Blossom Out Consultants"
"\n β€’ Location: 2nd floor Park View Towers (opposite Parklands police station)"
"\n β€’ Cost: 5000 KES (as of 2018)"
"\n β€’ Accepts insurance"
"\n β€’ Contact: 0722938606 or 0780938606"
"\n\n2. Dr. Judy Kamau at Scripture Union, Hurlingham"
"\n β€’ Cost: 3,500 KES (as of 2018)"
"\n β€’ Contact: +254202712852"
"\n\n3. Rhoda Mutiso (specializes in depression)"
"\n β€’ Cost: 2000 KES (as of 2018)"
"\n β€’ Contact: 0722 333378"
)
# Next steps and validation
response_parts.append(
"\n\nI know reaching out for help can feel overwhelming, but it's a sign of strength, "
"not weakness. Would you like me to:"
"\n1. Help you make a safety plan for the next 24 hours?"
"\n2. Talk about what you're going through right now?"
"\n3. Give you more information about any of these mental health professionals?"
"\n\nYou don't have to go through this alone. I'm here to support you in taking the next step, "
"whatever that looks like for you."
)
return "".join(response_parts)
def handle_api_error(e: Exception) -> APIError:
"""
Convert API exceptions to our custom exception types.
Args:
e (Exception): The caught exception
Returns:
APIError: The appropriate custom exception type
"""
error_msg = str(e).lower()
logger.debug(f"Handling API error: {error_msg}")
if "rate limit" in error_msg:
logger.warning("Rate limit exceeded")
return RateLimitError("Rate limit exceeded. Please try again later.")
elif "token limit" in error_msg:
logger.warning("Token limit exceeded")
return TokenLimitError(
"Input too long. Please reduce the length of your message."
)
elif "authentication" in error_msg or "api key" in error_msg:
logger.error("Authentication failed")
return AuthenticationError("Authentication failed. Please check your API key.")
elif "invalid request" in error_msg:
logger.warning("Invalid request received")
return InvalidRequestError("Invalid request. Please check your input.")
elif any(code in error_msg for code in ["502", "503", "504"]):
logger.warning("Server error encountered")
return ServerError("Server is temporarily unavailable. Please try again later.")
logger.error(f"Unhandled API error: {error_msg}")
return APIError(f"API error occurred: {str(e)}")
@sleep_and_retry
@limits(calls=RATE_LIMIT_CALLS, period=RATE_LIMIT_PERIOD)
@retry(
stop=stop_after_attempt(MAX_RETRIES),
wait=wait_exponential(
multiplier=INITIAL_RETRY_DELAY, min=INITIAL_RETRY_DELAY, max=MAX_RETRY_DELAY
),
retry=lambda e: isinstance(e, (ServerError, RateLimitError)),
reraise=True,
)
def create_chat_completion(
messages: List[Dict[str, str]],
) -> Generator[str, None, None]:
"""
Create a chat completion with comprehensive error handling and rate limiting
Args:
messages (List[Dict[str, str]]): List of messages in the chat
Yields:
Generator[str, None, None]: A generator of chat completion response chunks
Raises:
APIError: Base class for all API-related errors
RateLimitError: When API rate limit is exceeded
TokenLimitError: When input tokens exceed model's limit
AuthenticationError: When API key is invalid
InvalidRequestError: When request is malformed
ServerError: When API server has issues
"""
try:
# Initialize the OpenAI client with Cerebras API endpoint
client = openai.OpenAI(
base_url="https://api.cerebras.ai/v1",
api_key=TIMA_API_KEY,
timeout=60.0, # 60 second timeout
max_retries=0, # We handle retries ourselves
)
logger.info("πŸš€ Starting chat completion request to Cerebras API")
logger.info(f"πŸ“Š Sending {len(messages)} messages to API")
logger.debug("Messages sent to API: %s", messages)
try:
stream = client.chat.completions.create(
model="llama-3.3-70b", # Cerebras recommended model
messages=messages,
temperature=0.8, # Adjust temperature for creativity 0 to 1
max_tokens=500, # Limit response length to prevent token limit errors
top_p=0.9, # Use top-p sampling for more diverse responses
stream=True,
)
except openai.APIError as e:
raise handle_api_error(e)
except httpx.TimeoutException:
raise ServerError("Request timed out. Please try again.")
except httpx.RequestError as e:
raise ServerError(f"Network error occurred: {str(e)}")
try:
# Check for crisis situations and provide appropriate response
crisis_detected = detect_crisis_situation(messages)
if crisis_detected:
logger.warning("🚨 Crisis situation detected - Serving crisis response")
crisis_response = get_crisis_response()
yield crisis_response
else:
logger.info("βœ… Processing normal chat completion stream")
chunk_count = 0
content_length = 0
for chunk in stream:
if (
chunk.choices
and chunk.choices[0].delta
and chunk.choices[0].delta.content
):
chunk_count += 1
chunk_content = chunk.choices[0].delta.content
content_length += len(chunk_content)
if chunk_count % 10 == 0: # Log every 10th chunk to avoid spam
logger.info(
f"πŸ“‘ Streaming... {chunk_count} chunks, {content_length} chars"
)
yield chunk_content
logger.info(
f"βœ… Chat completion stream finished - Chunks: {chunk_count}, Total chars: {content_length}"
)
except Exception as e:
logger.error("Error during stream processing: %s", str(e), exc_info=True)
raise handle_api_error(e)
except APIError as e:
logger.error("API Error in chat completion: %s", str(e), exc_info=True)
raise
except Exception as e:
logger.error("Unexpected error in chat completion: %s", str(e), exc_info=True)
raise APIError(f"Unexpected error occurred: {str(e)}")
def chat_handler(message: str, history: List[List[str]]) -> Generator[str, None, None]:
"""Handle chat interactions with proper history management and error handling
Args:
message (str): The user message
history (List[List[str]]): The chat history
Yields:
Generator[str, None, None]: The response message chunks
Raises:
Exception: If an error occurs during chat handling
"""
if not isinstance(message, str):
raise InvalidRequestError("Message must be a string")
if not message.strip():
raise InvalidRequestError("Message cannot be empty")
try:
logger.info(
f"πŸ”„ Processing new chat request - History entries: {len(history) if history else 0}"
)
logger.debug(
f"Processing chat request with {len(history) if history else 0} history entries"
)
# Validate and sanitize history format
if history:
if not isinstance(history, (list, tuple)):
logger.warning(
"History is not a list or tuple, converting to empty list"
)
history = []
else:
# Filter out invalid history entries and handle both old and new Gradio formats
sanitized_history = []
skipped_entries = 0
for h in history:
# Handle new Gradio format with message dictionaries
if isinstance(h, dict) and "role" in h and "content" in h:
role = h.get("role")
content = h.get("content")
if role and content:
# Store as temporary format for later processing
sanitized_history.append(
{"role": role, "content": str(content)}
)
else:
skipped_entries += 1
logger.debug(
f"Skipping message with missing role or content: {h}"
)
# Handle old Gradio format with [user_msg, assistant_msg] pairs
elif isinstance(h, (list, tuple)) and len(h) == 2:
# Ensure both elements are strings or None
user_msg = str(h[0]) if h[0] is not None else None
assistant_msg = str(h[1]) if h[1] is not None else None
if (
user_msg or assistant_msg
): # At least one message must be non-None
sanitized_history.append([user_msg, assistant_msg])
else:
skipped_entries += 1
logger.debug(f"Skipping empty message pair: {h}")
else:
skipped_entries += 1
logger.debug(f"Skipping unrecognized history entry format: {h}")
if skipped_entries > 0:
logger.debug(f"Skipped {skipped_entries} invalid history entries")
history = sanitized_history
logger.debug(f"Sanitized history to {len(history)} valid entries")
# Apply rate limiting at the conversation level
current_time = time.time()
if hasattr(chat_handler, "last_call_time"):
time_since_last_call = current_time - chat_handler.last_call_time
if time_since_last_call < (
60 / RATE_LIMIT_CALLS
): # Minimum time between calls
sleep_time = (60 / RATE_LIMIT_CALLS) - time_since_last_call
logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds")
time.sleep(sleep_time)
chat_handler.last_call_time = current_time
# Define your system prompt
system_prompt = load_system_prompt()
formatted_messages = [{"role": "system", "content": system_prompt}]
if history:
for h in history:
# Handle new Gradio format with message dictionaries
if isinstance(h, dict) and "role" in h and "content" in h:
formatted_messages.append(
{"role": h["role"], "content": str(h["content"])}
)
# Handle old Gradio format with [user_msg, assistant_msg] pairs
elif isinstance(h, (list, tuple)) and len(h) == 2:
user_msg, assistant_msg = h
if user_msg:
formatted_messages.append(
{"role": "user", "content": str(user_msg)}
)
if assistant_msg:
formatted_messages.append(
{"role": "assistant", "content": str(assistant_msg)}
)
formatted_messages.append({"role": "user", "content": message})
logger.info(f"πŸ“ User message received - Length: {len(message)} chars")
logger.info(
f"πŸ’¬ Message preview: {message}"
)
logger.debug("User message content: %s", message)
full_response = ""
response_start_time = time.time()
for chunk in create_chat_completion(formatted_messages):
full_response += chunk
yield full_response # Stream response to Gradio
response_time = time.time() - response_start_time
logger.info(
f"🎯 AI response completed - Length: {len(full_response)} chars, Time: {response_time:.2f}s"
)
logger.info(f"πŸ“€ Full AI response: {full_response}")
logger.debug("AI response content: %s", full_response)
except ChatError as e:
logger.error(f"❌ ChatError in chat_handler: {e}", exc_info=True)
yield f"Error: {e}" # Display error in Gradio UI
except Exception as e:
logger.error(f"πŸ’₯ Unexpected error in chat_handler: {e}", exc_info=True)
yield "An unexpected error occurred. Please try again later."
def main():
"""Main function to launch the Gradio interface"""
logger.info("πŸš€ Starting Tima mental health chatbot application")
logger.info("πŸ“‹ Application configuration:")
logger.info(f" - Rate limit: {RATE_LIMIT_CALLS} calls per {RATE_LIMIT_PERIOD}s")
logger.info(f" - Max retries: {MAX_RETRIES}")
logger.info(f" - API timeout: 60s")
# Block sensitive files and directories
blocked_paths = [
"app2.py", # Hide main application logic
"app_enhanced.py", # Hide enhanced version
"prompts.py", # Hide prompt management
".env", # Hide environment variables
"*.log", # Hide log files
"__pycache__", # Hide Python cache
"*.pyc", # Hide compiled Python files
".git", # Hide git directory
".gitignore", # Hide git configuration
"tests", # Hide test directory
"*.md", # Hide documentation
"requirements.txt", # Hide dependencies
"Dockerfile"
]
# Create and launch the Gradio interface
chat_interface = gr.ChatInterface(
fn=chat_handler,
title="Tima - Your Mental Health Companion",
description="A safe space to talk. Tima is here to listen, offer support, and provide understanding. ⚠️ This is not a replacement for professional medical advice.",
examples=[
"I feel like giving up on everything",
"I'm feeling really anxious lately and can't stop worrying",
"I've been feeling down and hopeless for weeks",
"I think people are watching me and I keep hearing voices",
"Can you recommend a therapist in Nairobi?",
"I need someone to talk to about my depression",
],
type="messages",
)
logger.info("🌐 Launching Gradio interface on 0.0.0.0:7860")
logger.info("πŸ’‘ Access the application at: http://localhost:7860")
chat_interface.launch(
server_name="0.0.0.0", # Server name
server_port=7860, # Different port to avoid conflicts
share=False, # Share the server publicly
max_threads=16, # Increased maximum number of threads to handle more concurrent jobs
show_error=True, # Show error messages
inbrowser=True, # Open in browser
show_api=False, # Show API
enable_monitoring=True,
state_session_capacity=50,
blocked_paths=blocked_paths
)
if __name__ == "__main__":
logger.info("🎬 Application started from command line")
try:
main()
except KeyboardInterrupt:
logger.info("⏹️ Application stopped by user (Ctrl+C)")
except Exception as e:
logger.error(
f"πŸ’€ Fatal error during application startup: {str(e)}", exc_info=True
)
raise