joaogante HF staff commited on
Commit
c8e9390
·
1 Parent(s): 6f5c011

replicate to other generation types

Browse files
Files changed (1) hide show
  1. app.py +53 -4
app.py CHANGED
@@ -153,7 +153,7 @@ def get_plot(model_name, plot_eager, generate_type):
153
  ci="sd", palette="dark", alpha=.6, height=6
154
  )
155
  g.despine(left=True)
156
- g.set_axis_labels("GPU", "Generation time (ms)")
157
  g.legend.set_title("Framework")
158
  return plt.gcf()
159
 
@@ -164,7 +164,7 @@ with demo:
164
  """
165
  # TensorFlow XLA Text Generation Benchmark
166
  Pick a tab for the type of generation (or other information), and then select a model from the dropdown menu.
167
- You can also ommit results from TensorFlow Eager Execution, if you wish to better compare the performance of
168
  PyTorch to TensorFlow with XLA.
169
  """
170
  )
@@ -195,9 +195,58 @@ with demo:
195
  model_selector.change(fn=plot_fn, inputs=[model_selector, eager_enabler], outputs=plot)
196
  eager_enabler.change(fn=plot_fn, inputs=[model_selector, eager_enabler], outputs=plot)
197
  with gr.TabItem("Sample"):
198
- gr.Button("New Tiger")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  with gr.TabItem("Beam Search"):
200
- gr.Button("New Tiger")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  with gr.TabItem("Benchmark Information"):
202
  gr.Dataframe(
203
  headers=["Parameter", "Value"],
 
153
  ci="sd", palette="dark", alpha=.6, height=6
154
  )
155
  g.despine(left=True)
156
+ g.set_axis_labels("GPU", "Generation time (ms) -- LOWER IS BETTER")
157
  g.legend.set_title("Framework")
158
  return plt.gcf()
159
 
 
164
  """
165
  # TensorFlow XLA Text Generation Benchmark
166
  Pick a tab for the type of generation (or other information), and then select a model from the dropdown menu.
167
+ You can also omit results from TensorFlow Eager Execution, if you wish to better compare the performance of
168
  PyTorch to TensorFlow with XLA.
169
  """
170
  )
 
195
  model_selector.change(fn=plot_fn, inputs=[model_selector, eager_enabler], outputs=plot)
196
  eager_enabler.change(fn=plot_fn, inputs=[model_selector, eager_enabler], outputs=plot)
197
  with gr.TabItem("Sample"):
198
+ gr.Markdown(
199
+ """
200
+ ### Sample benchmark parameters
201
+ - `max_new_tokens = 128`;
202
+ - `temperature = 2.0`;
203
+ - `top_k = 50`;
204
+ - `pad_to_multiple_of = 64` for Tensorflow XLA models. Others do not pad (input prompts between 2 and 33 tokens).
205
+ """
206
+ )
207
+ with gr.Row():
208
+ model_selector = gr.Dropdown(
209
+ choices=["DistilGPT2", "GPT2", "OPT-1.3B", "GPTJ-6B", "T5 Small", "T5 Base", "T5 Large", "T5 3B"],
210
+ value="T5 Small",
211
+ label="Model",
212
+ interactive=True,
213
+ )
214
+ eager_enabler = gr.Radio(
215
+ ["Yes", "No"],
216
+ value="Yes",
217
+ label="Plot TF Eager Execution?",
218
+ interactive=True
219
+ )
220
+ plot_fn = functools.partial(get_plot, generate_type="Sample")
221
+ plot = gr.Plot(value=plot_fn("T5 Small", "Yes")) # Show plot when the gradio app is initialized
222
+ model_selector.change(fn=plot_fn, inputs=[model_selector, eager_enabler], outputs=plot)
223
+ eager_enabler.change(fn=plot_fn, inputs=[model_selector, eager_enabler], outputs=plot)
224
  with gr.TabItem("Beam Search"):
225
+ gr.Markdown(
226
+ """
227
+ ### Beam Search benchmark parameters
228
+ - `max_new_tokens = 256`;
229
+ - `num_beams = 16`;
230
+ - `pad_to_multiple_of = 64` for Tensorflow XLA models. Others do not pad (input prompts between 2 and 33 tokens).
231
+ """
232
+ )
233
+ with gr.Row():
234
+ model_selector = gr.Dropdown(
235
+ choices=["DistilGPT2", "GPT2", "OPT-1.3B", "GPTJ-6B", "T5 Small", "T5 Base", "T5 Large", "T5 3B"],
236
+ value="T5 Small",
237
+ label="Model",
238
+ interactive=True,
239
+ )
240
+ eager_enabler = gr.Radio(
241
+ ["Yes", "No"],
242
+ value="Yes",
243
+ label="Plot TF Eager Execution?",
244
+ interactive=True
245
+ )
246
+ plot_fn = functools.partial(get_plot, generate_type="Beam Search")
247
+ plot = gr.Plot(value=plot_fn("T5 Small", "Yes")) # Show plot when the gradio app is initialized
248
+ model_selector.change(fn=plot_fn, inputs=[model_selector, eager_enabler], outputs=plot)
249
+ eager_enabler.change(fn=plot_fn, inputs=[model_selector, eager_enabler], outputs=plot)
250
  with gr.TabItem("Benchmark Information"):
251
  gr.Dataframe(
252
  headers=["Parameter", "Value"],