import random import gradio as gr import json from utils.data import dataset from utils.multiple_stream import stream_data from pages.summarization_playground import get_model_batch_generation from pages.summarization_playground import custom_css def random_data_selection(): datapoint = random.choice(dataset) datapoint = datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue'] return datapoint def create_arena(): with open("prompt/prompt.json", "r") as file: json_data = file.read() prompts = json.loads(json_data) with gr.Blocks(theme=gr.themes.Soft().set(spacing_size="sm", text_size="sm"), css=custom_css) as demo: with gr.Group(): datapoint = random_data_selection() gr.Markdown("""This arena is designed to compare different prompts. Click the button to stream responses from randomly shuffled prompts. Each column represents a response generated from one randomly selected prompt. Once the streaming is complete, you can choose the best response.\u2764\ufe0f""") data_textbox = gr.Textbox(label="Data", lines=10, placeholder="Datapoints to test...", value=datapoint) with gr.Row(): random_selection_button = gr.Button("Change Data") stream_button = gr.Button("✨ Click to Streaming ✨") random_selection_button.click( fn=random_data_selection, inputs=[], outputs=[data_textbox] ) random.shuffle(prompts) random_selected_prompts = prompts[:3] # Store prompts in state components state_prompts = gr.State(value=prompts) state_random_selected_prompts = gr.State(value=random_selected_prompts) with gr.Row(): columns = [gr.Textbox(label=f"Prompt {i+1}", lines=10) for i in range(len(random_selected_prompts))] model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct") def start_streaming(data, random_selected_prompts): content_list = [prompt['prompt'] + '\n{' + data + '}\n\nsummary:' for prompt in random_selected_prompts] for response_data in stream_data(content_list, model): updates = [gr.update(value=response_data[i]) for i in range(len(columns))] yield tuple(updates) stream_button.click( fn=start_streaming, inputs=[data_textbox, state_random_selected_prompts], outputs=columns, show_progress=False ) choice = gr.Radio(label="Choose the best response:", choices=["Response 1", "Response 2", "Response 3"]) submit_button = gr.Button("Submit") output = gr.Textbox(label="You selected:", visible=False) def update_prompt_metrics(selected_choice, prompts, random_selected_prompts): if selected_choice == "Response 1": prompt_id = random_selected_prompts[0]['id'] elif selected_choice == "Response 2": prompt_id = random_selected_prompts[1]['id'] elif selected_choice == "Response 3": prompt_id = random_selected_prompts[2]['id'] else: raise ValueError(f"No corresponding response of {selected_choice}") for prompt in prompts: if prompt['id'] == prompt_id: prompt["metric"]["winning_number"] += 1 break else: raise ValueError(f"No prompt of id {prompt_id}") with open("prompt/prompt.json", "w") as f: json.dump(prompts, f) return gr.update(value=f"You selected: {selected_choice}", visible=True), gr.update(interactive=False), gr.update(interactive=False) submit_button.click( fn=update_prompt_metrics, inputs=[choice, state_prompts, state_random_selected_prompts], outputs=[output, choice, submit_button], ) return demo if __name__ == "__main__": demo = create_arena() demo.queue() demo.launch()