File size: 21,068 Bytes
3c0c65c
 
 
 
 
 
 
 
 
 
 
 
8e5acbf
3c0c65c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e5acbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c0c65c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e5acbf
 
 
 
 
3c0c65c
 
 
 
 
 
 
 
8e5acbf
3c0c65c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e5acbf
 
3c0c65c
8e5acbf
3c0c65c
 
8e5acbf
3c0c65c
 
 
 
8e5acbf
3c0c65c
 
8e5acbf
3c0c65c
 
8e5acbf
3c0c65c
8e5acbf
 
3c0c65c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e5acbf
 
3c0c65c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e5acbf
3c0c65c
 
 
8e5acbf
 
 
3c0c65c
 
 
 
 
 
8e5acbf
 
 
 
 
 
 
 
 
 
 
3c0c65c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e5acbf
 
 
 
 
 
 
3c0c65c
 
 
 
 
 
 
 
8e5acbf
3c0c65c
8e5acbf
3c0c65c
8e5acbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c0c65c
 
 
 
 
 
 
8e5acbf
 
 
3c0c65c
8e5acbf
 
 
 
 
 
3c0c65c
8e5acbf
3c0c65c
 
 
 
 
 
 
 
 
8e5acbf
3c0c65c
 
 
 
 
 
 
 
 
8e5acbf
 
 
3c0c65c
8e5acbf
3c0c65c
8e5acbf
 
 
 
 
 
 
 
 
 
 
3c0c65c
 
 
8e5acbf
 
9c58d91
8e5acbf
 
3c0c65c
 
8e5acbf
3c0c65c
 
 
 
8e5acbf
 
 
 
5b65319
8e5acbf
3c0c65c
 
8e5acbf
3c0c65c
 
8e5acbf
3c0c65c
 
 
 
 
8e5acbf
 
 
 
 
 
f2b2fe3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c0c65c
 
 
 
f18242a
3c0c65c
 
 
 
 
 
 
 
 
 
8e5acbf
 
3c0c65c
 
 
 
 
 
 
f18242a
3c0c65c
 
ccf2da0
3c0c65c
 
 
 
8e5acbf
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
"""Module for mental health chatbot with improved crisis handling.

This module provides a Gradio-based chat interface for mental health support
with proper crisis intervention protocols and streamlined code organization.

Example:
    To run the chat interface:
        $ python app2.py

Attributes:
    TIMA_API_KEY (str): The API token for the AI service, loaded from environment variables.

Created by Nyabuti 2025-06-01
"""

import os
import time
import logging
from typing import Generator, List, Dict
import gradio as gr
import openai
import httpx
from tenacity import retry, stop_after_attempt, wait_exponential
from ratelimit import limits, sleep_and_retry

from prompts import load_system_prompt

# Constants for API interaction
MAX_RETRIES = 5
INITIAL_RETRY_DELAY = 1  # seconds
MAX_RETRY_DELAY = 60  # seconds
RATE_LIMIT_CALLS = 40  # Cerebras recommended rate limit
RATE_LIMIT_PERIOD = 60  # 1 minute period


# Configure logging with both file and console handlers for streaming logs
def setup_logging():
    """Set up logging with both file and console output for streaming visibility"""
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)

    # Clear any existing handlers
    logger.handlers.clear()

    # File handler for persistent logging
    file_handler = logging.FileHandler("chat_interactions.log")
    file_handler.setLevel(logging.INFO)
    file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    file_handler.setFormatter(file_formatter)

    # Console handler for streaming logs
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_formatter = logging.Formatter(
        "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
    )
    console_handler.setFormatter(console_formatter)

    # Add both handlers
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    return logger


logger = setup_logging()

# Load environment variables
TIMA_API_KEY = os.getenv("TIMA_API_KEY", os.getenv("CEREBRAS_API_KEY", None))
if not TIMA_API_KEY:
    raise ValueError("CEREBRAS_API_KEY environment variable not found in environment")

# Rate limiting helps to avoid exceeding the API rate limits
CALLS_PER_MINUTE = 60
PERIOD = 60


