File size: 4,940 Bytes
37b3db0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright (c) Meta Platforms, Inc. and affiliates

# usage:
# torchrun --nproc_per_node=4 src/training/main.py b16_400m my-experiment-name <path-to-metaclip-pretrained-checkpoint>
from dataclasses import dataclass
from configs import Config


@dataclass
class b32_400m(Config):
    inmem=True
    engine="train_one_epoch_ex"
    eval_steps=5000
    save_frequency=1
    # First prepare UniMed-Dataset using instructions in the docs/PREPARE-UniMed-DATA.md and then,
    # provide paths for each sub-part of UniMed dataset below.
    train_data="/<dataset-path>/radimagenet_webdataset/dataset-{000001..001049}.tar::/<dataset-path>/chexpert_webdataset/dataset-{000001..000212}.tar::/<dataset-path>/openi_webdataset/dataset-{000001..000007}.tar::/<dataset-path>/chest_xray8_webdataset/dataset-{000001..000113}.tar::/<dataset-path>/mimic_cxr/dataset-{000001..000270}.tar::/<dataset-path>/roco_webdataset/dataset-{000001..000061}.tar::/<dataset-path>/pmc_clip_webdataset/dataset-{000001..001645}.tar::/<dataset-path>/llava_med_alignment_set_webdataset/dataset-{000001..000468}.tar::/<dataset-path>/llava_med_hq_60k_set_webdataset/dataset-{000001..000265}.tar::/<dataset-path>/quilt_webdataset/dataset-{000001..001018}.tar::/<dataset-path>/retina_part1_webdataset/dataset-{000001..000155}.tar::/<dataset-path>/retina_part2_webdataset/dataset-{000001..000013}.tar::/<dataset-path>/retina_part3_webdataset/dataset-{000001..000006}.tar"
    # train_num_samples = 1049000 (radimagenet) + 212000 (chexpert) + 7000 (openi) + 113000 (chest-xray8) + 270000 (mimic-cxr) + 61000 (rocov2) + 1645000 (pmc-clip) + 468000 (llavamed-alignment) + 265000 (llava-medhq) + 1018000 (quilt) + 155000 (retina part 1) + 13000 (retina part 2) + 6000 (retina part 3)
    # Total training samples must equal total dataset size
    train_num_samples = 5282000
    # By default, we provide equal weightage to all dataset parts
    train_data_upsampling_factors = "1::1::1::1::1::1::1::1::1::1::1::1"
    # ----------------------------------------
    workers=8
    batch_size=128
    epochs= 10
    eval_freq = 1
    model="ViT-B-32-quickgelu"
    name="ViT-B-32"
    force_quick_gelu=True
    warmup=2000
    seed=0
    local_loss=True
    gather_with_grad=True
    nodes=16
    ngpus=4
    imagenet_val = None
    report_to = 'wandb'
    tokenizer_context_length = 256
    text_encode_model_name = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'

@dataclass
class b32_400m_eval(Config):
    inmem=True
    engine="train_one_epoch_ex"
    eval_steps=5000
    save_frequency=1
    train_data=""
    workers=8
    eval_freq = 1
    train_num_samples=400000000
    batch_size=512
    epochs=10
    model="ViT-B-32-quickgelu"
    name="ViT-B-32"
    force_quick_gelu=True
    warmup=2000
    seed=0
    local_loss=True
    gather_with_grad=True
    nodes=16
    ngpus=4
    imagenet_val = None
    pretrained = '<path-to-metaclip-pretrained-weights-file>/b16_400m.pt'
    text_encode_model_name = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'
    tokenizer_context_length = 256


@dataclass
class b16_400m(b32_400m):
    model="ViT-B-16-quickgelu"
    name="ViT-B-16"
    grad_checkpointing=True
    # Change below
    pretrained = '<path-to-metaclip-pretrained-weights-file>/b16_400m.pt'
    text_encode_model_name = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'

@dataclass
class b16_400m_eval(b32_400m_eval):
    model="ViT-B-16-quickgelu"
    name="ViT-B-16"
    grad_checkpointing=True
    pretrained = '<path-to-metaclip-pretrained-weights-file>/b16_400m.pt'
    text_encoder_model_name = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'


@dataclass
class l14_400m(b32_400m): 
    model="ViT-L-14-336-quickgelu"
    name="ViT-L-14"
    lr=0.0004
    grad_checkpointing=True
    batch_size=128
    nodes=16
    ngpus=8
    text_encoder_model_name = 'microsoft/BiomedNLP-BiomedBERT-large-uncased-abstract'

@dataclass
class l14_400m_eval(b32_400m_eval):
    model="ViT-L-14-336-quickgelu"
    name="ViT-L-14"
    lr=0.0004
    grad_checkpointing=True
    batch_size=256
    nodes=16
    ngpus=8
    text_encoder_model_name = 'microsoft/BiomedNLP-BiomedBERT-large-uncased-abstract'

@dataclass
class l14_400m_base_text_encoder(b32_400m):
    model="ViT-L-14-336-quickgelu"
    name="ViT-L-14"
    lr=0.0004
    grad_checkpointing=True
    batch_size=128
    nodes=16
    ngpus=8
    text_encoder_model_name = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'

@dataclass
class l14_400m_base_text_encoder_eval(b32_400m_eval):
    model="ViT-L-14-336-quickgelu"
    name="ViT-L-14"
    lr=0.0004
    grad_checkpointing=True
    batch_size=256
    nodes=16
    ngpus=8
    text_encode_model_name = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'

if __name__ == "__main__":
    import inspect
    import sys
    for name, obj in inspect.getmembers(sys.modules[__name__]):
        if inspect.isfunction(obj):
            print(name)