leemeng commited on
Commit
433d9a2
·
1 Parent(s): 54ac395

feat: add additional params for reducing repetition

Browse files
Files changed (1) hide show
  1. app.py +37 -22
app.py CHANGED
@@ -29,8 +29,7 @@ class DefaultArgs:
29
 
30
  if os.getenv("RUNNING_ON_HF_SPACE"):
31
  login(token=os.getenv("HF_TOKEN"))
32
- hf_repo = "leemeng/stablelm-jp-alpha"
33
-
34
  args = DefaultArgs()
35
  args.hf_model_name_or_path = hf_repo
36
  args.spm_model_path = hf_hub_download(repo_id=hf_repo, filename="sentencepiece.model")
@@ -86,35 +85,35 @@ class SentencePieceStreamer(BaseStreamer):
86
  return
87
 
88
  self.generated_text += text
89
- # print(f"[streamer]: {self.generated_text}")
90
- # yield text
91
 
92
  def end(self):
93
  self.ended = True
94
 
95
- def user(user_message, history):
96
- logging.debug(f"[user] user_message: {user_message}")
97
  logging.debug(f"[user] history: {history}")
98
 
99
- res = ("", history + [[user_message, None]])
100
  return res
101
 
102
  def bot(
103
  history,
 
104
  temperature,
 
 
105
  max_new_tokens,
106
  ):
107
- logging.debug(f"[bot] history: {history}")
108
- logging.debug(f"temperature: {temperature}")
 
109
 
110
  # TODO: modify `<br>` back to `\n` based on the original user prinpt
111
  prompt = history[-1][0]
112
 
113
  tokens = sp.encode(prompt)
114
  input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(model.device)
115
-
116
- # TODO: parametrize setting on UI
117
- do_sample = True
118
 
119
  streamer = SentencePieceStreamer(sp=sp)
120
 
