|
import pandas as pd |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import json, os |
|
|
|
from statannot import add_stat_annotation |
|
from statannotations.Annotator import Annotator |
|
|
|
df = pd.read_csv('results/all_eval/all_metrics_median.csv') |
|
|
|
|
|
metric = 'assd' |
|
|
|
model_names = {metric: 'BiomedParse', f'medsam_{metric}': 'MedSAM (oracle box)', f'sam_{metric}': 'SAM (oracle box)', |
|
f'dino_medsam_{metric}': 'MedSAM (Grounding DINO)', f'dino_sam_{metric}': 'SAM (Grounding DINO)'} |
|
df = df.rename(columns=model_names) |
|
|
|
score_vars = list(model_names.values()) |
|
|
|
|
|
df = df[df['MedSAM (oracle box)'] < 1e10] |
|
|
|
modality_list = ['CT', 'MRI', 'X-Ray', 'Pathology', 'Ultrasound', 'Fundus', 'Endoscope', 'Dermoscopy', 'OCT'] |
|
|
|
mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology', |
|
'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'} |
|
df['modality'] = df['modality'].apply(lambda x: mod_names[x]) |
|
|
|
|
|
all_df = df.copy() |
|
all_df['modality'] = 'All' |
|
df = pd.concat([df, all_df]) |
|
|
|
df_long = df[['modality', 'task']+score_vars].melt(id_vars=['modality', 'task'], var_name='Model', value_name='Performance') |
|
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(9, 6)) |
|
ax = sns.boxplot(data=df_long, x='modality', y='Performance', hue='Model', ax=ax, palette='Set2', |
|
order=['All']+modality_list, |
|
whis=2, saturation=0.6, linewidth=0.8, fliersize=0.5) |
|
|
|
|
|
|
|
ax.spines['top'].set_visible(False) |
|
ax.spines['right'].set_visible(False) |
|
ax.spines['left'].set_visible(False) |
|
|
|
ax.annotate('', xy=(0, 1.05), xytext=(0, -0.01), arrowprops=dict(arrowstyle='->', lw=1, color='black'), xycoords='axes fraction') |
|
|
|
|
|
plt.title('') |
|
if metric == 'dice': |
|
plt.ylabel('Dice score', fontsize=18) |
|
elif metric == 'assd': |
|
plt.ylabel('ASSD', fontsize=18) |
|
plt.xlabel('') |
|
plt.xticks(rotation=45, fontsize=16) |
|
plt.yticks(fontsize=14) |
|
|
|
|
|
ax.spines['bottom'].set_linewidth(1) |
|
ax.spines['left'].set_linewidth(1) |
|
|
|
|
|
|
|
if metric == 'assd': |
|
plt.yscale('log') |
|
|
|
|
|
ax.legend(score_vars, fontsize=14) |
|
|
|
|
|
plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.4), ncol=2, fontsize=14, frameon=False) |
|
|
|
|
|
box_pairs = [] |
|
|
|
|
|
for modality in ['All']+modality_list: |
|
|
|
box_pairs += [((modality, 'BiomedParse'), (modality, 'MedSAM (oracle box)'))] |
|
annotator = Annotator(ax, box_pairs, data=df_long, x='modality', y='Performance', hue='Model', |
|
order=['All']+modality_list) |
|
annotator.configure(test='t-test_paired', text_format='star', loc='inside', hide_non_significant=True) |
|
annotator.apply_test(alternative='less') |
|
annotator.annotate() |
|
|
|
plt.tight_layout() |
|
|
|
|
|
ax.get_figure().savefig(f'plots/{metric}_comparison.png') |
|
ax.get_figure().savefig(f'plots/{metric}_comparison.pdf', bbox_inches='tight') |