Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- src/trend_utils.py +64 -38
src/trend_utils.py
CHANGED
@@ -7,10 +7,12 @@ import plotly.express as px
|
|
7 |
import plotly.graph_objects as go
|
8 |
import numpy as np
|
9 |
|
10 |
-
from src.assets.text_content import REGISTRY_URL, REPO
|
11 |
from src.leaderboard_utils import get_github_data
|
12 |
|
|
|
13 |
START_DATE = '2023-06-01'
|
|
|
14 |
def get_param_size(params: str) -> float:
|
15 |
"""Convert parameter size from string to float.
|
16 |
|
@@ -87,13 +89,13 @@ def populate_list(df: pd.DataFrame, abs_diff: float) -> list:
|
|
87 |
return l
|
88 |
|
89 |
|
90 |
-
def get_models_to_display(result_df: pd.DataFrame, open_dip: float =
|
91 |
"""Retrieve models to display based on clemscore differences.
|
92 |
|
93 |
Args:
|
94 |
result_df (pd.DataFrame): DataFrame containing model data.
|
95 |
-
open_dip (float, optional): Threshold for open models. Defaults to
|
96 |
-
comm_dip (float, optional): Threshold for commercial models. Defaults to
|
97 |
|
98 |
Returns:
|
99 |
tuple: Two lists of model names (open and commercial).
|
@@ -144,7 +146,7 @@ def get_trend_data(text_dfs: list, model_registry_data: list) -> pd.DataFrame:
|
|
144 |
|
145 |
|
146 |
def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '2024-12-30',
|
147 |
-
benchmark_ticks: dict = {}, **plot_kwargs) -> go.Figure:
|
148 |
"""Generate a scatter plot for the given DataFrame.
|
149 |
|
150 |
Args:
|
@@ -152,6 +154,7 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
152 |
start_date (str, optional): Start date for filtering. Defaults to '2023-06-01'.
|
153 |
end_date (str, optional): End date for filtering. Defaults to '2024-12-30'.
|
154 |
benchmark_ticks (dict, optional): Custom benchmark ticks for the version dates. Defaults to {}.
|
|
|
155 |
|
156 |
Keyword Args:
|
157 |
open_dip (float, optional): Threshold for open models' clemscore differences. Max dip in clemscore allowed to be considered in trend.
|
@@ -172,7 +175,7 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
172 |
max_clemscore = df['clemscore'].max()
|
173 |
# Convert 'release_date' to datetime
|
174 |
df['Release date'] = pd.to_datetime(df['release_date'], format='ISO8601')
|
175 |
-
# Filter out data before April 2023
|
176 |
df = df[df['Release date'] >= pd.to_datetime(start_date)]
|
177 |
open_model_list, comm_model_list = get_models_to_display(df, open_dip, comm_dip)
|
178 |
models_to_display = open_model_list + comm_model_list
|
@@ -180,8 +183,8 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
180 |
|
181 |
# Create a column to indicate if the model should be labeled
|
182 |
df['label_model'] = df['model'].apply(lambda x: x if x in models_to_display else "")
|
183 |
-
# If select_all is False, then show only the models in models_to_display
|
184 |
|
|
|
185 |
if mobile_view:
|
186 |
df = df[df['model'].isin(models_to_display)]
|
187 |
|
@@ -216,24 +219,48 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
216 |
date_range = pd.date_range(start=start_date, end=end_date, freq='2MS') # '2MS' stands for 2 Months Start frequency
|
217 |
# Create labels for these ticks
|
218 |
custom_ticks = {date: date.strftime('%b %Y') for date in date_range}
|
|
|
|
|
219 |
benchmark_tickvals = list(pd.to_datetime(list(benchmark_ticks.keys())))
|
220 |
custom_ticks = {k:v for k,v in custom_ticks.items() if k not in benchmark_tickvals}
|
221 |
custom_tickvals = list(custom_ticks.keys())
|
222 |
|
223 |
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
235 |
)
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
if mobile_view:
|
239 |
# Remove custom_tickvals within -1 month to +1 month of benchmark_tickvals for better visibility
|
@@ -246,13 +273,13 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
246 |
benchmark_tick_texts = []
|
247 |
for i in range(len(benchmark_tickvals)):
|
248 |
if i == 0:
|
249 |
-
benchmark_tick_texts.append(f"<br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
|
250 |
else:
|
251 |
date_diff = (benchmark_tickvals[i] - benchmark_tickvals[i - 1]).days
|
252 |
if date_diff <= 60:
|
253 |
-
benchmark_tick_texts.append(f"<br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
|
254 |
else:
|
255 |
-
benchmark_tick_texts.append(f"<br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
|
256 |
fig.update_xaxes(
|
257 |
tickvals=filtered_custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
|
258 |
ticktext=[f"{date.strftime('%b')}<br>{date.strftime('%y')}" for date in filtered_custom_tickvals] +
|
@@ -266,7 +293,7 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
266 |
fig.update_xaxes(
|
267 |
tickvals=custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
|
268 |
ticktext=[f"{date.strftime('%b')} {date.strftime('%Y')}" for date in custom_tickvals] +
|
269 |
-
[f"<br><b>{benchmark_ticks[date]}</b>" for date in benchmark_tickvals], # Added <br> for vertical alignment
|
270 |
tickangle=0,
|
271 |
tickfont=dict(size=10)
|
272 |
)
|
@@ -303,6 +330,7 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
303 |
x=0.01
|
304 |
)
|
305 |
)
|
|
|
306 |
return fig
|
307 |
|
308 |
def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) -> go.Figure:
|
@@ -319,8 +347,7 @@ def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) ->
|
|
319 |
response = requests.get(REGISTRY_URL)
|
320 |
model_registry_data = response.json()
|
321 |
# Custom tick labels
|
322 |
-
|
323 |
-
json_url = base_repo + "benchmark_runs.json"
|
324 |
response = requests.get(json_url)
|
325 |
|
326 |
# Check if the JSON file request was successful
|
@@ -335,37 +362,36 @@ def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) ->
|
|
335 |
else:
|
336 |
height = 1000
|
337 |
|
|
|
338 |
plot_kwargs = {'height': height, 'open_dip': 0, 'comm_dip': 0,
|
339 |
'mobile_view': mobile_view}
|
340 |
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
if benchmark == "Text":
|
345 |
text_dfs = get_github_data()['text']['dataframes']
|
346 |
text_result_df = get_trend_data(text_dfs, model_registry_data)
|
347 |
-
|
348 |
## Get benchmark tickvalues as dates for X-axis
|
349 |
-
benchmark_ticks = {}
|
350 |
for ver in versions:
|
351 |
if 'multimodal' not in ver['version']: # Skip MM specific benchmark dates
|
352 |
benchmark_ticks[pd.to_datetime(ver['release_date'])] = ver['version']
|
353 |
-
|
|
|
|
|
|
|
|
|
|
|
354 |
else:
|
355 |
mm_dfs = get_github_data()['multimodal']['dataframes']
|
356 |
result_df = get_trend_data(mm_dfs, model_registry_data)
|
357 |
df = result_df
|
358 |
-
|
359 |
-
## Get benchmark tickvalues as dates for X-axis
|
360 |
-
benchmark_ticks = {}
|
361 |
for ver in versions:
|
362 |
if 'multimodal' in ver['version']:
|
363 |
temp_ver = ver['version']
|
364 |
temp_ver = temp_ver.replace('_multimodal', '')
|
365 |
benchmark_ticks[pd.to_datetime(ver['release_date'])] = temp_ver ## MM benchmark dates considered after v1.6 (incl.)
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
fig = get_plot(df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, **plot_kwargs)
|
370 |
|
371 |
return fig
|
|
|
7 |
import plotly.graph_objects as go
|
8 |
import numpy as np
|
9 |
|
10 |
+
from src.assets.text_content import REGISTRY_URL, REPO, BENCHMARK_FILE
|
11 |
from src.leaderboard_utils import get_github_data
|
12 |
|
13 |
+
# Cut-off date from where to start the trendgraph
|
14 |
START_DATE = '2023-06-01'
|
15 |
+
|
16 |
def get_param_size(params: str) -> float:
|
17 |
"""Convert parameter size from string to float.
|
18 |
|
|
|
89 |
return l
|
90 |
|
91 |
|
92 |
+
def get_models_to_display(result_df: pd.DataFrame, open_dip: float = 0, comm_dip: float = 0) -> tuple:
|
93 |
"""Retrieve models to display based on clemscore differences.
|
94 |
|
95 |
Args:
|
96 |
result_df (pd.DataFrame): DataFrame containing model data.
|
97 |
+
open_dip (float, optional): Threshold for open models. Defaults to 0.
|
98 |
+
comm_dip (float, optional): Threshold for commercial models. Defaults to 0.
|
99 |
|
100 |
Returns:
|
101 |
tuple: Two lists of model names (open and commercial).
|
|
|
146 |
|
147 |
|
148 |
def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '2024-12-30',
|
149 |
+
benchmark_ticks: dict = {}, benchmark_update = {}, **plot_kwargs) -> go.Figure:
|
150 |
"""Generate a scatter plot for the given DataFrame.
|
151 |
|
152 |
Args:
|
|
|
154 |
start_date (str, optional): Start date for filtering. Defaults to '2023-06-01'.
|
155 |
end_date (str, optional): End date for filtering. Defaults to '2024-12-30'.
|
156 |
benchmark_ticks (dict, optional): Custom benchmark ticks for the version dates. Defaults to {}.
|
157 |
+
benchmark_update (dict, optional): Custom benchmark metadata containing last_updated date for the versions. Defaults to {}.
|
158 |
|
159 |
Keyword Args:
|
160 |
open_dip (float, optional): Threshold for open models' clemscore differences. Max dip in clemscore allowed to be considered in trend.
|
|
|
175 |
max_clemscore = df['clemscore'].max()
|
176 |
# Convert 'release_date' to datetime
|
177 |
df['Release date'] = pd.to_datetime(df['release_date'], format='ISO8601')
|
178 |
+
# Filter out data before April 2023/START_DATE
|
179 |
df = df[df['Release date'] >= pd.to_datetime(start_date)]
|
180 |
open_model_list, comm_model_list = get_models_to_display(df, open_dip, comm_dip)
|
181 |
models_to_display = open_model_list + comm_model_list
|
|
|
183 |
|
184 |
# Create a column to indicate if the model should be labeled
|
185 |
df['label_model'] = df['model'].apply(lambda x: x if x in models_to_display else "")
|
|
|
186 |
|
187 |
+
# If mobile_view, then show only the models in models_to_display i.e. on the trend line #minimalistic
|
188 |
if mobile_view:
|
189 |
df = df[df['model'].isin(models_to_display)]
|
190 |
|
|
|
219 |
date_range = pd.date_range(start=start_date, end=end_date, freq='2MS') # '2MS' stands for 2 Months Start frequency
|
220 |
# Create labels for these ticks
|
221 |
custom_ticks = {date: date.strftime('%b %Y') for date in date_range}
|
222 |
+
|
223 |
+
## Benchmark Version ticks
|
224 |
benchmark_tickvals = list(pd.to_datetime(list(benchmark_ticks.keys())))
|
225 |
custom_ticks = {k:v for k,v in custom_ticks.items() if k not in benchmark_tickvals}
|
226 |
custom_tickvals = list(custom_ticks.keys())
|
227 |
|
228 |
|
229 |
+
for date, version in benchmark_ticks.items():
|
230 |
+
# Find the corresponding update date from benchmark_update based on the version name
|
231 |
+
update_date = next((update_date for update_date, ver in benchmark_update.items() if version in ver), None)
|
232 |
+
|
233 |
+
if update_date:
|
234 |
+
# Add vertical black dotted line for each benchmark_tick date
|
235 |
+
fig.add_shape(
|
236 |
+
go.layout.Shape(
|
237 |
+
type='line',
|
238 |
+
x0=date,
|
239 |
+
x1=date,
|
240 |
+
y0=0,
|
241 |
+
y1=1,
|
242 |
+
yref='paper',
|
243 |
+
line=dict(color='#A9A9A9', dash='dash'), # Black dotted line
|
244 |
+
)
|
245 |
)
|
246 |
+
|
247 |
+
# Add hover information across the full y-axis range
|
248 |
+
fig.add_trace(
|
249 |
+
go.Scatter(
|
250 |
+
x=[date]*100,
|
251 |
+
y=list(range(0,100)), # Covers full y-axis range
|
252 |
+
mode='markers',
|
253 |
+
line=dict(color='rgba(255,255,255,0)', width=0), # Fully transparent line
|
254 |
+
hovertext=[
|
255 |
+
f"Version: {version} released on {date.strftime("%d %b %Y")}, last updated on: {update_date.strftime("%d %b %Y")}"
|
256 |
+
for _ in range(100)
|
257 |
+
], # Unique hovertext for all points
|
258 |
+
hoverinfo="text",
|
259 |
+
hoveron='points',
|
260 |
+
showlegend=False
|
261 |
+
)
|
262 |
+
)
|
263 |
+
|
264 |
|
265 |
if mobile_view:
|
266 |
# Remove custom_tickvals within -1 month to +1 month of benchmark_tickvals for better visibility
|
|
|
273 |
benchmark_tick_texts = []
|
274 |
for i in range(len(benchmark_tickvals)):
|
275 |
if i == 0:
|
276 |
+
benchmark_tick_texts.append(f"<br><span style='font-size:7px;'><b>{benchmark_ticks[benchmark_tickvals[i]]}</b></span>")
|
277 |
else:
|
278 |
date_diff = (benchmark_tickvals[i] - benchmark_tickvals[i - 1]).days
|
279 |
if date_diff <= 60:
|
280 |
+
benchmark_tick_texts.append(f"<br><br><span style='font-size:7px;'><b>{benchmark_ticks[benchmark_tickvals[i]]}</b></span>")
|
281 |
else:
|
282 |
+
benchmark_tick_texts.append(f"<br><span style='font-size:7px;'><b>{benchmark_ticks[benchmark_tickvals[i]]}</b></span>")
|
283 |
fig.update_xaxes(
|
284 |
tickvals=filtered_custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
|
285 |
ticktext=[f"{date.strftime('%b')}<br>{date.strftime('%y')}" for date in filtered_custom_tickvals] +
|
|
|
293 |
fig.update_xaxes(
|
294 |
tickvals=custom_tickvals + benchmark_tickvals, # Use filtered_custom_tickvals
|
295 |
ticktext=[f"{date.strftime('%b')} {date.strftime('%Y')}" for date in custom_tickvals] +
|
296 |
+
[f"<br><span style='font-size:12px;'><b>{benchmark_ticks[date]}</b></span>" for date in benchmark_tickvals], # Added <br> for vertical alignment
|
297 |
tickangle=0,
|
298 |
tickfont=dict(size=10)
|
299 |
)
|
|
|
330 |
x=0.01
|
331 |
)
|
332 |
)
|
333 |
+
|
334 |
return fig
|
335 |
|
336 |
def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) -> go.Figure:
|
|
|
347 |
response = requests.get(REGISTRY_URL)
|
348 |
model_registry_data = response.json()
|
349 |
# Custom tick labels
|
350 |
+
json_url = REPO + BENCHMARK_FILE
|
|
|
351 |
response = requests.get(json_url)
|
352 |
|
353 |
# Check if the JSON file request was successful
|
|
|
362 |
else:
|
363 |
height = 1000
|
364 |
|
365 |
+
|
366 |
plot_kwargs = {'height': height, 'open_dip': 0, 'comm_dip': 0,
|
367 |
'mobile_view': mobile_view}
|
368 |
|
369 |
+
benchmark_ticks = {}
|
370 |
+
benchmark_update = {}
|
|
|
371 |
if benchmark == "Text":
|
372 |
text_dfs = get_github_data()['text']['dataframes']
|
373 |
text_result_df = get_trend_data(text_dfs, model_registry_data)
|
|
|
374 |
## Get benchmark tickvalues as dates for X-axis
|
|
|
375 |
for ver in versions:
|
376 |
if 'multimodal' not in ver['version']: # Skip MM specific benchmark dates
|
377 |
benchmark_ticks[pd.to_datetime(ver['release_date'])] = ver['version']
|
378 |
+
if pd.to_datetime(ver['last_updated']) not in benchmark_update:
|
379 |
+
benchmark_update[pd.to_datetime(ver['last_updated'])] = [ver['version']]
|
380 |
+
else:
|
381 |
+
benchmark_update[pd.to_datetime(ver['last_updated'])].append(ver['version'])
|
382 |
+
|
383 |
+
fig = get_plot(text_result_df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)
|
384 |
else:
|
385 |
mm_dfs = get_github_data()['multimodal']['dataframes']
|
386 |
result_df = get_trend_data(mm_dfs, model_registry_data)
|
387 |
df = result_df
|
|
|
|
|
|
|
388 |
for ver in versions:
|
389 |
if 'multimodal' in ver['version']:
|
390 |
temp_ver = ver['version']
|
391 |
temp_ver = temp_ver.replace('_multimodal', '')
|
392 |
benchmark_ticks[pd.to_datetime(ver['release_date'])] = temp_ver ## MM benchmark dates considered after v1.6 (incl.)
|
393 |
+
benchmark_update[pd.to_datetime(ver['last_updated'])] = temp_ver
|
394 |
+
|
395 |
+
fig = get_plot(df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)
|
|
|
396 |
|
397 |
return fig
|