File size: 5,962 Bytes
37b3db0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright (c) Meta Platforms, Inc. and affiliates

import math
import torch
import torch.nn.functional as F

try:
    import wandb
except ImportError:
    wandb = None
import json
import os
from constants import CHEXPERT_CLASS_PROMPTS, CHEXPERT_CLASS_PROMPTS_webdataset, RSNA_CLASS_PROMPTS_webdataset, \
    RSNA_CLASS_PROMPTS, thyroid_us_prompts, breast_us_prompts, meniscus_mri_prompts, acl_mri_prompts, \
    radimagenet_all_prompts, ct_scan_labels, diabetic_retinopathy_prompts, PCAM, LC25000_lung, \
    LC25000_colon, NCK_CRC_prompts, BACH_prompts, Osteo_prompts, skin_cancer_prompts, \
    skin_tumor_prompts, SICAPv2_prompts, refuge_prompts, five_prompts, odir_retina_prompts

from tqdm import tqdm
from collections import defaultdict


def is_global_master(args):
    return args.rank == 0


def is_local_master(args):
    return args.local_rank == 0


def is_master(args, local=False):
    return is_local_master(args) if local else is_global_master(args)


SKIP_DATASETS = ["imagenet", "radimagenet", 'SICAPv2']
# SKIP_DATASETS = ["radimagenet", "imagenet", "rsna_pneumonia", "meniscal_mri", "breast_us", "acl_mri", "thyroid_us", 'PCAM', 'LC25000_lung', 'LC25000_colon', 'CT_axial', 'CT_coronal', 'CT_sagittal', 'dr_uwf', 'dr_regular', 'SICAPv2', 'NCK_CRC', 'skin_tumor', 'skin_cancer', 'BACH'] # 'BACH'


@torch.no_grad()
def slip_evaluate(args, model, val_transform, tokenizer, epoch=0):
    metrics = {}
    if not is_master(args):
        return metrics
    from clipeval import datasets, eval_zeroshot

    catalog, all_templates, all_labels = eval_zeroshot.load_metadata("clipeval")

    if hasattr(model, "module"):
        model = model.module

    for d in catalog:
        if d in SKIP_DATASETS:
            continue
        val_dataset = datasets.get_downstream_dataset(
            catalog, d, is_train=False, transform=val_transform)
        if d == 'chexpert-5x200':
            # templates = CHEXPERT_CLASS_PROMPTS
            templates = CHEXPERT_CLASS_PROMPTS_webdataset
            labels = None
        elif d == 'rsna_pneumonia':
            # templates = RSNA_CLASS_PROMPTS
            templates = RSNA_CLASS_PROMPTS_webdataset
            labels = None
        elif d == 'thyroid_us':
            templates = thyroid_us_prompts
            labels = None
        elif d == 'breast_us':
            templates = breast_us_prompts
            labels = None
        elif d == 'meniscal_mri':
            templates = meniscus_mri_prompts
            labels = None
        elif d == 'acl_mri':
            templates = acl_mri_prompts
            labels = None
        elif d == 'radimagenet':
            templates = radimagenet_all_prompts
            labels = None
        elif (d == "CT_axial") or (d == "CT_coronal") or (d == "CT_sagittal"):
            templates = ct_scan_labels
            labels = None
        elif (d == "dr_regular") or (d == "dr_uwf"):
            templates = diabetic_retinopathy_prompts
            labels = None
        elif d == 'LC25000_lung':
            templates = LC25000_lung
            labels = None
        elif d == 'LC25000_colon':
            templates = LC25000_colon
            labels = None
        elif d == 'PCAM':
            templates = PCAM
            labels = None
        elif d == 'NCK_CRC':
            templates = NCK_CRC_prompts
        elif d == 'BACH':
            templates = BACH_prompts
        elif d == 'Osteo':
            templates = Osteo_prompts
        elif d == 'skin_cancer':
            templates = skin_cancer_prompts
        elif d == 'SICAPv2':
            templates = SICAPv2_prompts
        elif d == 'skin_tumor':
            templates = skin_tumor_prompts
        elif d == 'refuge_retina':
            templates = refuge_prompts
        elif d == 'five_retina':
            templates = five_prompts
        elif d == 'odir_retina':
            templates = odir_retina_prompts
        else:
            templates = all_templates[d]
            labels = all_labels[d]

        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=args.batch_size // 2, shuffle=False,
            num_workers=args.workers, pin_memory=False, drop_last=False)

        metric = eval_zeroshot.evaluate(d, val_loader, templates, labels, model, tokenizer)
        metrics[d] = metric
        json_str = json.dumps({"task": d, "acc": metric})
        if args.rank == 0:
            print(json_str)
            with open(os.path.join(args.output_dir, "slip.txt"), mode="a+", encoding="utf-8") as f:
                f.write(f"Saving results for Epoch {epoch}" + "\n")
                f.write(json_str + "\n")
        if args.wandb:
            assert wandb is not None, 'Please install wandb.'
            for name, val in metrics.items():
                if name == 'radimagenet':
                    wandb.log({f"val/{name}": val["acc"], 'epoch': epoch})
                elif name == 'chexpert-5x200' or name == 'chexpert-5x200' or name == 'CT_sagittal' or \
                        name == 'CT_axial' or name == 'CT_coronal' or name == 'dr_uwf' or name == 'dr_regular' or \
                        name == 'PCAM' or name == 'LC25000_lung' or name == 'LC25000_colon' \
                        or name == 'NCK_CRC' or name == 'BACH' or name == 'Osteo' \
                        or name == 'skin_cancer' or name == "skin_tumor" or name == 'SICAPv2' \
                        or name == 'five_retina' or name == 'odir_retina':

                    wandb.log({f"val/{name}": val, 'epoch': epoch})
                else:
                    wandb.log({f"val/{name}/acc": val['acc'],
                               f"val/{name}/auc_roc": val['auc_roc'],
                               f"val/{name}/precision_score": val['precision_score'],
                               f"val/{name}/f1_score": val['f1_score'],
                               f"val/{name}/recall_score": val['recall_score'],
                               'epoch': epoch})
    return metrics