muzairkhattak
first commit for the demo
37b3db0
raw
history blame
5.96 kB
# Copyright (c) Meta Platforms, Inc. and affiliates
import math
import torch
import torch.nn.functional as F
try:
import wandb
except ImportError:
wandb = None
import json
import os
from constants import CHEXPERT_CLASS_PROMPTS, CHEXPERT_CLASS_PROMPTS_webdataset, RSNA_CLASS_PROMPTS_webdataset, \
RSNA_CLASS_PROMPTS, thyroid_us_prompts, breast_us_prompts, meniscus_mri_prompts, acl_mri_prompts, \
radimagenet_all_prompts, ct_scan_labels, diabetic_retinopathy_prompts, PCAM, LC25000_lung, \
LC25000_colon, NCK_CRC_prompts, BACH_prompts, Osteo_prompts, skin_cancer_prompts, \
skin_tumor_prompts, SICAPv2_prompts, refuge_prompts, five_prompts, odir_retina_prompts
from tqdm import tqdm
from collections import defaultdict
def is_global_master(args):
return args.rank == 0
def is_local_master(args):
return args.local_rank == 0
def is_master(args, local=False):
return is_local_master(args) if local else is_global_master(args)
SKIP_DATASETS = ["imagenet", "radimagenet", 'SICAPv2']
# SKIP_DATASETS = ["radimagenet", "imagenet", "rsna_pneumonia", "meniscal_mri", "breast_us", "acl_mri", "thyroid_us", 'PCAM', 'LC25000_lung', 'LC25000_colon', 'CT_axial', 'CT_coronal', 'CT_sagittal', 'dr_uwf', 'dr_regular', 'SICAPv2', 'NCK_CRC', 'skin_tumor', 'skin_cancer', 'BACH'] # 'BACH'
@torch.no_grad()
def slip_evaluate(args, model, val_transform, tokenizer, epoch=0):
metrics = {}
if not is_master(args):
return metrics
from clipeval import datasets, eval_zeroshot
catalog, all_templates, all_labels = eval_zeroshot.load_metadata("clipeval")
if hasattr(model, "module"):
model = model.module
for d in catalog:
if d in SKIP_DATASETS:
continue
val_dataset = datasets.get_downstream_dataset(
catalog, d, is_train=False, transform=val_transform)
if d == 'chexpert-5x200':
# templates = CHEXPERT_CLASS_PROMPTS
templates = CHEXPERT_CLASS_PROMPTS_webdataset
labels = None
elif d == 'rsna_pneumonia':
# templates = RSNA_CLASS_PROMPTS
templates = RSNA_CLASS_PROMPTS_webdataset
labels = None
elif d == 'thyroid_us':
templates = thyroid_us_prompts
labels = None
elif d == 'breast_us':
templates = breast_us_prompts
labels = None
elif d == 'meniscal_mri':
templates = meniscus_mri_prompts
labels = None
elif d == 'acl_mri':
templates = acl_mri_prompts
labels = None
elif d == 'radimagenet':
templates = radimagenet_all_prompts
labels = None
elif (d == "CT_axial") or (d == "CT_coronal") or (d == "CT_sagittal"):
templates = ct_scan_labels
labels = None
elif (d == "dr_regular") or (d == "dr_uwf"):
templates = diabetic_retinopathy_prompts
labels = None
elif d == 'LC25000_lung':
templates = LC25000_lung
labels = None
elif d == 'LC25000_colon':
templates = LC25000_colon
labels = None
elif d == 'PCAM':
templates = PCAM
labels = None
elif d == 'NCK_CRC':
templates = NCK_CRC_prompts
elif d == 'BACH':
templates = BACH_prompts
elif d == 'Osteo':
templates = Osteo_prompts
elif d == 'skin_cancer':
templates = skin_cancer_prompts
elif d == 'SICAPv2':
templates = SICAPv2_prompts
elif d == 'skin_tumor':
templates = skin_tumor_prompts
elif d == 'refuge_retina':
templates = refuge_prompts
elif d == 'five_retina':
templates = five_prompts
elif d == 'odir_retina':
templates = odir_retina_prompts
else:
templates = all_templates[d]
labels = all_labels[d]
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size // 2, shuffle=False,
num_workers=args.workers, pin_memory=False, drop_last=False)
metric = eval_zeroshot.evaluate(d, val_loader, templates, labels, model, tokenizer)
metrics[d] = metric
json_str = json.dumps({"task": d, "acc": metric})
if args.rank == 0:
print(json_str)
with open(os.path.join(args.output_dir, "slip.txt"), mode="a+", encoding="utf-8") as f:
f.write(f"Saving results for Epoch {epoch}" + "\n")
f.write(json_str + "\n")
if args.wandb:
assert wandb is not None, 'Please install wandb.'
for name, val in metrics.items():
if name == 'radimagenet':
wandb.log({f"val/{name}": val["acc"], 'epoch': epoch})
elif name == 'chexpert-5x200' or name == 'chexpert-5x200' or name == 'CT_sagittal' or \
name == 'CT_axial' or name == 'CT_coronal' or name == 'dr_uwf' or name == 'dr_regular' or \
name == 'PCAM' or name == 'LC25000_lung' or name == 'LC25000_colon' \
or name == 'NCK_CRC' or name == 'BACH' or name == 'Osteo' \
or name == 'skin_cancer' or name == "skin_tumor" or name == 'SICAPv2' \
or name == 'five_retina' or name == 'odir_retina':
wandb.log({f"val/{name}": val, 'epoch': epoch})
else:
wandb.log({f"val/{name}/acc": val['acc'],
f"val/{name}/auc_roc": val['auc_roc'],
f"val/{name}/precision_score": val['precision_score'],
f"val/{name}/f1_score": val['f1_score'],
f"val/{name}/recall_score": val['recall_score'],
'epoch': epoch})
return metrics