|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Pre-training ViT on ILSVRC-2012 with GSAM in https://arxiv.org/abs/2203.08065 |
|
|
|
Run training of a B/32 model: |
|
|
|
big_vision.trainers.proj.gsam.train \ |
|
--config big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py \ |
|
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` |
|
|
|
""" |
|
|
|
import big_vision.configs.common as bvcc |
|
from big_vision.configs.common_fewshot import get_fewshot_lsr |
|
import ml_collections as mlc |
|
|
|
def get_config(arg=None): |
|
"""Config for training.""" |
|
arg = bvcc.parse_arg(arg, variant='B/32', runlocal=False) |
|
config = mlc.ConfigDict() |
|
|
|
config.dataset = 'imagenet2012' |
|
config.train_split = 'train[:99%]' |
|
config.cache_raw = not arg.runlocal |
|
config.shuffle_buffer_size = 250_000 |
|
config.num_classes = 1000 |
|
config.loss = 'sigmoid_xent' |
|
config.batch_size = 4096 |
|
config.num_epochs = 300 |
|
|
|
pp_common = ( |
|
'|value_range(-1, 1)' |
|
'|onehot(1000, key="{lbl}", key_result="labels")' |
|
'|keep("image", "labels")' |
|
) |
|
config.pp_train = ( |
|
'decode_jpeg_and_inception_crop(224)|flip_lr|' + |
|
pp_common.format(lbl='label') |
|
) |
|
pp = 'decode|resize_small(256)|central_crop(224)' + pp_common |
|
|
|
|
|
|
|
|
|
config.prefetch_to_host = 8 |
|
config.prefetch_to_device = 4 |
|
|
|
config.log_training_steps = 50 |
|
config.checkpoint_steps = 1000 |
|
|
|
|
|
config.model_name = 'vit' |
|
config.model = dict( |
|
variant=arg.variant, |
|
rep_size=False, |
|
pool_type='gap', |
|
) |
|
config.init_head_bias = -10.0 |
|
|
|
|
|
config.grad_clip_norm = 1.0 |
|
config.optax_name = 'scale_by_adam' |
|
config.optax = dict(mu_dtype='float32') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.lr = 0.003 |
|
config.wd = 0.001 |
|
config.schedule = dict( |
|
warmup_steps=10_000, |
|
decay_type='linear', |
|
linear_end=0.01, |
|
) |
|
|
|
|
|
|
|
config.gsam = dict( |
|
rho_max=0.6, |
|
rho_min=0.1, |
|
alpha=0.6, |
|
lr_max=config.get_ref('lr'), |
|
lr_min=config.schedule.get_ref('linear_end') * config.get_ref('lr'), |
|
) |
|
|
|
|
|
eval_common = dict( |
|
type='classification', |
|
dataset='imagenet2012', |
|
pp_fn=pp.format(lbl='label'), |
|
loss_name=config.loss, |
|
log_steps=2500, |
|
) |
|
config.evals = {} |
|
config.evals.train = {**eval_common, 'split': 'train[:2%]'} |
|
config.evals.minival = {**eval_common, 'split': 'train[99%:]'} |
|
config.evals.val = {**eval_common, 'split': 'validation'} |
|
config.evals.v2 = {**eval_common, 'dataset': 'imagenet_v2', 'split': 'test'} |
|
|
|
config.evals.real = {**eval_common} |
|
config.evals.real.dataset = 'imagenet2012_real' |
|
config.evals.real.split = 'validation' |
|
config.evals.real.pp_fn = pp.format(lbl='real_label') |
|
|
|
config.fewshot = get_fewshot_lsr(runlocal=arg.runlocal) |
|
config.fewshot.log_steps = 10_000 |
|
|
|
|
|
if arg.runlocal: |
|
config.shuffle_buffer_size = 10 |
|
config.batch_size = 8 |
|
config.minival.split = 'train[:16]' |
|
config.val.split = 'validation[:16]' |
|
config.real.split = 'validation[:16]' |
|
config.v2.split = 'test[:16]' |
|
|
|
return config |
|
|