|
from pathlib import Path |
|
from typing import List |
|
|
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
|
|
from models import phc_models |
|
from utils import utils |
|
|
|
BILATERIAL_WEIGHT = 'weights/phresnet18_cbis2views.pt' |
|
BILATERAL_MODEL = phc_models.PHCResNet18( |
|
channels=2, n=2, num_classes=1, visualize=True) |
|
BILATERAL_MODEL.add_top_blocks(num_classes=1) |
|
BILATERAL_MODEL.load_state_dict(torch.load( |
|
BILATERIAL_WEIGHT, map_location='cpu')) |
|
BILATERAL_MODEL = BILATERAL_MODEL.to('cpu') |
|
BILATERAL_MODEL.eval() |
|
|
|
OUTPUT_GALLERY = gr.Gallery( |
|
label='Highlighted Area').style(grid=[2], height='auto') |
|
SUPPORTED_IMG_EXT = ['.png', '.jpg', '.jpeg'] |
|
|
|
|
|
def filter_files(files: List) -> List: |
|
"""Filter uploaded files. |
|
|
|
The model requires a pair of CC-MLO view of the breast scan. |
|
This function will filter and ensure the inputs are as expected. |
|
FIlter: |
|
- Not enough number of files |
|
- Unsupported extensions |
|
- Missing required pair or part |
|
|
|
Parameters |
|
---------- |
|
files : List[tempfile._TemporaryFileWrapper] |
|
List of path to downloaded files |
|
|
|
Returns |
|
------- |
|
List[pathlib.Path] |
|
List of path to downloaded files |
|
|
|
Raises |
|
------ |
|
gr.Error |
|
If the files is not equal to 2, |
|
gr.Error |
|
If the extension is unsupported |
|
gr.Error |
|
If specific view or side of mammography is missing. |
|
""" |
|
if len(files) != 2: |
|
raise gr.Error( |
|
f'Need exactly 2 images. Currently have {len(files)} images!') |
|
|
|
file_paths = [Path(file.name) for file in files] |
|
|
|
if not all([path.suffix in SUPPORTED_IMG_EXT for path in file_paths]): |
|
raise gr.Error(f'There is a file with unsupported type. \ |
|
Make sure all files are in {SUPPORTED_IMG_EXT}!') |
|
|
|
|
|
table = np.zeros((2, 2), dtype=bool) |
|
bin_left = 0 |
|
bin_right = 0 |
|
cc_first = False |
|
for idx, file in enumerate(file_paths): |
|
|
|
splits = file.name.split('_') |
|
|
|
|
|
if any(['cc' in part.lower() for part in splits]): |
|
table[0, :] = [True, True] |
|
if idx == 0: |
|
cc_first = True |
|
if any(['mlo' in part.lower() for part in splits]): |
|
table[1, :] = [True, True] |
|
|
|
|
|
if any(['left' in part.lower() for part in splits]): |
|
table[:, 0] &= True |
|
bin_left += 1 |
|
elif any(['right' in part.lower() for part in splits]): |
|
table[:, 1] &= True |
|
bin_right += 1 |
|
|
|
|
|
if not cc_first: |
|
file_paths.reverse() |
|
|
|
|
|
if bin_left < 2: |
|
table[:, 0] &= False |
|
if bin_right < 2: |
|
table[:, 1] &= False |
|
|
|
if not any([all(table[:, 0]), all(table[:, 1])]): |
|
raise gr.Error('Missing bilateral-view pair for Left or Right side.') |
|
|
|
return file_paths |
|
|
|
|
|
def predict_bilateral(files): |
|
"""Predict Bilateral Mammography. |
|
|
|
Parameters |
|
---------- |
|
files : List[tempfile._TemporaryFileWrapper] |
|
TemporaryFile object for the uploaded file |
|
|
|
Returns |
|
------- |
|
List[List, Dict] |
|
List of objects that will be used to display the result |
|
""" |
|
|
|
filtered_files = filter_files(files) |
|
|
|
displays_imgs = [] |
|
images = [] |
|
|
|
for path in filtered_files: |
|
image = np.array(Image.open(str(path))) |
|
image = cv2.normalize( |
|
image, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) |
|
|
|
images.append(image) |
|
|
|
images = np.asarray(images).astype(np.float32) |
|
im_h, im_w = images[0].shape[:2] |
|
|
|
images_t = torch.from_numpy(images) |
|
images_t = images_t.unsqueeze(0) |
|
|
|
out, _, out_refiner = BILATERAL_MODEL(images_t) |
|
|
|
out_refiner = utils.mean_activations(out_refiner).numpy() |
|
|
|
probability = torch.sigmoid(out).detach().cpu().item() |
|
label_name = 'Malignant' if probability > 0.5 else 'Normal/Benign' |
|
lebels_dict = {label_name: probability} |
|
|
|
refined_view_norm = cv2.normalize( |
|
out_refiner, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) |
|
refined_view = cv2.applyColorMap(refined_view_norm, cv2.COLORMAP_JET) |
|
refined_view = cv2.resize( |
|
refined_view, (im_w, im_h), interpolation=cv2.INTER_LINEAR) |
|
|
|
image0_colored = cv2.normalize( |
|
images[0], None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) |
|
image0_colored = cv2.cvtColor(image0_colored, cv2.COLOR_GRAY2RGB) |
|
image1_colored = cv2.normalize( |
|
images[1], None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) |
|
image1_colored = cv2.cvtColor(image1_colored, cv2.COLOR_GRAY2RGB) |
|
|
|
heatmap0_overlay = cv2.addWeighted( |
|
image0_colored, 1.0, refined_view, 0.5, 0) |
|
heatmap1_overlay = cv2.addWeighted( |
|
image1_colored, 1.0, refined_view, 0.5, 0) |
|
|
|
displays_imgs += [(image0_colored, 'CC'), (image1_colored, 'MLO')] |
|
|
|
displays_imgs.append((heatmap0_overlay, 'CC Interest Area')) |
|
displays_imgs.append((heatmap1_overlay, 'MLO Interest Area')) |
|
|
|
return displays_imgs, lebels_dict |
|
|
|
|
|
def run(): |
|
"""Run Gradio App.""" |
|
demo = gr.Interface( |
|
fn=predict_bilateral, |
|
inputs=gr.File(file_count='multiple', file_types=SUPPORTED_IMG_EXT), |
|
outputs=[OUTPUT_GALLERY, gr.Label(label='Cancer Type')] |
|
) |
|
|
|
demo.launch(server_name='0.0.0.0', server_port=7860) |
|
demo.close() |
|
|
|
|
|
if __name__ == '__main__': |
|
run() |
|
|