MohamedAtta-AI commited on
Commit
f852fd4
·
1 Parent(s): 3f020e2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.transforms import transforms
4
+ from torchvision.models import resnet50
5
+ import gradio as gr
6
+
7
+ # Import model
8
+ model = resnet50()
9
+
10
+ # Freeze all layers
11
+ for param in model.parameters():
12
+ param.requires_grad = False
13
+
14
+ # Replace FC
15
+ # Parameters of newly constructed modules have requires_grad=True by default
16
+ num_ftrs = model.fc.in_features
17
+ model.fc = nn.Linear(num_ftrs, 2)
18
+
19
+ # Load parameters
20
+ model.load_state_dict(torch.load('./weights/tuned_resnet50.pth'))
21
+
22
+ # Define the transformations to be applied to each iamge
23
+ transform = transforms.Compose([
24
+ transforms.Resize(256),
25
+ transforms.CenterCrop(224),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(
28
+ mean=[0.485, 0.456, 0.406],
29
+ std=[0.229, 0.224, 0.225]
30
+ )
31
+ ])
32
+
33
+ def predict(image):
34
+ # Preprocess
35
+ image = transform(Image.fromarray(image))
36
+
37
+ # Model prediction
38
+ model.eval()
39
+ output = model(torch.unsqueeze(image,0))
40
+
41
+ # Cast to desired
42
+ _, prediction = torch.max(output, 1) # argmax
43
+
44
+ # Prediction mapping
45
+ mapping = {0: 'Fake', 1: 'Authentic'}
46
+
47
+ return mapping[int(prediction.item())]
48
+
49
+
50
+ api = gr.Interface(
51
+ fn=predict,
52
+ inputs=gr.Image(shape=(224, 224),label="Upload an Image"),
53
+ outputs=gr.Textbox(label="Predicted Class"),
54
+ title="Image Forgery Detection System",
55
+ description= "This system checks whether an image was deepfaked. Input an image to be checked."
56
+ )
57
+
58
+ api.launch(share=True)