ccaglieri's picture
Update app.py
86a880b
raw
history blame
3.92 kB
import torch
import cv2
import torch.nn as nn
import numpy as np
from torchvision import models, transforms
import time
import os
import copy
import pickle
from PIL import Image
import datetime
import gdown
import zipfile
import urllib.request
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
import gradio as gr
IMG_SIZE = 512
CLASSES = [ "No DR", "Mild", "Moderate", "Severe", "Proliferative DR" ]
checkpoint = "./demo_checkpoint_convnext.pth"
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
model = torch.load(checkpoint, device)
global_transforms = transforms.Compose([
transforms.ToPILImage(),
transforms.Lambda(lambda image: image.convert('RGB')),
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.2786802, 0.2786802, 0.2786802], [0.16637428, 0.16637428, 0.16637428])
])
def crop_image_from_gray(img,tol=7):
mask = img>tol
img1=img[np.ix_(mask.any(1),mask.any(0))]
img2=img[np.ix_(mask.any(1),mask.any(0))]
img3=img[np.ix_(mask.any(1),mask.any(0))]
img = np.stack([img1,img2,img3],axis=-1)
return img
def circle_crop(img):
height, width = img.shape
x = int(width/2)
y = int(height/2)
r = np.amin((x,y))
circle_img = np.zeros((height, width), np.uint8)
cv2.circle(circle_img, (x,y), int(r), 1, thickness=-1)
img = cv2.bitwise_and(img, img, mask=circle_img)
img = crop_image_from_gray(img)
return img
def preprocess(img):
# Extract Green Channel
img = img[:,:,1]
#CLAHE
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
img = clahe.apply(img)
# Circle crop
img = circle_crop(img)
# Resize
img = cv2.resize(img, (IMG_SIZE,IMG_SIZE))
return img
def grad_campp(img):
img = np.float32(img) / 255
input_tensor = preprocess_image(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]).to(device)
# Set target layers
target_layers = [model.features[-1]]
# GradCAM++
gradcampp = GradCAMPlusPlus(model=model, target_layers=target_layers, use_cuda=True)
grayscale_gradcampp = gradcampp(input_tensor=input_tensor, targets=None , eigen_smooth=False, aug_smooth=False)
grayscale_gradcampp = grayscale_gradcampp[0, :]
gradcampp_image = show_cam_on_image(img, grayscale_gradcampp)
return gradcampp_image
def do_inference(img):
img = preprocess(img)
img_t = global_transforms(img)
batch_t = torch.unsqueeze(img_t, 0)
model.eval()
# We don't need gradients for test, so wrap in
# no_grad to save memory
with torch.no_grad():
batch_t = batch_t.to(device)
# forward propagation
output = model( batch_t)
# get prediction
probs = torch.nn.functional.softmax(output, dim=1)
output = torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int)
probs = probs.cpu().numpy()[0]
probs = probs[output]
labels = np.array(CLASSES)[output]
gradcam_img = grad_campp(img)
return {labels[i]: round(float(probs[i]),2) for i in range(len(labels))}, gradcam_img
im = gr.inputs.Image(shape=(512, 512), image_mode='RGB',
invert_colors=False, source="upload",
type="numpy")
title = "ConvNeXt for Diabetic Retinopathy Detection"
description = ""
examples = [['./noDr.png'],['./severe.png']]
#article="<p style='text-align: center'><a href='https://github.com/mawady/colab-recipes-cv' target='_blank'>Colab Recipes for Computer Vision - Dr. Mohamed Elawady</a></p>"
iface = gr.Interface(
do_inference,
im,
outputs = [ gr.outputs.Label(num_top_classes=5), gr.outputs.Image(label='Output image', type='pil')],
live=False,
interpretation=None,
title=title,
description=description,
examples=examples
)
#iface.test_launch()
iface.launch()