hushell's picture
Update README.md
522aba9

Model checkpoints for PMF

NOTE: for DINO-small, peak VRAM is about 32GB; for DINO-base, peak VRAM is about 42GB.

Meta-testing with dino_small_batch16 trained on full Meta-Dataset:

python -m torch.distributed.launch --nproc_per_node=8 --use_env test_meta_dataset.py --data-path ../../datasets/meta_dataset --dataset meta_dataset --arch dino_small_patch16 --deploy finetune --output outputs/md_full_dinosmall --resume md_full_128x128_dinosmall_fp16_lr5e-5/best.pth --dist-eval --ada_steps 100 --ada_lr 0.0001

Meta-testing with dino_small_batch16 trained on ImageNet domain of Meta-Dataset:

python -m torch.distributed.launch --nproc_per_node=8 --use_env test_meta_dataset.py --data-path ../../datasets/meta_dataset --dataset meta_dataset --arch dino_small_patch16 --deploy finetune --output outputs/md_inet_dinosmall_6gpus --resume pmf_metadataset_dino/md_inet_128x128_dinosmall_fp16_lr5e-5/best.pth --dist-eval --ada_steps 100 --ada_lr 0.0001

Results

The validated meta-test learning rate using 5 episodes for each domain is shown in the bracket.

Method ILSVRC (test) Omniglot Aircraft Birds Textures QuickDraw Fungi VGG Flower Traffic signs MSCOCO
md_full_128x128_dinosmall_fp16_lr5e-5 73.52±0.80 (lr=0.0001) 92.17±0.57 (lr=0.0001) 89.49±0.52 (lr=0.001) 91.04±0.37 (lr=0.0001) 85.73±0.62 (lr=0.001) 79.43±0.67 (lr=0.0001) 74.99±0.94 (lr=0) 95.30±0.44 (lr=0.001) 89.85±0.76 (lr=0.01) 59.69±1.02 (lr=0.001)
md_inet_128x128_dinosmall_fp16_lr2e-4 75.51±0.72 (lr=0.001) 82.81±1.10 (lr=0.01) 78.38±1.09 (lr=0.01) 85.18±0.77 (lr=0.001) 86.95±0.60 (lr=0.001) 74.47±0.83 (lr=0.01) 55.16±1.09 (lr=0) 94.66±0.48 (lr=0) 90.04±0.81 (lr=0.01) 62.60±0.96 (lr=0.001)