import spaces
import gradio as gr
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from open_lm.hf import *
from open_lm.precision import get_autocast

# Define model options
MODEL_OPTIONS = {
    "TRI DCLM-1B": "TRI-ML/DCLM-1B",
    "Apple DCLM-Baseline-7B": "apple/DCLM-Baseline-7B",
    "[IT] TRI DCLM-1B": "TRI-ML/DCLM-1B-IT",
    "[IT] Apple DCLM-Baseline-7B": "mlfoundations/dclm-7b-it",
}

# Global variables for model and tokenizer
current_model = None
current_tokenizer = None

def load_model(model_name):
    global current_model, current_tokenizer
    current_tokenizer = AutoTokenizer.from_pretrained(MODEL_OPTIONS[model_name])
    current_model = AutoModelForCausalLM.from_pretrained(MODEL_OPTIONS[model_name])
    device = "cuda" if torch.cuda.is_available() else "cpu"
    current_model = current_model.to(device)
    return f"Loaded model: {model_name}"

@spaces.GPU
def generate_completion(
    prompt, model_choice, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    global current_model, current_tokenizer
    
    if current_model is None or current_tokenizer is None:
        return "Please select a model first."

    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    
    inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
    autocast = get_autocast("amp_bf16")

    with autocast():
        generate_kwargs = dict(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=current_tokenizer.eos_token_id
        )

        streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
        streamer.stop_signal = current_tokenizer.decode(current_tokenizer.eos_token_id)
        generate_kwargs["streamer"] = streamer

        thread = Thread(target=current_model.generate, kwargs=generate_kwargs)
        thread.start()

        output = "<span style='color: blue;'>" + prompt + "</span>"
        for new_text in streamer:
            if isinstance(new_text, torch.Tensor):
                new_text = current_tokenizer.decode(new_text)
            if streamer.stop_signal in new_text:
                output += new_text.split(streamer.stop_signal)[0]
                break
            output += new_text
            yield output

        thread.join()
    return output

def format_prompt(message, history):
    prompt = ""
    for user_prompt, bot_response in history:
        prompt += f"User: {user_prompt}\nAssistant: {bot_response}\n"
    prompt += f"User: {message}\nAssistant:"
    return prompt

@spaces.GPU
def generate_chat(
    message, chat_history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    global current_model, current_tokenizer
    
    if current_model is None or current_tokenizer is None:
        yield chat_history + [("Error", "Please select a model first.")]
        return

    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    
    formatted_prompt = format_prompt(message, chat_history)
    inputs = current_tokenizer(formatted_prompt, return_tensors="pt").to(current_model.device)
    
    generate_kwargs = dict(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        pad_token_id=current_tokenizer.eos_token_id
    )

    streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
    streamer.stop_signal = current_tokenizer.decode(current_tokenizer.eos_token_id)
    generate_kwargs["streamer"] = streamer

    thread = Thread(target=current_model.generate, kwargs=generate_kwargs)
    thread.start()

    new_history = chat_history + [(message, "")]
    for new_text in streamer:
        if isinstance(new_text, torch.Tensor):
            new_text = current_tokenizer.decode(new_text)
        if streamer.stop_signal in new_text:
            new_text = new_text.split(streamer.stop_signal)[0]
            new_history[-1] = (message, new_history[-1][1] + new_text)
            break
        new_history[-1] = (message, new_history[-1][1] + new_text)
        yield new_history

    thread.join()

additional_inputs = [
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=256,
        minimum=0,
        maximum=1048,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

with gr.Blocks() as demo:
    gr.Markdown(
        """
        # DCLM Demo
        This demo allows you to generate text using DCLM models in two modes: 
        1. Text Completion:
            For non-Instruction-Tuned models, it generates the continuation of the input text.
        2. Chatbot:
            For Instruction-Tuned [IT] models, it generates responses to user messages as a chatbot.
        
        Select a model from the dropdown to start, it might take a few seconds to load. 
        The interface will automatically switch between Text Completion and Chatbot modes based on the selected model.
        """
    )

    with gr.Row():
        model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model")
        model_status = gr.Textbox(label="Model Status")

    # Text Completion interface
    with gr.Row(visible=False) as completion_interface:
        with gr.Column():
            text_input = gr.Textbox(lines=3, label="Input Text")
            text_output = gr.Markdown(label="Generated Text")
            generate_button = gr.Button("Generate")

    # Chatbot interface
    with gr.Row(visible=False) as chat_interface:
        with gr.Column():
            chatbot = gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel")
            msg = gr.Textbox(label="Message")
            clear = gr.Button("Clear")

    with gr.Accordion("Advanced Options", open=False):
        for input_component in additional_inputs:
            input_component.render()

    def switch_interface(model_name):
        is_it_model = model_name.startswith("[IT]")
        status = load_model(model_name)
        return (
            gr.Row(visible=not is_it_model),  # completion_interface
            gr.Row(visible=is_it_model),      # chat_interface
            status                            # model_status
        )

    model_dropdown.change(
        switch_interface,
        inputs=[model_dropdown],
        outputs=[completion_interface, chat_interface, model_status]
    )

    generate_button.click(
        generate_completion,
        inputs=[text_input, model_dropdown, *additional_inputs],
        outputs=[text_output]
    )

    msg.submit(generate_chat, [msg, chatbot, *additional_inputs], chatbot)
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue().launch()