import torch import cv2 import torch.nn as nn import numpy as np from torchvision import models, transforms import time import os import copy import pickle from PIL import Image import datetime import gdown import zipfile import urllib.request import gradio as gr IMG_SIZE = 512 CLASSES = [ "No DR", "Mild", "Moderate", "Severe", "Proliferative DR" ] checkpoint = "./demo_checkpoint_convnext.pth" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = torch.load(checkpoint).to(device) global_transforms = transforms.Compose([ transforms.ToPILImage(), transforms.Lambda(lambda image: image.convert('RGB')), transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize([0.2786802, 0.2786802, 0.2786802], [0.16637428, 0.16637428, 0.16637428]) ]) def crop_image_from_gray(img,tol=7): mask = img>tol img1=img[np.ix_(mask.any(1),mask.any(0))] img2=img[np.ix_(mask.any(1),mask.any(0))] img3=img[np.ix_(mask.any(1),mask.any(0))] img = np.stack([img1,img2,img3],axis=-1) return img def circle_crop(img): height, width = img.shape x = int(width/2) y = int(height/2) r = np.amin((x,y)) circle_img = np.zeros((height, width), np.uint8) cv2.circle(circle_img, (x,y), int(r), 1, thickness=-1) img = cv2.bitwise_and(img, img, mask=circle_img) img = crop_image_from_gray(img) return img def preprocess(img): # Extract Green Channel img = img[:,:,1] #CLAHE clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) img = clahe.apply(img) # Circle crop img = circle_crop(img) # Resize img = cv2.resize(img, (IMG_SIZE,IMG_SIZE)) return img def do_inference(img): img = preprocess(img) img_t = global_transforms(img) batch_t = torch.unsqueeze(img_t, 0) model.eval() # We don't need gradients for test, so wrap in # no_grad to save memory with torch.no_grad(): batch_t = batch_t.to(device) # forward propagation output = model( batch_t) # get prediction probs = torch.nn.functional.softmax(output, dim=1) output = torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int) probs = probs.cpu().numpy()[0] probs = probs[output] labels = np.array(CLASSES)[output] return {labels[i]: round(float(probs[i]),2) for i in range(len(labels))} im = gr.inputs.Image(shape=(512, 512), image_mode='RGB', invert_colors=False, source="upload", type="pil") title = "ConvNeXt for Diabetic Retinopathy Detection" description = "" examples = [['./noDr.png'],['./severe.png']] #article="
" iface = gr.Interface( do_inference, im, gr.outputs.Label(num_top_classes=5), live=False, interpretation=None, title=title, description=description, examples=examples ) #iface.test_launch() iface.launch()