import os
import base64
import markdown
import gradio as gr
from openai import OpenAI
from dotenv import load_dotenv
from typing import List, Dict

load_dotenv()
XAI_API_KEY = os.getenv("XAI_API_KEY")

client = OpenAI(
    api_key=XAI_API_KEY,
    base_url="https://api.x.ai/v1",
)

#I will try out system prompts and change it later
def build_system_prompt() -> dict:
    return {
        "role": "system",
        "content": (
            "You are Grok Vision, created by xAI. You're designed to understand and describe images and answer text-based queries. "
            "Use all previous conversation context to provide clear, positive, and helpful responses. "
            "Respond in markdown format when appropriate."
        )
    }

def encode_image(image_path: str) -> str:
    file_size = os.path.getsize(image_path)
    if file_size > 10 * 1024 * 1024:
        raise ValueError("Image exceeds maximum size of 10MB.")
    ext = os.path.splitext(image_path)[1].lower()
    if ext in ['.jpg', '.jpeg']:
        mime_type = 'image/jpeg'
    elif ext == '.png':
        mime_type = 'image/png'
    else:
        raise ValueError("Unsupported image format. Only JPEG and PNG are allowed.")
    #Encodes a local image file to base64 which only supports
    with open(image_path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
    return f"data:{mime_type};base64,{encoded_string}"

def process_input(user_text: str, user_image_paths: List[str]) -> tuple[str, List[str]]:
    user_text = user_text.strip() if user_text else ""
    image_urls = []
    # Extract URLs
    text_parts = user_text.split()
    remaining_text = []
    for part in text_parts:
        if part.startswith("http"):
            image_urls.append(part)
        else:
            remaining_text.append(part)
    user_text = " ".join(remaining_text) if remaining_text else ""
    if user_image_paths:
        for path in user_image_paths:
            if path: 
                image_urls.append(encode_image(path))
    
    return user_text, image_urls

def create_message_content(text: str, image_urls: List[str]) -> list[dict]:
    content = []
    for image_url in image_urls:
        content.append({
            "type": "image_url",
            "image_url": {
                "url": image_url, 
                "detail": "high"
            }
        })
    if text:
        content.append({
            "type": "text",
            "text": text
        })
    return content

def stream_response(history: List[Dict], user_text: str, user_image_paths: List[str]):
    user_text, image_urls = process_input(user_text, user_image_paths)
    if not user_text and not image_urls:
        history.append({"role": "assistant", "content": "Please provide text or at least one image (JPEG/PNG only)."})
        yield history
        return
    messages = [build_system_prompt()]
    for entry in history:
        if entry["role"] == "user":
            content = create_message_content(entry["content"], entry.get("image_urls", []))
            messages.append({"role": "user", "content": content})
        elif entry["role"] == "assistant":
            messages.append({"role": "assistant", "content": entry["content"]})
    new_content = create_message_content(user_text, image_urls)
    messages.append({"role": "user", "content": new_content})
    history.append({"role": "user", "content": user_text, "image_urls": image_urls})
    stream = client.chat.completions.create(
        model="grok-2-vision-1212",
        messages=messages,
        stream=True,
        temperature=0.01,
    )
    response_text = ""
    temp_history = history.copy()
    temp_history.append({"role": "assistant", "content": ""})
    for chunk in stream:
        delta_content = chunk.choices[0].delta.content
        if delta_content is not None:
            response_text += delta_content
            temp_history[-1] = {"role": "assistant", "content": response_text}
            yield temp_history

def clear_inputs_and_chat():
    return [], [], "", None 

def update_and_clear(history: List[Dict], streamed_response: List[Dict]) -> tuple[List[Dict], str, None]:
    if streamed_response and history[-1]["content"] != streamed_response[-1]["content"]:
        history[-1] = streamed_response[-1]
    return history, "", None  

with gr.Blocks(
    theme=gr.themes.Soft(),
    css="""
        .chatbot-container {max-height: 80vh; overflow-y: auto;}
        .input-container {margin-top: 20px;}
        .title {text-align: center; margin-bottom: 20px;}
    """
) as demo:
    gr.Markdown(
        """
        # Grok 2 Vision Chatbot 𝕏
        
        Interact with Grok 2 Vision you can do:
        - 📸 Upload one or more images (Max 10MB each)
        - 🔗 Provide image URLs in your message (`https://example.com/image1.jpg)
        - ✍️ Ask text-only questions
        - 💬 Chat history is preserved.
        """
    )
    
    with gr.Column(elem_classes="chatbot-container"):
        chatbot = gr.Chatbot(
            label="Conversation",
            type="messages",
            bubble_full_width=False
        )
    
    with gr.Row(elem_classes="input-container"):
        with gr.Column(scale=1):
            image_input = gr.File(
                file_count="multiple", 
                file_types=[".jpg", ".jpeg", ".png"], 
                label="Upload JPEG or PNG Images",
                height=300,
                interactive=True
            )
        with gr.Column(scale=3):
            message_input = gr.Textbox(
                label="Your Message",
                placeholder="Type your question or paste JPEG/PNG image URLs",
                lines=3
            )
            with gr.Row():
                submit_btn = gr.Button("Send", variant="primary")
                clear_btn = gr.Button("Clear", variant="secondary")
    
    state = gr.State([])

    submit_btn.click(
        fn=stream_response,
        inputs=[state, message_input, image_input],
        outputs=chatbot,
        queue=True
    ).then(
        fn=update_and_clear,
        inputs=[state, chatbot],
        outputs=[state, message_input, image_input]
    )
    
    clear_btn.click(
        fn=clear_inputs_and_chat,
        inputs=[],
        outputs=[chatbot, state, message_input, image_input]
    )

if __name__ == "__main__":
    demo.launch()