zR commited on
Commit
e64071c
·
1 Parent(s): c9f689e
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -74,7 +74,7 @@ def preprocess_messages(history, img_path, platform_str, format_str):
74
 
75
 
76
  @spaces.GPU()
77
- def predict(history, max_length, top_p, temperature, img_path, platform_str, format_str, output_dir):
78
  # Reset the stop_event at the start of prediction
79
  stop_event.clear()
80
 
@@ -98,9 +98,7 @@ def predict(history, max_length, top_p, temperature, img_path, platform_str, for
98
  "attention_mask": model_inputs["attention_mask"].to(model.device),
99
  "streamer": streamer,
100
  "max_new_tokens": max_length,
101
- "do_sample": True if temperature > 0.0 else False,
102
- "top_p": top_p,
103
- "temperature": temperature,
104
  "top_k": 1,
105
  }
106
  t = Thread(target=model.generate, kwargs=generate_kwargs)
@@ -201,8 +199,6 @@ def main():
201
  submitBtn = gr.Button("Submit")
202
  with gr.Column(scale=1):
203
  max_length = gr.Slider(0, 8192, value=1024, step=1.0, label="Maximum length", interactive=True)
204
- top_p = gr.Slider(0, 1, value=0.0, step=0.01, label="Top P", interactive=True)
205
- temperature = gr.Slider(0.01, 1, value=0.0, step=0.01, label="Temperature", interactive=True)
206
  undo_last_round_btn = gr.Button("Back to Last Round")
207
  clear_history_btn = gr.Button("Clear All History")
208
 
@@ -213,7 +209,7 @@ def main():
213
  user, [task, chatbot], [task, chatbot], queue=False
214
  ).then(
215
  predict,
216
- [chatbot, max_length, top_p, temperature, img_path, gr.State(platform_str), gr.State(format_str),
217
  gr.State(args.output_dir)],
218
  [chatbot, output_img],
219
  queue=True
 
74
 
75
 
76
  @spaces.GPU()
77
+ def predict(history, max_length, img_path, platform_str, format_str, output_dir):
78
  # Reset the stop_event at the start of prediction
79
  stop_event.clear()
80
 
 
98
  "attention_mask": model_inputs["attention_mask"].to(model.device),
99
  "streamer": streamer,
100
  "max_new_tokens": max_length,
101
+ "do_sample": True,
 
 
102
  "top_k": 1,
103
  }
104
  t = Thread(target=model.generate, kwargs=generate_kwargs)
 
199
  submitBtn = gr.Button("Submit")
200
  with gr.Column(scale=1):
201
  max_length = gr.Slider(0, 8192, value=1024, step=1.0, label="Maximum length", interactive=True)
 
 
202
  undo_last_round_btn = gr.Button("Back to Last Round")
203
  clear_history_btn = gr.Button("Clear All History")
204
 
 
209
  user, [task, chatbot], [task, chatbot], queue=False
210
  ).then(
211
  predict,
212
+ [chatbot, max_length, img_path, gr.State(platform_str), gr.State(format_str),
213
  gr.State(args.output_dir)],
214
  [chatbot, output_img],
215
  queue=True