Koshti10 commited on
Commit
aa96db8
·
verified ·
1 Parent(s): 988cfdd

Upload 6 files

Browse files
Files changed (1) hide show
  1. src/trend_utils.py +64 -38
src/trend_utils.py CHANGED
@@ -7,10 +7,12 @@ import plotly.express as px
7
  import plotly.graph_objects as go
8
  import numpy as np
9
 
10
- from src.assets.text_content import REGISTRY_URL, REPO
11
  from src.leaderboard_utils import get_github_data
12
 
 
13
  START_DATE = '2023-06-01'
 
14
  def get_param_size(params: str) -> float:
15
  """Convert parameter size from string to float.
16
 
@@ -87,13 +89,13 @@ def populate_list(df: pd.DataFrame, abs_diff: float) -> list:
87
  return l
88
 
89
 
90
- def get_models_to_display(result_df: pd.DataFrame, open_dip: float = -0.5, comm_dip: float = -10) -> tuple:
91
  """Retrieve models to display based on clemscore differences.
92
 
93
  Args:
94
  result_df (pd.DataFrame): DataFrame containing model data.
95
- open_dip (float, optional): Threshold for open models. Defaults to -0.5.
96
- comm_dip (float, optional): Threshold for commercial models. Defaults to -10.
97
 
98
  Returns:
99
  tuple: Two lists of model names (open and commercial).
@@ -144,7 +146,7 @@ def get_trend_data(text_dfs: list, model_registry_data: list) -> pd.DataFrame:
144
 
145
 
146
  def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '2024-12-30',
