MohamedAtta-AI's picture
Update app.py
f6f9f84
import torch
import torch.nn as nn
from torchvision.transforms import transforms
from torchvision.models import resnet50
from PIL import Image
import gradio as gr
# Import model
model = resnet50()
# Freeze all layers
for param in model.parameters():
param.requires_grad = False
# Replace FC
# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
# Load parameters
model.load_state_dict(torch.load('./weights/tuned_resnet50.pth', map_location=torch.device('cpu')))
# Define the transformations to be applied to each iamge
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def predict(image):
# Preprocess
image = transform(Image.fromarray(image))
# Model prediction
model.eval()
output = model(torch.unsqueeze(image,0))
# Cast to desired
_, prediction = torch.max(output, 1) # argmax
# Prediction mapping
mapping = {0: 'Fake', 1: 'Authentic'}
return mapping[int(prediction.item())]
api = gr.Interface(
fn=predict,
inputs=gr.Image(shape=(224, 224),label="Upload an Image"),
outputs=gr.Textbox(label="Predicted Class"),
title="Image Forgery Detection System",
description= "This system checks whether an image was deepfaked. Input an image to be checked."
)
api.launch()