Spaces:
Sleeping
Sleeping
"""Module for mental health chatbot with improved crisis handling. | |
This module provides a Gradio-based chat interface for mental health support | |
with proper crisis intervention protocols and streamlined code organization. | |
Example: | |
To run the chat interface: | |
$ python app2.py | |
Attributes: | |
TIMA_API_KEY (str): The API token for the AI service, loaded from environment variables. | |
Created by Nyabuti 2025-06-01 | |
""" | |
import os | |
import time | |
import logging | |
from typing import Generator, List, Dict | |
import gradio as gr | |
import openai | |
import httpx | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
from ratelimit import limits, sleep_and_retry | |
from prompts import load_system_prompt | |
# Constants for API interaction | |
MAX_RETRIES = 5 | |
INITIAL_RETRY_DELAY = 1 # seconds | |
MAX_RETRY_DELAY = 60 # seconds | |
RATE_LIMIT_CALLS = 40 # Cerebras recommended rate limit | |
RATE_LIMIT_PERIOD = 60 # 1 minute period | |
# Configure logging with both file and console handlers for streaming logs | |
def setup_logging(): | |
"""Set up logging with both file and console output for streaming visibility""" | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
# Clear any existing handlers | |
logger.handlers.clear() | |
# File handler for persistent logging | |
file_handler = logging.FileHandler("chat_interactions.log") | |
file_handler.setLevel(logging.INFO) | |
file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") | |
file_handler.setFormatter(file_formatter) | |
# Console handler for streaming logs | |
console_handler = logging.StreamHandler() | |
console_handler.setLevel(logging.INFO) | |
console_formatter = logging.Formatter( | |
"%(asctime)s [%(levelname)s] %(name)s: %(message)s" | |
) | |
console_handler.setFormatter(console_formatter) | |
# Add both handlers | |
logger.addHandler(file_handler) | |
logger.addHandler(console_handler) | |
return logger | |
logger = setup_logging() | |
# Load environment variables | |
TIMA_API_KEY = os.getenv("TIMA_API_KEY", os.getenv("CEREBRAS_API_KEY", None)) | |
if not TIMA_API_KEY: | |
raise ValueError("CEREBRAS_API_KEY environment variable not found in environment") | |
# Rate limiting helps to avoid exceeding the API rate limits | |
CALLS_PER_MINUTE = 60 | |
PERIOD = 60 | |
class ChatError(Exception): | |
"""Custom exception for chat-related errors""" | |
pass | |
class APIError(Exception): | |
"""Base exception for API errors""" | |
pass | |
class RateLimitError(APIError): | |
"""Exception for rate limit errors""" | |
pass | |
class TokenLimitError(APIError): | |
"""Exception for token limit errors""" | |
pass | |
class InvalidRequestError(APIError): | |
"""Exception for invalid request errors""" | |
pass | |
class AuthenticationError(APIError): | |
"""Exception for authentication errors""" | |
pass | |
class ServerError(APIError): | |
"""Exception for server-side errors""" | |
pass | |
def detect_crisis_situation(messages: List[Dict[str, str]]) -> bool: | |
"""Detect if the conversation indicates a crisis situation. | |
Args: | |
messages: List of conversation messages | |
Returns: | |
bool: True if crisis situation detected, False otherwise | |
""" | |
if not messages: | |
return False | |
# Check for PHQ-9 suicide question with concerning response | |
last_message = messages[-1].get("content", "").lower() | |
prev_message = messages[-2].get("content", "").lower() if len(messages) >= 2 else "" | |
# Look for PHQ-9 suicide question and high score response | |
suicide_question_keywords = [ | |
"thoughts that you would be better off dead", | |
"hurting yourself", | |
] | |
concerning_responses = ["3", "nearly every day", "more than half"] | |
has_suicide_question = any( | |
keyword in prev_message for keyword in suicide_question_keywords | |
) | |
has_concerning_response = any( | |
response in last_message for response in concerning_responses | |
) | |
crisis_detected = has_suicide_question and has_concerning_response | |
if crisis_detected: | |
logger.warning("Crisis situation detected in conversation") | |
return crisis_detected | |
def get_crisis_response() -> str: | |
"""Generate an appropriate crisis response with local resources | |
Returns: | |
str: A compassionate response with immediate and local resources | |
""" | |
logger.info("Generating crisis response with local resources") | |
response_parts = [] | |
# Initial validation and support | |
response_parts.append( | |
"I hear you, and I want you to know that your life has immense value. " | |
"What you're going through sounds incredibly difficult, and it's so important " | |
"that you get the support you need right now." | |
) | |
# Immediate crisis resources | |
response_parts.append( | |
"\n\nThere are caring people available 24/7 to support you:" | |
"\nβ’ National Suicide Prevention Lifeline: 1-800-273-TALK (8255)" | |
"\nβ’ Crisis Text Line: Text HOME to 741741" | |
"\nβ’ Emergency Services: 911 (if you're in immediate danger)" | |
) | |
# Local mental health professionals | |
response_parts.append( | |
"\n\nI also want to connect you with some local mental health professionals in Nairobi " | |
"who can provide ongoing support:" | |
"\n\n1. Dr Nancy Nyagah at Blossom Out Consultants" | |
"\n β’ Location: 2nd floor Park View Towers (opposite Parklands police station)" | |
"\n β’ Cost: 5000 KES (as of 2018)" | |
"\n β’ Accepts insurance" | |
"\n β’ Contact: 0722938606 or 0780938606" | |
"\n\n2. Dr. Judy Kamau at Scripture Union, Hurlingham" | |
"\n β’ Cost: 3,500 KES (as of 2018)" | |
"\n β’ Contact: +254202712852" | |
"\n\n3. Rhoda Mutiso (specializes in depression)" | |
"\n β’ Cost: 2000 KES (as of 2018)" | |
"\n β’ Contact: 0722 333378" | |
) | |
# Next steps and validation | |
response_parts.append( | |
"\n\nI know reaching out for help can feel overwhelming, but it's a sign of strength, " | |
"not weakness. Would you like me to:" | |
"\n1. Help you make a safety plan for the next 24 hours?" | |
"\n2. Talk about what you're going through right now?" | |
"\n3. Give you more information about any of these mental health professionals?" | |
"\n\nYou don't have to go through this alone. I'm here to support you in taking the next step, " | |
"whatever that looks like for you." | |
) | |
return "".join(response_parts) | |
def handle_api_error(e: Exception) -> APIError: | |
""" | |
Convert API exceptions to our custom exception types. | |
Args: | |
e (Exception): The caught exception | |
Returns: | |
APIError: The appropriate custom exception type | |
""" | |
error_msg = str(e).lower() | |
logger.debug(f"Handling API error: {error_msg}") | |
if "rate limit" in error_msg: | |
logger.warning("Rate limit exceeded") | |
return RateLimitError("Rate limit exceeded. Please try again later.") | |
elif "token limit" in error_msg: | |
logger.warning("Token limit exceeded") | |
return TokenLimitError( | |
"Input too long. Please reduce the length of your message." | |
) | |
elif "authentication" in error_msg or "api key" in error_msg: | |
logger.error("Authentication failed") | |
return AuthenticationError("Authentication failed. Please check your API key.") | |
elif "invalid request" in error_msg: | |
logger.warning("Invalid request received") | |
return InvalidRequestError("Invalid request. Please check your input.") | |
elif any(code in error_msg for code in ["502", "503", "504"]): | |
logger.warning("Server error encountered") | |
return ServerError("Server is temporarily unavailable. Please try again later.") | |
logger.error(f"Unhandled API error: {error_msg}") | |
return APIError(f"API error occurred: {str(e)}") | |
def create_chat_completion( | |
messages: List[Dict[str, str]], | |
) -> Generator[str, None, None]: | |
""" | |
Create a chat completion with comprehensive error handling and rate limiting | |
Args: | |
messages (List[Dict[str, str]]): List of messages in the chat | |
Yields: | |
Generator[str, None, None]: A generator of chat completion response chunks | |
Raises: | |
APIError: Base class for all API-related errors | |
RateLimitError: When API rate limit is exceeded | |
TokenLimitError: When input tokens exceed model's limit | |
AuthenticationError: When API key is invalid | |
InvalidRequestError: When request is malformed | |
ServerError: When API server has issues | |
""" | |
try: | |
# Initialize the OpenAI client with Cerebras API endpoint | |
client = openai.OpenAI( | |
base_url="https://api.cerebras.ai/v1", | |
api_key=TIMA_API_KEY, | |
timeout=60.0, # 60 second timeout | |
max_retries=0, # We handle retries ourselves | |
) | |
logger.info("π Starting chat completion request to Cerebras API") | |
logger.info(f"π Sending {len(messages)} messages to API") | |
logger.debug("Messages sent to API: %s", messages) | |
try: | |
stream = client.chat.completions.create( | |
model="llama-3.3-70b", # Cerebras recommended model | |
messages=messages, | |
temperature=0.8, # Adjust temperature for creativity 0 to 1 | |
max_tokens=500, # Limit response length to prevent token limit errors | |
top_p=0.9, # Use top-p sampling for more diverse responses | |
stream=True, | |
) | |
except openai.APIError as e: | |
raise handle_api_error(e) | |
except httpx.TimeoutException: | |
raise ServerError("Request timed out. Please try again.") | |
except httpx.RequestError as e: | |
raise ServerError(f"Network error occurred: {str(e)}") | |
try: | |
# Check for crisis situations and provide appropriate response | |
crisis_detected = detect_crisis_situation(messages) | |
if crisis_detected: | |
logger.warning("π¨ Crisis situation detected - Serving crisis response") | |
crisis_response = get_crisis_response() | |
yield crisis_response | |
else: | |
logger.info("β Processing normal chat completion stream") | |
chunk_count = 0 | |
content_length = 0 | |
for chunk in stream: | |
if ( | |
chunk.choices | |
and chunk.choices[0].delta | |
and chunk.choices[0].delta.content | |
): | |
chunk_count += 1 | |
chunk_content = chunk.choices[0].delta.content | |
content_length += len(chunk_content) | |
if chunk_count % 10 == 0: # Log every 10th chunk to avoid spam | |
logger.info( | |
f"π‘ Streaming... {chunk_count} chunks, {content_length} chars" | |
) | |
yield chunk_content | |
logger.info( | |
f"β Chat completion stream finished - Chunks: {chunk_count}, Total chars: {content_length}" | |
) | |
except Exception as e: | |
logger.error("Error during stream processing: %s", str(e), exc_info=True) | |
raise handle_api_error(e) | |
except APIError as e: | |
logger.error("API Error in chat completion: %s", str(e), exc_info=True) | |
raise | |
except Exception as e: | |
logger.error("Unexpected error in chat completion: %s", str(e), exc_info=True) | |
raise APIError(f"Unexpected error occurred: {str(e)}") | |
def chat_handler(message: str, history: List[List[str]]) -> Generator[str, None, None]: | |
"""Handle chat interactions with proper history management and error handling | |
Args: | |
message (str): The user message | |
history (List[List[str]]): The chat history | |
Yields: | |
Generator[str, None, None]: The response message chunks | |
Raises: | |
Exception: If an error occurs during chat handling | |
""" | |
if not isinstance(message, str): | |
raise InvalidRequestError("Message must be a string") | |
if not message.strip(): | |
raise InvalidRequestError("Message cannot be empty") | |
try: | |
logger.info( | |
f"π Processing new chat request - History entries: {len(history) if history else 0}" | |
) | |
logger.debug( | |
f"Processing chat request with {len(history) if history else 0} history entries" | |
) | |
# Validate and sanitize history format | |
if history: | |
if not isinstance(history, (list, tuple)): | |
logger.warning( | |
"History is not a list or tuple, converting to empty list" | |
) | |
history = [] | |
else: | |
# Filter out invalid history entries and handle both old and new Gradio formats | |
sanitized_history = [] | |
skipped_entries = 0 | |
for h in history: | |
# Handle new Gradio format with message dictionaries | |
if isinstance(h, dict) and "role" in h and "content" in h: | |
role = h.get("role") | |
content = h.get("content") | |
if role and content: | |
# Store as temporary format for later processing | |
sanitized_history.append( | |
{"role": role, "content": str(content)} | |
) | |
else: | |
skipped_entries += 1 | |
logger.debug( | |
f"Skipping message with missing role or content: {h}" | |
) | |
# Handle old Gradio format with [user_msg, assistant_msg] pairs | |
elif isinstance(h, (list, tuple)) and len(h) == 2: | |
# Ensure both elements are strings or None | |
user_msg = str(h[0]) if h[0] is not None else None | |
assistant_msg = str(h[1]) if h[1] is not None else None | |
if ( | |
user_msg or assistant_msg | |
): # At least one message must be non-None | |
sanitized_history.append([user_msg, assistant_msg]) | |
else: | |
skipped_entries += 1 | |
logger.debug(f"Skipping empty message pair: {h}") | |
else: | |
skipped_entries += 1 | |
logger.debug(f"Skipping unrecognized history entry format: {h}") | |
if skipped_entries > 0: | |
logger.debug(f"Skipped {skipped_entries} invalid history entries") | |
history = sanitized_history | |
logger.debug(f"Sanitized history to {len(history)} valid entries") | |
# Apply rate limiting at the conversation level | |
current_time = time.time() | |
if hasattr(chat_handler, "last_call_time"): | |
time_since_last_call = current_time - chat_handler.last_call_time | |
if time_since_last_call < ( | |
60 / RATE_LIMIT_CALLS | |
): # Minimum time between calls | |
sleep_time = (60 / RATE_LIMIT_CALLS) - time_since_last_call | |
logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds") | |
time.sleep(sleep_time) | |
chat_handler.last_call_time = current_time | |
# Define your system prompt | |
system_prompt = load_system_prompt() | |
formatted_messages = [{"role": "system", "content": system_prompt}] | |
if history: | |
for h in history: | |
# Handle new Gradio format with message dictionaries | |
if isinstance(h, dict) and "role" in h and "content" in h: | |
formatted_messages.append( | |
{"role": h["role"], "content": str(h["content"])} | |
) | |
# Handle old Gradio format with [user_msg, assistant_msg] pairs | |
elif isinstance(h, (list, tuple)) and len(h) == 2: | |
user_msg, assistant_msg = h | |
if user_msg: | |
formatted_messages.append( | |
{"role": "user", "content": str(user_msg)} | |
) | |
if assistant_msg: | |
formatted_messages.append( | |
{"role": "assistant", "content": str(assistant_msg)} | |
) | |
formatted_messages.append({"role": "user", "content": message}) | |
logger.info(f"π User message received - Length: {len(message)} chars") | |
logger.info( | |
f"π¬ Message preview: {message}" | |
) | |
logger.debug("User message content: %s", message) | |
full_response = "" | |
response_start_time = time.time() | |
for chunk in create_chat_completion(formatted_messages): | |
full_response += chunk | |
yield full_response # Stream response to Gradio | |
response_time = time.time() - response_start_time | |
logger.info( | |
f"π― AI response completed - Length: {len(full_response)} chars, Time: {response_time:.2f}s" | |
) | |
logger.info(f"π€ Full AI response: {full_response}") | |
logger.debug("AI response content: %s", full_response) | |
except ChatError as e: | |
logger.error(f"β ChatError in chat_handler: {e}", exc_info=True) | |
yield f"Error: {e}" # Display error in Gradio UI | |
except Exception as e: | |
logger.error(f"π₯ Unexpected error in chat_handler: {e}", exc_info=True) | |
yield "An unexpected error occurred. Please try again later." | |
def main(): | |
"""Main function to launch the Gradio interface""" | |
logger.info("π Starting Tima mental health chatbot application") | |
logger.info("π Application configuration:") | |
logger.info(f" - Rate limit: {RATE_LIMIT_CALLS} calls per {RATE_LIMIT_PERIOD}s") | |
logger.info(f" - Max retries: {MAX_RETRIES}") | |
logger.info(f" - API timeout: 60s") | |
# Block sensitive files and directories | |
blocked_paths = [ | |
"app2.py", # Hide main application logic | |
"app_enhanced.py", # Hide enhanced version | |
"prompts.py", # Hide prompt management | |
".env", # Hide environment variables | |
"*.log", # Hide log files | |
"__pycache__", # Hide Python cache | |
"*.pyc", # Hide compiled Python files | |
".git", # Hide git directory | |
".gitignore", # Hide git configuration | |
"tests", # Hide test directory | |
"*.md", # Hide documentation | |
"requirements.txt", # Hide dependencies | |
"Dockerfile" | |
] | |
# Create and launch the Gradio interface | |
chat_interface = gr.ChatInterface( | |
fn=chat_handler, | |
title="Tima - Your Mental Health Companion", | |
description="A safe space to talk. Tima is here to listen, offer support, and provide understanding. β οΈ This is not a replacement for professional medical advice.", | |
examples=[ | |
"I feel like giving up on everything", | |
"I'm feeling really anxious lately and can't stop worrying", | |
"I've been feeling down and hopeless for weeks", | |
"I think people are watching me and I keep hearing voices", | |
"Can you recommend a therapist in Nairobi?", | |
"I need someone to talk to about my depression", | |
], | |
type="messages", | |
) | |
logger.info("π Launching Gradio interface on 0.0.0.0:7860") | |
logger.info("π‘ Access the application at: http://localhost:7860") | |
chat_interface.launch( | |
server_name="0.0.0.0", # Server name | |
server_port=7860, # Different port to avoid conflicts | |
share=False, # Share the server publicly | |
max_threads=16, # Increased maximum number of threads to handle more concurrent jobs | |
show_error=True, # Show error messages | |
inbrowser=True, # Open in browser | |
show_api=False, # Show API | |
enable_monitoring=True, | |
state_session_capacity=50, | |
blocked_paths=blocked_paths | |
) | |
if __name__ == "__main__": | |
logger.info("π¬ Application started from command line") | |
try: | |
main() | |
except KeyboardInterrupt: | |
logger.info("βΉοΈ Application stopped by user (Ctrl+C)") | |
except Exception as e: | |
logger.error( | |
f"π Fatal error during application startup: {str(e)}", exc_info=True | |
) | |
raise | |