147
- benchmark_ticks: dict = {}, **plot_kwargs) -> go.Figure:
148
  """Generate a scatter plot for the given DataFrame.
149
 
150
  Args:
@@ -152,6 +154,7 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
152
  start_date (str, optional): Start date for filtering. Defaults to '2023-06-01'.
153
  end_date (str, optional): End date for filtering. Defaults to '2024-12-30'.
154
  benchmark_ticks (dict, optional): Custom benchmark ticks for the version dates. Defaults to {}.
 
155
 
156
  Keyword Args:
157
  open_dip (float, optional): Threshold for open models' clemscore differences. Max dip in clemscore allowed to be considered in trend.
@@ -172,7 +175,7 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
172
  max_clemscore = df['clemscore'].max()
173
  # Convert 'release_date' to datetime
174
  df['Release date'] = pd.to_datetime(df['release_date'], format='ISO8601')
175
- # Filter out data before April 2023
176
  df = df[df['Release date'] >= pd.to_datetime(start_date)]
177
  open_model_list, comm_model_list = get_models_to_display(df, open_dip, comm_dip)
178
  models_to_display = open_model_list + comm_model_list
@@ -180,8 +183,8 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
180
 
181
  # Create a column to indicate if the model should be labeled
182
  df['label_model'] = df['model'].apply(lambda x: x if x in models_to_display else "")
183
- # If select_all is False, then show only the models in models_to_display
184
 
 
185
  if mobile_view:
186
  df = df[df['model'].isin(models_to_display)]
187
 
@@ -216,24 +219,48 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
216
  date_range = pd.date_range(start=start_date, end=end_date, freq='2MS') # '2MS' stands for 2 Months Start frequency
217
  # Create labels for these ticks
218
  custom_ticks = {date: date.strftime('%b %Y') for date in date_range}
 
 
219
  benchmark_tickvals = list(pd.to_datetime(list(benchmark_ticks.keys())))
220
  custom_ticks = {k:v for k,v in custom_ticks.items() if k not in benchmark_tickvals}
221
  custom_tickvals = list(custom_ticks.keys())
222
 
223
 
224
- # Plot Benchmark X-axis ticks with Vertical Dotted Lines
225
- for date in benchmark_tickvals:
226
- fig.add_shape(
227
- go.layout.Shape(
228
- type='line',
229
- x0=date,
230
- x1=date,
231
- y0=0,
232
- y1=1,
233
- yref='paper',
234
- line=dict(color='#A9A9A9', dash='dash')
 
 
 
 
 
235
  )
236
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  if mobile_view:
239
  # Remove custom_tickvals within -1 month to +1 month of benchmark_tickvals for better visibility
@@ -246,13 +273,13 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
246
  benchmark_tick_texts = []
247
  for i in range(len(benchmark_tickvals)):
248
  if i == 0:
249
- benchmark_tick_texts.append(f"<br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
250
  else:
251
  date_diff = (benchmark_tickvals[i] - benchmark_tickvals[i - 1]).days
252
  if date_diff <= 60:
253
- benchmark_tick_texts.append(f"<br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
254
  else:
255
- benchmark_tick_texts.append(f"<br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
256
  fig.update_xaxes(
257
  tickvals=filtered_custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
258
  ticktext=[f"{date.strftime('%b')}<br>{date.strftime('%y')}" for date in filtered_custom_tickvals] +
@@ -266,7 +293,7 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
266
  fig.update_xaxes(
267
  tickvals=custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
268
  ticktext=[f"{date.strftime('%b')} {date.strftime('%Y')}" for date in custom_tickvals] +
269
- [f"<br><b>{benchmark_ticks[date]}</b>" for date in benchmark_tickvals], # Added <br> for vertical alignment
270
  tickangle=0,
271
  tickfont=dict(size=10)
272
  )
@@ -303,6 +330,7 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
303
  x=0.01
304
  )
305
  )
 
306
  return fig
307
 
308
  def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) -> go.Figure:
@@ -319,8 +347,7 @@ def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) ->
319
  response = requests.get(REGISTRY_URL)
320
  model_registry_data = response.json()
321
  # Custom tick labels
322
- base_repo = REPO
323
- json_url = base_repo + "benchmark_runs.json"
324
  response = requests.get(json_url)
325
 
326
  # Check if the JSON file request was successful
@@ -335,37 +362,36 @@ def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) ->
335
  else:
336
  height = 1000
337
 
 
338
  plot_kwargs = {'height': height, 'open_dip': 0, 'comm_dip': 0,
339
  'mobile_view': mobile_view}
340
 
341
- # plot_kwargs = {'height': height, 'open_dip': -0.5, 'comm_dip': -5,
342
- # 'mobile_view': mobile_view}
343
-
344
  if benchmark == "Text":
345
  text_dfs = get_github_data()['text']['dataframes']
346
  text_result_df = get_trend_data(text_dfs, model_registry_data)
347
-
348
  ## Get benchmark tickvalues as dates for X-axis
349
- benchmark_ticks = {}
350
  for ver in versions:
351
  if 'multimodal' not in ver['version']: # Skip MM specific benchmark dates
352
  benchmark_ticks[pd.to_datetime(ver['release_date'])] = ver['version']
353
- fig = get_plot(text_result_df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, **plot_kwargs)
 
 
 
 
 
354
  else:
355
  mm_dfs = get_github_data()['multimodal']['dataframes']
356
  result_df = get_trend_data(mm_dfs, model_registry_data)
357
  df = result_df
358
-
359
- ## Get benchmark tickvalues as dates for X-axis
360
- benchmark_ticks = {}
361
  for ver in versions:
362
  if 'multimodal' in ver['version']:
363
  temp_ver = ver['version']
364
  temp_ver = temp_ver.replace('_multimodal', '')
365
  benchmark_ticks[pd.to_datetime(ver['release_date'])] = temp_ver ## MM benchmark dates considered after v1.6 (incl.)
366
-
367
- print("benchmark_ticks")
368
- print(benchmark_ticks)
369
- fig = get_plot(df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, **plot_kwargs)
370
 
371
  return fig
 
7
  import plotly.graph_objects as go
8
  import numpy as np
9
 
10
+ from src.assets.text_content import REGISTRY_URL, REPO, BENCHMARK_FILE
11
  from src.leaderboard_utils import get_github_data
12
 
13
+ # Cut-off date from where to start the trendgraph
14
  START_DATE = '2023-06-01'
15
+
16
  def get_param_size(params: str) -> float:
17
  """Convert parameter size from string to float.
