File size: 4,940 Bytes
37b3db0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# Copyright (c) Meta Platforms, Inc. and affiliates
# usage:
# torchrun --nproc_per_node=4 src/training/main.py b16_400m my-experiment-name <path-to-metaclip-pretrained-checkpoint>
from dataclasses import dataclass
from configs import Config
@dataclass
class b32_400m(Config):
inmem=True
engine="train_one_epoch_ex"
eval_steps=5000
save_frequency=1
# First prepare UniMed-Dataset using instructions in the docs/PREPARE-UniMed-DATA.md and then,
# provide paths for each sub-part of UniMed dataset below.
train_data="/<dataset-path>/radimagenet_webdataset/dataset-{000001..001049}.tar::/<dataset-path>/chexpert_webdataset/dataset-{000001..000212}.tar::/<dataset-path>/openi_webdataset/dataset-{000001..000007}.tar::/<dataset-path>/chest_xray8_webdataset/dataset-{000001..000113}.tar::/<dataset-path>/mimic_cxr/dataset-{000001..000270}.tar::/<dataset-path>/roco_webdataset/dataset-{000001..000061}.tar::/<dataset-path>/pmc_clip_webdataset/dataset-{000001..001645}.tar::/<dataset-path>/llava_med_alignment_set_webdataset/dataset-{000001..000468}.tar::/<dataset-path>/llava_med_hq_60k_set_webdataset/dataset-{000001..000265}.tar::/<dataset-path>/quilt_webdataset/dataset-{000001..001018}.tar::/<dataset-path>/retina_part1_webdataset/dataset-{000001..000155}.tar::/<dataset-path>/retina_part2_webdataset/dataset-{000001..000013}.tar::/<dataset-path>/retina_part3_webdataset/dataset-{000001..000006}.tar"
# train_num_samples = 1049000 (radimagenet) + 212000 (chexpert) + 7000 (openi) + 113000 (chest-xray8) + 270000 (mimic-cxr) + 61000 (rocov2) + 1645000 (pmc-clip) + 468000 (llavamed-alignment) + 265000 (llava-medhq) + 1018000 (quilt) + 155000 (retina part 1) + 13000 (retina part 2) + 6000 (retina part 3)
# Total training samples must equal total dataset size
train_num_samples = 5282000
# By default, we provide equal weightage to all dataset parts
train_data_upsampling_factors = "1::1::1::1::1::1::1::1::1::1::1::1"
# ----------------------------------------
workers=8
batch_size=128
epochs= 10
eval_freq = 1
model="ViT-B-32-quickgelu"
name="ViT-B-32"
force_quick_gelu=True
warmup=2000
seed=0
local_loss=True
gather_with_grad=True
nodes=16
ngpus=4
imagenet_val = None
report_to = 'wandb'
tokenizer_context_length = 256
text_encode_model_name = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'
@dataclass
class b32_400m_eval(Config):
inmem=True
engine="train_one_epoch_ex"
eval_steps=5000
save_frequency=1
train_data=""
workers=8
eval_freq = 1
train_num_samples=400000000
batch_size=512
epochs=10
model="ViT-B-32-quickgelu"
name="ViT-B-32"
force_quick_gelu=True
warmup=2000
seed=0
local_loss=True
gather_with_grad=True
nodes=16
ngpus=4
imagenet_val = None
pretrained = '<path-to-metaclip-pretrained-weights-file>/b16_400m.pt'
text_encode_model_name = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'
tokenizer_context_length = 256
@dataclass
class b16_400m(b32_400m):
model="ViT-B-16-quickgelu"
name="ViT-B-16"
grad_checkpointing=True
# Change below
pretrained = '<path-to-metaclip-pretrained-weights-file>/b16_400m.pt'
text_encode_model_name = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'
@dataclass
class b16_400m_eval(b32_400m_eval):
model="ViT-B-16-quickgelu"
name="ViT-B-16"
grad_checkpointing=True
pretrained = '<path-to-metaclip-pretrained-weights-file>/b16_400m.pt'
text_encoder_model_name = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'
@dataclass
class l14_400m(b32_400m):
model="ViT-L-14-336-quickgelu"
name="ViT-L-14"
lr=0.0004
grad_checkpointing=True
batch_size=128
nodes=16
ngpus=8
text_encoder_model_name = 'microsoft/BiomedNLP-BiomedBERT-large-uncased-abstract'
@dataclass
class l14_400m_eval(b32_400m_eval):
model="ViT-L-14-336-quickgelu"
name="ViT-L-14"
lr=0.0004
grad_checkpointing=True
batch_size=256
nodes=16
ngpus=8
text_encoder_model_name = 'microsoft/BiomedNLP-BiomedBERT-large-uncased-abstract'
@dataclass
class l14_400m_base_text_encoder(b32_400m):
model="ViT-L-14-336-quickgelu"
name="ViT-L-14"
lr=0.0004
grad_checkpointing=True
batch_size=128
nodes=16
ngpus=8
text_encoder_model_name = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'
@dataclass
class l14_400m_base_text_encoder_eval(b32_400m_eval):
model="ViT-L-14-336-quickgelu"
name="ViT-L-14"
lr=0.0004
grad_checkpointing=True
batch_size=256
nodes=16
ngpus=8
text_encode_model_name = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'
if __name__ == "__main__":
import inspect
import sys
for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isfunction(obj):
print(name)
|