Spaces:
Runtime error
Runtime error
refactor: add feedback function, update ui
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
|
|
2 |
|
3 |
import os
|
4 |
import threading
|
|
|
5 |
import time
|
6 |
import argparse
|
7 |
import logging
|
@@ -90,27 +91,17 @@ class SentencePieceStreamer(BaseStreamer):
|
|
90 |
def end(self):
|
91 |
self.ended = True
|
92 |
|
93 |
-
def
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
res = ("", history + [[prompt, None]])
|
98 |
-
return res
|
99 |
-
|
100 |
-
def bot(
|
101 |
-
history,
|
102 |
-
do_sample,
|
103 |
temperature,
|
104 |
repetition_penalty,
|
|
|
|
|
105 |
no_repeat_ngram_size,
|
106 |
-
max_new_tokens,
|
107 |
):
|
108 |
-
|
109 |
-
logging.
|
110 |
-
logging.debug(f"history: {history}")
|
111 |
-
|
112 |
-
# TODO: modify `<br>` back to `\n` based on the original user prinpt
|
113 |
-
prompt = history[-1][0]
|
114 |
|
115 |
tokens = sp.encode(prompt)
|
116 |
input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(model.device)
|
@@ -120,7 +111,6 @@ def bot(
|
|
120 |
max_possilbe_new_tokens = model.config.max_position_embeddings - len(tokens)
|
121 |
max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens)
|
122 |
|
123 |
-
|
124 |
thr = threading.Thread(target=model.generate, args=(), kwargs=dict(
|
125 |
input_ids=input_ids,
|
126 |
do_sample=do_sample,
|
@@ -137,87 +127,123 @@ def bot(
|
|
137 |
))
|
138 |
thr.start()
|
139 |
|
140 |
-
history[-1][1] = ""
|
141 |
while not streamer.ended:
|
142 |
-
history[-1][1] = streamer.generated_text
|
143 |
time.sleep(0.05)
|
144 |
-
yield
|
145 |
|
146 |
# TODO: optimize for final few tokens
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
if gr_interface:
|
152 |
gr_interface.close(verbose=False)
|
153 |
|
154 |
with gr.Blocks() as gr_interface:
|
155 |
-
chatbot = gr.Chatbot(label="StableLM JP Alpha").style(height=500)
|
156 |
-
|
157 |
-
# generation params
|
158 |
-
do_sample = gr.Checkbox(True, label="Do Sample", visible=False)
|
159 |
-
|
160 |
with gr.Row():
|
161 |
-
|
162 |
-
repetition_penalty = gr.Slider(1, 1.5, value=1.2, step=0.05, label="Repetition Penalty")
|
163 |
with gr.Row():
|
164 |
-
|
165 |
-
max_new_tokens = gr.Slider(
|
166 |
-
128,
|
167 |
-
model.config.max_position_embeddings,
|
168 |
-
value=128, step=64, label="Max New Tokens")
|
169 |
-
|
170 |
-
# prompt
|
171 |
-
# TODO: add more options
|
172 |
-
# prompt_options = gr.Dropdown(
|
173 |
-
# choices=[
|
174 |
-
# "運が良かったのか悪かったのか日本に帰ってきたタイミングでコロナが猛威を振るい始め、",
|
175 |
-
# """[問題]に対する[答え]を[選択肢]の中から選んでください。
|
176 |
-
|
177 |
-
# [問題]: ある場所の周辺地域を指す言葉は?
|
178 |
-
# [選択肢]: [空, オレゴン州, 街, 歩道橋, 近辺]
|
179 |
-
# [答え]: 近辺
|
180 |
-
|
181 |
-
# [問題]: 若くて世間に慣れていないことを何という?
|
182 |
-
# [選択肢]: [青っぽい, 若い, ベテラン, 生々しい, 玄人]
|
183 |
-
# [答え]: """
|
184 |
-
# ],
|
185 |
-
# label="Prompt Options",
|
186 |
-
# info="Select 1 option for quick start",
|
187 |
-
# allow_custom_value=False,
|
188 |
-
# )
|
189 |
-
prompt = gr.Textbox(label="Prompt", info="Pro tip: press Enter to submit directly")
|
190 |
-
|
191 |
-
|
192 |
-
# def on_prompt_options_change(pmt_opts, pmt):
|
193 |
-
# return pmt_opts
|
194 |
-
|
195 |
-
# prompt_options.change(on_prompt_options_change, [prompt_options, prompt], prompt)
|
196 |
-
|
197 |
with gr.Row():
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
|
|
222 |
gr_interface.queue(max_size=32, concurrency_count=2)
|
223 |
gr_interface.launch(server_port=args.port, share=args.make_public)
|
|
|
2 |
|
3 |
import os
|
4 |
import threading
|
5 |
+
import arrow
|
6 |
import time
|
7 |
import argparse
|
8 |
import logging
|
|
|
91 |
def end(self):
|
92 |
self.ended = True
|
93 |
|
94 |
+
def generate(
|
95 |
+
prompt,
|
96 |
+
max_new_tokens,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
temperature,
|
98 |
repetition_penalty,
|
99 |
+
|
100 |
+
do_sample,
|
101 |
no_repeat_ngram_size,
|
|
|
102 |
):
|
103 |
+
log = dict(locals())
|
104 |
+
logging.debug(log)
|
|
|
|
|
|
|
|
|
105 |
|
106 |
tokens = sp.encode(prompt)
|
107 |
input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(model.device)
|
|
|
111 |
max_possilbe_new_tokens = model.config.max_position_embeddings - len(tokens)
|
112 |
max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens)
|
113 |
|
|
|
114 |
thr = threading.Thread(target=model.generate, args=(), kwargs=dict(
|
115 |
input_ids=input_ids,
|
116 |
do_sample=do_sample,
|
|
|
127 |
))
|
128 |
thr.start()
|
129 |
|
|
|
130 |
while not streamer.ended:
|
|
|
131 |
time.sleep(0.05)
|
132 |
+
yield streamer.generated_text
|
133 |
|
134 |
# TODO: optimize for final few tokens
|
135 |
+
gen = streamer.generated_text
|
136 |
+
log.update(dict(generation=gen, time=str(arrow.now("+09:00"))))
|
137 |
+
logging.info(log)
|
138 |
+
yield gen
|
139 |
+
|
140 |
+
def process_feedback(
|
141 |
+
rating,
|
142 |
+
prompt,
|
143 |
+
generation,
|
144 |
+
|
145 |
+
max_new_tokens,
|
146 |
+
temperature,
|
147 |
+
repetition_penalty,
|
148 |
+
do_sample,
|
149 |
+
no_repeat_ngram_size,
|
150 |
+
):
|
151 |
+
log = dict(locals())
|
152 |
+
log["time"] = str(arrow.now("+09:00"))
|
153 |
+
logging.info(log)
|
154 |
|
155 |
if gr_interface:
|
156 |
gr_interface.close(verbose=False)
|
157 |
|
158 |
with gr.Blocks() as gr_interface:
|
|
|
|
|
|
|
|
|
|
|
159 |
with gr.Row():
|
160 |
+
gr.Markdown("# 日本語 StableLM Pre-Alpha")
|
|
|
161 |
with gr.Row():
|
162 |
+
gr.Markdown("Description about this page. ホゲホゲ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
with gr.Row():
|
164 |
+
|
165 |
+
# left panel
|
166 |
+
with gr.Column(scale=1):
|
167 |
+
|
168 |
+
# generation params
|
169 |
+
with gr.Box():
|
170 |
+
gr.Markdown("パ���メータ")
|
171 |
+
|
172 |
+
# hidden default params
|
173 |
+
do_sample = gr.Checkbox(True, label="Do Sample", visible=False)
|
174 |
+
no_repeat_ngram_size = gr.Slider(0, 10, value=5, step=1, label="No Repeat Ngram Size", visible=False)
|
175 |
+
|
176 |
+
# visible params
|
177 |
+
max_new_tokens = gr.Slider(
|
178 |
+
128,
|
179 |
+
min(512, model.config.max_position_embeddings),
|
180 |
+
value=128,
|
181 |
+
step=128,
|
182 |
+
label="max tokens",
|
183 |
+
info="生成するトークンの最大数を指定する",
|
184 |
+
)
|
185 |
+
temperature = gr.Slider(
|
186 |
+
0, 1, value=0.7, step=0.05, label="temperature",
|
187 |
+
info="低い値は出力をより集中させて決定論的にする")
|
188 |
+
repetition_penalty = gr.Slider(
|
189 |
+
1, 1.5, value=1.2, step=0.05, label="frequency penalty",
|
190 |
+
info="高い値はAIが繰り返す可能性を減少させる")
|
191 |
+
|
192 |
+
# grouping params for easier reference
|
193 |
+
gr_params = [
|
194 |
+
max_new_tokens,
|
195 |
+
temperature,
|
196 |
+
repetition_penalty,
|
197 |
+
|
198 |
+
do_sample,
|
199 |
+
no_repeat_ngram_size,
|
200 |
+
]
|
201 |
+
|
202 |
+
# right panel
|
203 |
+
with gr.Column(scale=2):
|
204 |
+
# user input block
|
205 |
+
with gr.Box():
|
206 |
+
textbox_prompt = gr.Textbox(
|
207 |
+
label="Human",
|
208 |
+
placeholder="AIに続きを書いて欲しいプロンプト",
|
209 |
+
interactive=True,
|
210 |
+
lines=5,
|
211 |
+
value=""
|
212 |
+
)
|
213 |
+
with gr.Box():
|
214 |
+
with gr.Row():
|
215 |
+
btn_submit = gr.Button(value="実行", variant="primary")
|
216 |
+
btn_stop = gr.Button(value="中止", variant="stop")
|
217 |
+
|
218 |
+
# model output block
|
219 |
+
with gr.Box():
|
220 |
+
textbox_generation = gr.Textbox(
|
221 |
+
label="AI",
|
222 |
+
lines=5,
|
223 |
+
value=""
|
224 |
+
)
|
225 |
+
with gr.Box():
|
226 |
+
with gr.Row():
|
227 |
+
rating_options = [
|
228 |
+
"😫すごく悪い",
|
229 |
+
"😞微妙",
|
230 |
+
"😐アリ",
|
231 |
+
"🙂合格",
|
232 |
+
"😄すごく良い",
|
233 |
+
]
|
234 |
+
btn_ratings = [gr.Button(value=v) for v in rating_options]
|
235 |
+
|
236 |
+
with gr.Box():
|
237 |
+
gr.Markdown("TODO:For more feedback link for google form")
|
238 |
+
|
239 |
+
# event handling
|
240 |
+
inputs = [textbox_prompt] + gr_params
|
241 |
+
click_event = btn_submit.click(generate, inputs, textbox_generation, queue=True)
|
242 |
+
btn_stop.click(None, None, None, cancels=click_event, queue=False)
|
243 |
+
|
244 |
+
for btn_rating in btn_ratings:
|
245 |
+
btn_rating.click(process_feedback, [btn_rating, textbox_prompt, textbox_generation] + gr_params, queue=False)
|
246 |
|
247 |
+
|
248 |
gr_interface.queue(max_size=32, concurrency_count=2)
|
249 |
gr_interface.launch(server_port=args.port, share=args.make_public)
|