File size: 1,414 Bytes
65abdbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from torch.cuda import amp

from .base import Attacker


class FGSM(Attacker):
    def __init__(self, model, img_transform=(lambda x: x, lambda x: x), use_amp=False):
        super().__init__(model, img_transform)
        self.use_amp = use_amp

        if use_amp:
            self.scaler = amp.GradScaler()

    def set_para(self, eps=8, alpha=lambda: 8, **kwargs):
        super().set_para(eps=eps, alpha=alpha, **kwargs)

    def step(self, images, labels, loss):
        with amp.autocast(enabled=self.use_amp):
            images.requires_grad = True
            outputs = self.model(images).logits

            self.model.zero_grad()
            cost = loss(outputs, labels)

        if self.use_amp:
            self.scaler.scale(cost).backward()
        else:
            cost.backward()

        adv_images = (images + self.alpha() * images.grad.sign()).detach_()
        eta = torch.clamp(adv_images - self.ori_images, min=-self.eps, max=self.eps)
        images = self.img_transform[0](
            torch.clamp(self.img_transform[1](self.ori_images + eta), min=0, max=255).detach_())

        return images

    def attack(self, images, labels):
        # images = deepcopy(images)
        # self.ori_images = deepcopy(images)

        self.model.eval()

        images = self.forward(self, images, labels)

        self.model.zero_grad()
        self.model.train()

        return images