class ChatError(Exception):
    """Custom exception for chat-related errors"""

    pass


class APIError(Exception):
    """Base exception for API errors"""

    pass


class RateLimitError(APIError):
    """Exception for rate limit errors"""

    pass


class TokenLimitError(APIError):
    """Exception for token limit errors"""

    pass


class InvalidRequestError(APIError):
    """Exception for invalid request errors"""

    pass


class AuthenticationError(APIError):
    """Exception for authentication errors"""

    pass


class ServerError(APIError):
    """Exception for server-side errors"""

    pass


def detect_crisis_situation(messages: List[Dict[str, str]]) -> bool:
    """Detect if the conversation indicates a crisis situation.

    Args:
        messages: List of conversation messages

    Returns:
        bool: True if crisis situation detected, False otherwise
    """
    if not messages:
        return False

    # Check for PHQ-9 suicide question with concerning response
    last_message = messages[-1].get("content", "").lower()
    prev_message = messages[-2].get("content", "").lower() if len(messages) >= 2 else ""

    # Look for PHQ-9 suicide question and high score response
    suicide_question_keywords = [
        "thoughts that you would be better off dead",
        "hurting yourself",
    ]
    concerning_responses = ["3", "nearly every day", "more than half"]

    has_suicide_question = any(
        keyword in prev_message for keyword in suicide_question_keywords
    )
    has_concerning_response = any(
        response in last_message for response in concerning_responses
    )

    crisis_detected = has_suicide_question and has_concerning_response
    if crisis_detected:
        logger.warning("Crisis situation detected in conversation")

    return crisis_detected


def get_crisis_response() -> str:
    """Generate an appropriate crisis response with local resources

    Returns:
        str: A compassionate response with immediate and local resources
    """
    logger.info("Generating crisis response with local resources")
    response_parts = []

    # Initial validation and support
    response_parts.append(
        "I hear you, and I want you to know that your life has immense value. "
        "What you're going through sounds incredibly difficult, and it's so important "
        "that you get the support you need right now."
    )

    # Immediate crisis resources
    response_parts.append(
        "\n\nThere are caring people available 24/7 to support you:"
        "\nβ€’ National Suicide Prevention Lifeline: 1-800-273-TALK (8255)"
        "\nβ€’ Crisis Text Line: Text HOME to 741741"
        "\nβ€’ Emergency Services: 911 (if you're in immediate danger)"
    )

    # Local mental health professionals
    response_parts.append(
        "\n\nI also want to connect you with some local mental health professionals in Nairobi "
        "who can provide ongoing support:"
        "\n\n1. Dr Nancy Nyagah at Blossom Out Consultants"
        "\n   β€’ Location: 2nd floor Park View Towers (opposite Parklands police station)"
        "\n   β€’ Cost: 5000 KES (as of 2018)"
        "\n   β€’ Accepts insurance"
        "\n   β€’ Contact: 0722938606 or 0780938606"
        "\n\n2. Dr. Judy Kamau at Scripture Union, Hurlingham"
        "\n   β€’ Cost: 3,500 KES (as of 2018)"
        "\n   β€’ Contact: +254202712852"
        "\n\n3. Rhoda Mutiso (specializes in depression)"
        "\n   β€’ Cost: 2000 KES (as of 2018)"
        "\n   β€’ Contact: 0722 333378"
    )

    # Next steps and validation
    response_parts.append(
        "\n\nI know reaching out for help can feel overwhelming, but it's a sign of strength, "
        "not weakness. Would you like me to:"
        "\n1. Help you make a safety plan for the next 24 hours?"
        "\n2. Talk about what you're going through right now?"
        "\n3. Give you more information about any of these mental health professionals?"
        "\n\nYou don't have to go through this alone. I'm here to support you in taking the next step, "
        "whatever that looks like for you."
    )

    return "".join(response_parts)


