leemeng commited on
Commit
af43a3d
·
1 Parent(s): 433d9a2

refactor: add feedback function, update ui

Browse files
Files changed (1) hide show
  1. app.py +113 -87
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
 
3
  import os
4
  import threading
 
5
  import time
6
  import argparse
7
  import logging
@@ -90,27 +91,17 @@ class SentencePieceStreamer(BaseStreamer):
90
  def end(self):
91
  self.ended = True
92
 
93
- def user(prompt, history):
94
- logging.info(f"[user] prompt: {prompt}")
95
- logging.debug(f"[user] history: {history}")
96
-
97
- res = ("", history + [[prompt, None]])
98
- return res
99
-
100
- def bot(
101
- history,
102
- do_sample,
103
  temperature,
104
  repetition_penalty,
 
 
105
  no_repeat_ngram_size,
106
- max_new_tokens,
107
  ):
108
- logging.info("[bot]")
109
- logging.info(dict(locals()))
110
- logging.debug(f"history: {history}")
111
-
112
- # TODO: modify `<br>` back to `\n` based on the original user prinpt
113
- prompt = history[-1][0]
114
 
115
  tokens = sp.encode(prompt)
116
  input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(model.device)
@@ -120,7 +111,6 @@ def bot(
120
  max_possilbe_new_tokens = model.config.max_position_embeddings - len(tokens)
121
  max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens)
122
 
