Spaces:
Runtime error
Runtime error
feat: add additional params for reducing repetition
Browse files
app.py
CHANGED
@@ -29,8 +29,7 @@ class DefaultArgs:
|
|
29 |
|
30 |
if os.getenv("RUNNING_ON_HF_SPACE"):
|
31 |
login(token=os.getenv("HF_TOKEN"))
|
32 |
-
hf_repo = "
|
33 |
-
|
34 |
args = DefaultArgs()
|
35 |
args.hf_model_name_or_path = hf_repo
|
36 |
args.spm_model_path = hf_hub_download(repo_id=hf_repo, filename="sentencepiece.model")
|
@@ -86,35 +85,35 @@ class SentencePieceStreamer(BaseStreamer):
|
|
86 |
return
|
87 |
|
88 |
self.generated_text += text
|
89 |
-
|
90 |
-
# yield text
|
91 |
|
92 |
def end(self):
|
93 |
self.ended = True
|
94 |
|
95 |
-
def user(
|
96 |
-
logging.
|
97 |
logging.debug(f"[user] history: {history}")
|
98 |
|
99 |
-
res = ("", history + [[
|
100 |
return res
|
101 |
|
102 |
def bot(
|
103 |
history,
|
|
|
104 |
temperature,
|
|
|
|
|
105 |
max_new_tokens,
|
106 |
):
|
107 |
-
logging.
|
108 |
-
logging.
|
|
|
109 |
|
110 |
# TODO: modify `<br>` back to `\n` based on the original user prinpt
|
111 |
prompt = history[-1][0]
|
112 |
|
113 |
tokens = sp.encode(prompt)
|
114 |
input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(model.device)
|
115 |
-
|
116 |
-
# TODO: parametrize setting on UI
|
117 |
-
do_sample = True
|
118 |
|
119 |
streamer = SentencePieceStreamer(sp=sp)
|
120 |
|
@@ -124,14 +123,15 @@ def bot(
|
|
124 |
|
125 |
thr = threading.Thread(target=model.generate, args=(), kwargs=dict(
|
126 |
input_ids=input_ids,
|
|
|
127 |
temperature=temperature,
|
|
|
|
|
128 |
max_new_tokens=max_possilbe_new_tokens,
|
129 |
-
do_sample=do_sample,
|
130 |
streamer=streamer,
|
131 |
# max_length=4096,
|
132 |
# top_k=100,
|
133 |
# top_p=0.9,
|
134 |
-
# repetition_penalty=1.0,
|
135 |
# num_return_sequences=2,
|
136 |
# num_beams=2,
|
137 |
))
|
@@ -145,6 +145,7 @@ def bot(
|
|
145 |
|
146 |
# TODO: optimize for final few tokens
|
147 |
history[-1][1] = streamer.generated_text
|
|
|
148 |
yield history
|
149 |
|
150 |
if gr_interface:
|
@@ -154,8 +155,13 @@ with gr.Blocks() as gr_interface:
|
|
154 |
chatbot = gr.Chatbot(label="StableLM JP Alpha").style(height=500)
|
155 |
|
156 |
# generation params
|
|
|
|
|
157 |
with gr.Row():
|
158 |
temperature = gr.Slider(0, 1, value=0.7, step=0.05, label="Temperature")
|
|
|
|
|
|
|
159 |
max_new_tokens = gr.Slider(
|
160 |
128,
|
161 |
model.config.max_position_embeddings,
|
@@ -191,18 +197,27 @@ with gr.Blocks() as gr_interface:
|
|
191 |
with gr.Row():
|
192 |
submit = gr.Button("Submit")
|
193 |
stop = gr.Button("Stop")
|
194 |
-
|
195 |
-
clear = gr.Button("Clear History")
|
196 |
|
197 |
# event handling
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
stop.click(None, None, None, cancels=[submit_event, submit_click_event], queue=False)
|
205 |
clear.click(lambda: None, None, chatbot, queue=False)
|
206 |
|
207 |
-
gr_interface.queue()
|
208 |
gr_interface.launch(server_port=args.port, share=args.make_public)
|
|
|
29 |
|
30 |
if os.getenv("RUNNING_ON_HF_SPACE"):
|
31 |
login(token=os.getenv("HF_TOKEN"))
|
32 |
+
hf_repo = os.getenv("HF_MODEL_REPO")
|
|
|
33 |
args = DefaultArgs()
|
34 |
args.hf_model_name_or_path = hf_repo
|
35 |
args.spm_model_path = hf_hub_download(repo_id=hf_repo, filename="sentencepiece.model")
|
|
|
85 |
return
|
86 |
|
87 |
self.generated_text += text
|
88 |
+
logging.debug(f"[streamer]: {self.generated_text}")
|
|
|
89 |
|
90 |
def end(self):
|
91 |
self.ended = True
|
92 |
|
93 |
+
def user(prompt, history):
|
94 |
+
logging.info(f"[user] prompt: {prompt}")
|
95 |
logging.debug(f"[user] history: {history}")
|
96 |
|
97 |
+
res = ("", history + [[prompt, None]])
|
98 |
return res
|
99 |
|
100 |
def bot(
|
101 |
history,
|
102 |
+
do_sample,
|
103 |
temperature,
|
104 |
+
repetition_penalty,
|
105 |
+
no_repeat_ngram_size,
|
106 |
max_new_tokens,
|
107 |
):
|
108 |
+
logging.info("[bot]")
|
109 |
+
logging.info(dict(locals()))
|
110 |
+
logging.debug(f"history: {history}")
|
111 |
|
112 |
# TODO: modify `<br>` back to `\n` based on the original user prinpt
|
113 |
prompt = history[-1][0]
|
114 |
|
115 |
tokens = sp.encode(prompt)
|
116 |
input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(model.device)
|
|
|
|
|
|
|
117 |
|
118 |
streamer = SentencePieceStreamer(sp=sp)
|
119 |
|
|
|
123 |
|
124 |
thr = threading.Thread(target=model.generate, args=(), kwargs=dict(
|
125 |
input_ids=input_ids,
|
126 |
+
do_sample=do_sample,
|
127 |
temperature=temperature,
|
128 |
+
repetition_penalty=repetition_penalty,
|
129 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
130 |
max_new_tokens=max_possilbe_new_tokens,
|
|
|
131 |
streamer=streamer,
|
132 |
# max_length=4096,
|
133 |
# top_k=100,
|
134 |
# top_p=0.9,
|
|
|
135 |
# num_return_sequences=2,
|
136 |
# num_beams=2,
|
137 |
))
|
|
|
145 |
|
146 |
# TODO: optimize for final few tokens
|
147 |
history[-1][1] = streamer.generated_text
|
148 |
+
logging.info(f"generation: {history[-1][1]}")
|
149 |
yield history
|
150 |
|
151 |
if gr_interface:
|
|
|
155 |
chatbot = gr.Chatbot(label="StableLM JP Alpha").style(height=500)
|
156 |
|
157 |
# generation params
|
158 |
+
do_sample = gr.Checkbox(True, label="Do Sample", visible=False)
|
159 |
+
|
160 |
with gr.Row():
|
161 |
temperature = gr.Slider(0, 1, value=0.7, step=0.05, label="Temperature")
|
162 |
+
repetition_penalty = gr.Slider(1, 1.5, value=1.2, step=0.05, label="Repetition Penalty")
|
163 |
+
with gr.Row():
|
164 |
+
no_repeat_ngram_size = gr.Slider(0, 10, value=5, step=1, label="No Repeat Ngram Size")
|
165 |
max_new_tokens = gr.Slider(
|
166 |
128,
|
167 |
model.config.max_position_embeddings,
|
|
|
197 |
with gr.Row():
|
198 |
submit = gr.Button("Submit")
|
199 |
stop = gr.Button("Stop")
|
200 |
+
clear = gr.Button("Clear")
|
|
|
201 |
|
202 |
# event handling
|
203 |
+
user_io = [prompt, chatbot]
|
204 |
+
bot_inputs = [
|
205 |
+
chatbot,
|
206 |
+
do_sample,
|
207 |
+
temperature,
|
208 |
+
repetition_penalty,
|
209 |
+
no_repeat_ngram_size,
|
210 |
+
max_new_tokens,
|
211 |
+
]
|
212 |
+
|
213 |
+
submit_event = prompt.submit(user, user_io, user_io, queue=False)\
|
214 |
+
.then(bot, bot_inputs, chatbot, queue=True)
|
215 |
+
|
216 |
+
submit_click_event = submit.click(user, user_io, user_io, queue=False)\
|
217 |
+
.then(bot, bot_inputs, chatbot, queue=True)
|
218 |
|
219 |
stop.click(None, None, None, cancels=[submit_event, submit_click_event], queue=False)
|
220 |
clear.click(lambda: None, None, chatbot, queue=False)
|
221 |
|
222 |
+
gr_interface.queue(max_size=32, concurrency_count=2)
|
223 |
gr_interface.launch(server_port=args.port, share=args.make_public)
|