@@ -124,14 +123,15 @@ def bot(
124
 
125
  thr = threading.Thread(target=model.generate, args=(), kwargs=dict(
126
  input_ids=input_ids,
 
127
  temperature=temperature,
 
 
128
  max_new_tokens=max_possilbe_new_tokens,
129
- do_sample=do_sample,
130
  streamer=streamer,
131
  # max_length=4096,
132
  # top_k=100,
133
  # top_p=0.9,
134
- # repetition_penalty=1.0,
135
  # num_return_sequences=2,
136
  # num_beams=2,
137
  ))
@@ -145,6 +145,7 @@ def bot(
145
 
146
  # TODO: optimize for final few tokens
147
  history[-1][1] = streamer.generated_text
 
148
  yield history
149
 
150
  if gr_interface:
@@ -154,8 +155,13 @@ with gr.Blocks() as gr_interface:
154
  chatbot = gr.Chatbot(label="StableLM JP Alpha").style(height=500)
155
 
156
  # generation params
 
 
157
  with gr.Row():
158
  temperature = gr.Slider(0, 1, value=0.7, step=0.05, label="Temperature")
 
 
 
159
  max_new_tokens = gr.Slider(
160
  128,
161
  model.config.max_position_embeddings,
@@ -191,18 +197,27 @@ with gr.Blocks() as gr_interface:
191
  with gr.Row():
192
  submit = gr.Button("Submit")
193
  stop = gr.Button("Stop")
194
-
195
- clear = gr.Button("Clear History")
196
 
197
  # event handling
198
- submit_event = prompt.submit(user, [prompt, chatbot], [prompt, chatbot], queue=False)\
199
- .then(bot, [chatbot, temperature, max_new_tokens], chatbot, queue=True)
200
-
201
- submit_click_event = submit.click(user, [prompt, chatbot], [prompt, chatbot], queue=False)\
202
- .then(bot, [chatbot, temperature, max_new_tokens], chatbot, queue=True)
 
 
 
 
 
 
 
 
 
 
203
 
204
  stop.click(None, None, None, cancels=[submit_event, submit_click_event], queue=False)
205
  clear.click(lambda: None, None, chatbot, queue=False)
206
 
207
- gr_interface.queue()
208
  gr_interface.launch(server_port=args.port, share=args.make_public)
 
29
 
30
  if os.getenv("RUNNING_ON_HF_SPACE"):
31
  login(token=os.getenv("HF_TOKEN"))
32
+ hf_repo = os.getenv("HF_MODEL_REPO")
 
33
  args = DefaultArgs()
34
  args.hf_model_name_or_path = hf_repo
35
  args.spm_model_path = hf_hub_download(repo_id=hf_repo, filename="sentencepiece.model")
 
85
  return
86
 
87
  self.generated_text += text
88
+ logging.debug(f"[streamer]: {self.generated_text}")
 
89
 
90
  def end(self):
91
  self.ended = True
92
 
93
+ def user(prompt, history):
94
+ logging.info(f"[user] prompt: {prompt}")
95
  logging.debug(f"[user] history: {history}")
96
 
97
+ res = ("", history + [[prompt, None]])
98
  return res
99
 
100
  def bot(
101
  history,
102
+ do_sample,
103
  temperature,
104
+ repetition_penalty,
105
+ no_repeat_ngram_size,
106
  max_new_tokens,
107
  ):
108
+ logging.info("[bot]")
109
+ logging.info(dict(locals()))
110
+ logging.debug(f"history: {history}")
111
 
112
  # TODO: modify `<br>` back to `\n` based on the original user prinpt
113
  prompt = history[-1][0]
114
 
115
  tokens = sp.encode(prompt)
116
  input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(model.device)
 
 
 
117
 
118
  streamer = SentencePieceStreamer(sp=sp)
119
 
 
123
 
124
  thr = threading.Thread(target=model.generate, args=(), kwargs=dict(
125
  input_ids=input_ids,
126
+ do_sample=do_sample,
127
  temperature=temperature,
128
+ repetition_penalty=repetition_penalty,
129
+ no_repeat_ngram_size=no_repeat_ngram_size,
130
  max_new_tokens=max_possilbe_new_tokens,
 
131
  streamer=streamer,
132
  # max_length=4096,
133
  # top_k=100,
134
  # top_p=0.9,
 
135
  # num_return_sequences=2,
136
  # num_beams=2,
137
  ))
 
145
 
146
  # TODO: optimize for final few tokens
147
  history[-1][1] = streamer.generated_text
148
+ logging.info(f"generation: {history[-1][1]}")
149
  yield history
150
 
151
  if gr_interface:
 
155
  chatbot = gr.Chatbot(label="StableLM JP Alpha").style(height=500)
156
 
157
  # generation params
158
+ do_sample = gr.Checkbox(True, label="Do Sample", visible=False)
159
+
160
  with gr.Row():
161
  temperature = gr.Slider(0, 1, value=0.7, step=0.05, label="Temperature")
162
+ repetition_penalty = gr.Slider(1, 1.5, value=1.2, step=0.05, label="Repetition Penalty")
163
+ with gr.Row():
164
+ no_repeat_ngram_size = gr.Slider(0, 10, value=5, step=1, label="No Repeat Ngram Size")
165
  max_new_tokens = gr.Slider(
166
  128,
167
  model.config.max_position_embeddings,
 
197
  with gr.Row():
198
  submit = gr.Button("Submit")
199
  stop = gr.Button("Stop")
200
+ clear = gr.Button("Clear")
 
201
 
202
  # event handling
203
+ user_io = [prompt, chatbot]
204
+ bot_inputs = [
205
+ chatbot,
206
+ do_sample,
207
+ temperature,
208
+ repetition_penalty,
209
+ no_repeat_ngram_size,
210
+ max_new_tokens,
211
+ ]
212
+
213
+ submit_event = prompt.submit(user, user_io, user_io, queue=False)\
214
+ .then(bot, bot_inputs, chatbot, queue=True)
215
+
216
+ submit_click_event = submit.click(user, user_io, user_io, queue=False)\
217
+ .then(bot, bot_inputs, chatbot, queue=True)
218
 
219
  stop.click(None, None, None, cancels=[submit_event, submit_click_event], queue=False)
220
  clear.click(lambda: None, None, chatbot, queue=False)
221
 
222
+ gr_interface.queue(max_size=32, concurrency_count=2)
223
  gr_interface.launch(server_port=args.port, share=args.make_public)