def handle_api_error(e: Exception) -> APIError:
    """
    Convert API exceptions to our custom exception types.

    Args:
        e (Exception): The caught exception

    Returns:
        APIError: The appropriate custom exception type
    """
    error_msg = str(e).lower()
    logger.debug(f"Handling API error: {error_msg}")

    if "rate limit" in error_msg:
        logger.warning("Rate limit exceeded")
        return RateLimitError("Rate limit exceeded. Please try again later.")
    elif "token limit" in error_msg:
        logger.warning("Token limit exceeded")
        return TokenLimitError(
            "Input too long. Please reduce the length of your message."
        )
    elif "authentication" in error_msg or "api key" in error_msg:
        logger.error("Authentication failed")
        return AuthenticationError("Authentication failed. Please check your API key.")
    elif "invalid request" in error_msg:
        logger.warning("Invalid request received")
        return InvalidRequestError("Invalid request. Please check your input.")
    elif any(code in error_msg for code in ["502", "503", "504"]):
        logger.warning("Server error encountered")
        return ServerError("Server is temporarily unavailable. Please try again later.")

    logger.error(f"Unhandled API error: {error_msg}")
    return APIError(f"API error occurred: {str(e)}")


@sleep_and_retry
@limits(calls=RATE_LIMIT_CALLS, period=RATE_LIMIT_PERIOD)
@retry(
    stop=stop_after_attempt(MAX_RETRIES),
    wait=wait_exponential(
        multiplier=INITIAL_RETRY_DELAY, min=INITIAL_RETRY_DELAY, max=MAX_RETRY_DELAY
    ),
    retry=lambda e: isinstance(e, (ServerError, RateLimitError)),
    reraise=True,
)
def create_chat_completion(
    messages: List[Dict[str, str]],
) -> Generator[str, None, None]:
    """
    Create a chat completion with comprehensive error handling and rate limiting

    Args:
        messages (List[Dict[str, str]]): List of messages in the chat

    Yields:
        Generator[str, None, None]: A generator of chat completion response chunks

    Raises:
        APIError: Base class for all API-related errors
        RateLimitError: When API rate limit is exceeded
        TokenLimitError: When input tokens exceed model's limit
        AuthenticationError: When API key is invalid
        InvalidRequestError: When request is malformed
        ServerError: When API server has issues
    """
    try:
        # Initialize the OpenAI client with Cerebras API endpoint
        client = openai.OpenAI(
            base_url="https://api.cerebras.ai/v1",
            api_key=TIMA_API_KEY,
            timeout=60.0,  # 60 second timeout
            max_retries=0,  # We handle retries ourselves
        )

        logger.info("πŸš€ Starting chat completion request to Cerebras API")
        logger.info(f"πŸ“Š Sending {len(messages)} messages to API")
        logger.debug("Messages sent to API: %s", messages)

        try:
            stream = client.chat.completions.create(
                model="llama-3.3-70b",  # Cerebras recommended model
                messages=messages,
                temperature=0.8,  # Adjust temperature for creativity 0 to 1
                max_tokens=500,  # Limit response length to prevent token limit errors
                top_p=0.9,  # Use top-p sampling for more diverse responses
                stream=True,
            )
        except openai.APIError as e:
            raise handle_api_error(e)
        except httpx.TimeoutException:
            raise ServerError("Request timed out. Please try again.")
        except httpx.RequestError as e:
            raise ServerError(f"Network error occurred: {str(e)}")

        try:
            # Check for crisis situations and provide appropriate response
            crisis_detected = detect_crisis_situation(messages)
            if crisis_detected:
                logger.warning("🚨 Crisis situation detected - Serving crisis response")
                crisis_response = get_crisis_response()
                yield crisis_response
            else:
                logger.info("βœ… Processing normal chat completion stream")
                chunk_count = 0
                content_length = 0
                for chunk in stream:
                    if (
                        chunk.choices
                        and chunk.choices[0].delta
                        and chunk.choices[0].delta.content
                    ):
                        chunk_count += 1
                        chunk_content = chunk.choices[0].delta.content
                        content_length += len(chunk_content)
                        if chunk_count % 10 == 0:  # Log every 10th chunk to avoid spam
                            logger.info(
                                f"πŸ“‘ Streaming... {chunk_count} chunks, {content_length} chars"
                            )
                        yield chunk_content
                logger.info(
                    f"βœ… Chat completion stream finished - Chunks: {chunk_count}, Total chars: {content_length}"
                )
        except Exception as e:
            logger.error("Error during stream processing: %s", str(e), exc_info=True)
            raise handle_api_error(e)

    except APIError as e:
        logger.error("API Error in chat completion: %s", str(e), exc_info=True)
        raise
    except Exception as e:
        logger.error("Unexpected error in chat completion: %s", str(e), exc_info=True)
        raise APIError(f"Unexpected error occurred: {str(e)}")


