hugolb's picture
change files 2
f0237a6
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
device = torch.device("cpu")
class VGGBlock(nn.Module):
def __init__(self, in_channels, out_channels, batch_norm=False):
super().__init__()
conv2_params = {'kernel_size': (3, 3),
'stride' : (1, 1),
'padding' : 1}
noop = lambda x : x
self._batch_norm = batch_norm
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels , **conv2_params)
self.bn1 = nn.BatchNorm2d(out_channels) if batch_norm else noop
self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, **conv2_params)
self.bn2 = nn.BatchNorm2d(out_channels) if batch_norm else noop
self.max_pooling = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
@property
def batch_norm(self):
return self._batch_norm
def forward(self,x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.max_pooling(x)
return x
class VGG16(nn.Module):
def __init__(self, input_size, num_classes=10, batch_norm=False):
super(VGG16, self).__init__()
self.in_channels, self.in_width, self.in_height = input_size
self.block_1 = VGGBlock(self.in_channels, 64, batch_norm=batch_norm)
self.block_2 = VGGBlock(64, 128, batch_norm=batch_norm)
self.block_3 = VGGBlock(128, 256, batch_norm=batch_norm)
self.block_4 = VGGBlock(256,512, batch_norm=batch_norm)
self.classifier = nn.Sequential(
nn.Linear(2048, 4096),
nn.ReLU(True),
nn.Dropout(p=0.65),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(p=0.65),
nn.Linear(4096, num_classes)
)
@property
def input_size(self):
return self.in_channels, self.in_width, self.in_height
def forward(self, x):
x = self.block_1(x)
x = self.block_2(x)
x = self.block_3(x)
x = self.block_4(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
model = VGG16((1,32,32), batch_norm=True)
model.to(device)
# Load the saved checkpoint
model.load_state_dict(torch.load('model.pth', map_location=device))
label_map = {
0: 'T-shirt/top',
1: 'Trouser',
2: 'Pullover',
3: 'Dress',
4: 'Coat',
5: 'Sandal',
6: 'Shirt',
7: 'Sneaker',
8: 'FLAG{3883}',
9: 'Ankle boot'
}
def predict_from_local_image(image: str):
# Define the transformation to match the model's input requirements
transform = transforms.Compose([
transforms.Resize((32, 32)), # Resize to the input size of the model
transforms.ToTensor(), # Convert the image to a tensor
])
# Load the image
image = Image.open(image).convert('L') # Convert numpy array to PIL image and then to grayscale if necessary
image = transform(image).unsqueeze(0) # Add batch dimension
# Move the image to the specified device
image = image.to(device)
# Set the model to evaluation mode
model.eval()
# Make a prediction
with torch.no_grad():
output = model(image)
_, predicted_label = torch.max(output, 1)
confidence = torch.nn.functional.softmax(output, dim=1)[0] * 100
# Get the predicted class label and confidence
predicted_class = label_map[predicted_label.item()]
predicted_confidence = confidence[predicted_label.item()].item()
return predicted_class, predicted_confidence
# Gradio interface
iface = gr.Interface(
fn=predict_from_local_image, # Function to call for prediction
inputs=gr.Image(type='filepath', label="Upload an image"), # Input: .pt file upload
outputs=gr.Textbox(label="Predicted Class"), # Output: Text showing predicted class
title="Vault Challenge 4 - DeepFool", # Title of the interface
description="Upload an image, and the model will predict the class. Try to fool the model into predicting the FLAG using DeepFool! Tips: apply DeepFool attack on the image to make the model predict it as a BAG."
)
# Launch the Gradio interface
iface.launch()