MohamedAtta-AI's picture
Update app.py
f6f9f84
raw
history blame
1.52 kB
import torch
import torch.nn as nn
from torchvision.transforms import transforms
from torchvision.models import resnet50
from PIL import Image
import gradio as gr
# Import model
model = resnet50()
# Freeze all layers
for param in model.parameters():
param.requires_grad = False
# Replace FC
# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
# Load parameters
model.load_state_dict(torch.load('./weights/tuned_resnet50.pth', map_location=torch.device('cpu')))
# Define the transformations to be applied to each iamge
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def predict(image):
# Preprocess
image = transform(Image.fromarray(image))
# Model prediction
model.eval()
output = model(torch.unsqueeze(image,0))
# Cast to desired
_, prediction = torch.max(output, 1) # argmax
# Prediction mapping
mapping = {0: 'Fake', 1: 'Authentic'}
return mapping[int(prediction.item())]
api = gr.Interface(
fn=predict,
inputs=gr.Image(shape=(224, 224),label="Upload an Image"),
outputs=gr.Textbox(label="Predicted Class"),
title="Image Forgery Detection System",
description= "This system checks whether an image was deepfaked. Input an image to be checked."
)
api.launch()