Corey Morris commited on
Commit
c1a84da
·
1 Parent(s): ac931c6

Updated data cleanup so that column names are cleaned up appropriatly with regex=True

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -27,9 +27,12 @@ class MultiURLData:
27
 
28
  df = df.rename(columns={'acc': model_name})
29
 
30
- df.index = df.index.str.replace('hendrycksTest-', '')
31
 
32
- df.index = df.index.str.replace('harness\\|', '')
 
 
 
33
 
34
  dataframes.append(df[[model_name]])
35
 
@@ -89,11 +92,13 @@ def create_plot(df, model_column, arc_column, moral_column, models=None):
89
 
90
  # Calculate color column
91
  plot_data['color'] = 'purple'
92
- plot_data.loc[plot_data[moral_column] < plot_data[arc_column], 'color'] = 'red'
93
- plot_data.loc[plot_data[moral_column] > plot_data[arc_column], 'color'] = 'blue'
94
 
95
- # Create the scatter plot
96
- fig = px.scatter(plot_data, x=arc_column, y=moral_column, color='color', hover_data=['Model'])
 
 
 
 
97
  fig.update_layout(showlegend=False, # hide legend
98
  xaxis_title=arc_column,
99
  yaxis_title=moral_column,
@@ -102,14 +107,15 @@ def create_plot(df, model_column, arc_column, moral_column, models=None):
102
 
103
  return fig
104
 
 
105
  # models_to_plot = ['Model1', 'Model2', 'Model3']
106
  # fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'moral_scenarios|5', models=models_to_plot)
107
 
108
- fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'moral_scenarios|5')
109
  st.plotly_chart(fig)
110
 
111
  fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'hellaswag|10')
112
  st.plotly_chart(fig)
113
 
114
- fig = create_plot(filtered_data, 'Model Name', 'moral_disputes|5', 'moral_scenarios|5')
115
  st.plotly_chart(fig)
 
27
 
28
  df = df.rename(columns={'acc': model_name})
29
 
30
+ df.index = df.index.str.replace('hendrycksTest-', '', regex=True)
31
 
32
+ df.index = df.index.str.replace('harness\|', '', regex=True)
33
+
34
+ # remove |5 from the index
35
+ df.index = df.index.str.replace('\|5', '', regex=True)
36
 
37
  dataframes.append(df[[model_name]])
38
 
 
92
 
93
  # Calculate color column
94
  plot_data['color'] = 'purple'
 
 
95
 
96
+ # # TODO maybe change this
97
+ # plot_data.loc[plot_data[moral_column] < plot_data[arc_column], 'color'] = 'red'
98
+ # plot_data.loc[plot_data[moral_column] > plot_data[arc_column], 'color'] = 'blue'
99
+
100
+ # Create the scatter plot with trendline
101
+ fig = px.scatter(plot_data, x=arc_column, y=moral_column, color='color', hover_data=['Model'], trendline="ols") #other option ols
102
  fig.update_layout(showlegend=False, # hide legend
103
  xaxis_title=arc_column,
104
  yaxis_title=moral_column,
 
107
 
108
  return fig
109
 
110
+
111
  # models_to_plot = ['Model1', 'Model2', 'Model3']
112
  # fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'moral_scenarios|5', models=models_to_plot)
113
 
114
+ fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'moral_scenarios')
115
  st.plotly_chart(fig)
116
 
117
  fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'hellaswag|10')
118
  st.plotly_chart(fig)
119
 
120
+ fig = create_plot(filtered_data, 'Model Name', 'moral_disputes', 'moral_scenarios')
121
  st.plotly_chart(fig)