from typing import List import cv2 import gradio as gr import numpy as np import torch from PIL import Image from models import phc_models from utils import utils BILATERIAL_WEIGHT = 'weights/phresnet18_cbis2views.pt' BILATERAL_MODEL = phc_models.PHCResNet18( channels=2, n=2, num_classes=1, visualize=True) BILATERAL_MODEL.add_top_blocks(num_classes=1) BILATERAL_MODEL.load_state_dict(torch.load( BILATERIAL_WEIGHT, map_location='cpu')) BILATERAL_MODEL = BILATERAL_MODEL.to('cpu') BILATERAL_MODEL.eval() OUTPUT_GALLERY = gr.Gallery( label='Highlighted Area').style(grid=[2], height='auto') def predict_bilateral(file: str) -> List: """Predict Bilateral Mammography. Parameters ---------- file : TemporaryFileWrapper TemporaryFile object for the uploaded file Returns ------- List[List, Dict] List of objects that will be used to display the result """ displays_imgs = [] image = np.array(Image.open(file.name))/257 image = np.reshape(image, (2, image.shape[0]//2, image.shape[1])) im_h, im_w = image[0].shape[:2] image_t = torch.from_numpy(image) image_t = image_t.unsqueeze(0) # Add batch dimension out, _, out_refiner = BILATERAL_MODEL(image_t) out_refiner = utils.mean_activations(out_refiner).numpy() probability = torch.sigmoid(out).detach().cpu().item() label_name = 'Malignant' if probability > 0.5 else 'Normal/Benign' lebels_dict = {label_name: probability} refined_view_norm = cv2.normalize( out_refiner, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) refined_view = cv2.applyColorMap(refined_view_norm, cv2.COLORMAP_JET) refined_view = cv2.resize( refined_view, (im_w, im_h), interpolation=cv2.INTER_LINEAR) image0_colored = cv2.normalize( image[0], None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) image0_colored = cv2.cvtColor(image0_colored, cv2.COLOR_GRAY2RGB) image1_colored = cv2.normalize( image[1], None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) image1_colored = cv2.cvtColor(image1_colored, cv2.COLOR_GRAY2RGB) heatmap0_overlay = cv2.addWeighted( image0_colored, 1.0, refined_view, 0.5, 0) heatmap1_overlay = cv2.addWeighted( image1_colored, 1.0, refined_view, 0.5, 0) displays_imgs += [(image0_colored, 'CC'), (image1_colored, 'MLO')] displays_imgs.append((heatmap0_overlay, 'CC Interest Area')) displays_imgs.append((heatmap1_overlay, 'MLO Interest Area')) return displays_imgs, lebels_dict def run(): """Run Gradio App.""" demo = gr.Interface( fn=predict_bilateral, inputs=gr.File(file_count='single', file_types=['.png']), outputs=[OUTPUT_GALLERY, gr.Label(label='Cancer Type')] ) demo.launch(server_name='0.0.0.0', server_port=7860) demo.close() if __name__ == '__main__': run()