distildire / utils /trainer.py
Yewon Lim
first
424919d
import os
import torch
import torch.nn as nn
import torchvision.models as TVM
from collections import OrderedDict
from sklearn.metrics import accuracy_score, average_precision_score, precision_score
from torchvision.io import decode_jpeg, encode_jpeg
import torch.distributed as dist
from tqdm.auto import tqdm
import numpy as np
from utils.config import CONFIGCLASS
from networks.distill_model import DistilDIRE, DIRE, DistilDIREOnlyEPS
from utils.warmup import GradualWarmupScheduler
import os.path as osp
from guided_diffusion.compute_dire_eps import dire_get_first_step_noise, create_dicts_for_static_init
from guided_diffusion.guided_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
add_dict_to_argparser,
dict_parse,
args_to_dict,
)
class BaseModel(nn.Module):
def __init__(self, cfg: CONFIGCLASS):
super().__init__()
self.cfg = cfg
self.total_steps = 0
self.isTrain = cfg.isTrain
self.save_dir = cfg.ckpt_dir
self.nepoch = cfg.nepoch
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.student: nn.Module
self.optimizer: torch.optim.Optimizer
def save_networks(self, epoch: int):
save_filename = f"model_epoch_{epoch}.pth"
save_path = os.path.join(self.save_dir, save_filename)
# serialize model and optimizer to dict
state_dict = {
"model": self.student.state_dict(),
"optimizer": self.optimizer.state_dict(),
"total_steps": self.total_steps,
}
torch.save(state_dict, save_path)
# load models from the disk
def load_networks(self, epoch: int):
load_filename = f"model_epoch_{epoch}.pth"
load_path = os.path.join(self.save_dir, load_filename)
print(f"loading the model from {load_path}")
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=self.device)["model"]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
if hasattr(state_dict, "_metadata"):
del state_dict._metadata
self.student.load_state_dict(state_dict)
self.total_steps = state_dict["total_steps"]
if self.isTrain and not self.cfg.new_optim:
self.optimizer.load_state_dict(state_dict["optimizer"])
# move optimizer state to GPU
for state in self.optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to(self.device)
for g in self.optimizer.param_groups:
g["lr"] = self.cfg.lr
def eval(self):
self.student.eval()
def test(self):
with torch.no_grad():
self.forward()
class Trainer(BaseModel):
def name(self):
return "DistilDIRE Trainer"
def __init__(self, cfg: CONFIGCLASS, train_loader, val_loader, run, rank=0, distributed=True, world_size=1, kd=True):
super().__init__(cfg)
self.arch = cfg.arch
self.reproduce_dire = cfg.reproduce_dire
self.only_eps = cfg.only_eps
self.only_img = cfg.only_img
self.test_name = osp.basename(cfg.dataset_test_root)
self.rank = rank
self.device = torch.device(f"cuda")
self.distributed = distributed
self.world_size = world_size
self.kd = kd
self.kd_weight = cfg.kd_weight
self.train_loader = train_loader
self.val_loader = val_loader
self.val_every = cfg.val_every
self.cur_epoch = 0
# wandb logger (pass if None)
self.run = run
self.adm = None
if self.reproduce_dire:
self.student = DIRE(self.device).to(self.device)
else:
if self.only_eps or self.only_img:
self.student = DistilDIREOnlyEPS(self.device).to(self.device)
if self.only_img:
adm_args = create_dicts_for_static_init()
adm_args['timestep_respacing'] = 'ddim20'
adm_model, diffusion = create_model_and_diffusion(**dict_parse(adm_args, model_and_diffusion_defaults().keys()))
adm_model.load_state_dict(torch.load(adm_args['model_path'], map_location="cpu"))
print("ADM model loaded...")
self.adm = adm_model
self.adm.convert_to_fp16()
self.adm.to(self.device)
self.adm.eval()
self.diffusion = diffusion
self.adm_args = adm_args
else:
self.student = DistilDIRE(self.device).to(self.device)
# self.student.convert_to_fp16_student()
__backbone = TVM.resnet50(weights=TVM.ResNet50_Weights.DEFAULT)
self.teacher = nn.Sequential(OrderedDict([*(list(__backbone.named_children())[:-2])])) # drop last layer which is classifier
self.teacher.eval().to(self.device)
# Freeze teacher model
for param in self.teacher.parameters():
param.requires_grad = False
self.kd_criterion = nn.MSELoss(reduction='mean')
self.cls_criterion = nn.BCEWithLogitsLoss(reduction='mean', pos_weight=torch.tensor(0.3))
# initialize optimizers
if cfg.optim == "adam":
self.optimizer = torch.optim.Adam(self.student.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))
elif cfg.optim == "sgd":
self.optimizer = torch.optim.SGD(self.student.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=5e-4)
else:
raise ValueError("optim should be [adam, sgd]")
if self.distributed:
from torch.nn.parallel import DistributedDataParallel as DDP
self.student = DDP(self.student, device_ids=[rank], output_device=rank)
if cfg.warmup:
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, cfg.nepoch - cfg.warmup_epoch, eta_min=1e-6
)
self.scheduler = GradualWarmupScheduler(
self.optimizer, multiplier=1, total_epoch=cfg.warmup_epoch, after_scheduler=scheduler_cosine
)
self.scheduler.step()
def adjust_learning_rate(self, min_lr=1e-6):
for param_group in self.optimizer.param_groups:
param_group["lr"] /= 10.0
if param_group["lr"] < min_lr:
return False
return True
def set_input(self, input, istrain=True):
if self.reproduce_dire:
dire, label = input
self.dire = dire.to(self.device)
self.label = label.to(self.device).float()
else:
img, dire, eps, label = input # if len(input) == 3 else (input[0], input[1], {})
H, W = img.shape[-2:]
B = img.shape[0]
# random jpeg compression
if istrain:
comp_quality = torch.randint(10, 100, (B,))
img = (img+1)/2
img = (img*255).to(torch.uint8)
img = torch.stack([decode_jpeg(encode_jpeg(img[i], quality=int(comp_quality[i]))) for i in range(B)], dim=0)
img = img / 255.
img = (img*2)-1
# only-img
if self.only_img:
img = img.to(self.device)
# calc eps from img
eps = dire_get_first_step_noise(img, self.adm, self.diffusion, self.adm_args, self.device)
# cutmix
# if torch.rand(1) < 0.3 and istrain:
# c_lambda = torch.rand(1)
# r_x = torch.randint(0, W, (1,))
# r_y = torch.randint(0, H, (1,))
# r_w = int(torch.sqrt(1-c_lambda)*W)
# r_h = int(torch.sqrt(1-c_lambda)*H)
# img[:, :, r_y:r_y+r_h, r_x:r_x+r_w] = img[0:1, :, r_y:r_y+r_h, r_x:r_x+r_w].repeat(B, 1, 1, 1)
# dire[:, :, r_y:r_y+r_h, r_x:r_x+r_w] = dire[0:1, :, r_y:r_y+r_h, r_x:r_x+r_w].repeat(B, 1, 1, 1)
# eps[:, :, r_y:r_y+r_h, r_x:r_x+r_w] = eps[0:1, :, r_y:r_y+r_h, r_x:r_x+r_w].repeat(B, 1, 1, 1)
# label = c_lambda * label + (1-c_lambda) * label[0:1]
self.input = img.to(self.device)
self.dire = dire.to(self.device)
self.eps = eps.to(self.device)
self.label = label.to(self.device).float()
def forward(self):
if self.reproduce_dire:
self.output = self.student(self.dire)
else:
if self.only_eps or self.only_img:
self.output = self.student(self.eps)
else:
self.output = self.student(self.input, self.eps)
with torch.no_grad():
self.teacher_feature = self.teacher(self.dire)
def get_loss(self, kd=True):
loss = self.cls_criterion(self.output['logit'].squeeze(), self.label)
if kd and (not self.reproduce_dire):
loss2 = self.kd_criterion(self.output['feature'], self.teacher_feature)
loss = loss + loss2 * self.kd_weight
return loss
def optimize_parameters(self):
self.optimizer.zero_grad()
self.forward()
self.loss = self.get_loss(self.kd)
self.loss.backward()
self.optimizer.step()
def load_networks(self, model_path):
state_dict = torch.load(model_path, map_location=self.device)
model_state_dict = state_dict["model"]
optimizer_state_dict = state_dict["optimizer"]
model_state_dict = {k.replace("module.", ""): v for k, v in model_state_dict.items()}
optimizer_state_dict = {k.replace("module.", ""): v for k, v in optimizer_state_dict.items()}
if self.distributed:
self.student.module.load_state_dict(model_state_dict)
else:
self.student.load_state_dict(model_state_dict)
self.optimizer.load_state_dict(optimizer_state_dict)
print(f"Model loaded from {model_path}")
return True
@torch.no_grad()
def validate(self, gather=False, save=False, save_name=""):
self.student.eval()
y_pred = []
y_true = []
N_FAKE, N_REAL = 0, 0
for data, path in tqdm(self.val_loader, desc=f"Validation after {self.cur_epoch} epoch..."):
self.set_input(data, istrain=False)
self.forward()
pred = self.output['logit'].sigmoid()
if gather:
try:
dist
except:
import torch.distributed as dist
pred_gather = [pred for _ in range(self.world_size)]
label_gather = [self.label for _ in range(self.world_size)]
dist.all_gather(pred_gather, pred)
dist.all_gather(label_gather, self.label)
else:
pred_gather = [pred]
label_gather = [self.label]
N_FAKE += sum([(label == 1).sum().item() for label in label_gather])
N_REAL += sum([(label == 0).sum().item() for label in label_gather])
y_pred.extend(torch.cat(pred_gather).flatten().detach().cpu().tolist())
y_true.extend(torch.cat(label_gather).flatten().detach().cpu().tolist())
y_true, y_pred = np.array(y_true), np.array(y_pred)
acc = accuracy_score(y_true, y_pred > 0.5)
ap = average_precision_score(y_true, y_pred)
precision = precision_score(y_true, y_pred > 0.5)
if self.run:
self.run.log({"val_acc": acc, "val_ap": ap , "val_precision": precision})
self.run.log({"N_FAKE": N_FAKE, "N_REAL": N_REAL})
print(f"Validation: acc: {acc}, ap: {ap}, precision: {precision}")
print(f"N_FAKE: {N_FAKE}, N_REAL: {N_REAL}")
if save:
with open(save_name, "w") as f:
f.write(f"Validation: acc: {acc}, ap: {ap}, precision: {precision}\n")
f.write(f"N_FAKE: {N_FAKE}, N_REAL: {N_REAL}")
def train(self):
for epoch in range(self.cfg.nepoch):
if self.run:
self.run.log({"epoch": epoch})
if epoch % self.val_every == 0 and epoch != 0:
self.validate(gather=self.distributed)
self.student.train()
for data_and_paths in tqdm(self.train_loader, desc=f"Trainig {epoch} epoch..."):
data, path = data_and_paths
self.total_steps += 1
self.set_input(data)
self.optimize_parameters()
if self.total_steps % 100 == 0 and self.run:
print(f"step: {self.total_steps}, loss: {self.loss}")
self.run.log({"loss": self.loss, "step": self.total_steps})
if self.run:
self.save_networks(epoch)
if self.cfg.warmup:
self.scheduler.step()
if self.distributed:
dist.barrier()
self.cur_epoch = epoch