def chat_handler(message: str, history: List[List[str]]) -> Generator[str, None, None]:
    """Handle chat interactions with proper history management and error handling

    Args:
        message (str): The user message
        history (List[List[str]]): The chat history

    Yields:
        Generator[str, None, None]: The response message chunks

    Raises:
        Exception: If an error occurs during chat handling
    """
    if not isinstance(message, str):
        raise InvalidRequestError("Message must be a string")

    if not message.strip():
        raise InvalidRequestError("Message cannot be empty")

    try:
        logger.info(
            f"πŸ”„ Processing new chat request - History entries: {len(history) if history else 0}"
        )
        logger.debug(
            f"Processing chat request with {len(history) if history else 0} history entries"
        )

        # Validate and sanitize history format
        if history:
            if not isinstance(history, (list, tuple)):
                logger.warning(
                    "History is not a list or tuple, converting to empty list"
                )
                history = []
            else:
                # Filter out invalid history entries and handle both old and new Gradio formats
                sanitized_history = []
                skipped_entries = 0
                for h in history:
                    # Handle new Gradio format with message dictionaries
                    if isinstance(h, dict) and "role" in h and "content" in h:
                        role = h.get("role")
                        content = h.get("content")
                        if role and content:
                            # Store as temporary format for later processing
                            sanitized_history.append(
                                {"role": role, "content": str(content)}
                            )
                        else:
                            skipped_entries += 1
                            logger.debug(
                                f"Skipping message with missing role or content: {h}"
                            )
                    # Handle old Gradio format with [user_msg, assistant_msg] pairs
                    elif isinstance(h, (list, tuple)) and len(h) == 2:
                        # Ensure both elements are strings or None
                        user_msg = str(h[0]) if h[0] is not None else None
                        assistant_msg = str(h[1]) if h[1] is not None else None
                        if (
                            user_msg or assistant_msg
                        ):  # At least one message must be non-None
                            sanitized_history.append([user_msg, assistant_msg])
                        else:
                            skipped_entries += 1
                            logger.debug(f"Skipping empty message pair: {h}")
                    else:
                        skipped_entries += 1
                        logger.debug(f"Skipping unrecognized history entry format: {h}")

                if skipped_entries > 0:
                    logger.debug(f"Skipped {skipped_entries} invalid history entries")

                history = sanitized_history
                logger.debug(f"Sanitized history to {len(history)} valid entries")

        # Apply rate limiting at the conversation level
        current_time = time.time()
        if hasattr(chat_handler, "last_call_time"):
            time_since_last_call = current_time - chat_handler.last_call_time
            if time_since_last_call < (
                60 / RATE_LIMIT_CALLS
            ):  # Minimum time between calls
                sleep_time = (60 / RATE_LIMIT_CALLS) - time_since_last_call
                logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds")
                time.sleep(sleep_time)
        chat_handler.last_call_time = current_time

        # Define your system prompt
        system_prompt = load_system_prompt()

        formatted_messages = [{"role": "system", "content": system_prompt}]

        if history:
            for h in history:
                # Handle new Gradio format with message dictionaries
                if isinstance(h, dict) and "role" in h and "content" in h:
                    formatted_messages.append(
                        {"role": h["role"], "content": str(h["content"])}
                    )
                # Handle old Gradio format with [user_msg, assistant_msg] pairs
                elif isinstance(h, (list, tuple)) and len(h) == 2:
                    user_msg, assistant_msg = h
                    if user_msg:
                        formatted_messages.append(
                            {"role": "user", "content": str(user_msg)}
                        )
                    if assistant_msg:
                        formatted_messages.append(
                            {"role": "assistant", "content": str(assistant_msg)}
                        )

        formatted_messages.append({"role": "user", "content": message})

        logger.info(f"πŸ“ User message received - Length: {len(message)} chars")
        logger.info(
            f"πŸ’¬ Message preview: {message}"
        )
        logger.debug("User message content: %s", message)

        full_response = ""
        response_start_time = time.time()
        for chunk in create_chat_completion(formatted_messages):
            full_response += chunk
            yield full_response  # Stream response to Gradio

        response_time = time.time() - response_start_time
        logger.info(
            f"🎯 AI response completed - Length: {len(full_response)} chars, Time: {response_time:.2f}s"
        )
        logger.info(f"πŸ“€ Full AI response: {full_response}")
        logger.debug("AI response content: %s", full_response)

    except ChatError as e:
        logger.error(f"❌ ChatError in chat_handler: {e}", exc_info=True)
        yield f"Error: {e}"  # Display error in Gradio UI
    except Exception as e:
        logger.error(f"πŸ’₯ Unexpected error in chat_handler: {e}", exc_info=True)
        yield "An unexpected error occurred. Please try again later."


