|
|
|
|
|
import argparse |
|
import datetime |
|
import numpy as np |
|
import time |
|
import torch |
|
import torch.backends.cudnn as cudnn |
|
import json |
|
|
|
from pathlib import Path |
|
from utils import get_root_logger |
|
from timm.models import create_model |
|
import models_v2 |
|
import requests |
|
|
|
import utils |
|
import time |
|
import sys |
|
import datetime |
|
import os |
|
from snnet import SNNet, SNNetv2 |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
from fvcore.nn import FlopCountAnalysis |
|
|
|
from PIL import Image |
|
import gradio as gr |
|
import plotly.express as px |
|
|
|
|
|
def get_args_parser(): |
|
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) |
|
parser.add_argument('--batch-size', default=64, type=int) |
|
parser.add_argument('--epochs', default=300, type=int) |
|
parser.add_argument('--bce-loss', action='store_true') |
|
parser.add_argument('--unscale-lr', action='store_true') |
|
|
|
|
|
parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', |
|
help='Name of model to train') |
|
parser.add_argument('--input-size', default=224, type=int, help='images input size') |
|
|
|
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', |
|
help='Dropout rate (default: 0.)') |
|
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', |
|
help='Drop path rate (default: 0.1)') |
|
|
|
parser.add_argument('--model-ema', action='store_true') |
|
parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') |
|
parser.set_defaults(model_ema=True) |
|
parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') |
|
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') |
|
|
|
|
|
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', |
|
help='Optimizer (default: "adamw"') |
|
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', |
|
help='Optimizer Epsilon (default: 1e-8)') |
|
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', |
|
help='Optimizer Betas (default: None, use opt default)') |
|
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', |
|
help='Clip gradient norm (default: None, no clipping)') |
|
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', |
|
help='SGD momentum (default: 0.9)') |
|
parser.add_argument('--weight-decay', type=float, default=0.05, |
|
help='weight decay (default: 0.05)') |
|
|
|
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', |
|
help='LR scheduler (default: "cosine"') |
|
parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', |
|
help='learning rate (default: 5e-4)') |
|
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', |
|
help='learning rate noise on/off epoch percentages') |
|
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', |
|
help='learning rate noise limit percent (default: 0.67)') |
|
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', |
|
help='learning rate noise std-dev (default: 1.0)') |
|
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', |
|
help='warmup learning rate (default: 1e-6)') |
|
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', |
|
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') |
|
|
|
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', |
|
help='epoch interval to decay LR') |
|
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', |
|
help='epochs to warmup LR, if scheduler supports') |
|
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', |
|
help='epochs to cooldown LR at min_lr, after cyclic schedule ends') |
|
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', |
|
help='patience epochs for Plateau LR scheduler (default: 10') |
|
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', |
|
help='LR decay rate (default: 0.1)') |
|
|
|
|
|
parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT', |
|
help='Color jitter factor (default: 0.3)') |
|
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', |
|
help='Use AutoAugment policy. "v0" or "original". " + \ |
|
"(default: rand-m9-mstd0.5-inc1)'), |
|
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') |
|
parser.add_argument('--train-interpolation', type=str, default='bicubic', |
|
help='Training interpolation (random, bilinear, bicubic default: "bicubic")') |
|
|
|
parser.add_argument('--repeated-aug', action='store_true') |
|
parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') |
|
parser.set_defaults(repeated_aug=True) |
|
|
|
parser.add_argument('--train-mode', action='store_true') |
|
parser.add_argument('--no-train-mode', action='store_false', dest='train_mode') |
|
parser.set_defaults(train_mode=True) |
|
|
|
parser.add_argument('--ThreeAugment', action='store_true') |
|
|
|
parser.add_argument('--src', action='store_true') |
|
|
|
|
|
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', |
|
help='Random erase prob (default: 0.25)') |
|
parser.add_argument('--remode', type=str, default='pixel', |
|
help='Random erase mode (default: "pixel")') |
|
parser.add_argument('--recount', type=int, default=1, |
|
help='Random erase count (default: 1)') |
|
parser.add_argument('--resplit', action='store_true', default=False, |
|
help='Do not random erase first (clean) augmentation split') |
|
|
|
|
|
parser.add_argument('--mixup', type=float, default=0.8, |
|
help='mixup alpha, mixup enabled if > 0. (default: 0.8)') |
|
parser.add_argument('--cutmix', type=float, default=1.0, |
|
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') |
|
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, |
|
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') |
|
parser.add_argument('--mixup-prob', type=float, default=1.0, |
|
help='Probability of performing mixup or cutmix when either/both is enabled') |
|
parser.add_argument('--mixup-switch-prob', type=float, default=0.5, |
|
help='Probability of switching to cutmix when both mixup and cutmix enabled') |
|
parser.add_argument('--mixup-mode', type=str, default='batch', |
|
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') |
|
|
|
|
|
parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', |
|
help='Name of teacher model to train (default: "regnety_160"') |
|
parser.add_argument('--teacher-path', type=str, default='') |
|
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") |
|
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") |
|
parser.add_argument('--distillation-tau', default=1.0, type=float, help="") |
|
|
|
|
|
parser.add_argument('--finetune', default='', help='finetune from checkpoint') |
|
parser.add_argument('--attn-only', action='store_true') |
|
|
|
|
|
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, |
|
help='dataset path') |
|
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], |
|
type=str, help='Image Net dataset path') |
|
parser.add_argument('--inat-category', default='name', |
|
choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], |
|
type=str, help='semantic granularity') |
|
|
|
parser.add_argument('--output_dir', default='', |
|
help='path where to save, empty for no saving') |
|
parser.add_argument('--device', default='cpu', |
|
help='device to use for training / testing') |
|
parser.add_argument('--seed', default=0, type=int) |
|
parser.add_argument('--resume', default='', help='resume from checkpoint') |
|
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', |
|
help='start epoch') |
|
parser.add_argument('--eval', action='store_true', help='Perform evaluation only') |
|
parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation") |
|
parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') |
|
parser.add_argument('--num_workers', default=10, type=int) |
|
parser.add_argument('--pin-mem', action='store_true', |
|
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') |
|
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', |
|
help='') |
|
parser.set_defaults(pin_mem=True) |
|
|
|
|
|
parser.add_argument('--world_size', default=1, type=int, |
|
help='number of distributed processes') |
|
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') |
|
|
|
parser.add_argument('--exp_name', default='deit', type=str, help='experiment name') |
|
parser.add_argument('--config', default=None, type=str, help='configuration') |
|
parser.add_argument('--scoring', action='store_true', default=False, help='configuration') |
|
parser.add_argument('--proxy', default='synflow', type=str, help='configuration') |
|
parser.add_argument('--snnet_name', default='snnetv2', type=str, help='configuration') |
|
parser.add_argument('--get_flops', action='store_true') |
|
parser.add_argument('--flops_sampling_k', default=None, type=float, help="Crop ratio for evaluation") |
|
parser.add_argument('--low_rank', action='store_true', default=False, help='Enabling distributed evaluation') |
|
parser.add_argument('--lora_r', default=64, type=int, |
|
help='number of distributed processes') |
|
parser.add_argument('--flops_gap', default=1.0, type=float, |
|
help='number of distributed processes') |
|
|
|
return parser |
|
|
|
|
|
def initialize_model_stitching_layer(model, mixup_fn, data_loader, device): |
|
for samples, targets in data_loader: |
|
samples = samples.to(device, non_blocking=True) |
|
targets = targets.to(device, non_blocking=True) |
|
|
|
if mixup_fn is not None: |
|
samples, targets = mixup_fn(samples, targets) |
|
|
|
with torch.cuda.amp.autocast(): |
|
model.initialize_stitching_weights(samples) |
|
|
|
break |
|
|
|
|
|
@torch.no_grad() |
|
def analyse_flops_for_all(model, config_name): |
|
all_cfgs = model.all_cfgs |
|
stitch_results = {} |
|
|
|
for cfg_id in all_cfgs: |
|
model.reset_stitch_id(cfg_id) |
|
flops = FlopCountAnalysis(model, torch.randn(1, 3, 224, 224).cuda()).total() |
|
stitch_results[cfg_id] = flops |
|
|
|
save_dir = './model_flops' |
|
if not os.path.exists(save_dir): |
|
os.mkdir(save_dir) |
|
|
|
with open(os.path.join(save_dir, f'flops_{config_name}.json'), 'w+') as f: |
|
json.dump(stitch_results, f, indent=4) |
|
|
|
|
|
def main(args): |
|
utils.init_distributed_mode(args) |
|
|
|
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) |
|
logger = get_root_logger(os.path.join(args.output_dir, f'{timestamp}.log')) |
|
|
|
logger.info(str(args)) |
|
|
|
if args.distillation_type != 'none' and args.finetune and not args.eval: |
|
raise NotImplementedError("Finetuning with distillation not yet supported") |
|
|
|
device = torch.device(args.device) |
|
|
|
|
|
seed = args.seed + utils.get_rank() |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
|
|
|
|
cudnn.benchmark = True |
|
|
|
from datasets import build_transform |
|
|
|
transform = build_transform(False, args) |
|
|
|
anchors = [] |
|
for i, anchor_name in enumerate(args.anchors): |
|
logger.info(f"Creating model: {anchor_name}") |
|
anchor = create_model( |
|
anchor_name, |
|
pretrained=False, |
|
pretrained_deit=None, |
|
num_classes=1000, |
|
drop_path_rate=args.anchor_drop_path[i], |
|
img_size=args.input_size |
|
) |
|
anchors.append(anchor) |
|
|
|
model = SNNetv2(anchors, lora_r=args.lora_r) |
|
|
|
checkpoint = torch.load(args.resume, map_location='cpu') |
|
|
|
|
|
logger.info(f"load checkpoint from {args.resume}") |
|
model.load_state_dict(checkpoint['model']) |
|
|
|
model.to(device) |
|
|
|
config_name = args.config.split('/')[-1].split('.')[0] |
|
model.eval() |
|
|
|
eval_res = {} |
|
flops_res = {} |
|
with open('stitches_res_s_l.txt', 'r') as f: |
|
for line in f.readlines(): |
|
epoch_stat = json.loads(line.strip()) |
|
eval_res[epoch_stat['cfg_id']] = epoch_stat['acc1'] |
|
flops_res[epoch_stat['cfg_id']] = epoch_stat['flops'] / 1e9 |
|
|
|
def visualize_stitch_pos(stitch_id): |
|
if stitch_id == 13: |
|
|
|
stitch_id = 0 |
|
|
|
names = [f'ID {key}' for key in flops_res.keys()] |
|
|
|
fig = px.scatter(x=flops_res.values(), y=eval_res.values(), hover_name=names) |
|
fig.update_layout( |
|
title=f"SN-Netv2 - Stitch ID - {stitch_id}", |
|
title_x=0.5, |
|
xaxis_title="GFLOPs", |
|
yaxis_title="Top-1 Acc (%)", |
|
font=dict( |
|
family="Courier New, monospace", |
|
size=18, |
|
color="RebeccaPurple" |
|
), |
|
legend=dict( |
|
yanchor="bottom", |
|
y=0.99, |
|
xanchor="left", |
|
x=0.01), |
|
) |
|
|
|
fig.update_traces(marker=dict(size=10, |
|
line=dict(width=2)), |
|
selector=dict(mode='markers')) |
|
|
|
fig.add_scatter(x=[flops_res[stitch_id]], y=[eval_res[stitch_id]], mode='markers', marker=dict(size=15), |
|
name='Current Stitch') |
|
return fig |
|
|
|
|
|
response = requests.get("https://git.io/JJkYN") |
|
labels = response.text.split("\n") |
|
|
|
def process_image(image, stitch_id): |
|
|
|
inp = transform(image).unsqueeze(0).to(device) |
|
model.reset_stitch_id(stitch_id) |
|
with torch.no_grad(): |
|
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0) |
|
confidences = {labels[i]: float(prediction[i]) for i in range(1000)} |
|
fig = visualize_stitch_pos(stitch_id) |
|
return confidences, fig |
|
|
|
with gr.Blocks() as main_page: |
|
with gr.Column(): |
|
gr.HTML(""" |
|
<h1 align="center" style=" display: flex; flex-direction: row; justify-content: center; font-size: 25pt; ">Stitched ViTs are Flexible Vision Backbones</h1> |
|
<div align="center"> <img align="center" src='file/gradio_banner.png' width="70%"> </div> |
|
<h3 align="center" >This is the classification demo page of SN-Netv2, a flexible vision backbone that allows for 100+ runtime speed and performance trade-offs. You can also run this gradio demo on your local GPUs at <a href="https://github.com/ziplab/SN-Netv2">https://github.com/ziplab/SN-Netv2</a>, Paper link: <a href="https://arxiv.org/abs/2307.00154">https://arxiv.org/abs/2307.00154</a>.</h3> |
|
""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(type='pil') |
|
stitch_slider = gr.Slider(minimum=0, maximum=134, step=1, label="Stitch ID") |
|
with gr.Row(): |
|
clear_button = gr.ClearButton() |
|
submit_button = gr.Button() |
|
with gr.Column(): |
|
label_output = gr.Label(num_top_classes=5) |
|
stitch_plot = gr.Plot(label='Stitch Position') |
|
|
|
submit_button.click( |
|
fn=process_image, |
|
inputs=[image_input, stitch_slider], |
|
outputs=[label_output, stitch_plot], |
|
) |
|
|
|
stitch_slider.change( |
|
fn=visualize_stitch_pos, |
|
inputs=[stitch_slider], |
|
outputs=[stitch_plot], |
|
show_progress=False |
|
) |
|
|
|
clear_button.click( |
|
lambda: [None, 0, None, None], |
|
outputs=[image_input, stitch_slider, label_output, stitch_plot], |
|
) |
|
|
|
gr.Examples( |
|
[ |
|
['demo.jpg', 0], |
|
], |
|
inputs=[ |
|
image_input, |
|
stitch_slider |
|
], |
|
outputs=[ |
|
label_output, |
|
stitch_plot |
|
], |
|
) |
|
|
|
main_page.launch(allowed_paths=['./']) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) |
|
args = parser.parse_args() |
|
setattr(args, 'config', f'gradio_demo.json') |
|
if args.config is not None: |
|
config_args = json.load(open(args.config)) |
|
override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:] |
|
if arg.startswith('--')} |
|
for k, v in config_args.items(): |
|
if k not in override_keys: |
|
setattr(args, k, v) |
|
|
|
output_dir = os.path.join('outputs', args.exp_name) |
|
Path(output_dir).mkdir(parents=True, exist_ok=True) |
|
checkpoint_path = os.path.join(output_dir, 'checkpoint.pth') |
|
if os.path.exists(checkpoint_path) and not args.resume: |
|
setattr(args, 'resume', checkpoint_path) |
|
|
|
setattr(args, 'output_dir', output_dir) |
|
|
|
main(args) |
|
|