import gradio as gr |
import numpy as np |
import pandas as pd |
import plotly.graph_objects as go |
from datasets import load_dataset |
from evaluate.utils import parse_readme |
from scipy.stats import gaussian_kde, spearmanr |
import generate_annotated_diffs |
from api_wrappers import hf_data_loader |
from generation_steps.metrics_analysis import AGGR_METRICS, edit_distance_fn |
colors = { |
"Expert-labeled": "#C19C0B", |
"Synthetic Backward": "#913632", |
"Synthetic Forward": "#58136a", |
"Full": "#000000", |
} |
"Edit Distance": "editdist", |
"Edit Similarity": "editsim", |
"BLEU": "bleu", |
"METEOR": "meteor", |
"ROUGE-1": "rouge1", |
"ROUGE-2": "rouge2", |
"ROUGE-L": "rougeL", |
"BERTScore": "bertscore", |
"ChrF": "chrF", |
} |
df_related = generate_annotated_diffs.data_with_annotated_diffs() |
def golden(): |
return df_related.loc[(df_related["G_type"] == "initial") & (df_related["E_type"] == "expert_labeled")].reset_index( |
drop=True |
) |
def backward(): |
return df_related.loc[ |
(df_related["G_type"] == "synthetic_backward") & (df_related["E_type"] == "expert_labeled") |
].reset_index(drop=True) |
def forward(): |
return df_related.loc[ |
(df_related["G_type"] == "initial") & (df_related["E_type"] == "synthetic_forward") |
].reset_index(drop=True) |
def forward_from_backward(): |
return df_related.loc[ |
(df_related.G_type == "synthetic_backward") |
& (df_related.E_type.isin(["synthetic_forward", "synthetic_forward_from_backward"])) |
].reset_index(drop=True) |
n_diffs_manual = len(golden()) |
n_diffs_synthetic_backward = len(backward()) |
n_diffs_synthetic_forward = len(forward()) |
n_diffs_synthetic_forward_backward = len(forward_from_backward()) |
def update_dataset_view(diff_idx, df): |
diff_idx -= 1 |
return ( |
df.iloc[diff_idx]["annotated_diff"], |
df.iloc[diff_idx]["commit_msg_start"] if "commit_msg_start" in df.columns else df.iloc[diff_idx]["G_text"], |
df.iloc[diff_idx]["commit_msg_end"] if "commit_msg_end" in df.columns else df.iloc[diff_idx]["E_text"], |
f"https://github.com/{df.iloc[diff_idx]['repo']}/commit/{df.iloc[diff_idx]['hash']}", |
) |
def update_dataset_view_manual(diff_idx): |
return update_dataset_view(diff_idx, golden()) |
def update_dataset_view_synthetic_backward(diff_idx): |
return update_dataset_view(diff_idx, backward()) |
def update_dataset_view_synthetic_forward(diff_idx): |
return update_dataset_view(diff_idx, forward()) |
def update_dataset_view_synthetic_forward_backward(diff_idx): |
return update_dataset_view(diff_idx, forward_from_backward()) |
def number_of_pairs_plot(): |
related_plot_dict = { |
"Full": df_related, |
"Synthetic Backward": backward(), |
"Synthetic Forward": pd.concat([forward(), forward_from_backward()], axis=0, ignore_index=True), |
"Expert-labeled": golden(), |
} |
df_unrelated = hf_data_loader.load_synthetic_as_pandas() |
df_unrelated = df_unrelated.loc[~df_unrelated.is_related].copy() |
unrelated_plot_dict = { |
"Full": df_unrelated, |
"Synthetic Backward": df_unrelated.loc[ |
(df_unrelated["G_type"] == "synthetic_backward") |
& (~df_unrelated.E_type.isin(["synthetic_forward", "synthetic_forward_from_backward"])) |
], |
"Synthetic Forward": df_unrelated.loc[ |
((df_unrelated["G_type"] == "initial") & (df_unrelated["E_type"] == "synthetic_forward")) |
| ( |
(df_unrelated["G_type"] == "synthetic_backward") |
& (df_unrelated["E_type"].isin(["synthetic_forward", "synthetic_forward_from_backward"])) |
) |
], |
"Expert-labeled": df_unrelated.loc[ |
(df_unrelated.G_type == "initial") & (df_unrelated.E_type == "expert_labeled") |
], |
} |
traces = [] |
for split in related_plot_dict.keys(): |
related_count = len(related_plot_dict[split]) |
unrelated_count = len(unrelated_plot_dict[split]) |
traces.append( |
go.Bar( |
name=f"{split} - Related pairs", |
x=[split], |
y=[related_count], |
marker=dict( |
color=colors[split], |
), |
) |
) |
traces.append( |
go.Bar( |
name=f"{split} - Conditionally independent pairs", |
x=[split], |
y=[unrelated_count], |
marker=dict( |
color=colors[split], |
pattern=dict( |
shape="/", |
fillmode="overlay", |
solidity=0.5, |
), |
), |
) |
) |
fig = go.Figure(data=traces) |
fig.update_layout( |
barmode="stack", |
bargap=0.2, |
xaxis=dict(title="Split", showgrid=True, gridcolor="lightgrey"), |
yaxis=dict(title="Number of Examples", showgrid=True, gridcolor="lightgrey"), |
legend=dict(title="Pair Type", orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), |
plot_bgcolor="rgba(0,0,0,0)", |
paper_bgcolor="rgba(0,0,0,0)", |
width=1100, |
) |
return fig |
def edit_distance_plot(): |
df_edit_distance = { |
"Full": [edit_distance_fn(pred=row["G_text"], ref=row["E_text"]) for _, row in df_related.iterrows()], |
"Synthetic Backward": [ |
edit_distance_fn(pred=row["G_text"], ref=row["E_text"]) for _, row in backward().iterrows() |
], |
"Synthetic Forward": [ |
edit_distance_fn(pred=row["G_text"], ref=row["E_text"]) |
for _, row in pd.concat([forward(), forward_from_backward()], axis=0, ignore_index=True).iterrows() |
], |
"Expert-labeled": [edit_distance_fn(pred=row["G_text"], ref=row["E_text"]) for _, row in golden().iterrows()], |
} |
traces = [] |
for key in df_edit_distance: |
kde_x = np.linspace(0, 1200, 1000) |
kde = gaussian_kde(df_edit_distance[key]) |
kde_line = go.Scatter(x=kde_x, y=kde(kde_x), mode="lines", name=key, line=dict(color=colors[key], width=5)) |
traces.append(kde_line) |
fig = go.Figure(data=traces) |
fig.update_layout( |
bargap=0.1, |
xaxis=dict(title=dict(text="Edit Distance"), range=[0, 1200], showgrid=True, gridcolor="lightgrey"), |
yaxis=dict( |
title=dict(text="Probability Density"), |
range=[0, 0.004], |
showgrid=True, |
gridcolor="lightgrey", |
tickvals=[0.0005, 0.001, 0.0015, 0.002, 0.0025, 0.003, 0.0035, 0.004], |
tickformat=".4f", |
), |
plot_bgcolor="rgba(0,0,0,0)", |
paper_bgcolor="rgba(0,0,0,0)", |
width=1100, |
) |
return fig |
def get_correlations_table(online_metric_name: str) -> pd.DataFrame: |
df = load_dataset( |
"JetBrains-Research/synthetic-commit-msg-edits", "all_pairs_with_metrics_other_online_metrics", split="train" |
).to_pandas() |
corr_df = ( |
df.loc[~df.is_related] |
.groupby(["G_text", "G_type", "hash", "repo"] + [f"online_{online_metric_name}"]) |
.apply(lambda g: g.to_dict(orient="records"), include_groups=False) |
.reset_index(name="unrelated_pairs") |
.copy() |
) |
_ = corr_df.copy() |
for metric in AGGR_METRICS: |
if metric in ["editdist"]: |
_[metric] = _.unrelated_pairs.apply(lambda pairs: min(pair[metric] for pair in pairs)) |
else: |
_[metric] = _.unrelated_pairs.apply(lambda pairs: max(pair[metric] for pair in pairs)) |
results = [] |
for metric in AGGR_METRICS: |
x = _[metric].to_numpy() |
y = _[f"online_{online_metric_name}"].to_numpy() |
corr, p_value = spearmanr(x, y) |
results.append({"metric": metric, "corr": corr, "p_value": p_value}) |
__ = pd.DataFrame(results) |
__["p_value"] = ["< 0.05" if p < 0.05 else p for p in __.p_value] |
__["corr_abs"] = abs(__["corr"]) |
__["corr"] = __["corr"].round(2) |
__["metric"] = __["metric"].map({v: k for k, v in METRICS.items()}) |
return ( |
__.sort_values(by=["corr_abs"], ascending=False) |
.drop(columns=["corr_abs"]) |
.rename(columns={"metric": "Metric m", "corr": "Correlation Q(m, m*)", "p_value": "p-value"}) |
) |
force_light_theme_js_func = """ |
function refresh() { |
const url = new URL(window.location); |
if (url.searchParams.get('__theme') !== 'light') { |
url.searchParams.set('__theme', 'light'); |
window.location.href = url.href; |
} |
} |
""" |
if __name__ == "__main__": |
with gr.Blocks(theme=gr.themes.Soft(), js=force_light_theme_js_func) as application: |
gr.Markdown(parse_readme("README.md")) |
def dataset_view_tab(n_items): |
slider = gr.Slider(minimum=1, maximum=n_items, step=1, value=1, label=f"Sample number (total: {n_items})") |
diff_view = gr.Highlightedtext(combine_adjacent=True, color_map={"+": "green", "-": "red"}) |
start_view = gr.Textbox(interactive=False, label="Initial message G", container=True) |
end_view = gr.Textbox(interactive=False, label="Edited message E", container=True) |
link_view = gr.Markdown() |
view = [diff_view, start_view, end_view, link_view] |
return slider, view |
with gr.Tab("Examples Exploration"): |
with gr.Tab("Manual"): |
slider_manual, view_manual = dataset_view_tab(n_diffs_manual) |
slider_manual.change(update_dataset_view_manual, inputs=slider_manual, outputs=view_manual) |
with gr.Tab("Synthetic Backward"): |
slider_synthetic_backward, view_synthetic_backward = dataset_view_tab(n_diffs_synthetic_backward) |
slider_synthetic_backward.change( |
update_dataset_view_synthetic_backward, |
inputs=slider_synthetic_backward, |
outputs=view_synthetic_backward, |
) |
with gr.Tab("Synthetic Forward (from initial)"): |
slider_synthetic_forward, view_synthetic_forward = dataset_view_tab(n_diffs_synthetic_forward) |
slider_synthetic_forward.change( |
update_dataset_view_synthetic_forward, |
inputs=slider_synthetic_forward, |
outputs=view_synthetic_forward, |
) |
with gr.Tab("Synthetic Forward (from backward)"): |
slider_synthetic_forward_backward, view_synthetic_forward_backward = dataset_view_tab( |
n_diffs_synthetic_forward_backward |
) |
slider_synthetic_forward_backward.change( |
update_dataset_view_synthetic_forward_backward, |
inputs=slider_synthetic_forward_backward, |
outputs=view_synthetic_forward_backward, |
) |
with gr.Tab("Dataset Statistics"): |
gr.Markdown("## Number of examples per split") |
number_of_pairs_gr_plot = gr.Plot(number_of_pairs_plot, label=None) |
gr.Markdown("## Edit Distance Distribution (w/o PyCharm Logs)") |
edit_distance_gr_plot = gr.Plot(edit_distance_plot(), label=None) |
with gr.Tab("Experimental Results"): |
gr.Markdown( |
"Here, we provide the additional experimental results with different text similarity metrics used as the target online metric, " |
"in addition to edit distance between generated messages G and their edited counterparts E." |
) |
gr.Markdown( |
"Please, select one of the available metrics **m*** below to see the correlations **Q(m, m\*)** of offline text similarity metrics with **m*** as an online metric." |
) |
for metric in METRICS: |
with gr.Tab(metric): |
gr.Markdown( |
f"The table below presents the correlation coefficients **Q(m, m\*)** where {metric} is used as an online metric **m***." |
) |
result_df = get_correlations_table(METRICS[metric]) |
gr.DataFrame(result_df) |
application.load(update_dataset_view_manual, inputs=slider_manual, outputs=view_manual) |
application.load( |
update_dataset_view_synthetic_backward, inputs=slider_synthetic_backward, outputs=view_synthetic_backward |
) |
application.load( |
update_dataset_view_synthetic_forward, inputs=slider_synthetic_forward, outputs=view_synthetic_forward |
) |
application.load( |
update_dataset_view_synthetic_forward_backward, |
inputs=slider_synthetic_forward_backward, |
outputs=view_synthetic_forward_backward, |
) |
application.launch() |