Hack90 commited on
Commit
c467935
·
verified ·
1 Parent(s): 4206a2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -1004,7 +1004,8 @@ with ui.navset_card_tab(id="tab"):
1004
  multiple=True,
1005
  selected=["compliment", "cross_entropy", "headless"]
1006
  )
1007
- def plot_loss_rates_model(df, param_types, loss_types, model_types):
 
1008
  # interplot each column to be same number of points
1009
  x = np.linspace(0, 1, 1000)
1010
  loss_rates = []
@@ -1022,9 +1023,15 @@ with ui.navset_card_tab(id="tab"):
1022
  labels.append(str(param_type) + '_' + loss_type + '_' + model_type)
1023
 
1024
  fig, ax = plt.subplots()
1025
- print(loss_rates)
1026
 
1027
  for i, loss_rate in enumerate(loss_rates):
 
 
 
 
 
 
1028
  ax.plot(x, loss_rate, label=labels[i])
1029
 
1030
  ax.legend()
@@ -1034,12 +1041,18 @@ with ui.navset_card_tab(id="tab"):
1034
  return fig
1035
 
1036
  import matplotlib as mpl
1037
- @render.plot()
1038
  def plot_model_scaling():
1039
  fig = None
1040
  df = pd.read_csv('training_data_5.csv')
1041
  mpl.rcParams.update(mpl.rcParamsDefault)
1042
- fig = plot_loss_rates_model(df, input.param_type(),input.loss_type(),input.model_type())
 
 
 
 
 
 
1043
  return fig
1044
  with ui.nav_panel("Scaling Laws"):
1045
  ui.page_opts(fillable=True)
 
1004
  multiple=True,
1005
  selected=["compliment", "cross_entropy", "headless"]
1006
  )
1007
+ ui.input_slider("x_filter", "x_filter", 0, 1, 0.01)
1008
+ def plot_loss_rates_model(df, param_types, loss_types, model_types, x_filter):
1009
  # interplot each column to be same number of points
1010
  x = np.linspace(0, 1, 1000)
1011
  loss_rates = []
 
1023
  labels.append(str(param_type) + '_' + loss_type + '_' + model_type)
1024
 
1025
  fig, ax = plt.subplots()
1026
+ # print(loss_rates)
1027
 
1028
  for i, loss_rate in enumerate(loss_rates):
1029
+ df_madmad = pd.DataFrame({'x':x, 'loss':loss_rate})
1030
+
1031
+ df_madmad = df_madmad.sort_values(by='x')
1032
+ df_madmad = df_madmad[df_madmad['x']>x_filter]
1033
+ x = df_madmad['x'].to_list()
1034
+ loss_rate = df_madmad['loss_rate'].to_list()
1035
  ax.plot(x, loss_rate, label=labels[i])
1036
 
1037
  ax.legend()
 
1041
  return fig
1042
 
1043
  import matplotlib as mpl
1044
+ @render.image
1045
  def plot_model_scaling():
1046
  fig = None
1047
  df = pd.read_csv('training_data_5.csv')
1048
  mpl.rcParams.update(mpl.rcParamsDefault)
1049
+ fig = plot_loss_rates_model(df, input.param_type(),input.loss_type(),input.model_type(),input.x_filter() )
1050
+
1051
+ import tempfile
1052
+ fd, path = tempfile.mkstemp(suffix = '.svg')
1053
+ if fig:
1054
+ fig.savefig(path)
1055
+ return {"src": str(path), "width": "600px", "format":"svg"}
1056
  return fig
1057
  with ui.nav_panel("Scaling Laws"):
1058
  ui.page_opts(fillable=True)