|
|
|
|
|
import math |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
try: |
|
import wandb |
|
except ImportError: |
|
wandb = None |
|
import json |
|
import os |
|
from constants import CHEXPERT_CLASS_PROMPTS, CHEXPERT_CLASS_PROMPTS_webdataset, RSNA_CLASS_PROMPTS_webdataset, \ |
|
RSNA_CLASS_PROMPTS, thyroid_us_prompts, breast_us_prompts, meniscus_mri_prompts, acl_mri_prompts, \ |
|
radimagenet_all_prompts, ct_scan_labels, diabetic_retinopathy_prompts, PCAM, LC25000_lung, \ |
|
LC25000_colon, NCK_CRC_prompts, BACH_prompts, Osteo_prompts, skin_cancer_prompts, \ |
|
skin_tumor_prompts, SICAPv2_prompts, refuge_prompts, five_prompts, odir_retina_prompts |
|
|
|
from tqdm import tqdm |
|
from collections import defaultdict |
|
|
|
|
|
def is_global_master(args): |
|
return args.rank == 0 |
|
|
|
|
|
def is_local_master(args): |
|
return args.local_rank == 0 |
|
|
|
|
|
def is_master(args, local=False): |
|
return is_local_master(args) if local else is_global_master(args) |
|
|
|
|
|
SKIP_DATASETS = ["imagenet", "radimagenet", 'SICAPv2'] |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def slip_evaluate(args, model, val_transform, tokenizer, epoch=0): |
|
metrics = {} |
|
if not is_master(args): |
|
return metrics |
|
from clipeval import datasets, eval_zeroshot |
|
|
|
catalog, all_templates, all_labels = eval_zeroshot.load_metadata("clipeval") |
|
|
|
if hasattr(model, "module"): |
|
model = model.module |
|
|
|
for d in catalog: |
|
if d in SKIP_DATASETS: |
|
continue |
|
val_dataset = datasets.get_downstream_dataset( |
|
catalog, d, is_train=False, transform=val_transform) |
|
if d == 'chexpert-5x200': |
|
|
|
templates = CHEXPERT_CLASS_PROMPTS_webdataset |
|
labels = None |
|
elif d == 'rsna_pneumonia': |
|
|
|
templates = RSNA_CLASS_PROMPTS_webdataset |
|
labels = None |
|
elif d == 'thyroid_us': |
|
templates = thyroid_us_prompts |
|
labels = None |
|
elif d == 'breast_us': |
|
templates = breast_us_prompts |
|
labels = None |
|
elif d == 'meniscal_mri': |
|
templates = meniscus_mri_prompts |
|
labels = None |
|
elif d == 'acl_mri': |
|
templates = acl_mri_prompts |
|
labels = None |
|
elif d == 'radimagenet': |
|
templates = radimagenet_all_prompts |
|
labels = None |
|
elif (d == "CT_axial") or (d == "CT_coronal") or (d == "CT_sagittal"): |
|
templates = ct_scan_labels |
|
labels = None |
|
elif (d == "dr_regular") or (d == "dr_uwf"): |
|
templates = diabetic_retinopathy_prompts |
|
labels = None |
|
elif d == 'LC25000_lung': |
|
templates = LC25000_lung |
|
labels = None |
|
elif d == 'LC25000_colon': |
|
templates = LC25000_colon |
|
labels = None |
|
elif d == 'PCAM': |
|
templates = PCAM |
|
labels = None |
|
elif d == 'NCK_CRC': |
|
templates = NCK_CRC_prompts |
|
elif d == 'BACH': |
|
templates = BACH_prompts |
|
elif d == 'Osteo': |
|
templates = Osteo_prompts |
|
elif d == 'skin_cancer': |
|
templates = skin_cancer_prompts |
|
elif d == 'SICAPv2': |
|
templates = SICAPv2_prompts |
|
elif d == 'skin_tumor': |
|
templates = skin_tumor_prompts |
|
elif d == 'refuge_retina': |
|
templates = refuge_prompts |
|
elif d == 'five_retina': |
|
templates = five_prompts |
|
elif d == 'odir_retina': |
|
templates = odir_retina_prompts |
|
else: |
|
templates = all_templates[d] |
|
labels = all_labels[d] |
|
|
|
val_loader = torch.utils.data.DataLoader( |
|
val_dataset, batch_size=args.batch_size // 2, shuffle=False, |
|
num_workers=args.workers, pin_memory=False, drop_last=False) |
|
|
|
metric = eval_zeroshot.evaluate(d, val_loader, templates, labels, model, tokenizer) |
|
metrics[d] = metric |
|
json_str = json.dumps({"task": d, "acc": metric}) |
|
if args.rank == 0: |
|
print(json_str) |
|
with open(os.path.join(args.output_dir, "slip.txt"), mode="a+", encoding="utf-8") as f: |
|
f.write(f"Saving results for Epoch {epoch}" + "\n") |
|
f.write(json_str + "\n") |
|
if args.wandb: |
|
assert wandb is not None, 'Please install wandb.' |
|
for name, val in metrics.items(): |
|
if name == 'radimagenet': |
|
wandb.log({f"val/{name}": val["acc"], 'epoch': epoch}) |
|
elif name == 'chexpert-5x200' or name == 'chexpert-5x200' or name == 'CT_sagittal' or \ |
|
name == 'CT_axial' or name == 'CT_coronal' or name == 'dr_uwf' or name == 'dr_regular' or \ |
|
name == 'PCAM' or name == 'LC25000_lung' or name == 'LC25000_colon' \ |
|
or name == 'NCK_CRC' or name == 'BACH' or name == 'Osteo' \ |
|
or name == 'skin_cancer' or name == "skin_tumor" or name == 'SICAPv2' \ |
|
or name == 'five_retina' or name == 'odir_retina': |
|
|
|
wandb.log({f"val/{name}": val, 'epoch': epoch}) |
|
else: |
|
wandb.log({f"val/{name}/acc": val['acc'], |
|
f"val/{name}/auc_roc": val['auc_roc'], |
|
f"val/{name}/precision_score": val['precision_score'], |
|
f"val/{name}/f1_score": val['f1_score'], |
|
f"val/{name}/recall_score": val['recall_score'], |
|
'epoch': epoch}) |
|
return metrics |
|
|