Spaces:
Running
Running
File size: 21,068 Bytes
3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 9c58d91 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 5b65319 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf 3c0c65c 8e5acbf f2b2fe3 3c0c65c f18242a 3c0c65c 8e5acbf 3c0c65c f18242a 3c0c65c ccf2da0 3c0c65c 8e5acbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 |
"""Module for mental health chatbot with improved crisis handling.
This module provides a Gradio-based chat interface for mental health support
with proper crisis intervention protocols and streamlined code organization.
Example:
To run the chat interface:
$ python app2.py
Attributes:
TIMA_API_KEY (str): The API token for the AI service, loaded from environment variables.
Created by Nyabuti 2025-06-01
"""
import os
import time
import logging
from typing import Generator, List, Dict
import gradio as gr
import openai
import httpx
from tenacity import retry, stop_after_attempt, wait_exponential
from ratelimit import limits, sleep_and_retry
from prompts import load_system_prompt
# Constants for API interaction
MAX_RETRIES = 5
INITIAL_RETRY_DELAY = 1 # seconds
MAX_RETRY_DELAY = 60 # seconds
RATE_LIMIT_CALLS = 40 # Cerebras recommended rate limit
RATE_LIMIT_PERIOD = 60 # 1 minute period
# Configure logging with both file and console handlers for streaming logs
def setup_logging():
"""Set up logging with both file and console output for streaming visibility"""
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Clear any existing handlers
logger.handlers.clear()
# File handler for persistent logging
file_handler = logging.FileHandler("chat_interactions.log")
file_handler.setLevel(logging.INFO)
file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
file_handler.setFormatter(file_formatter)
# Console handler for streaming logs
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
console_handler.setFormatter(console_formatter)
# Add both handlers
logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger
logger = setup_logging()
# Load environment variables
TIMA_API_KEY = os.getenv("TIMA_API_KEY", os.getenv("CEREBRAS_API_KEY", None))
if not TIMA_API_KEY:
raise ValueError("CEREBRAS_API_KEY environment variable not found in environment")
# Rate limiting helps to avoid exceeding the API rate limits
CALLS_PER_MINUTE = 60
PERIOD = 60
class ChatError(Exception):
"""Custom exception for chat-related errors"""
pass
class APIError(Exception):
"""Base exception for API errors"""
pass
class RateLimitError(APIError):
"""Exception for rate limit errors"""
pass
class TokenLimitError(APIError):
"""Exception for token limit errors"""
pass
class InvalidRequestError(APIError):
"""Exception for invalid request errors"""
pass
class AuthenticationError(APIError):
"""Exception for authentication errors"""
pass
class ServerError(APIError):
"""Exception for server-side errors"""
pass
def detect_crisis_situation(messages: List[Dict[str, str]]) -> bool:
"""Detect if the conversation indicates a crisis situation.
Args:
messages: List of conversation messages
Returns:
bool: True if crisis situation detected, False otherwise
"""
if not messages:
return False
# Check for PHQ-9 suicide question with concerning response
last_message = messages[-1].get("content", "").lower()
prev_message = messages[-2].get("content", "").lower() if len(messages) >= 2 else ""
# Look for PHQ-9 suicide question and high score response
suicide_question_keywords = [
"thoughts that you would be better off dead",
"hurting yourself",
]
concerning_responses = ["3", "nearly every day", "more than half"]
has_suicide_question = any(
keyword in prev_message for keyword in suicide_question_keywords
)
has_concerning_response = any(
response in last_message for response in concerning_responses
)
crisis_detected = has_suicide_question and has_concerning_response
if crisis_detected:
logger.warning("Crisis situation detected in conversation")
return crisis_detected
def get_crisis_response() -> str:
"""Generate an appropriate crisis response with local resources
Returns:
str: A compassionate response with immediate and local resources
"""
logger.info("Generating crisis response with local resources")
response_parts = []
# Initial validation and support
response_parts.append(
"I hear you, and I want you to know that your life has immense value. "
"What you're going through sounds incredibly difficult, and it's so important "
"that you get the support you need right now."
)
# Immediate crisis resources
response_parts.append(
"\n\nThere are caring people available 24/7 to support you:"
"\nβ’ National Suicide Prevention Lifeline: 1-800-273-TALK (8255)"
"\nβ’ Crisis Text Line: Text HOME to 741741"
"\nβ’ Emergency Services: 911 (if you're in immediate danger)"
)
# Local mental health professionals
response_parts.append(
"\n\nI also want to connect you with some local mental health professionals in Nairobi "
"who can provide ongoing support:"
"\n\n1. Dr Nancy Nyagah at Blossom Out Consultants"
"\n β’ Location: 2nd floor Park View Towers (opposite Parklands police station)"
"\n β’ Cost: 5000 KES (as of 2018)"
"\n β’ Accepts insurance"
"\n β’ Contact: 0722938606 or 0780938606"
"\n\n2. Dr. Judy Kamau at Scripture Union, Hurlingham"
"\n β’ Cost: 3,500 KES (as of 2018)"
"\n β’ Contact: +254202712852"
"\n\n3. Rhoda Mutiso (specializes in depression)"
"\n β’ Cost: 2000 KES (as of 2018)"
"\n β’ Contact: 0722 333378"
)
# Next steps and validation
response_parts.append(
"\n\nI know reaching out for help can feel overwhelming, but it's a sign of strength, "
"not weakness. Would you like me to:"
"\n1. Help you make a safety plan for the next 24 hours?"
"\n2. Talk about what you're going through right now?"
"\n3. Give you more information about any of these mental health professionals?"
"\n\nYou don't have to go through this alone. I'm here to support you in taking the next step, "
"whatever that looks like for you."
)
return "".join(response_parts)
def handle_api_error(e: Exception) -> APIError:
"""
Convert API exceptions to our custom exception types.
Args:
e (Exception): The caught exception
Returns:
APIError: The appropriate custom exception type
"""
error_msg = str(e).lower()
logger.debug(f"Handling API error: {error_msg}")
if "rate limit" in error_msg:
logger.warning("Rate limit exceeded")
return RateLimitError("Rate limit exceeded. Please try again later.")
elif "token limit" in error_msg:
logger.warning("Token limit exceeded")
return TokenLimitError(
"Input too long. Please reduce the length of your message."
)
elif "authentication" in error_msg or "api key" in error_msg:
logger.error("Authentication failed")
return AuthenticationError("Authentication failed. Please check your API key.")
elif "invalid request" in error_msg:
logger.warning("Invalid request received")
return InvalidRequestError("Invalid request. Please check your input.")
elif any(code in error_msg for code in ["502", "503", "504"]):
logger.warning("Server error encountered")
return ServerError("Server is temporarily unavailable. Please try again later.")
logger.error(f"Unhandled API error: {error_msg}")
return APIError(f"API error occurred: {str(e)}")
@sleep_and_retry
@limits(calls=RATE_LIMIT_CALLS, period=RATE_LIMIT_PERIOD)
@retry(
stop=stop_after_attempt(MAX_RETRIES),
wait=wait_exponential(
multiplier=INITIAL_RETRY_DELAY, min=INITIAL_RETRY_DELAY, max=MAX_RETRY_DELAY
),
retry=lambda e: isinstance(e, (ServerError, RateLimitError)),
reraise=True,
)
def create_chat_completion(
messages: List[Dict[str, str]],
) -> Generator[str, None, None]:
"""
Create a chat completion with comprehensive error handling and rate limiting
Args:
messages (List[Dict[str, str]]): List of messages in the chat
Yields:
Generator[str, None, None]: A generator of chat completion response chunks
Raises:
APIError: Base class for all API-related errors
RateLimitError: When API rate limit is exceeded
TokenLimitError: When input tokens exceed model's limit
AuthenticationError: When API key is invalid
InvalidRequestError: When request is malformed
ServerError: When API server has issues
"""
try:
# Initialize the OpenAI client with Cerebras API endpoint
client = openai.OpenAI(
base_url="https://api.cerebras.ai/v1",
api_key=TIMA_API_KEY,
timeout=60.0, # 60 second timeout
max_retries=0, # We handle retries ourselves
)
logger.info("π Starting chat completion request to Cerebras API")
logger.info(f"π Sending {len(messages)} messages to API")
logger.debug("Messages sent to API: %s", messages)
try:
stream = client.chat.completions.create(
model="llama-3.3-70b", # Cerebras recommended model
messages=messages,
temperature=0.8, # Adjust temperature for creativity 0 to 1
max_tokens=500, # Limit response length to prevent token limit errors
top_p=0.9, # Use top-p sampling for more diverse responses
stream=True,
)
except openai.APIError as e:
raise handle_api_error(e)
except httpx.TimeoutException:
raise ServerError("Request timed out. Please try again.")
except httpx.RequestError as e:
raise ServerError(f"Network error occurred: {str(e)}")
try:
# Check for crisis situations and provide appropriate response
crisis_detected = detect_crisis_situation(messages)
if crisis_detected:
logger.warning("π¨ Crisis situation detected - Serving crisis response")
crisis_response = get_crisis_response()
yield crisis_response
else:
logger.info("β
Processing normal chat completion stream")
chunk_count = 0
content_length = 0
for chunk in stream:
if (
chunk.choices
and chunk.choices[0].delta
and chunk.choices[0].delta.content
):
chunk_count += 1
chunk_content = chunk.choices[0].delta.content
content_length += len(chunk_content)
if chunk_count % 10 == 0: # Log every 10th chunk to avoid spam
logger.info(
f"π‘ Streaming... {chunk_count} chunks, {content_length} chars"
)
yield chunk_content
logger.info(
f"β
Chat completion stream finished - Chunks: {chunk_count}, Total chars: {content_length}"
)
except Exception as e:
logger.error("Error during stream processing: %s", str(e), exc_info=True)
raise handle_api_error(e)
except APIError as e:
logger.error("API Error in chat completion: %s", str(e), exc_info=True)
raise
except Exception as e:
logger.error("Unexpected error in chat completion: %s", str(e), exc_info=True)
raise APIError(f"Unexpected error occurred: {str(e)}")
def chat_handler(message: str, history: List[List[str]]) -> Generator[str, None, None]:
"""Handle chat interactions with proper history management and error handling
Args:
message (str): The user message
history (List[List[str]]): The chat history
Yields:
Generator[str, None, None]: The response message chunks
Raises:
Exception: If an error occurs during chat handling
"""
if not isinstance(message, str):
raise InvalidRequestError("Message must be a string")
if not message.strip():
raise InvalidRequestError("Message cannot be empty")
try:
logger.info(
f"π Processing new chat request - History entries: {len(history) if history else 0}"
)
logger.debug(
f"Processing chat request with {len(history) if history else 0} history entries"
)
# Validate and sanitize history format
if history:
if not isinstance(history, (list, tuple)):
logger.warning(
"History is not a list or tuple, converting to empty list"
)
history = []
else:
# Filter out invalid history entries and handle both old and new Gradio formats
sanitized_history = []
skipped_entries = 0
for h in history:
# Handle new Gradio format with message dictionaries
if isinstance(h, dict) and "role" in h and "content" in h:
role = h.get("role")
content = h.get("content")
if role and content:
# Store as temporary format for later processing
sanitized_history.append(
{"role": role, "content": str(content)}
)
else:
skipped_entries += 1
logger.debug(
f"Skipping message with missing role or content: {h}"
)
# Handle old Gradio format with [user_msg, assistant_msg] pairs
elif isinstance(h, (list, tuple)) and len(h) == 2:
# Ensure both elements are strings or None
user_msg = str(h[0]) if h[0] is not None else None
assistant_msg = str(h[1]) if h[1] is not None else None
if (
user_msg or assistant_msg
): # At least one message must be non-None
sanitized_history.append([user_msg, assistant_msg])
else:
skipped_entries += 1
logger.debug(f"Skipping empty message pair: {h}")
else:
skipped_entries += 1
logger.debug(f"Skipping unrecognized history entry format: {h}")
if skipped_entries > 0:
logger.debug(f"Skipped {skipped_entries} invalid history entries")
history = sanitized_history
logger.debug(f"Sanitized history to {len(history)} valid entries")
# Apply rate limiting at the conversation level
current_time = time.time()
if hasattr(chat_handler, "last_call_time"):
time_since_last_call = current_time - chat_handler.last_call_time
if time_since_last_call < (
60 / RATE_LIMIT_CALLS
): # Minimum time between calls
sleep_time = (60 / RATE_LIMIT_CALLS) - time_since_last_call
logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds")
time.sleep(sleep_time)
chat_handler.last_call_time = current_time
# Define your system prompt
system_prompt = load_system_prompt()
formatted_messages = [{"role": "system", "content": system_prompt}]
if history:
for h in history:
# Handle new Gradio format with message dictionaries
if isinstance(h, dict) and "role" in h and "content" in h:
formatted_messages.append(
{"role": h["role"], "content": str(h["content"])}
)
# Handle old Gradio format with [user_msg, assistant_msg] pairs
elif isinstance(h, (list, tuple)) and len(h) == 2:
user_msg, assistant_msg = h
if user_msg:
formatted_messages.append(
{"role": "user", "content": str(user_msg)}
)
if assistant_msg:
formatted_messages.append(
{"role": "assistant", "content": str(assistant_msg)}
)
formatted_messages.append({"role": "user", "content": message})
logger.info(f"π User message received - Length: {len(message)} chars")
logger.info(
f"π¬ Message preview: {message}"
)
logger.debug("User message content: %s", message)
full_response = ""
response_start_time = time.time()
for chunk in create_chat_completion(formatted_messages):
full_response += chunk
yield full_response # Stream response to Gradio
response_time = time.time() - response_start_time
logger.info(
f"π― AI response completed - Length: {len(full_response)} chars, Time: {response_time:.2f}s"
)
logger.info(f"π€ Full AI response: {full_response}")
logger.debug("AI response content: %s", full_response)
except ChatError as e:
logger.error(f"β ChatError in chat_handler: {e}", exc_info=True)
yield f"Error: {e}" # Display error in Gradio UI
except Exception as e:
logger.error(f"π₯ Unexpected error in chat_handler: {e}", exc_info=True)
yield "An unexpected error occurred. Please try again later."
def main():
"""Main function to launch the Gradio interface"""
logger.info("π Starting Tima mental health chatbot application")
logger.info("π Application configuration:")
logger.info(f" - Rate limit: {RATE_LIMIT_CALLS} calls per {RATE_LIMIT_PERIOD}s")
logger.info(f" - Max retries: {MAX_RETRIES}")
logger.info(f" - API timeout: 60s")
# Block sensitive files and directories
blocked_paths = [
"app2.py", # Hide main application logic
"app_enhanced.py", # Hide enhanced version
"prompts.py", # Hide prompt management
".env", # Hide environment variables
"*.log", # Hide log files
"__pycache__", # Hide Python cache
"*.pyc", # Hide compiled Python files
".git", # Hide git directory
".gitignore", # Hide git configuration
"tests", # Hide test directory
"*.md", # Hide documentation
"requirements.txt", # Hide dependencies
"Dockerfile"
]
# Create and launch the Gradio interface
chat_interface = gr.ChatInterface(
fn=chat_handler,
title="Tima - Your Mental Health Companion",
description="A safe space to talk. Tima is here to listen, offer support, and provide understanding. β οΈ This is not a replacement for professional medical advice.",
examples=[
"I feel like giving up on everything",
"I'm feeling really anxious lately and can't stop worrying",
"I've been feeling down and hopeless for weeks",
"I think people are watching me and I keep hearing voices",
"Can you recommend a therapist in Nairobi?",
"I need someone to talk to about my depression",
],
type="messages",
)
logger.info("π Launching Gradio interface on 0.0.0.0:7860")
logger.info("π‘ Access the application at: http://localhost:7860")
chat_interface.launch(
server_name="0.0.0.0", # Server name
server_port=7860, # Different port to avoid conflicts
share=False, # Share the server publicly
max_threads=16, # Increased maximum number of threads to handle more concurrent jobs
show_error=True, # Show error messages
inbrowser=True, # Open in browser
show_api=False, # Show API
enable_monitoring=True,
state_session_capacity=50,
blocked_paths=blocked_paths
)
if __name__ == "__main__":
logger.info("π¬ Application started from command line")
try:
main()
except KeyboardInterrupt:
logger.info("βΉοΈ Application stopped by user (Ctrl+C)")
except Exception as e:
logger.error(
f"π Fatal error during application startup: {str(e)}", exc_info=True
)
raise
|