File size: 2,917 Bytes
1921336
 
 
 
 
 
 
42c830b
f6590f0
9a1ab03
488c5c4
 
 
 
 
 
de53991
34ffea3
 
 
 
f6590f0
1921336
488c5c4
4ea28ea
 
 
 
488c5c4
ed67a17
488c5c4
 
 
 
 
 
fa738bd
488c5c4
1921336
 
9dfac6e
1921336
5f3eeaf
 
 
 
1921336
 
 
 
 
 
 
 
 
 
 
 
 
9a1ab03
f664ce2
 
 
 
 
 
 
 
 
 
de53991
1921336
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#from utils.multiple_stream import create_interface
import random
import gradio as gr
import json

from utils.data import dataset
from utils.multiple_stream import stream_data
from pages.summarization_playground import get_model_batch_generation
from pages.summarization_playground import custom_css

def random_data_selection():
    datapoint = random.choice(dataset)
    datapoint = datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue']

    return datapoint

def create_arena():
    with open("prompt/prompt.json", "r") as file:
        json_data = file.read()
        prompts = json.loads(json_data)

    with gr.Blocks(theme=gr.themes.Soft(spacing_size="sm",text_size="sm"), css=custom_css) as demo:
        with gr.Group():
            datapoint = random_data_selection()
            gr.Markdown("""This arena is designed to compare different prompts. Click the button to stream responses from randomly shuffled prompts. Each column represents a response generated from one randomly selected prompt.

Once the streaming is complete, you can choose the best response.\u2764\ufe0f""")

            data_textbox = gr.Textbox(label="Data", lines=10, placeholder="Datapoints to test...", value=datapoint)
            with gr.Row():
                random_selection_button = gr.Button("Change Data")
                submit_button = gr.Button("✨ Click to Streaming ✨")

            random_selection_button.click(
                fn=random_data_selection,
                inputs=[],
                outputs=[data_textbox]
            )
    
            with gr.Row():
                columns = [gr.Textbox(label=f"Prompt {i+1}", lines=10) for i in range(len(prompts))]
            
            random.shuffle(prompts)
            prompts = prompts[:3]
            
            content_list = [prompt['prompt'] + '\n{' + data_textbox.value + '}\n\nsummary:' for prompt in prompts]
            model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")

            def start_streaming():
                for data in stream_data(content_list, model):
                    updates = [gr.update(value=data[i]) for i in range(len(columns))]
                    yield tuple(updates)
            
            submit_button.click(
                fn=start_streaming,
                inputs=[],
                outputs=columns,
                show_progress=False
            )

            choice = gr.Radio(label="Choose the best response:", choices=["Response 1", "Response 2", "Response 3"])
    
            submit_button = gr.Button("Submit")

            submit_button.click(
                #fn=lambda response1, response2, response3, choice: save_to_db(eval(choice.lower())),
                inputs=[choice],
                outputs=f"Response '{choice}' saved successfully!"
            )

    return demo

if __name__ == "__main__":
    demo = create_arena()
    demo.queue()
    demo.launch()