File size: 1,517 Bytes
f852fd4
 
 
 
f6f9f84
f852fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f450cd
f852fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d128a5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
from torchvision.transforms import transforms
from torchvision.models import resnet50
from PIL import Image
import gradio as gr

# Import model
model = resnet50()

# Freeze all layers
for param in model.parameters():
    param.requires_grad = False

# Replace FC
# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

# Load parameters
model.load_state_dict(torch.load('./weights/tuned_resnet50.pth', map_location=torch.device('cpu')))

# Define the transformations to be applied to each iamge
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

def predict(image):
    # Preprocess
    image = transform(Image.fromarray(image))

    # Model prediction
    model.eval()
    output = model(torch.unsqueeze(image,0))
    
    # Cast to desired
    _, prediction = torch.max(output, 1)    # argmax

    # Prediction mapping
    mapping = {0: 'Fake', 1: 'Authentic'}

    return mapping[int(prediction.item())]


api = gr.Interface(
    fn=predict, 
    inputs=gr.Image(shape=(224, 224),label="Upload an Image"), 
    outputs=gr.Textbox(label="Predicted Class"), 
    title="Image Forgery Detection System",
    description= "This system checks whether an image was deepfaked. Input an image to be checked."
)

api.launch()