import torch import torch.nn as nn from torchvision.transforms import transforms from torchvision.models import resnet50 from PIL import Image import gradio as gr # Import model model = resnet50() # Freeze all layers for param in model.parameters(): param.requires_grad = False # Replace FC # Parameters of newly constructed modules have requires_grad=True by default num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 2) # Load parameters model.load_state_dict(torch.load('./weights/tuned_resnet50.pth', map_location=torch.device('cpu'))) # Define the transformations to be applied to each iamge transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) def predict(image): # Preprocess image = transform(Image.fromarray(image)) # Model prediction model.eval() output = model(torch.unsqueeze(image,0)) # Cast to desired _, prediction = torch.max(output, 1) # argmax # Prediction mapping mapping = {0: 'Fake', 1: 'Authentic'} return mapping[int(prediction.item())] api = gr.Interface( fn=predict, inputs=gr.Image(shape=(224, 224),label="Upload an Image"), outputs=gr.Textbox(label="Predicted Class"), title="Image Forgery Detection System", description= "This system checks whether an image was deepfaked. Input an image to be checked." ) api.launch()