example / app.py
Kieran Fraser
First commit.
d2635ec
raw
history blame
22.2 kB
'''
ART-JATIC Gradio Example App
To run:
- clone the repository
- execute: gradio examples/gradio_app.py or python examples/gradio_app.py
- navigate to local URL e.g. http://127.0.0.1:7860
'''
import torch
import numpy as np
import pandas as pd
from carbon_theme import Carbon
import gradio as gr
import os
import matplotlib.pyplot as plt
css = """
.input-image { margin: auto !important }
.plot-padding { padding: 20px; }
"""
def extract_predictions(predictions_, conf_thresh):
coco_labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
# Get the predicted class
predictions_class = [coco_labels[i] for i in list(predictions_["labels"])]
# print("\npredicted classes:", predictions_class)
if len(predictions_class) < 1:
return [], [], []
# Get the predicted bounding boxes
predictions_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(predictions_["boxes"])]
# Get the predicted prediction score
predictions_score = list(predictions_["scores"])
# print("predicted score:", predictions_score)
# Get a list of index with score greater than threshold
threshold = conf_thresh
predictions_t = [predictions_score.index(x) for x in predictions_score if x > threshold]
if len(predictions_t) > 0:
predictions_t = predictions_t # [-1] #indices where score over threshold
else:
# no predictions esxceeding threshold
return [], [], []
# predictions in score order
predictions_boxes = [predictions_boxes[i] for i in predictions_t]
predictions_class = [predictions_class[i] for i in predictions_t]
predictions_scores = [predictions_score[i] for i in predictions_t]
return predictions_class, predictions_boxes, predictions_scores
def plot_image_with_boxes(img, boxes, pred_cls, title):
import cv2
text_size = 1
text_th = 2
rect_th = 1
sections = []
for i in range(len(boxes)):
cv2.rectangle(img, (int(boxes[i][0][0]), int(boxes[i][0][1])), (int(boxes[i][1][0]), int(boxes[i][1][1])),
color=(0, 255, 0), thickness=rect_th)
# Write the prediction class
cv2.putText(img, pred_cls[i], (int(boxes[i][0][0]), int(boxes[i][0][1])), cv2.FONT_HERSHEY_SIMPLEX, text_size,
(0, 255, 0), thickness=text_th)
sections.append( ((int(boxes[i][0][0]),
int(boxes[i][0][1]),
int(boxes[i][1][0]),
int(boxes[i][1][1])), (pred_cls[i])) )
return img.astype(np.uint8)
def filter_boxes(predictions, conf_thresh):
dictionary = {}
boxes_list = []
scores_list = []
labels_list = []
for i in range(len(predictions[0]["boxes"])):
score = predictions[0]["scores"][i]
if score >= conf_thresh:
boxes_list.append(predictions[0]["boxes"][i])
scores_list.append(predictions[0]["scores"][[i]])
labels_list.append(predictions[0]["labels"][[i]])
dictionary["boxes"] = np.vstack(boxes_list)
dictionary["scores"] = np.hstack(scores_list)
dictionary["labels"] = np.hstack(labels_list)
y = [dictionary]
return y
def basic_cifar10_model(overfit=False):
'''
Load an example CIFAR10 model
'''
from art.estimators.classification.pytorch import PyTorchClassifier
labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
path = './'
class Model(torch.nn.Module):
"""
Create model for pytorch.
Here the model does not use maxpooling. Needed for certification tests.
"""
def __init__(self):
super(Model, self).__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=(4, 4), dilation=(1, 1), padding=(0, 0), stride=(3, 3)
)
self.fullyconnected = torch.nn.Linear(in_features=1600, out_features=10)
self.relu = torch.nn.ReLU()
w_conv2d = np.load(
os.path.join(
os.path.dirname(path),
"utils/resources/models",
"W_CONV2D_NO_MPOOL_CIFAR10.npy",
)
)
b_conv2d = np.load(
os.path.join(
os.path.dirname(path),
"utils/resources/models",
"B_CONV2D_NO_MPOOL_CIFAR10.npy",
)
)
w_dense = np.load(
os.path.join(
os.path.dirname(path),
"utils/resources/models",
"W_DENSE_NO_MPOOL_CIFAR10.npy",
)
)
b_dense = np.load(
os.path.join(
os.path.dirname(path),
"utils/resources/models",
"B_DENSE_NO_MPOOL_CIFAR10.npy",
)
)
self.conv.weight = torch.nn.Parameter(torch.Tensor(w_conv2d))
self.conv.bias = torch.nn.Parameter(torch.Tensor(b_conv2d))
self.fullyconnected.weight = torch.nn.Parameter(torch.Tensor(w_dense))
self.fullyconnected.bias = torch.nn.Parameter(torch.Tensor(b_dense))
# pylint: disable=W0221
# disable pylint because of API requirements for function
def forward(self, x):
"""
Forward function to evaluate the model
:param x: Input to the model
:return: Prediction of the model
"""
x = self.conv(x)
x = self.relu(x)
x = x.reshape(-1, 1600)
x = self.fullyconnected(x)
return x
# Define the network
model = Model()
# Define a loss function and optimizer
if overfit:
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0)
else:
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Get classifier
jptc = PyTorchClassifier(
model=model, loss=loss_fn, optimizer=optimizer, input_shape=(3, 32, 32), nb_classes=10, clip_values=(0, 1), labels=labels
)
return jptc
def det_evasion_evaluate(*args):
'''
Run a detection task evaluation
'''
def clf_evasion_evaluate(*args):
'''
Run a classification task evaluation
'''
def show_model_params(model_type):
'''
Show model parameters based on selected model type
'''
if model_type!="Example CIFAR10" and model_type!="Example XView" and model_type!="CIFAR10 Overfit":
return gr.Column(visible=True)
return gr.Column(visible=False)
def show_dataset_params(dataset_type):
'''
Show dataset parameters based on dataset type
'''
if dataset_type=="Example CIFAR10":
return [gr.Column(visible=False), gr.Row(visible=False), gr.Row(visible=False)]
elif dataset_type=="local":
return [gr.Column(visible=True), gr.Row(visible=True), gr.Row(visible=False)]
return [gr.Column(visible=True), gr.Row(visible=False), gr.Row(visible=True)]
def pgd_show_label_output(dataset_type):
'''
Show PGD output component based on dataset type
'''
if dataset_type=="local":
return [gr.Label(visible=True), gr.Label(visible=True), gr.Number(visible=False), gr.Number(visible=False), gr.Number(visible=True)]
return [gr.Label(visible=False), gr.Label(visible=False), gr.Number(visible=True), gr.Number(visible=True), gr.Number(visible=True)]
def pgd_update_epsilon(clip_values):
'''
Update max value of PGD epsilon slider based on model clip values
'''
if clip_values == 255:
return gr.Slider(minimum=0.0001, maximum=255, label="Epslion", value=55)
return gr.Slider(minimum=0.0001, maximum=1, label="Epslion", value=0.05)
def patch_show_label_output(dataset_type):
'''
Show adversarial patch output components based on dataset type
'''
if dataset_type=="local":
return [gr.Label(visible=True), gr.Label(visible=True), gr.Number(visible=False), gr.Number(visible=False), gr.Number(visible=True)]
return [gr.Label(visible=False), gr.Label(visible=False), gr.Number(visible=True), gr.Number(visible=True), gr.Number(visible=True)]
# e.g. To use a local alternative theme: carbon_theme = Carbon()
carbon_theme = Carbon()
with gr.Blocks(css=css, theme=carbon_theme) as demo:
import art
text = art.__version__
gr.Markdown(f"<h1>ART (v{text}) Gradio Example</h1>")
with gr.Tab("Info"):
gr.Markdown('This is step 1. Using the tabs, select a task for evaluation.')
with gr.Tab("Classification", elem_classes="task-tab"):
gr.Markdown("Classifying images with a set of categories.")
# Model and Dataset Selection
with gr.Row():
# Model and Dataset type e.g. Torchvision, HuggingFace, local etc.
with gr.Column():
model_type = gr.Radio(label="Model type", choices=["Example CIFAR10", "Huggingface", "torchvision"],
value="Example CIFAR10")
dataset_type = gr.Radio(label="Dataset", choices=["Example CIFAR10", "Huggingface", "local"],
value="Example CIFAR10")
# Model parameters e.g. RESNET, VIT, input dimensions, clipping values etc.
with gr.Column(visible=False) as model_params:
model_path = gr.Textbox(placeholder="URL", label="Model path")
with gr.Row():
with gr.Column():
model_channels = gr.Textbox(placeholder="Integer, 3 for RGB images", label="Input Channels", value=3)
with gr.Column():
model_width = gr.Textbox(placeholder="Integer", label="Input Width", value=640)
with gr.Row():
with gr.Column():
model_height = gr.Textbox(placeholder="Integer", label="Input Height", value=480)
with gr.Column():
model_clip = gr.Radio(choices=[1, 255], label="Pixel clip", value=1)
# Dataset parameters e.g. Torchvision, HuggingFace, local etc.
with gr.Column(visible=False) as dataset_params:
with gr.Row() as local_image:
image = gr.Image(sources=['upload'], type="pil", height=150, width=150, elem_classes="input-image")
with gr.Row() as hosted_image:
dataset_path = gr.Textbox(placeholder="URL", label="Dataset path")
dataset_split = gr.Textbox(placeholder="test", label="Dataset split")
model_type.change(show_model_params, model_type, model_params)
dataset_type.change(show_dataset_params, dataset_type, [dataset_params, local_image, hosted_image])
# Attack Selection
with gr.Row():
with gr.Tab("Info"):
gr.Markdown("This is step 2. Select the type of attack for evaluation.")
with gr.Tab("White Box"):
gr.Markdown("White box attacks assume the attacker has __full access__ to the model.")
with gr.Tab("Info"):
gr.Markdown("This is step 3. Select the type of white-box attack to evaluate.")
with gr.Tab("Evasion"):
gr.Markdown("Evasion attacks are deployed to cause a model to incorrectly classify or detect items/objects in an image.")
with gr.Tab("Info"):
gr.Markdown("This is step 4. Select the type of Evasion attack to evaluate.")
with gr.Tab("Projected Gradient Descent"):
gr.Markdown("This attack uses PGD to identify adversarial examples.")
with gr.Row():
with gr.Column():
attack = gr.Textbox(visible=True, value="PGD", label="Attack", interactive=False)
max_iter = gr.Slider(minimum=1, maximum=5000, label="Max iterations", value=1000)
eps = gr.Slider(minimum=0.0001, maximum=1, label="Epslion", value=0.05)
eps_steps = gr.Slider(minimum=0.001, maximum=1000, label="Epsilon steps", value=0.1)
targeted = gr.Textbox(placeholder="Target label (integer)", label="Target")
eval_btn_pgd = gr.Button("Evaluate")
model_clip.change(pgd_update_epsilon, model_clip, eps)
# Evaluation Output. Visualisations of success/failures of running evaluation attacks.
with gr.Column():
with gr.Row():
with gr.Column():
original_gallery = gr.Gallery(label="Original", preview=True, show_download_button=True)
benign_output = gr.Label(num_top_classes=3, visible=False)
clean_accuracy = gr.Number(label="Clean Accuracy", precision=2)
quality_plot = gr.LinePlot(label="Gradient Quality", x='iteration', y='value', color='metric',
x_title='Iteration', y_title='Avg in Gradients (%)',
caption="""Illustrates the average percent of zero, infinity
or NaN gradients identified in images
across all batches.""", elem_classes="plot-padding", visible=False)
with gr.Column():
adversarial_gallery = gr.Gallery(label="Adversarial", preview=True, show_download_button=True)
adversarial_output = gr.Label(num_top_classes=3, visible=False)
robust_accuracy = gr.Number(label="Robust Accuracy", precision=2)
perturbation_added = gr.Number(label="Perturbation Added", precision=2)
dataset_type.change(pgd_show_label_output, dataset_type, [benign_output, adversarial_output,
clean_accuracy, robust_accuracy, perturbation_added])
eval_btn_pgd.click(clf_evasion_evaluate, inputs=[attack, model_type, model_path, model_channels, model_height, model_width,
model_clip, max_iter, eps, eps_steps, targeted,
dataset_type, dataset_path, dataset_split, image],
outputs=[original_gallery, benign_output, adversarial_gallery, adversarial_output, clean_accuracy,
robust_accuracy, perturbation_added, quality_plot], api_name='patch')
with gr.Row():
clear_btn = gr.ClearButton([image, targeted, original_gallery, benign_output, clean_accuracy,
adversarial_gallery, adversarial_output, robust_accuracy, perturbation_added])
with gr.Tab("Adversarial Patch"):
gr.Markdown("This attack crafts an adversarial patch that facilitates evasion.")
with gr.Row():
with gr.Column():
attack = gr.Textbox(visible=True, value="Adversarial Patch", label="Attack", interactive=False)
max_iter = gr.Slider(minimum=1, maximum=5000, label="Max iterations", value=100)
x_location = gr.Slider(minimum=1, maximum=640, label="Location (x)", value=18)
y_location = gr.Slider(minimum=1, maximum=480, label="Location (y)", value=18)
patch_height = gr.Slider(minimum=1, maximum=640, label="Patch height", value=18)
patch_width = gr.Slider(minimum=1, maximum=480, label="Patch width", value=18)
targeted = gr.Textbox(placeholder="Target label (integer)", label="Target")
eval_btn_patch = gr.Button("Evaluate")
model_clip.change()
# Evaluation Output. Visualisations of success/failures of running evaluation attacks.
with gr.Column():
with gr.Row():
with gr.Column():
original_gallery = gr.Gallery(label="Original", preview=True, show_download_button=True)
benign_output = gr.Label(num_top_classes=3, visible=False)
clean_accuracy = gr.Number(label="Clean Accuracy", precision=2)
with gr.Column():
adversarial_gallery = gr.Gallery(label="Adversarial", preview=True, show_download_button=True)
adversarial_output = gr.Label(num_top_classes=3, visible=False)
robust_accuracy = gr.Number(label="Robust Accuracy", precision=2)
patch_image = gr.Image(label="Adversarial Patch")
dataset_type.change(patch_show_label_output, dataset_type, [benign_output, adversarial_output,
clean_accuracy, robust_accuracy, patch_image])
eval_btn_patch.click(clf_evasion_evaluate, inputs=[attack, model_type, model_path, model_channels, model_height, model_width,
model_clip, max_iter, x_location, y_location, patch_height, patch_width, targeted,
dataset_type, dataset_path, dataset_split, image],
outputs=[original_gallery, benign_output, adversarial_gallery, adversarial_output, clean_accuracy,
robust_accuracy, patch_image])
with gr.Row():
clear_btn = gr.ClearButton([image, targeted, original_gallery, benign_output, clean_accuracy,
adversarial_gallery, adversarial_output, robust_accuracy, patch_image])
with gr.Tab("Poisoning"):
gr.Markdown("Coming soon.")
with gr.Tab("Black Box"):
gr.Markdown("Black box attacks assume the attacker __does not__ have full access to the model but can query it for predictions.")
with gr.Tab("Info"):
gr.Markdown("This is step 3. Select the type of black-box attack to evaluate.")
with gr.Tab("Evasion"):
gr.Markdown("Evasion attacks are deployed to cause a model to incorrectly classify or detect items/objects in an image.")
with gr.Tab("Info"):
gr.Markdown("This is step 4. Select the type of Evasion attack to evaluate.")
with gr.Tab("HopSkipJump"):
gr.Markdown("Coming soon.")
with gr.Tab("Square Attack"):
gr.Markdown("Coming soon.")
with gr.Tab("AutoAttack"):
gr.Markdown("Coming soon.")
if __name__ == "__main__":
# during development, set debug=True
demo.launch(show_api=False, debug=True, share=False,
server_name="0.0.0.0",
server_port=7777,
ssl_verify=False,
max_threads=20)