Christina Theodoris commited on
Commit
ae4867d
·
1 Parent(s): 5760b30

make plot roc compatible with eval-only metrics

Browse files
Files changed (1) hide show
  1. geneformer/evaluation_utils.py +9 -5
geneformer/evaluation_utils.py CHANGED
@@ -182,14 +182,18 @@ def plot_ROC(roc_metric_dict, model_style_dict, title, output_dir, output_prefix
182
  for model_name in roc_metric_dict.keys():
183
  mean_fpr = roc_metric_dict[model_name]["mean_fpr"]
184
  mean_tpr = roc_metric_dict[model_name]["mean_tpr"]
185
- roc_auc = roc_metric_dict[model_name]["roc_auc"]
186
- roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"]
187
  color = model_style_dict[model_name]["color"]
188
  linestyle = model_style_dict[model_name]["linestyle"]
189
- if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1:
190
- label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})"
 
191
  else:
192
- label = f"{model_name} (AUC {roc_auc:0.2f})"
 
 
 
 
 
193
  plt.plot(
194
  mean_fpr, mean_tpr, color=color, linestyle=linestyle, lw=lw, label=label
195
  )
 
182
  for model_name in roc_metric_dict.keys():
183
  mean_fpr = roc_metric_dict[model_name]["mean_fpr"]
184
  mean_tpr = roc_metric_dict[model_name]["mean_tpr"]
 
 
185
  color = model_style_dict[model_name]["color"]
186
  linestyle = model_style_dict[model_name]["linestyle"]
187
+ if "roc_auc" not in roc_metric_dict[model_name].keys():
188
+ all_roc_auc = roc_metric_dict[model_name]["all_roc_auc"]
189
+ label = f"{model_name} (AUC {all_roc_auc:0.2f})"
190
  else:
191
+ roc_auc = roc_metric_dict[model_name]["roc_auc"]
192
+ roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"]
193
+ if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1:
194
+ label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})"
195
+ else:
196
+ label = f"{model_name} (AUC {roc_auc:0.2f})"
197
  plt.plot(
198
  mean_fpr, mean_tpr, color=color, linestyle=linestyle, lw=lw, label=label
199
  )