HakimAiV2 / figures /main_figure_2a.py
scdrand23's picture
not working version
814a594
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json, os
from statannot import add_stat_annotation
from statannotations.Annotator import Annotator
df = pd.read_csv('results/all_eval/all_metrics_median.csv')
metric = 'dice'
model_names = {metric: 'BiomedParse', f'medsam_{metric}': 'MedSAM (oracle box)', f'sam_{metric}': 'SAM (oracle box)',
f'dino_medsam_{metric}': 'MedSAM (Grounding DINO)', f'dino_sam_{metric}': 'SAM (Grounding DINO)'}
df = df.rename(columns=model_names)
score_vars = list(model_names.values())
modality_list = ['CT', 'MRI', 'X-Ray', 'Pathology', 'Ultrasound', 'Fundus', 'Endoscope', 'Dermoscopy', 'OCT']
# modify modality names
mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
df['modality'] = df['modality'].apply(lambda x: mod_names[x])
# add an "All" modality
all_df = df.copy()
all_df['modality'] = 'All'
df = pd.concat([df, all_df])
df_long = df[['modality', 'task']+score_vars].melt(id_vars=['modality', 'task'], var_name='Model', value_name='Performance')
# add statistical annotations
fig, ax = plt.subplots(figsize=(9, 6))
ax = sns.boxplot(data=df_long, x='modality', y='Performance', hue='Model', ax=ax, palette='Set2',
order=['All']+modality_list,
whis=2, saturation=0.6, linewidth=0.8, fliersize=0.5) # whiskers at 5th and 95th percentile)
#errorbar='sd', capsize=0.1, errwidth=1.5)
# no frame
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
# add arrow on y axis
ax.annotate('', xy=(0, 1.05), xytext=(0, -0.01), arrowprops=dict(arrowstyle='->', lw=1, color='black'), xycoords='axes fraction')
plt.title('')
if metric == 'dice':
plt.ylabel('Dice score', fontsize=18)
elif metric == 'assd':
plt.ylabel('ASSD', fontsize=18)
plt.xlabel('')
plt.xticks(rotation=45, fontsize=16)
plt.yticks(fontsize=14)
# axis thickness
ax.spines['bottom'].set_linewidth(1)
ax.spines['left'].set_linewidth(1)
# change to log scale
if metric == 'assd':
plt.yscale('log')
# set legend names
ax.legend(score_vars, fontsize=14)
# legend on top in a row, without frame
plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.4), ncol=2, fontsize=14, frameon=False)
# Define pairs between models for each modality
box_pairs = []
# Add statistical annotations for each modality
for modality in ['All']+modality_list:
# Define pairs between models within the same modality
box_pairs += [((modality, 'BiomedParse'), (modality, 'MedSAM (oracle box)'))]
annotator = Annotator(ax, box_pairs, data=df_long, x='modality', y='Performance', hue='Model',
order=['All']+modality_list)
annotator.configure(test='t-test_paired', text_format='star', loc='inside', hide_non_significant=True)
annotator.apply_test(alternative='less')
annotator.annotate()
plt.tight_layout()
# save the plot
ax.get_figure().savefig(f'plots/{metric}_comparison.png')
ax.get_figure().savefig(f'plots/{metric}_comparison.pdf', bbox_inches='tight')