Spaces:
Runtime error
Runtime error
work on trainin and dashboard statistics
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- app.py +254 -57
- data_mnist +1 -0
- mnist_c/brightness/test_images.npy +3 -0
- mnist_c/brightness/test_labels.npy +3 -0
- mnist_c/brightness/train_images.npy +3 -0
- mnist_c/brightness/train_labels.npy +3 -0
- mnist_c/canny_edges/test_images.npy +3 -0
- mnist_c/canny_edges/test_labels.npy +3 -0
- mnist_c/canny_edges/train_images.npy +3 -0
- mnist_c/canny_edges/train_labels.npy +3 -0
- mnist_c/dotted_line/test_images.npy +3 -0
- mnist_c/dotted_line/test_labels.npy +3 -0
- mnist_c/dotted_line/train_images.npy +3 -0
- mnist_c/dotted_line/train_labels.npy +3 -0
- mnist_c/fog/test_images.npy +3 -0
- mnist_c/fog/test_labels.npy +3 -0
- mnist_c/fog/train_images.npy +3 -0
- mnist_c/fog/train_labels.npy +3 -0
- mnist_c/glass_blur/test_images.npy +3 -0
- mnist_c/glass_blur/test_labels.npy +3 -0
- mnist_c/glass_blur/train_images.npy +3 -0
- mnist_c/glass_blur/train_labels.npy +3 -0
- mnist_c/identity/test_images.npy +3 -0
- mnist_c/identity/test_labels.npy +3 -0
- mnist_c/identity/train_images.npy +3 -0
- mnist_c/identity/train_labels.npy +3 -0
- mnist_c/impulse_noise/test_images.npy +3 -0
- mnist_c/impulse_noise/test_labels.npy +3 -0
- mnist_c/impulse_noise/train_images.npy +3 -0
- mnist_c/impulse_noise/train_labels.npy +3 -0
- mnist_c/motion_blur/test_images.npy +3 -0
- mnist_c/motion_blur/test_labels.npy +3 -0
- mnist_c/motion_blur/train_images.npy +3 -0
- mnist_c/motion_blur/train_labels.npy +3 -0
- mnist_c/rotate/test_images.npy +3 -0
- mnist_c/rotate/test_labels.npy +3 -0
- mnist_c/rotate/train_images.npy +3 -0
- mnist_c/rotate/train_labels.npy +3 -0
- mnist_c/scale/test_images.npy +3 -0
- mnist_c/scale/test_labels.npy +3 -0
- mnist_c/scale/train_images.npy +3 -0
- mnist_c/scale/train_labels.npy +3 -0
- mnist_c/shear/test_images.npy +3 -0
- mnist_c/shear/test_labels.npy +3 -0
- mnist_c/shear/train_images.npy +3 -0
- mnist_c/shear/train_labels.npy +3 -0
- mnist_c/shot_noise/test_images.npy +3 -0
- mnist_c/shot_noise/test_labels.npy +3 -0
- mnist_c/shot_noise/train_images.npy +3 -0
.gitattributes
CHANGED
@@ -26,3 +26,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
*ubyte* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
*ubyte* filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -2,51 +2,163 @@ import os
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
import torchvision
|
|
|
5 |
from utils import *
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
8 |
import torch.optim as optim
|
9 |
from huggingface_hub import Repository, upload_file
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
|
13 |
-
|
14 |
-
|
|
|
15 |
batch_size_test = 1000
|
16 |
learning_rate = 0.01
|
17 |
momentum = 0.5
|
18 |
log_interval = 10
|
19 |
random_seed = 1
|
20 |
-
|
|
|
|
|
21 |
REPOSITORY_DIR = "data"
|
22 |
LOCAL_DIR = 'data_local'
|
23 |
os.makedirs(LOCAL_DIR,exist_ok=True)
|
24 |
|
25 |
|
|
|
|
|
26 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
27 |
|
28 |
HF_DATASET ="mnist-adversarial-dataset"
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
torch.backends.cudnn.enabled = False
|
31 |
torch.manual_seed(random_seed)
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
torchvision.transforms.ToTensor(),
|
37 |
torchvision.transforms.Normalize(
|
38 |
(0.1307,), (0.3081,))
|
39 |
-
])
|
|
|
|
|
|
|
|
|
|
|
40 |
batch_size=batch_size_train, shuffle=True)
|
|
|
41 |
|
42 |
-
test_loader = torch.utils.data.DataLoader(
|
43 |
-
|
44 |
-
transform=torchvision.transforms.Compose([
|
45 |
-
torchvision.transforms.ToTensor(),
|
46 |
-
torchvision.transforms.Normalize(
|
47 |
-
(0.1307,), (0.3081,))
|
48 |
-
])),
|
49 |
-
batch_size=batch_size_test, shuffle=True)
|
50 |
|
51 |
|
52 |
# Source: https://nextjournal.com/gkoehler/pytorch-mnist
|
@@ -69,7 +181,7 @@ class MNIST_Model(nn.Module):
|
|
69 |
return F.log_softmax(x)
|
70 |
|
71 |
|
72 |
-
def train(epochs,network,optimizer):
|
73 |
|
74 |
train_losses=[]
|
75 |
network.train()
|
@@ -102,10 +214,11 @@ def test():
|
|
102 |
correct += pred.eq(target.data.view_as(pred)).sum()
|
103 |
test_loss /= len(test_loader.dataset)
|
104 |
test_losses.append(test_loss)
|
|
|
|
|
105 |
test_metric = '〽Current test metric - Avg. loss: `{:.4f}`, Accuracy: `{}/{}` (`{:.0f}%`)\n'.format(
|
106 |
-
test_loss, correct, len(test_loader.dataset),
|
107 |
-
|
108 |
-
return test_metric
|
109 |
|
110 |
|
111 |
|
@@ -156,15 +269,41 @@ def image_classifier(inp):
|
|
156 |
|
157 |
def train_and_test():
|
158 |
# Train for one epoch and test
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
metadata_name = get_unique_name()
|
169 |
SAVE_FILE_DIR = os.path.join(LOCAL_DIR,metadata_name)
|
170 |
os.makedirs(SAVE_FILE_DIR,exist_ok=True)
|
@@ -182,7 +321,7 @@ def flag(input_image,correct_result,train):
|
|
182 |
|
183 |
dump_json(metadata,json_file_path)
|
184 |
|
185 |
-
# Simply upload the
|
186 |
# Upload the image
|
187 |
repo_image_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'image.png'))
|
188 |
|
@@ -201,43 +340,84 @@ def flag(input_image,correct_result,train):
|
|
201 |
repo_type='dataset',
|
202 |
token=HF_TOKEN
|
203 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
-
|
206 |
-
train=True
|
207 |
-
if train:
|
208 |
-
output = f'<div> ✔ Successfully saved to flagged dataset. Training the model on adversarial data! </div>'
|
209 |
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
|
|
|
|
212 |
|
|
|
|
|
|
|
|
|
213 |
|
214 |
-
|
215 |
-
TITLE = "# MNIST Adversarial: Try to fool this MNIST model"
|
216 |
-
description = """This project is about dynamic adversarial data collection (DADC).
|
217 |
-
The basic idea is to collect “adversarial data” - the kind of data that is difficult for a model to predict correctly.
|
218 |
-
This kind of data is presumably the most valuable for a model, so this can be helpful in low-resource settings where data is hard to collect and label.
|
219 |
-
|
220 |
-
### What to do:
|
221 |
-
- Draw a number from 0-9.
|
222 |
-
- Click `Submit` and see the model's prediciton.
|
223 |
-
- If the model misclassifies it, Flag that example.
|
224 |
-
- This will add your (adversarial) example to a dataset on which the model will be trained later.
|
225 |
-
"""
|
226 |
|
227 |
-
MODEL_IS_WRONG = """
|
228 |
-
---
|
229 |
|
230 |
-
|
231 |
-
|
|
|
232 |
#block = gr.Blocks(css=BLOCK_CSS)
|
233 |
block = gr.Blocks()
|
234 |
|
235 |
with block:
|
236 |
gr.Markdown(TITLE)
|
|
|
237 |
|
238 |
with gr.Tabs():
|
239 |
-
gr.Markdown(description)
|
240 |
with gr.TabItem('MNIST'):
|
|
|
|
|
241 |
with gr.Row():
|
242 |
|
243 |
|
@@ -249,15 +429,32 @@ def main():
|
|
249 |
number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?")
|
250 |
|
251 |
flag_btn = gr.Button("Flag")
|
|
|
252 |
output_result = gr.outputs.HTML()
|
253 |
-
|
|
|
|
|
254 |
submit.click(image_classifier,inputs = [image_input],outputs=[label_output])
|
255 |
-
flag_btn.click(flag,inputs=[image_input,number_dropdown,
|
256 |
-
if to_train.value:
|
257 |
-
import pdb;pdb.set_trace()
|
258 |
-
train_and_test()
|
259 |
|
260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
block.launch()
|
262 |
|
263 |
|
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
import torchvision
|
5 |
+
from PIL import Image
|
6 |
from utils import *
|
7 |
import torch.nn as nn
|
8 |
import torch.nn.functional as F
|
9 |
import torch.optim as optim
|
10 |
from huggingface_hub import Repository, upload_file
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
import numpy as np
|
13 |
+
from collections import Counter
|
14 |
|
15 |
|
16 |
|
17 |
+
|
18 |
+
n_epochs = 10
|
19 |
+
batch_size_train = 128
|
20 |
batch_size_test = 1000
|
21 |
learning_rate = 0.01
|
22 |
momentum = 0.5
|
23 |
log_interval = 10
|
24 |
random_seed = 1
|
25 |
+
TRAIN_CUTOFF = 5
|
26 |
+
WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
|
27 |
+
METRIC_PATH = './metrics.json'
|
28 |
REPOSITORY_DIR = "data"
|
29 |
LOCAL_DIR = 'data_local'
|
30 |
os.makedirs(LOCAL_DIR,exist_ok=True)
|
31 |
|
32 |
|
33 |
+
|
34 |
+
|
35 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
36 |
|
37 |
HF_DATASET ="mnist-adversarial-dataset"
|
38 |
+
DATASET_REPO_URL = f"https://huggingface.co/datasets/chrisjay/{HF_DATASET}"
|
39 |
+
repo = Repository(
|
40 |
+
local_dir="data_mnist", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
|
41 |
+
)
|
42 |
+
repo.git_pull()
|
43 |
|
44 |
torch.backends.cudnn.enabled = False
|
45 |
torch.manual_seed(random_seed)
|
46 |
|
47 |
+
|
48 |
+
class MNISTAdversarial_Dataset(Dataset):
|
49 |
+
|
50 |
+
def __init__(self,data_dir,transform):
|
51 |
+
repo.git_pull()
|
52 |
+
self.data_dir = os.path.join(data_dir,'data')
|
53 |
+
self.transform = transform
|
54 |
+
files = [f.name for f in os.scandir(self.data_dir)]
|
55 |
+
self.images = []
|
56 |
+
self.numbers = []
|
57 |
+
for f in files:
|
58 |
+
self.FOLDER = os.path.join(os.path.join(self.data_dir,f))
|
59 |
+
|
60 |
+
metadata_path = os.path.join(self.FOLDER,'metadata.jsonl')
|
61 |
+
|
62 |
+
image_path =os.path.join(self.FOLDER,'image.png')
|
63 |
+
if os.path.exists(image_path) and os.path.exists(metadata_path):
|
64 |
+
img = Image.open(image_path)
|
65 |
+
self.images.append(img)
|
66 |
+
metadata = read_json_lines(metadata_path)
|
67 |
+
self.numbers.append(metadata[0]['correct_number'])
|
68 |
+
assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers."
|
69 |
+
def __len__(self):
|
70 |
+
return len(self.images)
|
71 |
+
|
72 |
+
def __getitem__(self,idx):
|
73 |
+
img, label = self.images[idx], self.numbers[idx]
|
74 |
+
img = self.transform(img)
|
75 |
+
return img, label
|
76 |
+
|
77 |
+
class MNISTCorrupted_By_Digit(Dataset):
|
78 |
+
def __init__(self,transform,digit,limit=30):
|
79 |
+
self.transform = transform
|
80 |
+
self.digit = digit
|
81 |
+
corrupted_dir="./mnist_c"
|
82 |
+
files = [f.name for f in os.scandir(corrupted_dir)]
|
83 |
+
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy')) for f in files]
|
84 |
+
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy')) for f in files]
|
85 |
+
self.data = np.vstack(images)
|
86 |
+
self.labels = np.hstack(labels)
|
87 |
+
|
88 |
+
assert (self.data.shape[0] == self.labels.shape[0])
|
89 |
+
|
90 |
+
mask = self.labels == self.digit
|
91 |
+
|
92 |
+
data_masked = self.data[mask]
|
93 |
+
# Just to be on the safe side, ensure limit is more than the minimum
|
94 |
+
limit = min(limit,data_masked.shape[0])
|
95 |
+
|
96 |
+
self.data_for_use = data_masked[:limit]
|
97 |
+
self.labels_for_use = self.labels[mask][:limit]
|
98 |
+
assert (self.data_for_use.shape[0] == self.labels_for_use.shape[0])
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return len(self.data_for_use)
|
102 |
+
def __getitem__(self,idx):
|
103 |
+
if torch.is_tensor(idx):
|
104 |
+
idx = idx.tolist()
|
105 |
+
|
106 |
+
image = self.data_for_use[idx]
|
107 |
+
label = self.labels_for_use[idx]
|
108 |
+
if self.transform:
|
109 |
+
image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms
|
110 |
+
image = self.transform(image_pil)
|
111 |
+
|
112 |
+
return image, label
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
class MNISTCorrupted(Dataset):
|
119 |
+
def __init__(self,transform):
|
120 |
+
self.transform = transform
|
121 |
+
corrupted_dir="./mnist_c"
|
122 |
+
files = [f.name for f in os.scandir(corrupted_dir)]
|
123 |
+
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy')) for f in files]
|
124 |
+
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy')) for f in files]
|
125 |
+
self.data = np.vstack(images)
|
126 |
+
self.labels = np.hstack(labels)
|
127 |
+
|
128 |
+
assert (self.data.shape[0] == self.labels.shape[0])
|
129 |
+
|
130 |
+
def __len__(self):
|
131 |
+
return len(self.data)
|
132 |
+
|
133 |
+
def __getitem__(self, idx):
|
134 |
+
if torch.is_tensor(idx):
|
135 |
+
idx = idx.tolist()
|
136 |
+
|
137 |
+
image = self.data[idx]
|
138 |
+
label = self.labels[idx]
|
139 |
+
if self.transform:
|
140 |
+
image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms
|
141 |
+
image = self.transform(image_pil)
|
142 |
+
|
143 |
+
return image, label
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
TRAIN_TRANSFORM = torchvision.transforms.Compose([
|
148 |
torchvision.transforms.ToTensor(),
|
149 |
torchvision.transforms.Normalize(
|
150 |
(0.1307,), (0.3081,))
|
151 |
+
])
|
152 |
+
|
153 |
+
'''
|
154 |
+
train_loader = torch.utils.data.DataLoader(
|
155 |
+
torchvision.datasets.MNIST('files/', train=True, download=True,
|
156 |
+
transform=TRAIN_TRANSFORM),
|
157 |
batch_size=batch_size_train, shuffle=True)
|
158 |
+
'''
|
159 |
|
160 |
+
test_loader = torch.utils.data.DataLoader(MNISTCorrupted(TRAIN_TRANSFORM),
|
161 |
+
batch_size=batch_size_test, shuffle=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
|
164 |
# Source: https://nextjournal.com/gkoehler/pytorch-mnist
|
|
|
181 |
return F.log_softmax(x)
|
182 |
|
183 |
|
184 |
+
def train(epochs,network,optimizer,train_loader):
|
185 |
|
186 |
train_losses=[]
|
187 |
network.train()
|
|
|
214 |
correct += pred.eq(target.data.view_as(pred)).sum()
|
215 |
test_loss /= len(test_loader.dataset)
|
216 |
test_losses.append(test_loss)
|
217 |
+
acc = 100. * correct / len(test_loader.dataset)
|
218 |
+
acc = acc.item()
|
219 |
test_metric = '〽Current test metric - Avg. loss: `{:.4f}`, Accuracy: `{}/{}` (`{:.0f}%`)\n'.format(
|
220 |
+
test_loss, correct, len(test_loader.dataset),acc )
|
221 |
+
return test_metric,acc
|
|
|
222 |
|
223 |
|
224 |
|
|
|
269 |
|
270 |
def train_and_test():
|
271 |
# Train for one epoch and test
|
272 |
+
train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM)
|
273 |
+
|
274 |
+
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True
|
275 |
+
)
|
276 |
+
train(n_epochs,network,optimizer,train_loader)
|
277 |
+
test_metric,test_acc = test()
|
278 |
+
|
279 |
+
if os.path.exists(METRIC_PATH):
|
280 |
+
metric_dict = read_json(METRIC_PATH)
|
281 |
+
metric_dict['all'] = metric_dict['all'] if 'all' in metric_dict else [] + [test_acc]
|
282 |
+
else:
|
283 |
+
metric_dict={}
|
284 |
+
metric_dict['all'] = [test_acc]
|
285 |
+
|
286 |
+
for i in range(10):
|
287 |
+
data_per_digit = MNISTCorrupted_By_Digit(TRAIN_TRANSFORM,i)
|
288 |
+
dataloader_per_digit = torch.utils.data.DataLoader(data_per_digit,batch_size=len(data_per_digit), shuffle=False)
|
289 |
+
data_per_digit, label_per_digit = iter(dataloader_per_digit).next()
|
290 |
+
output = network(data_per_digit)
|
291 |
+
pred = output.data.max(1, keepdim=True)[1]
|
292 |
+
correct = pred.eq(label_per_digit.data.view_as(pred)).sum()
|
293 |
+
acc = 100. * correct / len(data_per_digit)
|
294 |
+
acc=acc.item()
|
295 |
+
if os.path.exists(METRIC_PATH):
|
296 |
+
metric_dict[str(i)].append(acc)
|
297 |
+
else:
|
298 |
+
metric_dict[str(i)] = [acc]
|
299 |
+
|
300 |
+
dump_json(thing=metric_dict,file=METRIC_PATH)
|
301 |
+
return test_metric
|
302 |
+
|
303 |
+
def flag(input_image,correct_result,adversarial_number):
|
304 |
+
|
305 |
+
adversarial_number = 0 if None else adversarial_number
|
306 |
+
|
307 |
metadata_name = get_unique_name()
|
308 |
SAVE_FILE_DIR = os.path.join(LOCAL_DIR,metadata_name)
|
309 |
os.makedirs(SAVE_FILE_DIR,exist_ok=True)
|
|
|
321 |
|
322 |
dump_json(metadata,json_file_path)
|
323 |
|
324 |
+
# Simply upload the image file and metadata using the hub's upload_file
|
325 |
# Upload the image
|
326 |
repo_image_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'image.png'))
|
327 |
|
|
|
340 |
repo_type='dataset',
|
341 |
token=HF_TOKEN
|
342 |
)
|
343 |
+
adversarial_number+=1
|
344 |
+
output = f'<div> ✔ ({adversarial_number}) Successfully saved your adversarial data. </div>'
|
345 |
+
repo.git_pull()
|
346 |
+
length_of_dataset = len([f for f in os.scandir("./data_mnist/data")])
|
347 |
+
test_metric = f"<html> {DEFAULT_TEST_METRIC} </html>"
|
348 |
+
if length_of_dataset % TRAIN_CUTOFF ==0:
|
349 |
+
test_metric_ = train_and_test()
|
350 |
+
test_metric = f"<html> {test_metric_} </html>"
|
351 |
+
output = f'<div> ✔ ({adversarial_number}) Successfully saved your adversarial data and trained the model on adversarial data! </div>'
|
352 |
+
return output,test_metric,adversarial_number
|
353 |
+
|
354 |
+
def get_number_dict(DATA_DIR):
|
355 |
+
files = [f.name for f in os.scandir(DATA_DIR)]
|
356 |
+
numbers = [read_json_lines(os.path.join(os.path.join(DATA_DIR,f),'metadata.jsonl'))[0]['correct_number'] for f in files]
|
357 |
+
numbers_count = Counter(numbers)
|
358 |
+
numbers_count_keys = list(numbers_count.keys())
|
359 |
+
numbers_count_values = [numbers_count[k] for k in numbers_count_keys]
|
360 |
+
return numbers_count_keys,numbers_count_values
|
361 |
+
|
362 |
+
|
363 |
+
|
364 |
+
def get_statistics():
|
365 |
+
model_state_dict = 'model.pth'
|
366 |
+
optimizer_state_dict = 'optmizer.pth'
|
367 |
+
|
368 |
+
if os.path.exists(model_state_dict):
|
369 |
+
network_state_dict = torch.load(model_state_dict)
|
370 |
+
network.load_state_dict(network_state_dict)
|
371 |
+
|
372 |
+
if os.path.exists(optimizer_state_dict):
|
373 |
+
optimizer_state_dict = torch.load(optimizer_state_dict)
|
374 |
+
optimizer.load_state_dict(optimizer_state_dict)
|
375 |
+
repo.git_pull()
|
376 |
+
DATA_DIR = './data_mnist/data'
|
377 |
+
numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR)
|
378 |
+
|
379 |
+
|
380 |
+
plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples over digits")
|
381 |
|
382 |
+
fig_d, ax_d = plt.subplots(figsize=(10,4),tight_layout=True)
|
|
|
|
|
|
|
383 |
|
384 |
+
if os.path.exists(METRIC_PATH):
|
385 |
+
metric_dict = read_json(METRIC_PATH)
|
386 |
+
for i in range(10):
|
387 |
+
try:
|
388 |
+
x_i = [i+1 for i in range(len(metric_dict[str(i)]))]
|
389 |
+
ax_d.plot(x_i, metric_dict[str(i)],label=str(i))
|
390 |
+
except Exception:
|
391 |
+
continue
|
392 |
+
dump_json(thing=metric_dict,file=METRIC_PATH)
|
393 |
+
else:
|
394 |
+
metric_dict={}
|
395 |
|
396 |
+
fig_d.legend()
|
397 |
+
ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step")
|
398 |
|
399 |
+
done_html = """<div style="color: green">
|
400 |
+
<p> ✅ Statistics loaded successfully!</p>
|
401 |
+
</div>
|
402 |
+
"""
|
403 |
|
404 |
+
return plt_digits,fig_d,done_html
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
|
|
|
|
|
406 |
|
407 |
+
|
408 |
+
|
409 |
+
def main():
|
410 |
#block = gr.Blocks(css=BLOCK_CSS)
|
411 |
block = gr.Blocks()
|
412 |
|
413 |
with block:
|
414 |
gr.Markdown(TITLE)
|
415 |
+
gr.Markdown(description)
|
416 |
|
417 |
with gr.Tabs():
|
|
|
418 |
with gr.TabItem('MNIST'):
|
419 |
+
gr.Markdown(WHAT_TO_DO)
|
420 |
+
test_metric = gr.outputs.HTML(DEFAULT_TEST_METRIC)
|
421 |
with gr.Row():
|
422 |
|
423 |
|
|
|
429 |
number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?")
|
430 |
|
431 |
flag_btn = gr.Button("Flag")
|
432 |
+
|
433 |
output_result = gr.outputs.HTML()
|
434 |
+
adversarial_number = gr.Variable(value=0)
|
435 |
+
|
436 |
+
|
437 |
submit.click(image_classifier,inputs = [image_input],outputs=[label_output])
|
438 |
+
flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,test_metric,adversarial_number])
|
|
|
|
|
|
|
439 |
|
440 |
+
with gr.TabItem('Dashboard') as dashboard:
|
441 |
+
notification = gr.HTML("""<div style="color: green">
|
442 |
+
<p> ⌛ Creating statistics... </p>
|
443 |
+
</div>
|
444 |
+
""")
|
445 |
+
_,numbers_count_values_ = get_number_dict('./data_mnist/data')
|
446 |
+
|
447 |
+
STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values_))
|
448 |
+
|
449 |
+
gr.Markdown(STATS_EXPLANATION_)
|
450 |
+
stat_adv_image =gr.Plot(type="matplotlib")
|
451 |
+
gr.Markdown(DASHBOARD_EXPLANATION)
|
452 |
+
test_results=gr.Plot(type="matplotlib")
|
453 |
+
|
454 |
+
dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification])
|
455 |
+
|
456 |
+
|
457 |
+
|
458 |
block.launch()
|
459 |
|
460 |
|
data_mnist
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit eb1e3cf9de597112c1da3b921ffcd07c8e4419c1
|
mnist_c/brightness/test_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0be26b927b2be43bc0b2430ccf6aa6048a5fcd8bcd087d97576972f009c1e8a9
|
3 |
+
size 7840128
|
mnist_c/brightness/test_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:996074233ccae69ee65b8551233f608adc1bce84b6b440e3531478a847958149
|
3 |
+
size 80128
|
mnist_c/brightness/train_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:466c8f7365f37d14012ec66e2d1203ef07acae08240bb1af1863f09318293255
|
3 |
+
size 47040128
|
mnist_c/brightness/train_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10226b938104d9231ead4e9a0627f17e66cd18baf31a28784f1b72147057decf
|
3 |
+
size 480128
|
mnist_c/canny_edges/test_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e86b158456cbaf0bf621092259f186a223b68e31128c18cceedb8a2f1b8baf1f
|
3 |
+
size 7840128
|
mnist_c/canny_edges/test_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:996074233ccae69ee65b8551233f608adc1bce84b6b440e3531478a847958149
|
3 |
+
size 80128
|
mnist_c/canny_edges/train_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d6454058f0025bd15f34f9d14df4fdf2cca8f47b38661da5b650bee9c77aac5
|
3 |
+
size 47040128
|
mnist_c/canny_edges/train_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10226b938104d9231ead4e9a0627f17e66cd18baf31a28784f1b72147057decf
|
3 |
+
size 480128
|
mnist_c/dotted_line/test_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:02f5a8bd11e1bed6ca5f31418c6804a4f89ab5abdaba38f880e4743294516c6b
|
3 |
+
size 7840128
|
mnist_c/dotted_line/test_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:996074233ccae69ee65b8551233f608adc1bce84b6b440e3531478a847958149
|
3 |
+
size 80128
|
mnist_c/dotted_line/train_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d2d1ad4696a96af19f6bcd07a2ea9f2fb0990910397ef41b969d26080aef23cb
|
3 |
+
size 47040128
|
mnist_c/dotted_line/train_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10226b938104d9231ead4e9a0627f17e66cd18baf31a28784f1b72147057decf
|
3 |
+
size 480128
|
mnist_c/fog/test_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f48ea89b4793c45bedda39bd27b42710f4f0cbf5bced559597c015157188488
|
3 |
+
size 7840128
|
mnist_c/fog/test_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:996074233ccae69ee65b8551233f608adc1bce84b6b440e3531478a847958149
|
3 |
+
size 80128
|
mnist_c/fog/train_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5c92a6df9c3703b3df29fdb272d829736e706150104a7f7c81f37387940120e9
|
3 |
+
size 47040128
|
mnist_c/fog/train_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10226b938104d9231ead4e9a0627f17e66cd18baf31a28784f1b72147057decf
|
3 |
+
size 480128
|
mnist_c/glass_blur/test_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:300a4829b7fe6e4cfa5ad0e134cf19136e2c23e1e6bc29b6c574848b6a493388
|
3 |
+
size 7840128
|
mnist_c/glass_blur/test_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:996074233ccae69ee65b8551233f608adc1bce84b6b440e3531478a847958149
|
3 |
+
size 80128
|
mnist_c/glass_blur/train_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:255b6832620840d2db4423212d02c903d4edf5b31b94e16ac130f2539c73275f
|
3 |
+
size 47040128
|
mnist_c/glass_blur/train_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10226b938104d9231ead4e9a0627f17e66cd18baf31a28784f1b72147057decf
|
3 |
+
size 480128
|
mnist_c/identity/test_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c544e9053023b30ffd401a9825696f84af6fb1eb822bba5fcffc6992808357c3
|
3 |
+
size 7840128
|
mnist_c/identity/test_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:996074233ccae69ee65b8551233f608adc1bce84b6b440e3531478a847958149
|
3 |
+
size 80128
|
mnist_c/identity/train_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f323590fab80abce68b33a35fb0a176bed275ab73b934621f7583bd1c7c1b1ba
|
3 |
+
size 47040128
|
mnist_c/identity/train_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10226b938104d9231ead4e9a0627f17e66cd18baf31a28784f1b72147057decf
|
3 |
+
size 480128
|
mnist_c/impulse_noise/test_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:63ad2d4c0860dca3814da6034943dda6e9fe267ba6e12136341c52f2c13ca3d5
|
3 |
+
size 7840128
|
mnist_c/impulse_noise/test_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:996074233ccae69ee65b8551233f608adc1bce84b6b440e3531478a847958149
|
3 |
+
size 80128
|
mnist_c/impulse_noise/train_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5921469548b8d051544f29b984112d0680054bba619b7a69a1479cbd082b9563
|
3 |
+
size 47040128
|
mnist_c/impulse_noise/train_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10226b938104d9231ead4e9a0627f17e66cd18baf31a28784f1b72147057decf
|
3 |
+
size 480128
|
mnist_c/motion_blur/test_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7767b48be95a63abcdf851a4baedc0491d5643a192072e4f6de0a14c75cb089e
|
3 |
+
size 7840128
|
mnist_c/motion_blur/test_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:996074233ccae69ee65b8551233f608adc1bce84b6b440e3531478a847958149
|
3 |
+
size 80128
|
mnist_c/motion_blur/train_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f11f24aa857845517b700892d5dcbeeb8c2ee3a7b0028e27a6231c9aaa4f32db
|
3 |
+
size 47040128
|
mnist_c/motion_blur/train_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10226b938104d9231ead4e9a0627f17e66cd18baf31a28784f1b72147057decf
|
3 |
+
size 480128
|
mnist_c/rotate/test_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8fb565f28d858733ec94d4e6d976d9f2cb3e6aa75deb9a7a11b0504b9ae64520
|
3 |
+
size 7840128
|
mnist_c/rotate/test_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:996074233ccae69ee65b8551233f608adc1bce84b6b440e3531478a847958149
|
3 |
+
size 80128
|
mnist_c/rotate/train_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:505bcc0035d1f36fe6560e39cbb3888ed0130d429bca92019f29b0cfe5ab8724
|
3 |
+
size 47040128
|
mnist_c/rotate/train_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10226b938104d9231ead4e9a0627f17e66cd18baf31a28784f1b72147057decf
|
3 |
+
size 480128
|
mnist_c/scale/test_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:282dfee12607c13ea76cc82985e3e5d78f36b44d56d6ad9f12a2298e20ea2200
|
3 |
+
size 7840128
|
mnist_c/scale/test_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:996074233ccae69ee65b8551233f608adc1bce84b6b440e3531478a847958149
|
3 |
+
size 80128
|
mnist_c/scale/train_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f558258555337a45d82b572a6a2ed48c3117985213937ded3fc6a738832b4f1b
|
3 |
+
size 47040128
|
mnist_c/scale/train_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10226b938104d9231ead4e9a0627f17e66cd18baf31a28784f1b72147057decf
|
3 |
+
size 480128
|
mnist_c/shear/test_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:38412620554c8d775e38d6b9d771b4825605998179e88bd004d8d5d0a7afaa7d
|
3 |
+
size 7840128
|
mnist_c/shear/test_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:996074233ccae69ee65b8551233f608adc1bce84b6b440e3531478a847958149
|
3 |
+
size 80128
|
mnist_c/shear/train_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:133295f31cf838b4905b8e05083570e8189e37839168f37e3e83831244e54579
|
3 |
+
size 47040128
|
mnist_c/shear/train_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10226b938104d9231ead4e9a0627f17e66cd18baf31a28784f1b72147057decf
|
3 |
+
size 480128
|
mnist_c/shot_noise/test_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2c7026fb754f234a1563711576266f179f47b6f5e4a8e43dfbbb182e1299d764
|
3 |
+
size 7840128
|
mnist_c/shot_noise/test_labels.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:996074233ccae69ee65b8551233f608adc1bce84b6b440e3531478a847958149
|
3 |
+
size 80128
|
mnist_c/shot_noise/train_images.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:67dc136c87501002f971be5dabd14c8e56615af61c699a7707faf4185dfbfdf6
|
3 |
+
size 47040128
|