|
from functools import partial |
|
import argparse |
|
from torchvision import models |
|
import multiprocessing |
|
from datasets import DS_LIST |
|
from methods import METHOD_LIST |
|
|
|
|
|
def get_cfg(): |
|
""" generates configuration from user input in console """ |
|
parser = argparse.ArgumentParser(description="") |
|
parser.add_argument( |
|
"--method", type=str, choices=METHOD_LIST, default="w_mse", help="loss type", |
|
) |
|
parser.add_argument( |
|
"--wandb", |
|
type=str, |
|
default="ssl-sota", |
|
help="name of the project for logging at https://wandb.ai", |
|
) |
|
parser.add_argument( |
|
"--byol_tau", type=float, default=0.99, help="starting tau for byol loss" |
|
) |
|
parser.add_argument( |
|
"--num_samples", |
|
type=int, |
|
default=2, |
|
help="number of samples (d) generated from each image", |
|
) |
|
|
|
addf = partial(parser.add_argument, type=float) |
|
addf("--cj0", default=0.4, help="color jitter brightness") |
|
addf("--cj1", default=0.4, help="color jitter contrast") |
|
addf("--cj2", default=0.4, help="color jitter saturation") |
|
addf("--cj3", default=0.1, help="color jitter hue") |
|
addf("--cj_p", default=0.8, help="color jitter probability") |
|
addf("--gs_p", default=0.1, help="grayscale probability") |
|
addf("--crop_s0", default=0.2, help="crop size from") |
|
addf("--crop_s1", default=1.0, help="crop size to") |
|
addf("--crop_r0", default=0.75, help="crop ratio from") |
|
addf("--crop_r1", default=(4 / 3), help="crop ratio to") |
|
addf("--hf_p", default=0.5, help="horizontal flip probability") |
|
|
|
parser.add_argument( |
|
"--no_lr_warmup", |
|
dest="lr_warmup", |
|
action="store_false", |
|
help="do not use learning rate warmup", |
|
) |
|
parser.add_argument( |
|
"--no_add_bn", dest="add_bn", action="store_false", help="do not use BN in head" |
|
) |
|
parser.add_argument("--knn", type=int, default=5, help="k in k-nn classifier") |
|
parser.add_argument("--fname", type=str, help="load model from file") |
|
parser.add_argument( |
|
"--lr_step", |
|
type=str, |
|
choices=["cos", "step", "none"], |
|
default="step", |
|
help="learning rate schedule type", |
|
) |
|
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") |
|
parser.add_argument( |
|
"--eta_min", type=float, default=0, help="min learning rate (for --lr_step cos)" |
|
) |
|
parser.add_argument( |
|
"--adam_l2", type=float, default=1e-6, help="weight decay (L2 penalty)" |
|
) |
|
parser.add_argument("--T0", type=int, help="period (for --lr_step cos)") |
|
parser.add_argument( |
|
"--Tmult", type=int, default=1, help="period factor (for --lr_step cos)" |
|
) |
|
parser.add_argument( |
|
"--w_eps", type=float, default=1e-4, help="eps for stability for whitening" |
|
) |
|
parser.add_argument( |
|
"--head_layers", type=int, default=2, help="number of FC layers in head" |
|
) |
|
parser.add_argument( |
|
"--head_size", type=int, default=1024, help="size of FC layers in head" |
|
) |
|
|
|
parser.add_argument( |
|
"--w_size", type=int, default=128, help="size of sub-batch for W-MSE loss" |
|
) |
|
parser.add_argument( |
|
"--w_iter", |
|
type=int, |
|
default=1, |
|
help="iterations for whitening matrix estimation", |
|
) |
|
|
|
parser.add_argument( |
|
"--no_norm", dest="norm", action="store_false", help="don't normalize latents", |
|
) |
|
parser.add_argument( |
|
"--tau", type=float, default=0.5, help="contrastive loss temperature" |
|
) |
|
|
|
parser.add_argument("--epoch", type=int, default=200, help="total epoch number") |
|
parser.add_argument( |
|
"--eval_every_drop", |
|
type=int, |
|
default=5, |
|
help="how often to evaluate after learning rate drop", |
|
) |
|
parser.add_argument( |
|
"--eval_every", type=int, default=20, help="how often to evaluate" |
|
) |
|
parser.add_argument("--emb", type=int, default=64, help="embedding size") |
|
parser.add_argument( |
|
"--bs", type=int, default=384, help="number of original images in batch N", |
|
) |
|
parser.add_argument( |
|
"--drop", |
|
type=int, |
|
nargs="*", |
|
default=[50, 25], |
|
help="milestones for learning rate decay (0 = last epoch)", |
|
) |
|
parser.add_argument( |
|
"--drop_gamma", |
|
type=float, |
|
default=0.2, |
|
help="multiplicative factor of learning rate decay", |
|
) |
|
parser.add_argument( |
|
"--arch", |
|
type=str, |
|
choices=[x for x in dir(models) if "resn" in x], |
|
default="resnet18", |
|
help="encoder architecture", |
|
) |
|
parser.add_argument("--dataset", type=str, choices=DS_LIST, default="cifar10") |
|
parser.add_argument( |
|
"--num_workers", |
|
type=int, |
|
default=0, |
|
help="dataset workers number", |
|
) |
|
parser.add_argument( |
|
"--clf", |
|
type=str, |
|
default="sgd", |
|
choices=["sgd", "knn", "lbfgs"], |
|
help="classifier for test.py", |
|
) |
|
parser.add_argument( |
|
"--eval_head", action="store_true", help="eval head output instead of model", |
|
) |
|
parser.add_argument("--imagenet_path", type=str, default="~/IN100/") |
|
return parser.parse_args() |
|
|