Spaces:
Runtime error
Runtime error
File size: 1,517 Bytes
f852fd4 f6f9f84 f852fd4 4f450cd f852fd4 3d128a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import torch
import torch.nn as nn
from torchvision.transforms import transforms
from torchvision.models import resnet50
from PIL import Image
import gradio as gr
# Import model
model = resnet50()
# Freeze all layers
for param in model.parameters():
param.requires_grad = False
# Replace FC
# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
# Load parameters
model.load_state_dict(torch.load('./weights/tuned_resnet50.pth', map_location=torch.device('cpu')))
# Define the transformations to be applied to each iamge
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def predict(image):
# Preprocess
image = transform(Image.fromarray(image))
# Model prediction
model.eval()
output = model(torch.unsqueeze(image,0))
# Cast to desired
_, prediction = torch.max(output, 1) # argmax
# Prediction mapping
mapping = {0: 'Fake', 1: 'Authentic'}
return mapping[int(prediction.item())]
api = gr.Interface(
fn=predict,
inputs=gr.Image(shape=(224, 224),label="Upload an Image"),
outputs=gr.Textbox(label="Predicted Class"),
title="Image Forgery Detection System",
description= "This system checks whether an image was deepfaked. Input an image to be checked."
)
api.launch() |