Spaces:
Running
Running
from copy import deepcopy | |
import torch | |
from torch import nn | |
from torch.cuda import amp | |
from tqdm import tqdm | |
from .base import Attacker, Empty | |
class PGD(Attacker): | |
def __init__(self, model, img_transform=(lambda x: x, lambda x: x), use_amp=False): | |
super().__init__(model, img_transform) | |
self.use_amp = use_amp | |
self.call_back = None | |
self.img_loader = None | |
self.img_hook = None | |
self.scaler = amp.GradScaler(enabled=use_amp) | |
def set_para(self, eps=8, alpha=lambda: 8, iters=20, **kwargs): | |
super().set_para(eps=eps, alpha=alpha, iters=iters, **kwargs) | |
def set_call_back(self, call_back): | |
self.call_back = call_back | |
def set_img_loader(self, img_loader): | |
self.img_loader = img_loader | |
def step(self, images, labels, loss): | |
with amp.autocast(enabled=self.use_amp): | |
images.requires_grad = True | |
outputs = self.model(images).logits | |
self.model.zero_grad() | |
cost = loss(outputs, | |
labels) # +outputs[2].view(-1)[0]*0+outputs[1].view(-1)[0]*0+outputs[0].view(-1)[0]*0 #support DDP | |
self.scaler.scale(cost).backward() | |
adv_images = (images + self.alpha() * images.grad.sign()).detach_() | |
eta = torch.clamp(adv_images - self.ori_images, min=-self.eps, max=self.eps) | |
images = self.img_transform[0]( | |
torch.clamp(self.img_transform[1](self.ori_images + eta), min=0, max=1).detach_()) | |
return images | |
def set_data(self, images, labels): | |
self.ori_images = deepcopy(images) | |
self.images = images | |
self.labels = labels | |
def __iter__(self): | |
self.atk_step = 0 | |
return self | |
def __next__(self): | |
self.atk_step += 1 | |
if self.atk_step > self.iters: | |
raise StopIteration | |
with self.model.no_sync() if isinstance(self.model, nn.parallel.DistributedDataParallel) else Empty(): | |
self.model.eval() | |
self.images = self.forward(self, self.images, self.labels) | |
self.model.zero_grad() | |
self.model.train() | |
return self.ori_images, self.images.detach(), self.labels | |
def attack(self, images, labels, step_func=None): | |
# images = deepcopy(images) | |
self.ori_images = deepcopy(images) | |
for i in tqdm(range(self.iters)): | |
self.model.eval() | |
images = self.forward(self, images, labels) | |
self.model.zero_grad() | |
self.model.train() | |
if self.call_back: | |
self.call_back(self.ori_images, images.detach(), labels) | |
if self.img_hook is not None: | |
images = self.img_hook(self.ori_images, images.detach()) | |
if step_func: | |
step_func(i + 1) | |
return images | |