def main():
    """Main function to launch the Gradio interface"""
    logger.info("πŸš€ Starting Tima mental health chatbot application")
    logger.info("πŸ“‹ Application configuration:")
    logger.info(f"   - Rate limit: {RATE_LIMIT_CALLS} calls per {RATE_LIMIT_PERIOD}s")
    logger.info(f"   - Max retries: {MAX_RETRIES}")
    logger.info(f"   - API timeout: 60s")

    # Block sensitive files and directories
    blocked_paths = [
        "app2.py",              # Hide main application logic
        "app_enhanced.py",      # Hide enhanced version
        "prompts.py",          # Hide prompt management
        ".env",                # Hide environment variables
        "*.log",              # Hide log files
        "__pycache__",        # Hide Python cache
        "*.pyc",              # Hide compiled Python files
        ".git",               # Hide git directory
        ".gitignore",         # Hide git configuration
        "tests",              # Hide test directory
        "*.md",              # Hide documentation
        "requirements.txt",   # Hide dependencies
        "Dockerfile"
    ]

    # Create and launch the Gradio interface
    chat_interface = gr.ChatInterface(
        fn=chat_handler,
        title="Tima - Your Mental Health Companion",
        description="A safe space to talk. Tima is here to listen, offer support, and provide understanding. ⚠️ This is not a replacement for professional medical advice.",
        examples=[
            "I feel like giving up on everything",
            "I'm feeling really anxious lately and can't stop worrying",
            "I've been feeling down and hopeless for weeks",
            "I think people are watching me and I keep hearing voices",
            "Can you recommend a therapist in Nairobi?",
            "I need someone to talk to about my depression",
        ],
        type="messages",
    )
    logger.info("🌐 Launching Gradio interface on 0.0.0.0:7860")
    logger.info("πŸ’‘ Access the application at: http://localhost:7860")
    chat_interface.launch(
        server_name="0.0.0.0",  # Server name
        server_port=7860,  # Different port to avoid conflicts
        share=False,  # Share the server publicly
        max_threads=16,  # Increased maximum number of threads to handle more concurrent jobs
        show_error=True,  # Show error messages
        inbrowser=True,  # Open in browser
        show_api=False,  # Show API
        enable_monitoring=True,
        state_session_capacity=50,
        blocked_paths=blocked_paths
    )


if __name__ == "__main__":
    logger.info("🎬 Application started from command line")
    try:
        main()
    except KeyboardInterrupt:
        logger.info("⏹️  Application stopped by user (Ctrl+C)")
    except Exception as e:
        logger.error(
            f"πŸ’€ Fatal error during application startup: {str(e)}", exc_info=True
        )
        raise