"""Module for mental health chatbot with improved crisis handling. This module provides a Gradio-based chat interface for mental health support with proper crisis intervention protocols and streamlined code organization. Example: To run the chat interface: $ python app2.py Attributes: TIMA_API_KEY (str): The API token for the AI service, loaded from environment variables. Created by Nyabuti 2025-06-01 """ import os import time import logging from typing import Generator, List, Dict import gradio as gr import openai import httpx from tenacity import retry, stop_after_attempt, wait_exponential from ratelimit import limits, sleep_and_retry from prompts import load_system_prompt # Constants for API interaction MAX_RETRIES = 5 INITIAL_RETRY_DELAY = 1 # seconds MAX_RETRY_DELAY = 60 # seconds RATE_LIMIT_CALLS = 40 # Cerebras recommended rate limit RATE_LIMIT_PERIOD = 60 # 1 minute period # Configure logging with both file and console handlers for streaming logs def setup_logging(): """Set up logging with both file and console output for streaming visibility""" logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # Clear any existing handlers logger.handlers.clear() # File handler for persistent logging file_handler = logging.FileHandler("chat_interactions.log") file_handler.setLevel(logging.INFO) file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") file_handler.setFormatter(file_formatter) # Console handler for streaming logs console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) console_formatter = logging.Formatter( "%(asctime)s [%(levelname)s] %(name)s: %(message)s" ) console_handler.setFormatter(console_formatter) # Add both handlers logger.addHandler(file_handler) logger.addHandler(console_handler) return logger logger = setup_logging() # Load environment variables TIMA_API_KEY = os.getenv("TIMA_API_KEY", os.getenv("CEREBRAS_API_KEY", None)) if not TIMA_API_KEY: raise ValueError("CEREBRAS_API_KEY environment variable not found in environment") # Rate limiting helps to avoid exceeding the API rate limits CALLS_PER_MINUTE = 60 PERIOD = 60 class ChatError(Exception): """Custom exception for chat-related errors""" pass class APIError(Exception): """Base exception for API errors""" pass class RateLimitError(APIError): """Exception for rate limit errors""" pass class TokenLimitError(APIError): """Exception for token limit errors""" pass class InvalidRequestError(APIError): """Exception for invalid request errors""" pass class AuthenticationError(APIError): """Exception for authentication errors""" pass class ServerError(APIError): """Exception for server-side errors""" pass def detect_crisis_situation(messages: List[Dict[str, str]]) -> bool: """Detect if the conversation indicates a crisis situation. Args: messages: List of conversation messages Returns: bool: True if crisis situation detected, False otherwise """ if not messages: return False # Check for PHQ-9 suicide question with concerning response last_message = messages[-1].get("content", "").lower() prev_message = messages[-2].get("content", "").lower() if len(messages) >= 2 else "" # Look for PHQ-9 suicide question and high score response suicide_question_keywords = [ "thoughts that you would be better off dead", "hurting yourself", ] concerning_responses = ["3", "nearly every day", "more than half"] has_suicide_question = any( keyword in prev_message for keyword in suicide_question_keywords ) has_concerning_response = any( response in last_message for response in concerning_responses ) crisis_detected = has_suicide_question and has_concerning_response if crisis_detected: logger.warning("Crisis situation detected in conversation") return crisis_detected def get_crisis_response() -> str: """Generate an appropriate crisis response with local resources Returns: str: A compassionate response with immediate and local resources """ logger.info("Generating crisis response with local resources") response_parts = [] # Initial validation and support response_parts.append( "I hear you, and I want you to know that your life has immense value. " "What you're going through sounds incredibly difficult, and it's so important " "that you get the support you need right now." ) # Immediate crisis resources response_parts.append( "\n\nThere are caring people available 24/7 to support you:" "\n• National Suicide Prevention Lifeline: 1-800-273-TALK (8255)" "\n• Crisis Text Line: Text HOME to 741741" "\n• Emergency Services: 911 (if you're in immediate danger)" ) # Local mental health professionals response_parts.append( "\n\nI also want to connect you with some local mental health professionals in Nairobi " "who can provide ongoing support:" "\n\n1. Dr Nancy Nyagah at Blossom Out Consultants" "\n • Location: 2nd floor Park View Towers (opposite Parklands police station)" "\n • Cost: 5000 KES (as of 2018)" "\n • Accepts insurance" "\n • Contact: 0722938606 or 0780938606" "\n\n2. Dr. Judy Kamau at Scripture Union, Hurlingham" "\n • Cost: 3,500 KES (as of 2018)" "\n • Contact: +254202712852" "\n\n3. Rhoda Mutiso (specializes in depression)" "\n • Cost: 2000 KES (as of 2018)" "\n • Contact: 0722 333378" ) # Next steps and validation response_parts.append( "\n\nI know reaching out for help can feel overwhelming, but it's a sign of strength, " "not weakness. Would you like me to:" "\n1. Help you make a safety plan for the next 24 hours?" "\n2. Talk about what you're going through right now?" "\n3. Give you more information about any of these mental health professionals?" "\n\nYou don't have to go through this alone. I'm here to support you in taking the next step, " "whatever that looks like for you." ) return "".join(response_parts) def handle_api_error(e: Exception) -> APIError: """ Convert API exceptions to our custom exception types. Args: e (Exception): The caught exception Returns: APIError: The appropriate custom exception type """ error_msg = str(e).lower() logger.debug(f"Handling API error: {error_msg}") if "rate limit" in error_msg: logger.warning("Rate limit exceeded") return RateLimitError("Rate limit exceeded. Please try again later.") elif "token limit" in error_msg: logger.warning("Token limit exceeded") return TokenLimitError( "Input too long. Please reduce the length of your message." ) elif "authentication" in error_msg or "api key" in error_msg: logger.error("Authentication failed") return AuthenticationError("Authentication failed. Please check your API key.") elif "invalid request" in error_msg: logger.warning("Invalid request received") return InvalidRequestError("Invalid request. Please check your input.") elif any(code in error_msg for code in ["502", "503", "504"]): logger.warning("Server error encountered") return ServerError("Server is temporarily unavailable. Please try again later.") logger.error(f"Unhandled API error: {error_msg}") return APIError(f"API error occurred: {str(e)}") @sleep_and_retry @limits(calls=RATE_LIMIT_CALLS, period=RATE_LIMIT_PERIOD) @retry( stop=stop_after_attempt(MAX_RETRIES), wait=wait_exponential( multiplier=INITIAL_RETRY_DELAY, min=INITIAL_RETRY_DELAY, max=MAX_RETRY_DELAY ), retry=lambda e: isinstance(e, (ServerError, RateLimitError)), reraise=True, ) def create_chat_completion( messages: List[Dict[str, str]], ) -> Generator[str, None, None]: """ Create a chat completion with comprehensive error handling and rate limiting Args: messages (List[Dict[str, str]]): List of messages in the chat Yields: Generator[str, None, None]: A generator of chat completion response chunks Raises: APIError: Base class for all API-related errors RateLimitError: When API rate limit is exceeded TokenLimitError: When input tokens exceed model's limit AuthenticationError: When API key is invalid InvalidRequestError: When request is malformed ServerError: When API server has issues """ try: # Initialize the OpenAI client with Cerebras API endpoint client = openai.OpenAI( base_url="https://api.cerebras.ai/v1", api_key=TIMA_API_KEY, timeout=60.0, # 60 second timeout max_retries=0, # We handle retries ourselves ) logger.info("🚀 Starting chat completion request to Cerebras API") logger.info(f"📊 Sending {len(messages)} messages to API") logger.debug("Messages sent to API: %s", messages) try: stream = client.chat.completions.create( model="llama-3.3-70b", # Cerebras recommended model messages=messages, temperature=0.8, # Adjust temperature for creativity 0 to 1 max_tokens=500, # Limit response length to prevent token limit errors top_p=0.9, # Use top-p sampling for more diverse responses stream=True, ) except openai.APIError as e: raise handle_api_error(e) except httpx.TimeoutException: raise ServerError("Request timed out. Please try again.") except httpx.RequestError as e: raise ServerError(f"Network error occurred: {str(e)}") try: # Check for crisis situations and provide appropriate response crisis_detected = detect_crisis_situation(messages) if crisis_detected: logger.warning("🚨 Crisis situation detected - Serving crisis response") crisis_response = get_crisis_response() yield crisis_response else: logger.info("✅ Processing normal chat completion stream") chunk_count = 0 content_length = 0 for chunk in stream: if ( chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content ): chunk_count += 1 chunk_content = chunk.choices[0].delta.content content_length += len(chunk_content) if chunk_count % 10 == 0: # Log every 10th chunk to avoid spam logger.info( f"📡 Streaming... {chunk_count} chunks, {content_length} chars" ) yield chunk_content logger.info( f"✅ Chat completion stream finished - Chunks: {chunk_count}, Total chars: {content_length}" ) except Exception as e: logger.error("Error during stream processing: %s", str(e), exc_info=True) raise handle_api_error(e) except APIError as e: logger.error("API Error in chat completion: %s", str(e), exc_info=True) raise except Exception as e: logger.error("Unexpected error in chat completion: %s", str(e), exc_info=True) raise APIError(f"Unexpected error occurred: {str(e)}") def chat_handler(message: str, history: List[List[str]]) -> Generator[str, None, None]: """Handle chat interactions with proper history management and error handling Args: message (str): The user message history (List[List[str]]): The chat history Yields: Generator[str, None, None]: The response message chunks Raises: Exception: If an error occurs during chat handling """ if not isinstance(message, str): raise InvalidRequestError("Message must be a string") if not message.strip(): raise InvalidRequestError("Message cannot be empty") try: logger.info( f"🔄 Processing new chat request - History entries: {len(history) if history else 0}" ) logger.debug( f"Processing chat request with {len(history) if history else 0} history entries" ) # Validate and sanitize history format if history: if not isinstance(history, (list, tuple)): logger.warning( "History is not a list or tuple, converting to empty list" ) history = [] else: # Filter out invalid history entries and handle both old and new Gradio formats sanitized_history = [] skipped_entries = 0 for h in history: # Handle new Gradio format with message dictionaries if isinstance(h, dict) and "role" in h and "content" in h: role = h.get("role") content = h.get("content") if role and content: # Store as temporary format for later processing sanitized_history.append( {"role": role, "content": str(content)} ) else: skipped_entries += 1 logger.debug( f"Skipping message with missing role or content: {h}" ) # Handle old Gradio format with [user_msg, assistant_msg] pairs elif isinstance(h, (list, tuple)) and len(h) == 2: # Ensure both elements are strings or None user_msg = str(h[0]) if h[0] is not None else None assistant_msg = str(h[1]) if h[1] is not None else None if ( user_msg or assistant_msg ): # At least one message must be non-None sanitized_history.append([user_msg, assistant_msg]) else: skipped_entries += 1 logger.debug(f"Skipping empty message pair: {h}") else: skipped_entries += 1 logger.debug(f"Skipping unrecognized history entry format: {h}") if skipped_entries > 0: logger.debug(f"Skipped {skipped_entries} invalid history entries") history = sanitized_history logger.debug(f"Sanitized history to {len(history)} valid entries") # Apply rate limiting at the conversation level current_time = time.time() if hasattr(chat_handler, "last_call_time"): time_since_last_call = current_time - chat_handler.last_call_time if time_since_last_call < ( 60 / RATE_LIMIT_CALLS ): # Minimum time between calls sleep_time = (60 / RATE_LIMIT_CALLS) - time_since_last_call logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds") time.sleep(sleep_time) chat_handler.last_call_time = current_time # Define your system prompt system_prompt = load_system_prompt() formatted_messages = [{"role": "system", "content": system_prompt}] if history: for h in history: # Handle new Gradio format with message dictionaries if isinstance(h, dict) and "role" in h and "content" in h: formatted_messages.append( {"role": h["role"], "content": str(h["content"])} ) # Handle old Gradio format with [user_msg, assistant_msg] pairs elif isinstance(h, (list, tuple)) and len(h) == 2: user_msg, assistant_msg = h if user_msg: formatted_messages.append( {"role": "user", "content": str(user_msg)} ) if assistant_msg: formatted_messages.append( {"role": "assistant", "content": str(assistant_msg)} ) formatted_messages.append({"role": "user", "content": message}) logger.info(f"📝 User message received - Length: {len(message)} chars") logger.info( f"💬 Message preview: {message}" ) logger.debug("User message content: %s", message) full_response = "" response_start_time = time.time() for chunk in create_chat_completion(formatted_messages): full_response += chunk yield full_response # Stream response to Gradio response_time = time.time() - response_start_time logger.info( f"🎯 AI response completed - Length: {len(full_response)} chars, Time: {response_time:.2f}s" ) logger.info(f"📤 Full AI response: {full_response}") logger.debug("AI response content: %s", full_response) except ChatError as e: logger.error(f"❌ ChatError in chat_handler: {e}", exc_info=True) yield f"Error: {e}" # Display error in Gradio UI except Exception as e: logger.error(f"💥 Unexpected error in chat_handler: {e}", exc_info=True) yield "An unexpected error occurred. Please try again later." def main(): """Main function to launch the Gradio interface""" logger.info("🚀 Starting Tima mental health chatbot application") logger.info("📋 Application configuration:") logger.info(f" - Rate limit: {RATE_LIMIT_CALLS} calls per {RATE_LIMIT_PERIOD}s") logger.info(f" - Max retries: {MAX_RETRIES}") logger.info(f" - API timeout: 60s") # Block sensitive files and directories blocked_paths = [ "app2.py", # Hide main application logic "app_enhanced.py", # Hide enhanced version "prompts.py", # Hide prompt management ".env", # Hide environment variables "*.log", # Hide log files "__pycache__", # Hide Python cache "*.pyc", # Hide compiled Python files ".git", # Hide git directory ".gitignore", # Hide git configuration "tests", # Hide test directory "*.md", # Hide documentation "requirements.txt", # Hide dependencies "Dockerfile" ] # Create and launch the Gradio interface chat_interface = gr.ChatInterface( fn=chat_handler, title="Tima - Your Mental Health Companion", description="A safe space to talk. Tima is here to listen, offer support, and provide understanding. ⚠️ This is not a replacement for professional medical advice.", examples=[ "I feel like giving up on everything", "I'm feeling really anxious lately and can't stop worrying", "I've been feeling down and hopeless for weeks", "I think people are watching me and I keep hearing voices", "Can you recommend a therapist in Nairobi?", "I need someone to talk to about my depression", ], type="messages", ) logger.info("🌐 Launching Gradio interface on 0.0.0.0:7860") logger.info("💡 Access the application at: http://localhost:7860") chat_interface.launch( server_name="0.0.0.0", # Server name server_port=7860, # Different port to avoid conflicts share=False, # Share the server publicly max_threads=16, # Increased maximum number of threads to handle more concurrent jobs show_error=True, # Show error messages inbrowser=True, # Open in browser show_api=False, # Show API enable_monitoring=True, state_session_capacity=50, blocked_paths=blocked_paths ) if __name__ == "__main__": logger.info("🎬 Application started from command line") try: main() except KeyboardInterrupt: logger.info("⏹️ Application stopped by user (Ctrl+C)") except Exception as e: logger.error( f"💀 Fatal error during application startup: {str(e)}", exc_info=True ) raise