|
import pandas as pd |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import json, os |
|
|
|
from statannot import add_stat_annotation |
|
from statannotations.Annotator import Annotator |
|
|
|
df = pd.read_csv('results/all_eval/all_metrics_mean.csv') |
|
|
|
|
|
mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology', |
|
'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'} |
|
df['modality'] = df['modality'].apply(lambda x: mod_names[x]) |
|
|
|
|
|
reported_baseline_df = pd.read_csv('results/all_eval/reported_baseline_tasks.csv') |
|
|
|
|
|
overlap_df = pd.merge(df, reported_baseline_df, on=['task', 'modality', 'site', 'target'], |
|
suffixes=('_biomedparse', '_baseline')) |
|
|
|
non_overlap_df = df[~df['task'].isin(overlap_df['task'])] |
|
|
|
|
|
|
|
baseline = 'sam' |
|
metric = 'IRI' |
|
|
|
baseline_names = {'medsam': 'MedSAM', 'sam': 'SAM'} |
|
metric_names = {'box_ratio': 'Box Ratio', 'convex_ratio': 'Convex Ratio', |
|
'IRI': 'Inversed Rotational Inertia'} |
|
|
|
non_overlap_df['diff'] = non_overlap_df[f'dice'] - non_overlap_df[f'{baseline}_dice'] |
|
|
|
fig, ax = plt.subplots(figsize=(7,5)) |
|
sns.scatterplot(data=non_overlap_df, x=metric, y='diff', ax=ax, markers='o', s=80) |
|
|
|
|
|
sns.regplot(data=non_overlap_df, x=metric, y='diff', ax=ax, scatter=False, |
|
color='k', line_kws={'linestyle':'--', 'linewidth':1}) |
|
|
|
|
|
ax.spines['top'].set_visible(False) |
|
ax.spines['right'].set_visible(False) |
|
ax.spines['left'].set_visible(False) |
|
ax.spines['bottom'].set_visible(False) |
|
|
|
|
|
|
|
xlim = [0, 1.05] |
|
ylim = [-0.06, 0.79] |
|
ax.annotate('', xy=(xlim[1], ylim[0]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5)) |
|
ax.annotate('', xy=(xlim[0], ylim[1]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5)) |
|
ax.set_xlim(xlim) |
|
ax.set_ylim(ylim) |
|
|
|
ax.xaxis.set_tick_params(width=1.5) |
|
ax.yaxis.set_tick_params(width=1.5) |
|
|
|
|
|
plt.xticks(fontsize=18) |
|
plt.yticks(fontsize=18) |
|
|
|
|
|
from scipy.stats import linregress |
|
slope, intercept, r_value, p_value, std_err = linregress(non_overlap_df[metric], non_overlap_df['diff']) |
|
x_text = 0.4 |
|
plt.text(x_text, 0.84, f'$R^2={r_value**2:.2f}$', fontsize=20, transform=ax.transAxes) |
|
plt.text(x_text, 0.77, f'$p={p_value:.2e}$', fontsize=20, transform=ax.transAxes) |
|
plt.text(x_text, 0.7, f'$y={slope:.2f}x+{intercept:.2f}$', fontsize=20, transform=ax.transAxes) |
|
|
|
plt.title('') |
|
plt.ylabel(f'Improvement over {baseline_names[baseline]}', fontsize=20) |
|
plt.xlabel(f'{metric_names[metric]}', fontsize=22) |
|
|
|
plt.tight_layout() |
|
|
|
|
|
ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.png') |
|
ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.pdf', bbox_inches='tight') |
|
|
|
|