kz209 commited on
Commit
dd681d0
1 Parent(s): 9e26af4
Files changed (1) hide show
  1. pages/arena.py +37 -39
pages/arena.py CHANGED
@@ -1,4 +1,3 @@
1
- #from utils.multiple_stream import create_interface
2
  import random
3
  import gradio as gr
4
  import json
@@ -8,26 +7,17 @@ from utils.multiple_stream import stream_data
8
  from pages.summarization_playground import get_model_batch_generation
9
  from pages.summarization_playground import custom_css
10
 
11
- global global_selected_choice
12
-
13
  def random_data_selection():
14
  datapoint = random.choice(dataset)
15
  datapoint = datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue']
16
-
17
  return datapoint
18
 
19
- # Function to handle user selection and disable the radio
20
- def lock_selection(selected_option):
21
- global global_selected_choice
22
- global_selected_choice = selected_option # Store the selected choice in the variable
23
- return gr.update(visible=True), selected_option, gr.update(interactive=False), gr.update(interactive=False)
24
-
25
  def create_arena():
26
  with open("prompt/prompt.json", "r") as file:
27
  json_data = file.read()
28
  prompts = json.loads(json_data)
29
 
30
- with gr.Blocks(theme=gr.themes.Soft(spacing_size="sm",text_size="sm"), css=custom_css) as demo:
31
  with gr.Group():
32
  datapoint = random_data_selection()
