|
|
|
import requests |
|
import pandas as pd |
|
from datetime import datetime |
|
import pandas as pd |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
import numpy as np |
|
|
|
from src.assets.text_content import REGISTRY_URL, REPO, BENCHMARK_FILE |
|
from src.leaderboard_utils import get_github_data |
|
|
|
|
|
START_DATE = '2023-06-01' |
|
|
|
def get_param_size(params: str) -> float: |
|
"""Convert parameter size from string to float. |
|
|
|
Args: |
|
params (str): The parameter size as a string (e.g., '1000B', '1T'). |
|
|
|
Returns: |
|
float: The size of parameters in float. |
|
""" |
|
if not params: |
|
param_size = 0 |
|
else: |
|
if params[-1] == "B": |
|
param_size = params[:-1] |
|
param_size = float(param_size) |
|
elif params[-1] == "T": |
|
param_size = params[:-1] |
|
param_size = float(param_size) |
|
param_size *= 1000 |
|
else: |
|
print("Not a valid parameter size") |
|
|
|
return param_size |
|
|
|
def date_difference(date_str1: str, date_str2: str) -> int: |
|
"""Calculate the difference in days between two dates. |
|
|
|
Args: |
|
date_str1 (str): The first date as a string in 'YYYY-MM-DD' format. |
|
date_str2 (str): The second date as a string in 'YYYY-MM-DD' format. |
|
|
|
Returns: |
|
int: The difference in days between the two dates. |
|
""" |
|
date_format = "%Y-%m-%d" |
|
date1 = datetime.strptime(date_str1, date_format) |
|
date2 = datetime.strptime(date_str2, date_format) |
|
return (date1 - date2).days |
|
|
|
|
|
def populate_list(df: pd.DataFrame, abs_diff: float) -> list: |
|
"""Create a list of models based on clemscore differences. |
|
|
|
Args: |
|
df (pd.DataFrame): DataFrame containing model data. |
|
abs_diff (float): The absolute difference threshold for clemscore. |
|
|
|
Returns: |
|
list: A list of model names that meet the criteria. |
|
""" |
|
l = [df.iloc[0]['model']] |
|
prev_clemscore = df.iloc[0]['clemscore'] |
|
prev_date = df.iloc[0]['release_date'] |
|
|
|
for i in range(1, len(df)): |
|
curr_clemscore = df.iloc[i]['clemscore'] |
|
curr_date = df.iloc[i]['release_date'] |
|
date_diff = date_difference(curr_date, prev_date) |
|
|
|
if curr_clemscore - prev_clemscore >= abs_diff: |
|
if date_diff == 0: |
|
l[-1] = df.iloc[i]['model'] |
|
else: |
|
l.append(df.iloc[i]['model']) |
|
|
|
prev_clemscore = curr_clemscore |
|
prev_date = curr_date |
|
|
|
|
|
|
|
|
|
|
|
|
|
return l |
|
|
|
|
|
def get_models_to_display(result_df: pd.DataFrame, open_dip: float = 0, comm_dip: float = 0) -> tuple: |
|
"""Retrieve models to display based on clemscore differences. |
|
|
|
Args: |
|
result_df (pd.DataFrame): DataFrame containing model data. |
|
open_dip (float, optional): Threshold for open models. Defaults to 0. |
|
comm_dip (float, optional): Threshold for commercial models. Defaults to 0. |
|
|
|
Returns: |
|
tuple: Two lists of model names (open and commercial). |
|
""" |
|
open_model_df = result_df[result_df['open_weight']==True] |
|
comm_model_df = result_df[result_df['open_weight']==False] |
|
|
|
open_model_df = open_model_df.sort_values(by='release_date', ascending=True) |
|
comm_model_df = comm_model_df.sort_values(by='release_date', ascending=True) |
|
open_models = populate_list(open_model_df, open_dip) |
|
comm_models = populate_list(comm_model_df, comm_dip) |
|
return open_models, comm_models |
|
|
|
|
|
def get_trend_data(text_dfs: list, model_registry_data: list) -> pd.DataFrame: |
|
"""Process text data frames to extract model information. |
|
|
|
Args: |
|
text_dfs (list): List of DataFrames containing model information. |
|
model_registry_data (list): List of dictionaries containing model registry data. |
|
|
|
Returns: |
|
pd.DataFrame: DataFrame containing processed model data. |
|
""" |
|
visited = set() |
|
result_df = pd.DataFrame(columns=['model', 'clemscore', 'open_weight', 'release_date', 'parameters', 'est_flag']) |
|
|
|
for df in text_dfs: |
|
for i in range(len(df)): |
|
model_name = df['Model'].iloc[i] |
|
if model_name not in visited: |
|
visited.add(model_name) |
|
for dict_obj in model_registry_data: |
|
if dict_obj["model_name"] == model_name: |
|
if dict_obj["parameters"] == "" : |
|
params = "1000B" |
|
est_flag = True |
|
else: |
|
params = dict_obj['parameters'] |
|
est_flag = False |
|
|
|
param_size = get_param_size(params) |
|
new_data = {'model': model_name, 'clemscore': df['Clemscore'].iloc[i], 'open_weight':dict_obj['open_weight'], |
|
'release_date': dict_obj['release_date'], 'parameters': param_size, 'est_flag': est_flag} |
|
result_df.loc[len(result_df)] = new_data |
|
break |
|
return result_df |
|
|
|
|
|
def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '2024-12-30', |
|
benchmark_ticks: dict = {}, benchmark_update = {}, **plot_kwargs) -> go.Figure: |
|
"""Generate a scatter plot for the given DataFrame. |
|
|
|
Args: |
|
df (pd.DataFrame): DataFrame containing model data. |
|
start_date (str, optional): Start date for filtering. Defaults to '2023-06-01'. |
|
end_date (str, optional): End date for filtering. Defaults to '2024-12-30'. |
|
benchmark_ticks (dict, optional): Custom benchmark ticks for the version dates. Defaults to {}. |
|
benchmark_update (dict, optional): Custom benchmark metadata containing last_updated date for the versions. Defaults to {}. |
|
|
|
Keyword Args: |
|
open_dip (float, optional): Threshold for open models' clemscore differences. Max dip in clemscore allowed to be considered in trend. |
|
comm_dip (float, optional): Threshold for commercial models' clemscore differences. Max dip in clemscore allowed to be considered in trend. |
|
height (int, optional): Height of the plot in pixels. Adjusted for mobile or desktop views. |
|
mobile_view (bool, optional): Flag to indicate if the plot should be optimized for mobile display. Defaults to False. |
|
|
|
Returns: |
|
go.Figure: The generated plot. |
|
""" |
|
|
|
open_dip = plot_kwargs['open_dip'] |
|
comm_dip = plot_kwargs['comm_dip'] |
|
height = plot_kwargs['height'] |
|
width = plot_kwargs['width'] |
|
|
|
mobile_view = True if plot_kwargs['mobile_view'] else False |
|
|
|
max_clemscore = df['clemscore'].max() |
|
|
|
df['Release date'] = pd.to_datetime(df['release_date'], format='ISO8601') |
|
|
|
df = df[df['Release date'] >= pd.to_datetime(start_date)] |
|
open_model_list, comm_model_list = get_models_to_display(df, open_dip, comm_dip) |
|
models_to_display = open_model_list + comm_model_list |
|
print(f"open_model_list: {open_model_list}, comm_model_list: {comm_model_list}") |
|
|
|
|
|
df['label_model'] = df['model'].apply(lambda x: x if x in models_to_display else "") |
|
|
|
|
|
if mobile_view: |
|
df = df[df['model'].isin(models_to_display)] |
|
|
|
|
|
df['Model Type'] = df['open_weight'].map({True: 'Open-Weight', False: 'Commercial'}) |
|
|
|
marker_size = df['parameters'].apply(lambda x: np.sqrt(x) if x > 0 else np.sqrt(400)).astype(float) |
|
|
|
open_color = 'red' |
|
comm_color = 'blue' |
|
|
|
|
|
fig = px.scatter(df, |
|
x="Release date", |
|
y="clemscore", |
|
color="Model Type", |
|
hover_name="model", |
|
size=marker_size, |
|
size_max=40, |
|
template="plotly_white", |
|
hover_data={ |
|
"Release date": True, |
|
"clemscore": True, |
|
"Model Type": True |
|
}, |
|
custom_data=["model", "Release date", "clemscore"] |
|
) |
|
|
|
fig.update_traces( |
|
hovertemplate='Model Name: %{customdata[0]}<br>Release date: %{customdata[1]}<br>Clemscore: %{customdata[2]}<br>' |
|
) |
|
|
|
|
|
df_open = df[df['model'].isin(open_model_list)].sort_values(by='Release date') |
|
df_commercial = df[df['model'].isin(comm_model_list)].sort_values(by='Release date') |
|
|
|
|
|
|
|
start_date = pd.to_datetime(start_date) |
|
end_date = pd.to_datetime(end_date) |
|
|
|
date_range = pd.date_range(start=start_date, end=end_date, freq='2MS') |
|
|
|
custom_ticks = {date: date.strftime('%b %Y') for date in date_range} |
|
|
|
|
|
benchmark_tickvals = list(pd.to_datetime(list(benchmark_ticks.keys()))) |
|
custom_ticks = {k:v for k,v in custom_ticks.items() if k not in benchmark_tickvals} |
|
custom_tickvals = list(custom_ticks.keys()) |
|
|
|
|
|
for date, version in benchmark_ticks.items(): |
|
|
|
update_date = next((update_date for update_date, ver in benchmark_update.items() if version in ver), None) |
|
|
|
if update_date: |
|
|
|
fig.add_shape( |
|
go.layout.Shape( |
|
type='line', |
|
x0=date, |
|
x1=date, |
|
y0=0, |
|
y1=1, |
|
yref='paper', |
|
line=dict(color='#A9A9A9', dash='dash'), |
|
) |
|
) |
|
|
|
|
|
fig.add_trace( |
|
go.Scatter( |
|
x=[date]*100, |
|
y=list(range(0,100)), |
|
mode='markers', |
|
line=dict(color='rgba(255,255,255,0)', width=0), |
|
hovertext=[ |
|
f"Version: {version} released on {date.strftime('%d %b %Y')}, last updated on: {update_date.strftime('%d %b %Y')}" |
|
for _ in range(100) |
|
], |
|
hoverinfo="text", |
|
hoveron='points', |
|
showlegend=False |
|
) |
|
) |
|
|
|
|
|
if mobile_view: |
|
|
|
one_month = pd.DateOffset(months=1) |
|
filtered_custom_tickvals = [ |
|
date for date in custom_tickvals |
|
if not any((benchmark_date - one_month <= date <= benchmark_date + one_month) for benchmark_date in benchmark_tickvals) |
|
] |
|
|
|
benchmark_tick_texts = [] |
|
for i in range(len(benchmark_tickvals)): |
|
if i == 0: |
|
benchmark_tick_texts.append(f"<br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>") |
|
else: |
|
date_diff = (benchmark_tickvals[i] - benchmark_tickvals[i - 1]).days |
|
if date_diff <= 75: |
|
benchmark_tick_texts.append(f"<br><br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>") |
|
else: |
|
benchmark_tick_texts.append(f"<br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>") |
|
fig.update_xaxes( |
|
tickvals=filtered_custom_tickvals + benchmark_tickvals, |
|
ticktext=[f"{date.strftime('%b')}<br>{date.strftime('%y')}" for date in filtered_custom_tickvals] + |
|
benchmark_tick_texts, |
|
tickangle=0, |
|
tickfont=dict(size=10) |
|
) |
|
fig.update_yaxes(range=[0, 110]) |
|
display_mode = 'lines+markers' |
|
else: |
|
fig.update_xaxes( |
|
tickvals=custom_tickvals + benchmark_tickvals, |
|
ticktext=[f"{date.strftime('%b')} {date.strftime('%Y')}" for date in custom_tickvals] + |
|
[f"<br><span style='font-size:12px;'><b>{benchmark_ticks[date]}</b></span>" for date in benchmark_tickvals], |
|
tickangle=0, |
|
tickfont=dict(size=10) |
|
) |
|
fig.update_yaxes(range=[0, max_clemscore+10]) |
|
display_mode = 'lines+markers+text' |
|
|
|
|
|
|
|
fig.add_trace(go.Scatter(x=df_open['Release date'], y=df_open['clemscore'], |
|
mode=display_mode, |
|
name='Open Models Trendline', |
|
text=df_open['label_model'], |
|
textposition='top center', |
|
line=dict(color=open_color), showlegend=False)) |
|
|
|
|
|
fig.add_trace(go.Scatter(x=df_commercial['Release date'], y=df_commercial['clemscore'], |
|
mode=display_mode, |
|
name='Commercial Models Trendline', |
|
text=df_commercial['label_model'], |
|
textposition='top center', |
|
line=dict(color=comm_color), showlegend=False)) |
|
|
|
|
|
|
|
fig.update_traces(textposition='top center') |
|
|
|
|
|
fig.update_layout(height=height, |
|
legend=dict( |
|
yanchor="top", |
|
y=0.99, |
|
xanchor="left", |
|
x=0.01 |
|
) |
|
) |
|
|
|
if width: |
|
print("Custom Setting Width :") |
|
fig.update_layout(width=width) |
|
|
|
return fig |
|
|
|
def get_final_trend_plot(mobile_view: bool = False, custom_width: int = 0) -> go.Figure: |
|
"""Fetch and generate the final trend plot for all models. |
|
|
|
Args: |
|
custom_width: The custom width to use for loading the graph first time. |
|
mobile_view (bool, optional): Flag to indicate mobile view. Defaults to False. |
|
|
|
Returns: |
|
go.Figure: The generated trend plot for selected benchmark. |
|
""" |
|
|
|
response = requests.get(REGISTRY_URL) |
|
model_registry_data = response.json() |
|
|
|
json_url = REPO + BENCHMARK_FILE |
|
response = requests.get(json_url) |
|
|
|
|
|
if response.status_code != 200: |
|
print(f"Failed to read JSON file: Status Code: {response.status_code}") |
|
|
|
json_data = response.json() |
|
versions = json_data['versions'] |
|
|
|
if mobile_view: |
|
height = 450 |
|
width = 375 |
|
else: |
|
height = 1000 |
|
width = None |
|
|
|
if custom_width: |
|
width = custom_width |
|
|
|
plot_kwargs = {'height': height, 'width': width, 'open_dip': 0, 'comm_dip': 0, |
|
'mobile_view': mobile_view} |
|
|
|
benchmark_ticks = {} |
|
benchmark_update = {} |
|
mm_dfs = get_github_data()['multimodal']['dataframes'] |
|
result_df = get_trend_data(mm_dfs, model_registry_data) |
|
df = result_df |
|
for ver in versions: |
|
if 'multimodal' in ver['version']: |
|
temp_ver = ver['version'] |
|
temp_ver = temp_ver.replace('_multimodal', '') |
|
benchmark_ticks[pd.to_datetime(ver['release_date'])] = temp_ver |
|
benchmark_update[pd.to_datetime(ver['last_updated'])] = temp_ver |
|
|
|
fig = get_plot(df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs) |
|
|
|
return fig |
|
|