File size: 3,124 Bytes
814a594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json, os

from statannot import add_stat_annotation
from statannotations.Annotator import Annotator

df = pd.read_csv('results/all_eval/all_metrics_mean.csv')

# modify modality names
mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology', 
             'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
df['modality'] = df['modality'].apply(lambda x: mod_names[x])

# MedSAM reported tasks
reported_baseline_df = pd.read_csv('results/all_eval/reported_baseline_tasks.csv')

# find overlap between the dfs by dataset and target
overlap_df = pd.merge(df, reported_baseline_df, on=['task', 'modality', 'site', 'target'], 
                      suffixes=('_biomedparse', '_baseline'))
# non-overlapping datasets
non_overlap_df = df[~df['task'].isin(overlap_df['task'])]



baseline = 'sam'
metric = 'IRI'

baseline_names = {'medsam': 'MedSAM', 'sam': 'SAM'}
metric_names = {'box_ratio': 'Box Ratio', 'convex_ratio': 'Convex Ratio', 
                'IRI': 'Inversed Rotational Inertia'}

non_overlap_df['diff'] = non_overlap_df[f'dice'] - non_overlap_df[f'{baseline}_dice']
# scatter plot
fig, ax = plt.subplots(figsize=(7,5))
sns.scatterplot(data=non_overlap_df, x=metric, y='diff', ax=ax, markers='o', s=80)

# add linear regression line
sns.regplot(data=non_overlap_df, x=metric, y='diff', ax=ax, scatter=False, 
            color='k', line_kws={'linestyle':'--', 'linewidth':1})

# remove all spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)


# add arrow on x-axis and y-axis
xlim = [0, 1.05]
ylim = [-0.06, 0.79]
ax.annotate('', xy=(xlim[1], ylim[0]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
ax.annotate('', xy=(xlim[0], ylim[1]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
ax.set_xlim(xlim)
ax.set_ylim(ylim)

ax.xaxis.set_tick_params(width=1.5)
ax.yaxis.set_tick_params(width=1.5)

# set x-ticks and y-ticks
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)

# show R^2 value, p value, and equation of the line
from scipy.stats import linregress
slope, intercept, r_value, p_value, std_err = linregress(non_overlap_df[metric], non_overlap_df['diff'])
x_text = 0.4
plt.text(x_text, 0.84, f'$R^2={r_value**2:.2f}$', fontsize=20, transform=ax.transAxes)
plt.text(x_text, 0.77, f'$p={p_value:.2e}$', fontsize=20, transform=ax.transAxes)
plt.text(x_text, 0.7, f'$y={slope:.2f}x+{intercept:.2f}$', fontsize=20, transform=ax.transAxes)

plt.title('')
plt.ylabel(f'Improvement over {baseline_names[baseline]}', fontsize=20)
plt.xlabel(f'{metric_names[metric]}', fontsize=22)

plt.tight_layout()

# save the plot
ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.png')
ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.pdf', bbox_inches='tight')