33
  gr.Markdown("""This arena is designed to compare different prompts. Click the button to stream responses from randomly shuffled prompts. Each column represents a response generated from one randomly selected prompt.
@@ -37,7 +27,7 @@ Once the streaming is complete, you can choose the best response.\u2764\ufe0f"""
37
  data_textbox = gr.Textbox(label="Data", lines=10, placeholder="Datapoints to test...", value=datapoint)
38
  with gr.Row():
39
  random_selection_button = gr.Button("Change Data")
40
- submit_button = gr.Button("✨ Click to Streaming ✨")
41
 
42
  random_selection_button.click(
43
  fn=random_data_selection,
@@ -47,21 +37,25 @@ Once the streaming is complete, you can choose the best response.\u2764\ufe0f"""
47
 
48
  random.shuffle(prompts)
49
  random_selected_prompts = prompts[:3]
 
 
 
 
50
 
51
  with gr.Row():
52
  columns = [gr.Textbox(label=f"Prompt {i+1}", lines=10) for i in range(len(random_selected_prompts))]
53
 
54
- content_list = [prompt['prompt'] + '\n{' + data_textbox.value + '}\n\nsummary:' for prompt in random_selected_prompts]
55
  model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")
56
 
57
- def start_streaming():
58
- for data in stream_data(content_list, model):
59
- updates = [gr.update(value=data[i]) for i in range(len(columns))]
 
60
  yield tuple(updates)
61
 
62
- submit_button.click(
63
  fn=start_streaming,
64
- inputs=[],
65
  outputs=columns,
66
  show_progress=False
67
  )
@@ -70,31 +64,35 @@ Once the streaming is complete, you can choose the best response.\u2764\ufe0f"""
70
 
71
  submit_button = gr.Button("Submit")
72
 
73
- # Output to display the selected option
74
  output = gr.Textbox(label="You selected:", visible=False)
75
 
76
- submit_button.click(fn=lock_selection, inputs=choice, outputs=[output, output, choice, submit_button])
77
-
78
- global global_selected_choice
79
- if global_selected_choice == "Response 1":
80
- prompt_id = random_selected_prompts[0]
81
- elif global_selected_choice == "Response 2":
82
- prompt_id = random_selected_prompts[1]
83
- elif global_selected_choice == "Response 3":
84
- prompt_id = random_selected_prompts[2]
85
- else:
86
- raise ValueError(f"No corresponding response of {global_selected_choice}")
87
-
88
- for i in range(len(prompts)):
89
- if prompts[i]['id'] == prompt_id:
90
- prompts[i]["metric"]["winning_number"] += 1
91
- break
92
-
93
- if i == len(prompts)-1:
94
  raise ValueError(f"No prompt of id {prompt_id}")
95
 
96
- with open("prompt/prompt.json", "w") as f:
97
- json.dump(prompts, f)
 
 
 
 
 
 
 
 
98
 
99
  return demo
100
 
 
 
1
  import random
2
  import gradio as gr
3
  import json
 
7
  from pages.summarization_playground import get_model_batch_generation
8
  from pages.summarization_playground import custom_css
9
 
 
 
10
  def random_data_selection():
11
  datapoint = random.choice(dataset)
12
  datapoint = datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue']
 
13
  return datapoint
14
 
 
 
 
 
 
 
15
  def create_arena():
16
  with open("prompt/prompt.json", "r") as file:
17
  json_data = file.read()
18
  prompts = json.loads(json_data)
19
 
20
+ with gr.Blocks(theme=gr.themes.Soft().set(spacing_size="sm", text_size="sm"), css=custom_css) as demo:
21
  with gr.Group():
22
  datapoint = random_data_selection()
23
  gr.Markdown("""This arena is designed to compare different prompts. Click the button to stream responses from randomly shuffled prompts. Each column represents a response generated from one randomly selected prompt.
 
27
  data_textbox = gr.Textbox(label="Data", lines=10, placeholder="Datapoints to test...", value=datapoint)
28
  with gr.Row():
29
  random_selection_button = gr.Button("Change Data")
30
+ stream_button = gr.Button("✨ Click to Streaming ✨")
31
 
32
  random_selection_button.click(
33
  fn=random_data_selection,
 
37
 
38
  random.shuffle(prompts)
39
  random_selected_prompts = prompts[:3]
40
+
41
+ # Store prompts in state components
42
+ state_prompts = gr.State(value=prompts)
43
+ state_random_selected_prompts = gr.State(value=random_selected_prompts)
44
 
45
  with gr.Row():
46
  columns = [gr.Textbox(label=f"Prompt {i+1}", lines=10) for i in range(len(random_selected_prompts))]
47
 
 
48
  model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")
49
 
50
+ def start_streaming(data, random_selected_prompts):
51
+ content_list = [prompt['prompt'] + '\n{' + data + '}\n\nsummary:' for prompt in random_selected_prompts]
52
+ for response_data in stream_data(content_list, model):
53
+ updates = [gr.update(value=response_data[i]) for i in range(len(columns))]
54
  yield tuple(updates)
55
 
56
+ stream_button.click(
57
  fn=start_streaming,
58
+ inputs=[data_textbox, state_random_selected_prompts],
59
  outputs=columns,
60
  show_progress=False
61
  )
 
64
 
65
  submit_button = gr.Button("Submit")
66
 
 
67
  output = gr.Textbox(label="You selected:", visible=False)
68
 
69
+ def update_prompt_metrics(selected_choice, prompts, random_selected_prompts):
70
+ if selected_choice == "Response 1":
71
+ prompt_id = random_selected_prompts[0]['id']
72
+ elif selected_choice == "Response 2":
73
+ prompt_id = random_selected_prompts[1]['id']
74
+ elif selected_choice == "Response 3":
75
+ prompt_id = random_selected_prompts[2]['id']
76
+ else:
77
+ raise ValueError(f"No corresponding response of {selected_choice}")
78
+
79
+ for prompt in prompts:
80
+ if prompt['id'] == prompt_id:
81
+ prompt["metric"]["winning_number"] += 1
82
+ break
83
+ else:
 
 
 
84
  raise ValueError(f"No prompt of id {prompt_id}")
85
 
86
+ with open("prompt/prompt.json", "w") as f:
87
+ json.dump(prompts, f)
88
+
89
+ return gr.update(value=f"You selected: {selected_choice}", visible=True), gr.update(interactive=False), gr.update(interactive=False)
90
+
91
+ submit_button.click(
92
+ fn=update_prompt_metrics,
93
+ inputs=[choice, state_prompts, state_random_selected_prompts],
94
+ outputs=[output, choice, submit_button],
95
+ )
96
 
97
  return demo
98