123
-
124
  thr = threading.Thread(target=model.generate, args=(), kwargs=dict(
125
  input_ids=input_ids,
126
  do_sample=do_sample,
@@ -137,87 +127,123 @@ def bot(
137
  ))
138
  thr.start()
139
 
140
- history[-1][1] = ""
141
  while not streamer.ended:
142
- history[-1][1] = streamer.generated_text
143
  time.sleep(0.05)
144
- yield history
145
 
146
  # TODO: optimize for final few tokens
147
- history[-1][1] = streamer.generated_text
148
- logging.info(f"generation: {history[-1][1]}")
149
- yield history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  if gr_interface:
152
  gr_interface.close(verbose=False)
153
 
154
  with gr.Blocks() as gr_interface:
155
- chatbot = gr.Chatbot(label="StableLM JP Alpha").style(height=500)
156
-
157
- # generation params
158
- do_sample = gr.Checkbox(True, label="Do Sample", visible=False)
159
-
160
  with gr.Row():
161
- temperature = gr.Slider(0, 1, value=0.7, step=0.05, label="Temperature")
162
- repetition_penalty = gr.Slider(1, 1.5, value=1.2, step=0.05, label="Repetition Penalty")
163
  with gr.Row():
164
- no_repeat_ngram_size = gr.Slider(0, 10, value=5, step=1, label="No Repeat Ngram Size")
165
- max_new_tokens = gr.Slider(
166
- 128,
167
- model.config.max_position_embeddings,
168
- value=128, step=64, label="Max New Tokens")
169
-
170
- # prompt
171
- # TODO: add more options
172
- # prompt_options = gr.Dropdown(
173
- # choices=[
174
- # "運が良かったのか悪かったのか日本に帰ってきたタイミングでコロナが猛威を振るい始め、",
175
- # """[問題]に対する[答え]を[選択肢]の中から選んでください。
176
-
177
- # [問題]: ある場所の周辺地域を指す言葉は?
178
- # [選択肢]: [空, オレゴン州, 街, 歩道橋, 近辺]
179
- # [答え]: 近辺
180
-
181
- # [問題]: 若くて世間に慣れていないことを何という?
182
- # [選択肢]: [青っぽい, 若い, ベテラン, 生々しい, 玄人]
183
- # [答え]: """
184
- # ],
185
- # label="Prompt Options",
186
- # info="Select 1 option for quick start",
187
- # allow_custom_value=False,
188
- # )
189
- prompt = gr.Textbox(label="Prompt", info="Pro tip: press Enter to submit directly")
190
-
191
-
192
- # def on_prompt_options_change(pmt_opts, pmt):
193
- # return pmt_opts
194
-
195
- # prompt_options.change(on_prompt_options_change, [prompt_options, prompt], prompt)
196
-
197
  with gr.Row():
198
- submit = gr.Button("Submit")
199
- stop = gr.Button("Stop")
200
- clear = gr.Button("Clear")
201
-
202
- # event handling
203
- user_io = [prompt, chatbot]
204
- bot_inputs = [
205
- chatbot,
206
- do_sample,
207
- temperature,
208
- repetition_penalty,
209
- no_repeat_ngram_size,
210
- max_new_tokens,
211
- ]
212
-
213
- submit_event = prompt.submit(user, user_io, user_io, queue=False)\
214
- .then(bot, bot_inputs, chatbot, queue=True)
215
-
216
- submit_click_event = submit.click(user, user_io, user_io, queue=False)\
217
- .then(bot, bot_inputs, chatbot, queue=True)
218
-
219
- stop.click(None, None, None, cancels=[submit_event, submit_click_event], queue=False)
220
- clear.click(lambda: None, None, chatbot, queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
 
222
  gr_interface.queue(max_size=32, concurrency_count=2)
223
  gr_interface.launch(server_port=args.port, share=args.make_public)
 
2
 
3
  import os
4
  import threading
5
+ import arrow
6
  import time
7
  import argparse
8
  import logging
 
91
  def end(self):
92
  self.ended = True
93
 
94
+ def generate(
95
+ prompt,
96
+ max_new_tokens,
 
 
 
 
 
 
 
97
  temperature,
98
  repetition_penalty,
99
+
100
+ do_sample,
101
  no_repeat_ngram_size,
 
102
  ):
103
+ log = dict(locals())
104
+ logging.debug(log)
 
 
 
 
105
 
106
  tokens = sp.encode(prompt)
107
  input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(model.device)
 
111
  max_possilbe_new_tokens = model.config.max_position_embeddings - len(tokens)
112
  max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens)
113
 
 
114
  thr = threading.Thread(target=model.generate, args=(), kwargs=dict(
115
  input_ids=input_ids,
116
  do_sample=do_sample,
 
127
  ))
128
  thr.start()
129
 
 
130
  while not streamer.ended:
 
131
  time.sleep(0.05)
132
+ yield streamer.generated_text
133
 
134
  # TODO: optimize for final few tokens
135
+ gen = streamer.generated_text
136
+ log.update(dict(generation=gen, time=str(arrow.now("+09:00"))))
137
+ logging.info(log)
138
+ yield gen
139
+
140
+ def process_feedback(
141
+ rating,
142
+ prompt,
143
+ generation,
144
+
145
+ max_new_tokens,
146
+ temperature,
147
+ repetition_penalty,
148
+ do_sample,
149
+ no_repeat_ngram_size,
150
+ ):
151
+ log = dict(locals())
152
+ log["time"] = str(arrow.now("+09:00"))
153
+ logging.info(log)
154
 
155
  if gr_interface:
156
  gr_interface.close(verbose=False)
157
 
158
  with gr.Blocks() as gr_interface:
 
 
 
 
 
159
  with gr.Row():
160
+ gr.Markdown("# 日本語 StableLM Pre-Alpha")
 
161
  with gr.Row():
162
+ gr.Markdown("Description about this page. ホゲホゲ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  with gr.Row():
164
+
165
+ # left panel
166
+ with gr.Column(scale=1):
167
+
168
+ # generation params
169
+ with gr.Box():
170
+ gr.Markdown("パ���メータ")
171
+
172
+ # hidden default params
173
+ do_sample = gr.Checkbox(True, label="Do Sample", visible=False)
174
+ no_repeat_ngram_size = gr.Slider(0, 10, value=5, step=1, label="No Repeat Ngram Size", visible=False)
175
+
176
+ # visible params
177
+ max_new_tokens = gr.Slider(
178
+ 128,
179
+ min(512, model.config.max_position_embeddings),
180
+ value=128,
181
+ step=128,
182
+ label="max tokens",
183
+ info="生成するトークンの最大数を指定する",
184
+ )
185
+ temperature = gr.Slider(
186
+ 0, 1, value=0.7, step=0.05, label="temperature",
187
+ info="低い値は出力をより集中させて決定論的にする")
188
+ repetition_penalty = gr.Slider(
189
+ 1, 1.5, value=1.2, step=0.05, label="frequency penalty",
190
+ info="高い値はAIが繰り返す可能性を減少させる")
191
+
192
+ # grouping params for easier reference
193
+ gr_params = [
194
+ max_new_tokens,
195
+ temperature,
196
+ repetition_penalty,
197
+
198
+ do_sample,
199
+ no_repeat_ngram_size,
200
+ ]
201
+
202
+ # right panel
203
+ with gr.Column(scale=2):
204
+ # user input block
205
+ with gr.Box():
206
+ textbox_prompt = gr.Textbox(
207
+ label="Human",
208
+ placeholder="AIに続きを書いて欲しいプロンプト",
209
+ interactive=True,
210
+ lines=5,
211
+ value=""
212
+ )
213
+ with gr.Box():
214
+ with gr.Row():
215
+ btn_submit = gr.Button(value="実行", variant="primary")
216
+ btn_stop = gr.Button(value="中止", variant="stop")
217
+
218
+ # model output block
219
+ with gr.Box():
220
+ textbox_generation = gr.Textbox(
221
+ label="AI",
222
+ lines=5,
223
+ value=""
224
+ )
225
+ with gr.Box():
226
+ with gr.Row():
227
+ rating_options = [
228
+ "😫すごく悪い",
229
+ "😞微妙",
230
+ "😐アリ",
231
+ "🙂合格",
232
+ "😄すごく良い",
233
+ ]
234
+ btn_ratings = [gr.Button(value=v) for v in rating_options]
235
+
236
+ with gr.Box():
237
+ gr.Markdown("TODO:For more feedback link for google form")
238
+
239
+ # event handling
240
+ inputs = [textbox_prompt] + gr_params
241
+ click_event = btn_submit.click(generate, inputs, textbox_generation, queue=True)
242
+ btn_stop.click(None, None, None, cancels=click_event, queue=False)
243
+
244
+ for btn_rating in btn_ratings:
245
+ btn_rating.click(process_feedback, [btn_rating, textbox_prompt, textbox_generation] + gr_params, queue=False)
246
 
247
+
248
  gr_interface.queue(max_size=32, concurrency_count=2)
249
  gr_interface.launch(server_port=args.port, share=args.make_public)