from pathlib import Path from typing import List import cv2 import gradio as gr import numpy as np import torch from PIL import Image from models import phc_models from utils import utils, page_utils device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda:0') BILATERIAL_WEIGHT = 'weights/phresnet18_cbis2views.pt' BILATERAL_MODEL = phc_models.PHCResNet18( channels=2, n=2, num_classes=1, visualize=True) BILATERAL_MODEL.add_top_blocks(num_classes=1) BILATERAL_MODEL.load_state_dict(torch.load( BILATERIAL_WEIGHT, map_location='cpu')) BILATERAL_MODEL = BILATERAL_MODEL.to(device) BILATERAL_MODEL.eval() INPUT_HEIGHT, INPUT_WIDTH = 600, 500 SUPPORTED_IMG_EXT = ['.png', '.jpg', '.jpeg'] EXAMPLE_IMAGES = [ ['examples/f4b2d377f43ba0bd_left_cc.png', 'examples/f4b2d377f43ba0bd_left_mlo.jpg'], ['examples/f4b2d377f43ba0bd_right_cc.png', 'examples/f4b2d377f43ba0bd_right_mlo.jpeg'], ['examples/P_00001_LEFT_cc.jpg', 'examples/P_00001_LEFT_mlo.jpeg'], ] # Model warmup test_images = np.random.randint(0, 255, (2, INPUT_HEIGHT, INPUT_WIDTH)) test_images = torch.from_numpy(test_images).to(device) test_images = test_images.unsqueeze(0) # Add batch dimension for _ in range(10): _, _, _ = BILATERAL_MODEL(test_images) test_images = None def filter_files(files: List) -> List: """Filter uploaded files. The model requires a pair of CC-MLO view of the breast scan. This function will filter and ensure the inputs are as expected. FIlter: - Not enough number of files - Unsupported extensions - Missing required pair or part Parameters ---------- files : List[tempfile._TemporaryFileWrapper] List of path to downloaded files Returns ------- List[pathlib.Path] List of path to downloaded files Raises ------ gr.Error If the files is not equal to 2, gr.Error If the extension is unsupported gr.Error If specific view or side of mammography is missing. """ if len(files) != 2: raise gr.Error( f'Need exactly 2 images. Currently have {len(files)} images!') file_paths = [Path(file.name) for file in files] if not all([path.suffix in SUPPORTED_IMG_EXT for path in file_paths]): raise gr.Error(f'There is a file with unsupported type. \ Make sure all files are in {SUPPORTED_IMG_EXT}!') # Table to store view(row), side(column) table = np.zeros((2, 2), dtype=bool) bin_left = 0 bin_right = 0 cc_first = False for idx, file in enumerate(file_paths): splits = file.name.split('_') # Check if view is present if any(['cc' in part.lower() for part in splits]): table[0, :] = [True, True] if idx == 0: cc_first = True if any(['mlo' in part.lower() for part in splits]): table[1, :] = [True, True] # Check if side is present if any(['left' in part.lower() for part in splits]): table[:, 0] &= True bin_left += 1 elif any(['right' in part.lower() for part in splits]): table[:, 1] &= True bin_right += 1 # Ensure cc_first if not cc_first: file_paths.reverse() # Reset side that has not enough files if bin_left < 2: table[:, 0] &= False if bin_right < 2: table[:, 1] &= False if not any([all(table[:, 0]), all(table[:, 1])]): raise gr.Error('Missing bilateral-view pair for Left or Right side.') return file_paths def predict_bilateral(cc_file, mlo_file): """Predict Bilateral Mammography. Parameters ---------- files : List[tempfile._TemporaryFileWrapper] TemporaryFile object for the uploaded file Returns ------- List[List, Dict] List of objects that will be used to display the result """ filtered_files = filter_files([cc_file, mlo_file]) displays_imgs = [] images = [] for path in filtered_files: image = np.array(Image.open(str(path))) image = cv2.normalize( image, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) image = cv2.resize( image, (INPUT_WIDTH, INPUT_HEIGHT), interpolation=cv2.INTER_LINEAR) images.append(image) images = np.asarray(images).astype(np.float32) im_h, im_w = images[0].shape[:2] images_t = torch.from_numpy(images) images_t = images_t.unsqueeze(0) # Add batch dimension images_t = images_t.to(device) out, _, out_refiner = BILATERAL_MODEL(images_t) out_refiner = utils.mean_activations(out_refiner).numpy() probability = torch.sigmoid(out).detach().cpu().item() label_name = 'Malignant' if probability > 0.5 else 'Normal/Benign' lebels_dict = {label_name: probability} refined_view_norm = cv2.normalize( out_refiner, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) refined_view = cv2.applyColorMap(refined_view_norm, cv2.COLORMAP_JET) refined_view = cv2.resize( refined_view, (im_w, im_h), interpolation=cv2.INTER_LINEAR) image0_colored = cv2.normalize( images[0], None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) image0_colored = cv2.cvtColor(image0_colored, cv2.COLOR_GRAY2RGB) image1_colored = cv2.normalize( images[1], None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) image1_colored = cv2.cvtColor(image1_colored, cv2.COLOR_GRAY2RGB) heatmap0_overlay = cv2.addWeighted( image0_colored, 1.0, refined_view, 0.5, 0) heatmap1_overlay = cv2.addWeighted( image1_colored, 1.0, refined_view, 0.5, 0) displays_imgs += [(image0_colored, 'CC'), (image1_colored, 'MLO')] displays_imgs.append((heatmap0_overlay, 'CC Interest Area')) displays_imgs.append((heatmap1_overlay, 'MLO Interest Area')) return displays_imgs, lebels_dict def run(): """Run Gradio App.""" with open('index.html', encoding='utf-8') as f: html_content = f.read() with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set( button_primary_background_fill='*primary_600', button_primary_background_fill_hover='*primary_500', button_primary_text_color='white', )) as demo: with gr.Column(): gr.HTML(html_content) with gr.Row(): with gr.Column(): cc_file = gr.File(file_count='single', file_types=SUPPORTED_IMG_EXT, label='CC View') mlo_file = gr.File(file_count='single', file_types=SUPPORTED_IMG_EXT, label='MLO View') with gr.Row(): clear_btn = gr.Button('Clear') process_btn = gr.Button('Process', variant="primary") with gr.Column(): output_gallery = gr.Gallery( label='Highlighted Area').style(grid=[2], height='auto') cancer_type = gr.Label(label='Cancer Type') gr.Examples( examples=EXAMPLE_IMAGES, inputs=[cc_file, mlo_file], ) gr.Markdown('Note that this method is sensitive to input image types.\ Current pipeline expect the values between 0.0-255.0') process_btn.click( fn=predict_bilateral, inputs=[cc_file, mlo_file], outputs=[output_gallery, cancer_type] ) clear_btn.click( lambda _: ( gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), ), inputs=None, outputs=[ cc_file, mlo_file, output_gallery, cancer_type, ], ) demo.launch(server_name='0.0.0.0', server_port=7860) # nosec B104 if __name__ == '__main__': run()