18
 
 
89
  return l
90
 
91
 
92
+ def get_models_to_display(result_df: pd.DataFrame, open_dip: float = 0, comm_dip: float = 0) -> tuple:
93
  """Retrieve models to display based on clemscore differences.
94
 
95
  Args:
96
  result_df (pd.DataFrame): DataFrame containing model data.
97
+ open_dip (float, optional): Threshold for open models. Defaults to 0.
98
+ comm_dip (float, optional): Threshold for commercial models. Defaults to 0.
99
 
100
  Returns:
101
  tuple: Two lists of model names (open and commercial).
 
146
 
147
 
148
  def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '2024-12-30',
149
+ benchmark_ticks: dict = {}, benchmark_update = {}, **plot_kwargs) -> go.Figure:
150
  """Generate a scatter plot for the given DataFrame.
151
 
152
  Args:
 
154
  start_date (str, optional): Start date for filtering. Defaults to '2023-06-01'.
155
  end_date (str, optional): End date for filtering. Defaults to '2024-12-30'.
156
  benchmark_ticks (dict, optional): Custom benchmark ticks for the version dates. Defaults to {}.
157
+ benchmark_update (dict, optional): Custom benchmark metadata containing last_updated date for the versions. Defaults to {}.
158
 
159
  Keyword Args:
160
  open_dip (float, optional): Threshold for open models' clemscore differences. Max dip in clemscore allowed to be considered in trend.
 
175
  max_clemscore = df['clemscore'].max()
176
  # Convert 'release_date' to datetime
177
  df['Release date'] = pd.to_datetime(df['release_date'], format='ISO8601')
178
+ # Filter out data before April 2023/START_DATE
179
  df = df[df['Release date'] >= pd.to_datetime(start_date)]
180
  open_model_list, comm_model_list = get_models_to_display(df, open_dip, comm_dip)
181
  models_to_display = open_model_list + comm_model_list
 
183
 
184
  # Create a column to indicate if the model should be labeled
185
  df['label_model'] = df['model'].apply(lambda x: x if x in models_to_display else "")
 
186
 
187
+ # If mobile_view, then show only the models in models_to_display i.e. on the trend line #minimalistic
188
  if mobile_view:
189
  df = df[df['model'].isin(models_to_display)]
190
 
 
219
  date_range = pd.date_range(start=start_date, end=end_date, freq='2MS') # '2MS' stands for 2 Months Start frequency
220
  # Create labels for these ticks
221
  custom_ticks = {date: date.strftime('%b %Y') for date in date_range}
222
+
223
+ ## Benchmark Version ticks
224
  benchmark_tickvals = list(pd.to_datetime(list(benchmark_ticks.keys())))
225
  custom_ticks = {k:v for k,v in custom_ticks.items() if k not in benchmark_tickvals}
226
  custom_tickvals = list(custom_ticks.keys())
227
 
228
 
229
+ for date, version in benchmark_ticks.items():
230
+ # Find the corresponding update date from benchmark_update based on the version name
231
+ update_date = next((update_date for update_date, ver in benchmark_update.items() if version in ver), None)
232
+
233
+ if update_date:
234
+ # Add vertical black dotted line for each benchmark_tick date
235
+ fig.add_shape(
236
+ go.layout.Shape(
237
+ type='line',
238
+ x0=date,
239
+ x1=date,
240
+ y0=0,
241
+ y1=1,
242
+ yref='paper',
243
+ line=dict(color='#A9A9A9', dash='dash'), # Black dotted line
244
+ )
245
  )
246
+
247
+ # Add hover information across the full y-axis range
248
+ fig.add_trace(
249
+ go.Scatter(
250
+ x=[date]*100,
251
+ y=list(range(0,100)), # Covers full y-axis range
252
+ mode='markers',
253
+ line=dict(color='rgba(255,255,255,0)', width=0), # Fully transparent line
254
+ hovertext=[
255
+ f"Version: {version} released on {date.strftime("%d %b %Y")}, last updated on: {update_date.strftime("%d %b %Y")}"
256
+ for _ in range(100)
257
+ ], # Unique hovertext for all points
258
+ hoverinfo="text",
259
+ hoveron='points',
260
+ showlegend=False
261
+ )
262
+ )
263
+
264
 
