Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from torchvision.transforms import transforms | |
from torchvision.models import resnet50 | |
from PIL import Image | |
import gradio as gr | |
# Import model | |
model = resnet50() | |
# Freeze all layers | |
for param in model.parameters(): | |
param.requires_grad = False | |
# Replace FC | |
# Parameters of newly constructed modules have requires_grad=True by default | |
num_ftrs = model.fc.in_features | |
model.fc = nn.Linear(num_ftrs, 2) | |
# Load parameters | |
model.load_state_dict(torch.load('./weights/tuned_resnet50.pth', map_location=torch.device('cpu'))) | |
# Define the transformations to be applied to each iamge | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
) | |
]) | |
def predict(image): | |
# Preprocess | |
image = transform(Image.fromarray(image)) | |
# Model prediction | |
model.eval() | |
output = model(torch.unsqueeze(image,0)) | |
# Cast to desired | |
_, prediction = torch.max(output, 1) # argmax | |
# Prediction mapping | |
mapping = {0: 'Fake', 1: 'Authentic'} | |
return mapping[int(prediction.item())] | |
api = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(shape=(224, 224),label="Upload an Image"), | |
outputs=gr.Textbox(label="Predicted Class"), | |
title="Image Forgery Detection System", | |
description= "This system checks whether an image was deepfaked. Input an image to be checked." | |
) | |
api.launch() |