import torch, torchvision from torchvision import transforms import numpy as np import gradio as gr from PIL import Image from pytorch_grad_cam import GradCAM from resnet import ResNet18 import gradio as gr model = ResNet18() device = torch.device("cpu") model.load_state_dict(torch.load("model.pth"), strict=False, map_location=device) def inference(input_img, transparency): transform = transforms.ToTensor() input_img = transform(input_img) input_img = input_img input_img = input_img.unsqueeze(0) outputs = model(input_img) _, prediction = torch.max(outputs, 1) target_layers = [model.layer2[-2]] cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) grayscale_cam = cam(input_tensor=input_img, targets=targets) grayscale_cam = grayscale_cam[0, :] img = input_img.squeeze(0) img = inv_normalize(img) rgb_img = np.transpose(img, (1, 2, 0)) rgb_img = rgb_img.numpy() visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency) return classes[prediction[0].item()], visualization demo = gr.Interface(inference, [gr.Image(shape=(32, 32)), gr.Slider(0, 1)], ["text", gr.Image(shape=(32, 32)).style(width=128, height=128)]) demo.launch()