import gradio as gr import torch import torch.nn as nn from torchvision import transforms from PIL import Image import torch.nn.functional as F device = torch.device("cpu") class VGGBlock(nn.Module): def __init__(self, in_channels, out_channels, batch_norm=False): super().__init__() conv2_params = {'kernel_size': (3, 3), 'stride' : (1, 1), 'padding' : 1} noop = lambda x : x self._batch_norm = batch_norm self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels , **conv2_params) self.bn1 = nn.BatchNorm2d(out_channels) if batch_norm else noop self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, **conv2_params) self.bn2 = nn.BatchNorm2d(out_channels) if batch_norm else noop self.max_pooling = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) @property def batch_norm(self): return self._batch_norm def forward(self,x): x = self.conv1(x) x = self.bn1(x) x = F.relu(x) x = self.conv2(x) x = self.bn2(x) x = F.relu(x) x = self.max_pooling(x) return x class VGG16(nn.Module): def __init__(self, input_size, num_classes=10, batch_norm=False): super(VGG16, self).__init__() self.in_channels, self.in_width, self.in_height = input_size self.block_1 = VGGBlock(self.in_channels, 64, batch_norm=batch_norm) self.block_2 = VGGBlock(64, 128, batch_norm=batch_norm) self.block_3 = VGGBlock(128, 256, batch_norm=batch_norm) self.block_4 = VGGBlock(256,512, batch_norm=batch_norm) self.classifier = nn.Sequential( nn.Linear(2048, 4096), nn.ReLU(True), nn.Dropout(p=0.65), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(p=0.65), nn.Linear(4096, num_classes) ) @property def input_size(self): return self.in_channels, self.in_width, self.in_height def forward(self, x): x = self.block_1(x) x = self.block_2(x) x = self.block_3(x) x = self.block_4(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x model = VGG16((1,32,32), batch_norm=True) model.to(device) # Load the saved checkpoint model.load_state_dict(torch.load('model.pth', map_location=device)) label_map = { 0: 'T-shirt/top', 1: 'Trouser', 2: 'Pullover', 3: 'Dress', 4: 'Coat', 5: 'Sandal', 6: 'Shirt', 7: 'Sneaker', 8: 'FLAG{3883}', 9: 'Ankle boot' } def predict_from_local_image(image: str): # Define the transformation to match the model's input requirements transform = transforms.Compose([ transforms.Resize((32, 32)), # Resize to the input size of the model transforms.ToTensor(), # Convert the image to a tensor ]) # Load the image image = Image.open(image).convert('L') # Convert numpy array to PIL image and then to grayscale if necessary image = transform(image).unsqueeze(0) # Add batch dimension # Move the image to the specified device image = image.to(device) # Set the model to evaluation mode model.eval() # Make a prediction with torch.no_grad(): output = model(image) _, predicted_label = torch.max(output, 1) confidence = torch.nn.functional.softmax(output, dim=1)[0] * 100 # Get the predicted class label and confidence predicted_class = label_map[predicted_label.item()] predicted_confidence = confidence[predicted_label.item()].item() return predicted_class, predicted_confidence # Gradio interface iface = gr.Interface( fn=predict_from_local_image, # Function to call for prediction inputs=gr.Image(type='filepath', label="Upload an image"), # Input: .pt file upload outputs=gr.Textbox(label="Predicted Class"), # Output: Text showing predicted class title="Vault Challenge 4 - DeepFool", # Title of the interface description="Upload an image, and the model will predict the class. Try to fool the model into predicting the FLAG using DeepFool! Tips: apply DeepFool attack on the image to make the model predict it as a BAG." ) # Launch the Gradio interface iface.launch()