import random
import gradio as gr
import json

from utils.data import dataset
from utils.multiple_stream import stream_data
from pages.summarization_playground import get_model_batch_generation
from pages.summarization_playground import custom_css

def random_data_selection():
    datapoint = random.choice(dataset)
    datapoint = datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue']
    return datapoint

def create_arena():
    with open("prompt/prompt.json", "r") as file:
        json_data = file.read()
        prompts = json.loads(json_data)

    with gr.Blocks(css=custom_css) as demo:
        with gr.Group():
            datapoint = random_data_selection()
            gr.Markdown("""This arena is designed to compare different prompts. Click the button to stream responses from randomly shuffled prompts. Each column represents a response generated from one randomly selected prompt.

Once the streaming is complete, you can choose the best response.\u2764\ufe0f""")

            data_textbox = gr.Textbox(label="Data", lines=10, placeholder="Datapoints to test...", value=datapoint)
            with gr.Row():
                random_selection_button = gr.Button("Change Data")
                stream_button = gr.Button("✨ Click to Streaming ✨")

            random_selection_button.click(
                fn=random_data_selection,
                inputs=[],
                outputs=[data_textbox]
            )

            random.shuffle(prompts)
            random_selected_prompts = prompts[:3]

            # Store prompts in state components
            state_prompts = gr.State(value=prompts)
            state_random_selected_prompts = gr.State(value=random_selected_prompts)
    
            with gr.Row():
                columns = [gr.Textbox(label=f"Prompt {i+1}", lines=10) for i in range(len(random_selected_prompts))]
            
            model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")

            def start_streaming(data, random_selected_prompts):
                content_list = [prompt['prompt'] + '\n{' + data + '}\n\nsummary:' for prompt in random_selected_prompts]
                for response_data in stream_data(content_list, model):
                    updates = [gr.update(value=response_data[i]) for i in range(len(columns))]
                    yield tuple(updates)
            
            stream_button.click(
                fn=start_streaming,
                inputs=[data_textbox, state_random_selected_prompts],
                outputs=columns,
                show_progress=False
            )

            choice = gr.Radio(label="Choose the best response:", choices=["Response 1", "Response 2", "Response 3"])
    
            submit_button = gr.Button("Submit")

            output = gr.Textbox(label="You selected:", visible=False)

            def update_prompt_metrics(selected_choice, prompts, random_selected_prompts):
                if selected_choice == "Response 1":
                    prompt_id = random_selected_prompts[0]['id']
                elif selected_choice == "Response 2":
                    prompt_id = random_selected_prompts[1]['id']
                elif selected_choice == "Response 3":
                    prompt_id = random_selected_prompts[2]['id']
                else:
                    raise ValueError(f"No corresponding response of {selected_choice}")

                for prompt in prompts:
                    if prompt['id'] == prompt_id:
                        prompt["metric"]["winning_number"] += 1
                        break
                else:
                    raise ValueError(f"No prompt of id {prompt_id}")

                with open("prompt/prompt.json", "w") as f:
                    json.dump(prompts, f)

                return gr.update(value=f"You selected: {selected_choice}", visible=True), gr.update(interactive=False), gr.update(interactive=False)

            submit_button.click(
                fn=update_prompt_metrics,
                inputs=[choice, state_prompts, state_random_selected_prompts],
                outputs=[output, choice, submit_button],
            )

    return demo

if __name__ == "__main__":
    demo = create_arena()
    demo.queue()
    demo.launch()