haritsahm
Add main deployment script
cc64157
raw
history blame
2.88 kB
from typing import List
import cv2
import gradio as gr
import numpy as np
import torch
from PIL import Image
from models import phc_models
from utils import utils
BILATERIAL_WEIGHT = 'weights/phresnet18_cbis2views.pt'
BILATERAL_MODEL = phc_models.PHCResNet18(
channels=2, n=2, num_classes=1, visualize=True)
BILATERAL_MODEL.add_top_blocks(num_classes=1)
BILATERAL_MODEL.load_state_dict(torch.load(
BILATERIAL_WEIGHT, map_location='cpu'))
BILATERAL_MODEL = BILATERAL_MODEL.to('cpu')
BILATERAL_MODEL.eval()
OUTPUT_GALLERY = gr.Gallery(
label='Highlighted Area').style(grid=[2], height='auto')
def predict_bilateral(file: str) -> List:
"""Predict Bilateral Mammography.
Parameters
----------
file : TemporaryFileWrapper
TemporaryFile object for the uploaded file
Returns
-------
List[List, Dict]
List of objects that will be used to display the result
"""
displays_imgs = []
image = np.array(Image.open(file.name))/257
image = np.reshape(image, (2, image.shape[0]//2, image.shape[1]))
im_h, im_w = image[0].shape[:2]
image_t = torch.from_numpy(image)
image_t = image_t.unsqueeze(0) # Add batch dimension
out, _, out_refiner = BILATERAL_MODEL(image_t)
out_refiner = utils.mean_activations(out_refiner).numpy()
probability = torch.sigmoid(out).detach().cpu().item()
label_name = 'Malignant' if probability > 0.5 else 'Normal/Benign'
lebels_dict = {label_name: probability}
refined_view_norm = cv2.normalize(
out_refiner, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
refined_view = cv2.applyColorMap(refined_view_norm, cv2.COLORMAP_JET)
refined_view = cv2.resize(
refined_view, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
image0_colored = cv2.normalize(
image[0], None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
image0_colored = cv2.cvtColor(image0_colored, cv2.COLOR_GRAY2RGB)
image1_colored = cv2.normalize(
image[1], None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
image1_colored = cv2.cvtColor(image1_colored, cv2.COLOR_GRAY2RGB)
heatmap0_overlay = cv2.addWeighted(
image0_colored, 1.0, refined_view, 0.5, 0)
heatmap1_overlay = cv2.addWeighted(
image1_colored, 1.0, refined_view, 0.5, 0)
displays_imgs += [(image0_colored, 'CC'), (image1_colored, 'MLO')]
displays_imgs.append((heatmap0_overlay, 'CC Interest Area'))
displays_imgs.append((heatmap1_overlay, 'MLO Interest Area'))
return displays_imgs, lebels_dict
def run():
"""Run Gradio App."""
demo = gr.Interface(
fn=predict_bilateral,
inputs=gr.File(file_count='single', file_types=['.png']),
outputs=[OUTPUT_GALLERY, gr.Label(label='Cancer Type')]
)
demo.launch(server_name='0.0.0.0', server_port=7860)
demo.close()
if __name__ == '__main__':
run()