File size: 5,962 Bytes
37b3db0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# Copyright (c) Meta Platforms, Inc. and affiliates
import math
import torch
import torch.nn.functional as F
try:
import wandb
except ImportError:
wandb = None
import json
import os
from constants import CHEXPERT_CLASS_PROMPTS, CHEXPERT_CLASS_PROMPTS_webdataset, RSNA_CLASS_PROMPTS_webdataset, \
RSNA_CLASS_PROMPTS, thyroid_us_prompts, breast_us_prompts, meniscus_mri_prompts, acl_mri_prompts, \
radimagenet_all_prompts, ct_scan_labels, diabetic_retinopathy_prompts, PCAM, LC25000_lung, \
LC25000_colon, NCK_CRC_prompts, BACH_prompts, Osteo_prompts, skin_cancer_prompts, \
skin_tumor_prompts, SICAPv2_prompts, refuge_prompts, five_prompts, odir_retina_prompts
from tqdm import tqdm
from collections import defaultdict
def is_global_master(args):
return args.rank == 0
def is_local_master(args):
return args.local_rank == 0
def is_master(args, local=False):
return is_local_master(args) if local else is_global_master(args)
SKIP_DATASETS = ["imagenet", "radimagenet", 'SICAPv2']
# SKIP_DATASETS = ["radimagenet", "imagenet", "rsna_pneumonia", "meniscal_mri", "breast_us", "acl_mri", "thyroid_us", 'PCAM', 'LC25000_lung', 'LC25000_colon', 'CT_axial', 'CT_coronal', 'CT_sagittal', 'dr_uwf', 'dr_regular', 'SICAPv2', 'NCK_CRC', 'skin_tumor', 'skin_cancer', 'BACH'] # 'BACH'
@torch.no_grad()
def slip_evaluate(args, model, val_transform, tokenizer, epoch=0):
metrics = {}
if not is_master(args):
return metrics
from clipeval import datasets, eval_zeroshot
catalog, all_templates, all_labels = eval_zeroshot.load_metadata("clipeval")
if hasattr(model, "module"):
model = model.module
for d in catalog:
if d in SKIP_DATASETS:
continue
val_dataset = datasets.get_downstream_dataset(
catalog, d, is_train=False, transform=val_transform)
if d == 'chexpert-5x200':
# templates = CHEXPERT_CLASS_PROMPTS
templates = CHEXPERT_CLASS_PROMPTS_webdataset
labels = None
elif d == 'rsna_pneumonia':
# templates = RSNA_CLASS_PROMPTS
templates = RSNA_CLASS_PROMPTS_webdataset
labels = None
elif d == 'thyroid_us':
templates = thyroid_us_prompts
labels = None
elif d == 'breast_us':
templates = breast_us_prompts
labels = None
elif d == 'meniscal_mri':
templates = meniscus_mri_prompts
labels = None
elif d == 'acl_mri':
templates = acl_mri_prompts
labels = None
elif d == 'radimagenet':
templates = radimagenet_all_prompts
labels = None
elif (d == "CT_axial") or (d == "CT_coronal") or (d == "CT_sagittal"):
templates = ct_scan_labels
labels = None
elif (d == "dr_regular") or (d == "dr_uwf"):
templates = diabetic_retinopathy_prompts
labels = None
elif d == 'LC25000_lung':
templates = LC25000_lung
labels = None
elif d == 'LC25000_colon':
templates = LC25000_colon
labels = None
elif d == 'PCAM':
templates = PCAM
labels = None
elif d == 'NCK_CRC':
templates = NCK_CRC_prompts
elif d == 'BACH':
templates = BACH_prompts
elif d == 'Osteo':
templates = Osteo_prompts
elif d == 'skin_cancer':
templates = skin_cancer_prompts
elif d == 'SICAPv2':
templates = SICAPv2_prompts
elif d == 'skin_tumor':
templates = skin_tumor_prompts
elif d == 'refuge_retina':
templates = refuge_prompts
elif d == 'five_retina':
templates = five_prompts
elif d == 'odir_retina':
templates = odir_retina_prompts
else:
templates = all_templates[d]
labels = all_labels[d]
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size // 2, shuffle=False,
num_workers=args.workers, pin_memory=False, drop_last=False)
metric = eval_zeroshot.evaluate(d, val_loader, templates, labels, model, tokenizer)
metrics[d] = metric
json_str = json.dumps({"task": d, "acc": metric})
if args.rank == 0:
print(json_str)
with open(os.path.join(args.output_dir, "slip.txt"), mode="a+", encoding="utf-8") as f:
f.write(f"Saving results for Epoch {epoch}" + "\n")
f.write(json_str + "\n")
if args.wandb:
assert wandb is not None, 'Please install wandb.'
for name, val in metrics.items():
if name == 'radimagenet':
wandb.log({f"val/{name}": val["acc"], 'epoch': epoch})
elif name == 'chexpert-5x200' or name == 'chexpert-5x200' or name == 'CT_sagittal' or \
name == 'CT_axial' or name == 'CT_coronal' or name == 'dr_uwf' or name == 'dr_regular' or \
name == 'PCAM' or name == 'LC25000_lung' or name == 'LC25000_colon' \
or name == 'NCK_CRC' or name == 'BACH' or name == 'Osteo' \
or name == 'skin_cancer' or name == "skin_tumor" or name == 'SICAPv2' \
or name == 'five_retina' or name == 'odir_retina':
wandb.log({f"val/{name}": val, 'epoch': epoch})
else:
wandb.log({f"val/{name}/acc": val['acc'],
f"val/{name}/auc_roc": val['auc_roc'],
f"val/{name}/precision_score": val['precision_score'],
f"val/{name}/f1_score": val['f1_score'],
f"val/{name}/recall_score": val['recall_score'],
'epoch': epoch})
return metrics
|