Spaces:
Sleeping
Sleeping
''' | |
ART-JATIC Gradio Example App | |
To run: | |
- clone the repository | |
- execute: gradio examples/gradio_app.py or python examples/gradio_app.py | |
- navigate to local URL e.g. http://127.0.0.1:7860 | |
''' | |
import torch | |
import numpy as np | |
import pandas as pd | |
from carbon_theme import Carbon | |
import gradio as gr | |
import os | |
import matplotlib.pyplot as plt | |
css = """ | |
.input-image { margin: auto !important } | |
.plot-padding { padding: 20px; } | |
""" | |
def extract_predictions(predictions_, conf_thresh): | |
coco_labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', | |
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', | |
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', | |
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', | |
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', | |
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', | |
'teddy bear', 'hair drier', 'toothbrush'] | |
# Get the predicted class | |
predictions_class = [coco_labels[i] for i in list(predictions_["labels"])] | |
# print("\npredicted classes:", predictions_class) | |
if len(predictions_class) < 1: | |
return [], [], [] | |
# Get the predicted bounding boxes | |
predictions_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(predictions_["boxes"])] | |
# Get the predicted prediction score | |
predictions_score = list(predictions_["scores"]) | |
# print("predicted score:", predictions_score) | |
# Get a list of index with score greater than threshold | |
threshold = conf_thresh | |
predictions_t = [predictions_score.index(x) for x in predictions_score if x > threshold] | |
if len(predictions_t) > 0: | |
predictions_t = predictions_t # [-1] #indices where score over threshold | |
else: | |
# no predictions esxceeding threshold | |
return [], [], [] | |
# predictions in score order | |
predictions_boxes = [predictions_boxes[i] for i in predictions_t] | |
predictions_class = [predictions_class[i] for i in predictions_t] | |
predictions_scores = [predictions_score[i] for i in predictions_t] | |
return predictions_class, predictions_boxes, predictions_scores | |
def plot_image_with_boxes(img, boxes, pred_cls, title): | |
import cv2 | |
text_size = 1 | |
text_th = 2 | |
rect_th = 1 | |
sections = [] | |
for i in range(len(boxes)): | |
cv2.rectangle(img, (int(boxes[i][0][0]), int(boxes[i][0][1])), (int(boxes[i][1][0]), int(boxes[i][1][1])), | |
color=(0, 255, 0), thickness=rect_th) | |
# Write the prediction class | |
cv2.putText(img, pred_cls[i], (int(boxes[i][0][0]), int(boxes[i][0][1])), cv2.FONT_HERSHEY_SIMPLEX, text_size, | |
(0, 255, 0), thickness=text_th) | |
sections.append( ((int(boxes[i][0][0]), | |
int(boxes[i][0][1]), | |
int(boxes[i][1][0]), | |
int(boxes[i][1][1])), (pred_cls[i])) ) | |
return img.astype(np.uint8) | |
def filter_boxes(predictions, conf_thresh): | |
dictionary = {} | |
boxes_list = [] | |
scores_list = [] | |
labels_list = [] | |
for i in range(len(predictions[0]["boxes"])): | |
score = predictions[0]["scores"][i] | |
if score >= conf_thresh: | |
boxes_list.append(predictions[0]["boxes"][i]) | |
scores_list.append(predictions[0]["scores"][[i]]) | |
labels_list.append(predictions[0]["labels"][[i]]) | |
dictionary["boxes"] = np.vstack(boxes_list) | |
dictionary["scores"] = np.hstack(scores_list) | |
dictionary["labels"] = np.hstack(labels_list) | |
y = [dictionary] | |
return y | |
def basic_cifar10_model(overfit=False): | |
''' | |
Load an example CIFAR10 model | |
''' | |
from art.estimators.classification.pytorch import PyTorchClassifier | |
labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | |
path = './' | |
class Model(torch.nn.Module): | |
""" | |
Create model for pytorch. | |
Here the model does not use maxpooling. Needed for certification tests. | |
""" | |
def __init__(self): | |
super(Model, self).__init__() | |
self.conv = torch.nn.Conv2d( | |
in_channels=3, out_channels=16, kernel_size=(4, 4), dilation=(1, 1), padding=(0, 0), stride=(3, 3) | |
) | |
self.fullyconnected = torch.nn.Linear(in_features=1600, out_features=10) | |
self.relu = torch.nn.ReLU() | |
w_conv2d = np.load( | |
os.path.join( | |
os.path.dirname(path), | |
"utils/resources/models", | |
"W_CONV2D_NO_MPOOL_CIFAR10.npy", | |
) | |
) | |
b_conv2d = np.load( | |
os.path.join( | |
os.path.dirname(path), | |
"utils/resources/models", | |
"B_CONV2D_NO_MPOOL_CIFAR10.npy", | |
) | |
) | |
w_dense = np.load( | |
os.path.join( | |
os.path.dirname(path), | |
"utils/resources/models", | |
"W_DENSE_NO_MPOOL_CIFAR10.npy", | |
) | |
) | |
b_dense = np.load( | |
os.path.join( | |
os.path.dirname(path), | |
"utils/resources/models", | |
"B_DENSE_NO_MPOOL_CIFAR10.npy", | |
) | |
) | |
self.conv.weight = torch.nn.Parameter(torch.Tensor(w_conv2d)) | |
self.conv.bias = torch.nn.Parameter(torch.Tensor(b_conv2d)) | |
self.fullyconnected.weight = torch.nn.Parameter(torch.Tensor(w_dense)) | |
self.fullyconnected.bias = torch.nn.Parameter(torch.Tensor(b_dense)) | |
# pylint: disable=W0221 | |
# disable pylint because of API requirements for function | |
def forward(self, x): | |
""" | |
Forward function to evaluate the model | |
:param x: Input to the model | |
:return: Prediction of the model | |
""" | |
x = self.conv(x) | |
x = self.relu(x) | |
x = x.reshape(-1, 1600) | |
x = self.fullyconnected(x) | |
return x | |
# Define the network | |
model = Model() | |
# Define a loss function and optimizer | |
if overfit: | |
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum") | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0) | |
else: | |
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum") | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |
# Get classifier | |
jptc = PyTorchClassifier( | |
model=model, loss=loss_fn, optimizer=optimizer, input_shape=(3, 32, 32), nb_classes=10, clip_values=(0, 1), labels=labels | |
) | |
return jptc | |
def det_evasion_evaluate(*args): | |
''' | |
Run a detection task evaluation | |
''' | |
def clf_evasion_evaluate(*args): | |
''' | |
Run a classification task evaluation | |
''' | |
def show_model_params(model_type): | |
''' | |
Show model parameters based on selected model type | |
''' | |
if model_type!="Example CIFAR10" and model_type!="Example XView" and model_type!="CIFAR10 Overfit": | |
return gr.Column(visible=True) | |
return gr.Column(visible=False) | |
def show_dataset_params(dataset_type): | |
''' | |
Show dataset parameters based on dataset type | |
''' | |
if dataset_type=="Example CIFAR10": | |
return [gr.Column(visible=False), gr.Row(visible=False), gr.Row(visible=False)] | |
elif dataset_type=="local": | |
return [gr.Column(visible=True), gr.Row(visible=True), gr.Row(visible=False)] | |
return [gr.Column(visible=True), gr.Row(visible=False), gr.Row(visible=True)] | |
def pgd_show_label_output(dataset_type): | |
''' | |
Show PGD output component based on dataset type | |
''' | |
if dataset_type=="local": | |
return [gr.Label(visible=True), gr.Label(visible=True), gr.Number(visible=False), gr.Number(visible=False), gr.Number(visible=True)] | |
return [gr.Label(visible=False), gr.Label(visible=False), gr.Number(visible=True), gr.Number(visible=True), gr.Number(visible=True)] | |
def pgd_update_epsilon(clip_values): | |
''' | |
Update max value of PGD epsilon slider based on model clip values | |
''' | |
if clip_values == 255: | |
return gr.Slider(minimum=0.0001, maximum=255, label="Epslion", value=55) | |
return gr.Slider(minimum=0.0001, maximum=1, label="Epslion", value=0.05) | |
def patch_show_label_output(dataset_type): | |
''' | |
Show adversarial patch output components based on dataset type | |
''' | |
if dataset_type=="local": | |
return [gr.Label(visible=True), gr.Label(visible=True), gr.Number(visible=False), gr.Number(visible=False), gr.Number(visible=True)] | |
return [gr.Label(visible=False), gr.Label(visible=False), gr.Number(visible=True), gr.Number(visible=True), gr.Number(visible=True)] | |
# e.g. To use a local alternative theme: carbon_theme = Carbon() | |
carbon_theme = Carbon() | |
with gr.Blocks(css=css, theme=carbon_theme) as demo: | |
import art | |
text = art.__version__ | |
gr.Markdown(f"<h1>ART (v{text}) Gradio Example</h1>") | |
with gr.Tab("Info"): | |
gr.Markdown('This is step 1. Using the tabs, select a task for evaluation.') | |
with gr.Tab("Classification", elem_classes="task-tab"): | |
gr.Markdown("Classifying images with a set of categories.") | |
# Model and Dataset Selection | |
with gr.Row(): | |
# Model and Dataset type e.g. Torchvision, HuggingFace, local etc. | |
with gr.Column(): | |
model_type = gr.Radio(label="Model type", choices=["Example CIFAR10", "Huggingface", "torchvision"], | |
value="Example CIFAR10") | |
dataset_type = gr.Radio(label="Dataset", choices=["Example CIFAR10", "Huggingface", "local"], | |
value="Example CIFAR10") | |
# Model parameters e.g. RESNET, VIT, input dimensions, clipping values etc. | |
with gr.Column(visible=False) as model_params: | |
model_path = gr.Textbox(placeholder="URL", label="Model path") | |
with gr.Row(): | |
with gr.Column(): | |
model_channels = gr.Textbox(placeholder="Integer, 3 for RGB images", label="Input Channels", value=3) | |
with gr.Column(): | |
model_width = gr.Textbox(placeholder="Integer", label="Input Width", value=640) | |
with gr.Row(): | |
with gr.Column(): | |
model_height = gr.Textbox(placeholder="Integer", label="Input Height", value=480) | |
with gr.Column(): | |
model_clip = gr.Radio(choices=[1, 255], label="Pixel clip", value=1) | |
# Dataset parameters e.g. Torchvision, HuggingFace, local etc. | |
with gr.Column(visible=False) as dataset_params: | |
with gr.Row() as local_image: | |
image = gr.Image(sources=['upload'], type="pil", height=150, width=150, elem_classes="input-image") | |
with gr.Row() as hosted_image: | |
dataset_path = gr.Textbox(placeholder="URL", label="Dataset path") | |
dataset_split = gr.Textbox(placeholder="test", label="Dataset split") | |
model_type.change(show_model_params, model_type, model_params) | |
dataset_type.change(show_dataset_params, dataset_type, [dataset_params, local_image, hosted_image]) | |
# Attack Selection | |
with gr.Row(): | |
with gr.Tab("Info"): | |
gr.Markdown("This is step 2. Select the type of attack for evaluation.") | |
with gr.Tab("White Box"): | |
gr.Markdown("White box attacks assume the attacker has __full access__ to the model.") | |
with gr.Tab("Info"): | |
gr.Markdown("This is step 3. Select the type of white-box attack to evaluate.") | |
with gr.Tab("Evasion"): | |
gr.Markdown("Evasion attacks are deployed to cause a model to incorrectly classify or detect items/objects in an image.") | |
with gr.Tab("Info"): | |
gr.Markdown("This is step 4. Select the type of Evasion attack to evaluate.") | |
with gr.Tab("Projected Gradient Descent"): | |
gr.Markdown("This attack uses PGD to identify adversarial examples.") | |
with gr.Row(): | |
with gr.Column(): | |
attack = gr.Textbox(visible=True, value="PGD", label="Attack", interactive=False) | |
max_iter = gr.Slider(minimum=1, maximum=5000, label="Max iterations", value=1000) | |
eps = gr.Slider(minimum=0.0001, maximum=1, label="Epslion", value=0.05) | |
eps_steps = gr.Slider(minimum=0.001, maximum=1000, label="Epsilon steps", value=0.1) | |
targeted = gr.Textbox(placeholder="Target label (integer)", label="Target") | |
eval_btn_pgd = gr.Button("Evaluate") | |
model_clip.change(pgd_update_epsilon, model_clip, eps) | |
# Evaluation Output. Visualisations of success/failures of running evaluation attacks. | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
original_gallery = gr.Gallery(label="Original", preview=True, show_download_button=True) | |
benign_output = gr.Label(num_top_classes=3, visible=False) | |
clean_accuracy = gr.Number(label="Clean Accuracy", precision=2) | |
quality_plot = gr.LinePlot(label="Gradient Quality", x='iteration', y='value', color='metric', | |
x_title='Iteration', y_title='Avg in Gradients (%)', | |
caption="""Illustrates the average percent of zero, infinity | |
or NaN gradients identified in images | |
across all batches.""", elem_classes="plot-padding", visible=False) | |
with gr.Column(): | |
adversarial_gallery = gr.Gallery(label="Adversarial", preview=True, show_download_button=True) | |
adversarial_output = gr.Label(num_top_classes=3, visible=False) | |
robust_accuracy = gr.Number(label="Robust Accuracy", precision=2) | |
perturbation_added = gr.Number(label="Perturbation Added", precision=2) | |
dataset_type.change(pgd_show_label_output, dataset_type, [benign_output, adversarial_output, | |
clean_accuracy, robust_accuracy, perturbation_added]) | |
eval_btn_pgd.click(clf_evasion_evaluate, inputs=[attack, model_type, model_path, model_channels, model_height, model_width, | |
model_clip, max_iter, eps, eps_steps, targeted, | |
dataset_type, dataset_path, dataset_split, image], | |
outputs=[original_gallery, benign_output, adversarial_gallery, adversarial_output, clean_accuracy, | |
robust_accuracy, perturbation_added, quality_plot], api_name='patch') | |
with gr.Row(): | |
clear_btn = gr.ClearButton([image, targeted, original_gallery, benign_output, clean_accuracy, | |
adversarial_gallery, adversarial_output, robust_accuracy, perturbation_added]) | |
with gr.Tab("Adversarial Patch"): | |
gr.Markdown("This attack crafts an adversarial patch that facilitates evasion.") | |
with gr.Row(): | |
with gr.Column(): | |
attack = gr.Textbox(visible=True, value="Adversarial Patch", label="Attack", interactive=False) | |
max_iter = gr.Slider(minimum=1, maximum=5000, label="Max iterations", value=100) | |
x_location = gr.Slider(minimum=1, maximum=640, label="Location (x)", value=18) | |
y_location = gr.Slider(minimum=1, maximum=480, label="Location (y)", value=18) | |
patch_height = gr.Slider(minimum=1, maximum=640, label="Patch height", value=18) | |
patch_width = gr.Slider(minimum=1, maximum=480, label="Patch width", value=18) | |
targeted = gr.Textbox(placeholder="Target label (integer)", label="Target") | |
eval_btn_patch = gr.Button("Evaluate") | |
model_clip.change() | |
# Evaluation Output. Visualisations of success/failures of running evaluation attacks. | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
original_gallery = gr.Gallery(label="Original", preview=True, show_download_button=True) | |
benign_output = gr.Label(num_top_classes=3, visible=False) | |
clean_accuracy = gr.Number(label="Clean Accuracy", precision=2) | |
with gr.Column(): | |
adversarial_gallery = gr.Gallery(label="Adversarial", preview=True, show_download_button=True) | |
adversarial_output = gr.Label(num_top_classes=3, visible=False) | |
robust_accuracy = gr.Number(label="Robust Accuracy", precision=2) | |
patch_image = gr.Image(label="Adversarial Patch") | |
dataset_type.change(patch_show_label_output, dataset_type, [benign_output, adversarial_output, | |
clean_accuracy, robust_accuracy, patch_image]) | |
eval_btn_patch.click(clf_evasion_evaluate, inputs=[attack, model_type, model_path, model_channels, model_height, model_width, | |
model_clip, max_iter, x_location, y_location, patch_height, patch_width, targeted, | |
dataset_type, dataset_path, dataset_split, image], | |
outputs=[original_gallery, benign_output, adversarial_gallery, adversarial_output, clean_accuracy, | |
robust_accuracy, patch_image]) | |
with gr.Row(): | |
clear_btn = gr.ClearButton([image, targeted, original_gallery, benign_output, clean_accuracy, | |
adversarial_gallery, adversarial_output, robust_accuracy, patch_image]) | |
with gr.Tab("Poisoning"): | |
gr.Markdown("Coming soon.") | |
with gr.Tab("Black Box"): | |
gr.Markdown("Black box attacks assume the attacker __does not__ have full access to the model but can query it for predictions.") | |
with gr.Tab("Info"): | |
gr.Markdown("This is step 3. Select the type of black-box attack to evaluate.") | |
with gr.Tab("Evasion"): | |
gr.Markdown("Evasion attacks are deployed to cause a model to incorrectly classify or detect items/objects in an image.") | |
with gr.Tab("Info"): | |
gr.Markdown("This is step 4. Select the type of Evasion attack to evaluate.") | |
with gr.Tab("HopSkipJump"): | |
gr.Markdown("Coming soon.") | |
with gr.Tab("Square Attack"): | |
gr.Markdown("Coming soon.") | |
with gr.Tab("AutoAttack"): | |
gr.Markdown("Coming soon.") | |
if __name__ == "__main__": | |
# during development, set debug=True | |
demo.launch(show_api=False, debug=True, share=False, | |
server_name="0.0.0.0", | |
server_port=7777, | |
ssl_verify=False, | |
max_threads=20) | |