265
  if mobile_view:
266
  # Remove custom_tickvals within -1 month to +1 month of benchmark_tickvals for better visibility
 
273
  benchmark_tick_texts = []
274
  for i in range(len(benchmark_tickvals)):
275
  if i == 0:
276
+ benchmark_tick_texts.append(f"<br><span style='font-size:7px;'><b>{benchmark_ticks[benchmark_tickvals[i]]}</b></span>")
277
  else:
278
  date_diff = (benchmark_tickvals[i] - benchmark_tickvals[i - 1]).days
279
  if date_diff <= 60:
280
+ benchmark_tick_texts.append(f"<br><br><span style='font-size:7px;'><b>{benchmark_ticks[benchmark_tickvals[i]]}</b></span>")
281
  else:
282
+ benchmark_tick_texts.append(f"<br><span style='font-size:7px;'><b>{benchmark_ticks[benchmark_tickvals[i]]}</b></span>")
283
  fig.update_xaxes(
284
  tickvals=filtered_custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
285
  ticktext=[f"{date.strftime('%b')}<br>{date.strftime('%y')}" for date in filtered_custom_tickvals] +
 
293
  fig.update_xaxes(
294
  tickvals=custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
295
  ticktext=[f"{date.strftime('%b')} {date.strftime('%Y')}" for date in custom_tickvals] +
296
+ [f"<br><span style='font-size:12px;'><b>{benchmark_ticks[date]}</b></span>" for date in benchmark_tickvals], # Added <br> for vertical alignment
297
  tickangle=0,
298
  tickfont=dict(size=10)
299
  )
 
330
  x=0.01
331
  )
332
  )
333
+
334
  return fig
335
 
336
  def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) -> go.Figure:
 
347
  response = requests.get(REGISTRY_URL)
348
  model_registry_data = response.json()
349
  # Custom tick labels
350
+ json_url = REPO + BENCHMARK_FILE
 
351
  response = requests.get(json_url)
352
 
353
  # Check if the JSON file request was successful
 
362
  else:
363
  height = 1000
364
 
365
+
366
  plot_kwargs = {'height': height, 'open_dip': 0, 'comm_dip': 0,
367
  'mobile_view': mobile_view}
368
 
369
+ benchmark_ticks = {}
370
+ benchmark_update = {}
 
371
  if benchmark == "Text":
372
  text_dfs = get_github_data()['text']['dataframes']
373
  text_result_df = get_trend_data(text_dfs, model_registry_data)
 
374
  ## Get benchmark tickvalues as dates for X-axis
 
375
  for ver in versions:
376
  if 'multimodal' not in ver['version']: # Skip MM specific benchmark dates
377
  benchmark_ticks[pd.to_datetime(ver['release_date'])] = ver['version']
378
+ if pd.to_datetime(ver['last_updated']) not in benchmark_update:
379
+ benchmark_update[pd.to_datetime(ver['last_updated'])] = [ver['version']]
380
+ else:
381
+ benchmark_update[pd.to_datetime(ver['last_updated'])].append(ver['version'])
382
+
383
+ fig = get_plot(text_result_df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)
384
  else:
385
  mm_dfs = get_github_data()['multimodal']['dataframes']
386
  result_df = get_trend_data(mm_dfs, model_registry_data)
387
  df = result_df
 
 
 
388
  for ver in versions:
389
  if 'multimodal' in ver['version']:
390
  temp_ver = ver['version']
391
  temp_ver = temp_ver.replace('_multimodal', '')
392
  benchmark_ticks[pd.to_datetime(ver['release_date'])] = temp_ver ## MM benchmark dates considered after v1.6 (incl.)
393
+ benchmark_update[pd.to_datetime(ver['last_updated'])] = temp_ver
394
+
395
+ fig = get_plot(df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)
 
396
 
397
  return fig