Upload 24 files
Browse files- README.md +72 -3
- cellseg_time_eval.py +97 -0
- classification/train_classification.py +181 -0
- classification/unsup_classification.py +91 -0
- compute_metric.py +201 -0
- finetune_convnext_stardist.py +388 -0
- models/__init__.py +10 -0
- models/__pycache__/__init__.cpython-38.pyc +0 -0
- models/__pycache__/convnext.cpython-38.pyc +0 -0
- models/__pycache__/flexible_unet.cpython-38.pyc +0 -0
- models/__pycache__/flexible_unet_convext.cpython-38.pyc +0 -0
- models/__pycache__/swin_unetr.cpython-38.pyc +0 -0
- models/__pycache__/unetr2d.cpython-38.pyc +0 -0
- models/convnext.py +220 -0
- models/flexible_unet.py +312 -0
- models/flexible_unet_convext.py +451 -0
- overlay.py +116 -0
- predict.py +157 -0
- predict_unet_convnext.py +123 -0
- requirements.txt +37 -0
- train_convnext_hover..py +516 -0
- train_convnext_stardist.py +417 -0
- utils.py +868 -0
- utils_modify.py +369 -0
README.md
CHANGED
@@ -1,3 +1,72 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Solution of Team Sribd-med for NeurIPS-CellSeg Challenge
|
2 |
+
This repository provides the solution of team Sribd-med for [NeurIPS-CellSeg](https://neurips22-cellseg.grand-challenge.org/) Challenge. The details of our method are described in our paper [Multi-stream Cell Segmentation with Low-level Cues for Multi-modality Images]. Some parts of the codes are from the baseline codes of the [NeurIPS-CellSeg-Baseline](https://github.com/JunMa11/NeurIPS-CellSeg) repository,
|
3 |
+
|
4 |
+
You can reproduce our method as follows step by step:
|
5 |
+
|
6 |
+
## Environments and Requirements:
|
7 |
+
Install requirements by
|
8 |
+
|
9 |
+
```shell
|
10 |
+
python -m pip install -r requirements.txt
|
11 |
+
```
|
12 |
+
|
13 |
+
## Dataset
|
14 |
+
The competition training and tuning data can be downloaded from https://neurips22-cellseg.grand-challenge.org/dataset/
|
15 |
+
Besides, you can download three publiced data from the following link:
|
16 |
+
Cellpose: https://www.cellpose.org/dataset
|
17 |
+
Omnipose: http://www.cellpose.org/dataset_omnipose
|
18 |
+
Sartorius: https://www.kaggle.com/competitions/sartorius-cell-instance-segmentation/overview
|
19 |
+
|
20 |
+
## Automatic cell classification
|
21 |
+
You can classify the cells into four classes in this step.
|
22 |
+
Put all the images (competition + Cellpose + Omnipose + Sartorius) in one folder (data/allimages).
|
23 |
+
Run classification code:
|
24 |
+
|
25 |
+
```shell
|
26 |
+
python classification/unsup_classification.py
|
27 |
+
```
|
28 |
+
The results can be stored in data/classification_results/
|
29 |
+
|
30 |
+
## CNN-base classification model training
|
31 |
+
Using the classified images in data/classification_results/. A resnet18 is trained:
|
32 |
+
```shell
|
33 |
+
python classification/train_classification.py
|
34 |
+
```
|
35 |
+
## Segmentation Training
|
36 |
+
Pre-training convnext-stardist using all the images (data/allimages).
|
37 |
+
```shell
|
38 |
+
python train_convnext_stardist.py
|
39 |
+
```
|
40 |
+
For class 0,2,3 finetune on the classified data (Take class1 as a example):
|
41 |
+
```shell
|
42 |
+
python finetune_convnext_stardist.py model_dir=(The pretrained convnext-stardist model) data_dir='data/classification_results/class1'
|
43 |
+
```
|
44 |
+
For class 1 train the convnext-hover from scratch using classified class 3 data.
|
45 |
+
```shell
|
46 |
+
python train_convnext_hover.py data_dir='data/classification_results/class3'
|
47 |
+
```
|
48 |
+
|
49 |
+
Finally, four segmentation models will be trained.
|
50 |
+
|
51 |
+
## Trained models
|
52 |
+
The models can be downloaded from this link:
|
53 |
+
https://drive.google.com/drive/folders/1MkEOpgmdkg5Yqw6Ng5PoOhtmo9xPPwIj?usp=sharing
|
54 |
+
|
55 |
+
## Inference
|
56 |
+
The inference process includes classification and segmentation.
|
57 |
+
```shell
|
58 |
+
python predict.py -i input_path -o output_path --model_path './models'
|
59 |
+
```
|
60 |
+
|
61 |
+
## Evaluation
|
62 |
+
Calculate the F-score for evaluation:
|
63 |
+
```shell
|
64 |
+
python compute_metric.py --gt_path path_to_labels --seg_path output_path
|
65 |
+
```
|
66 |
+
|
67 |
+
## Results
|
68 |
+
The tuning set F1 score of our method is 0.8795. The rank running time of our method on all the 101 cases in the tuning set is zero in our local
|
69 |
+
workstation.
|
70 |
+
## Acknowledgement
|
71 |
+
We thank for the contributors of public datasets.
|
72 |
+
|
cellseg_time_eval.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The code was adapted from the MICCAI FLARE Challenge
|
3 |
+
https://flare22.grand-challenge.org/
|
4 |
+
|
5 |
+
The testing images will be evaluated one by one.
|
6 |
+
To compensate for the Docker container startup time, we give a time tolerance for the running time.
|
7 |
+
https://neurips22-cellseg.grand-challenge.org/metrics/
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
join = os.path.join
|
12 |
+
import sys
|
13 |
+
import shutil
|
14 |
+
import time
|
15 |
+
import torch
|
16 |
+
import argparse
|
17 |
+
from collections import OrderedDict
|
18 |
+
from skimage import io
|
19 |
+
import tifffile as tif
|
20 |
+
import numpy as np
|
21 |
+
import pandas as pd
|
22 |
+
|
23 |
+
parser = argparse.ArgumentParser('Segmentation efficiency eavluation for docker containers', add_help=False)
|
24 |
+
parser.add_argument('-i', '--test_img_path', default='./val-imgs-30/', type=str, help='testing data path')
|
25 |
+
parser.add_argument('-o','--save_path', default='./val_team_seg', type=str, help='segmentation output path')
|
26 |
+
parser.add_argument('-d','--docker_folder_path', default='./team_docker', type=str, help='team docker path')
|
27 |
+
args = parser.parse_args()
|
28 |
+
|
29 |
+
test_img_path = args.test_img_path
|
30 |
+
save_path = args.save_path
|
31 |
+
docker_path = args.docker_folder_path
|
32 |
+
|
33 |
+
input_temp = './inputs/'
|
34 |
+
output_temp = './outputs'
|
35 |
+
os.makedirs(save_path, exist_ok=True)
|
36 |
+
|
37 |
+
dockers = sorted(os.listdir(docker_path))
|
38 |
+
test_cases = sorted(os.listdir(test_img_path))
|
39 |
+
|
40 |
+
for docker in dockers:
|
41 |
+
try:
|
42 |
+
# create temp folers for inference one-by-one
|
43 |
+
if os.path.exists(input_temp):
|
44 |
+
shutil.rmtree(input_temp)
|
45 |
+
if os.path.exists(output_temp):
|
46 |
+
shutil.rmtree(output_temp)
|
47 |
+
os.makedirs(input_temp)
|
48 |
+
os.makedirs(output_temp)
|
49 |
+
# load docker and create a new folder to save segmentation results
|
50 |
+
teamname = docker.split('.')[0].lower()
|
51 |
+
print('teamname docker: ', docker)
|
52 |
+
# os.system('docker image load < {}'.format(join(docker_path, docker)))
|
53 |
+
team_outpath = join(save_path, teamname)
|
54 |
+
if os.path.exists(team_outpath):
|
55 |
+
shutil.rmtree(team_outpath)
|
56 |
+
os.mkdir(team_outpath)
|
57 |
+
metric = OrderedDict()
|
58 |
+
metric['Img Name'] = []
|
59 |
+
metric['Real Running Time'] = []
|
60 |
+
metric['Rank Running Time'] = []
|
61 |
+
# To obtain the running time for each case, we inference the testing case one-by-one
|
62 |
+
for case in test_cases:
|
63 |
+
shutil.copy(join(test_img_path, case), input_temp)
|
64 |
+
if case.endswith('.tif') or case.endswith('.tiff'):
|
65 |
+
img = tif.imread(join(input_temp, case))
|
66 |
+
else:
|
67 |
+
img = io.imread(join(input_temp, case))
|
68 |
+
pix_num = img.shape[0] * img.shape[1]
|
69 |
+
cmd = 'docker container run --gpus="device=0" -m 28g --name {} --rm -v $PWD/inputs/:/workspace/inputs/ -v $PWD/outputs/:/workspace/outputs/ {}:latest /bin/bash -c "sh predict.sh" '.format(teamname, teamname)
|
70 |
+
print(teamname, ' docker command:', cmd, '\n', 'testing image name:', case)
|
71 |
+
start_time = time.time()
|
72 |
+
os.system(cmd)
|
73 |
+
real_running_time = time.time() - start_time
|
74 |
+
print(f"{case} finished! Inference time: {real_running_time}")
|
75 |
+
# save metrics
|
76 |
+
metric['Img Name'].append(case)
|
77 |
+
metric['Real Running Time'].append(real_running_time)
|
78 |
+
if pix_num <= 1000000:
|
79 |
+
rank_running_time = np.max([0, real_running_time-10])
|
80 |
+
else:
|
81 |
+
rank_running_time = np.max([0, real_running_time-10*(pix_num/1000000)])
|
82 |
+
metric['Rank Running Time'].append(rank_running_time)
|
83 |
+
os.remove(join(input_temp, case))
|
84 |
+
seg_name = case.split('.')[0] + '_label.tiff'
|
85 |
+
try:
|
86 |
+
os.rename(join(output_temp, seg_name), join(team_outpath, seg_name))
|
87 |
+
except:
|
88 |
+
print(f"{join(output_temp, seg_name)}, {join(team_outpath, seg_name)}")
|
89 |
+
print("Wrong segmentation name!!! It should be image_name.split(\'.\')[0] + \'_label.tiff\' ")
|
90 |
+
metric_df = pd.DataFrame(metric)
|
91 |
+
metric_df.to_csv(join(team_outpath, teamname + '_running_time.csv'), index=False)
|
92 |
+
torch.cuda.empty_cache()
|
93 |
+
# os.system("docker rmi {}:latest".format(teamname))
|
94 |
+
shutil.rmtree(input_temp)
|
95 |
+
shutil.rmtree(output_temp)
|
96 |
+
except Exception as e:
|
97 |
+
print(e)
|
classification/train_classification.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys, glob, time, random, shutil, copy
|
2 |
+
from tqdm import tqdm
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
from torchvision import datasets, models, transforms
|
8 |
+
import torch.utils.data as data
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.optim as optim
|
11 |
+
from torch.optim import lr_scheduler
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torchsummary import summary
|
14 |
+
from matplotlib import pyplot as plt
|
15 |
+
from torchvision.models import resnet18, ResNet18_Weights # do not import
|
16 |
+
from PIL import Image, ImageFile
|
17 |
+
from skimage import io
|
18 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
19 |
+
|
20 |
+
# Set the train and validation directory paths
|
21 |
+
train_directory = 'dataset/train'
|
22 |
+
valid_directory = 'dataset/val'
|
23 |
+
|
24 |
+
# Batch size
|
25 |
+
bs = 64
|
26 |
+
# Number of epochs
|
27 |
+
num_epochs = 20
|
28 |
+
# Number of classes
|
29 |
+
num_classes = 4
|
30 |
+
# Number of workers
|
31 |
+
num_cpu = 8
|
32 |
+
|
33 |
+
# Applying transforms to the data
|
34 |
+
image_transforms = {
|
35 |
+
'train': transforms.Compose([
|
36 |
+
transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
|
37 |
+
transforms.RandomRotation(degrees=15),
|
38 |
+
transforms.RandomHorizontalFlip(),
|
39 |
+
transforms.CenterCrop(size=224),
|
40 |
+
transforms.ToTensor(),
|
41 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
42 |
+
[0.229, 0.224, 0.225])
|
43 |
+
]),
|
44 |
+
'valid': transforms.Compose([
|
45 |
+
transforms.Resize(size=256),
|
46 |
+
transforms.CenterCrop(size=224),
|
47 |
+
transforms.ToTensor(),
|
48 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
49 |
+
[0.229, 0.224, 0.225])
|
50 |
+
])
|
51 |
+
}
|
52 |
+
|
53 |
+
# Load data from folders
|
54 |
+
dataset = {
|
55 |
+
'train': datasets.ImageFolder(root=train_directory, transform=image_transforms['train']),
|
56 |
+
'valid': datasets.ImageFolder(root=valid_directory, transform=image_transforms['valid'])
|
57 |
+
}
|
58 |
+
|
59 |
+
# Size of train and validation data
|
60 |
+
dataset_sizes = {
|
61 |
+
'train':len(dataset['train']),
|
62 |
+
'valid':len(dataset['valid'])
|
63 |
+
}
|
64 |
+
|
65 |
+
# Create iterators for data loading
|
66 |
+
dataloaders = {
|
67 |
+
'train':data.DataLoader(dataset['train'], batch_size=bs, shuffle=True,
|
68 |
+
num_workers=num_cpu, pin_memory=True, drop_last=False),
|
69 |
+
'valid':data.DataLoader(dataset['valid'], batch_size=bs, shuffle=False,
|
70 |
+
num_workers=num_cpu, pin_memory=True, drop_last=False)
|
71 |
+
}
|
72 |
+
|
73 |
+
# Class names or target labels
|
74 |
+
class_names = dataset['train'].classes
|
75 |
+
print("Classes:", class_names)
|
76 |
+
|
77 |
+
# Print the train and validation data sizes
|
78 |
+
print("Training-set size:",dataset_sizes['train'],
|
79 |
+
"\nValidation-set size:", dataset_sizes['valid'])
|
80 |
+
|
81 |
+
modelname = 'resnet18'
|
82 |
+
|
83 |
+
# Set default device as gpu, if available
|
84 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
85 |
+
|
86 |
+
weights = ResNet18_Weights.DEFAULT
|
87 |
+
model = resnet18(weights=None)
|
88 |
+
num_ftrs = model.fc.in_features
|
89 |
+
model.fc = nn.Linear(num_ftrs, num_classes)
|
90 |
+
|
91 |
+
|
92 |
+
# Transfer the model to GPU
|
93 |
+
model = model.to(device)
|
94 |
+
|
95 |
+
# Print model summary
|
96 |
+
print('Model Summary:-\n')
|
97 |
+
for num, (name, param) in enumerate(model.named_parameters()):
|
98 |
+
print(num, name, param.requires_grad )
|
99 |
+
summary(model, input_size=(3, 224, 224))
|
100 |
+
|
101 |
+
# Loss function
|
102 |
+
criterion = nn.CrossEntropyLoss()
|
103 |
+
|
104 |
+
# Optimizer
|
105 |
+
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
|
106 |
+
|
107 |
+
# Learning rate decay
|
108 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
|
109 |
+
|
110 |
+
since = time.time()
|
111 |
+
|
112 |
+
best_model_wts = copy.deepcopy(model.state_dict())
|
113 |
+
best_acc = 0.0
|
114 |
+
|
115 |
+
for epoch in range(1, num_epochs+1):
|
116 |
+
print('Epoch {}/{}'.format(epoch, num_epochs))
|
117 |
+
print('-' * 10)
|
118 |
+
|
119 |
+
# Each epoch has a training and validation phase
|
120 |
+
for phase in ['train', 'valid']:
|
121 |
+
if phase == 'train':
|
122 |
+
model.train() # Set model to training mode
|
123 |
+
else:
|
124 |
+
model.eval() # Set model to evaluate mode
|
125 |
+
|
126 |
+
running_loss = 0.0
|
127 |
+
running_corrects = 0
|
128 |
+
|
129 |
+
# Iterate over data.
|
130 |
+
n = 0
|
131 |
+
stream = tqdm(dataloaders[phase])
|
132 |
+
for i, (inputs, labels) in enumerate(stream, start=1):
|
133 |
+
inputs = inputs.to(device)
|
134 |
+
labels = labels.to(device)
|
135 |
+
|
136 |
+
# zero the parameter gradients
|
137 |
+
optimizer.zero_grad()
|
138 |
+
|
139 |
+
# forward
|
140 |
+
# track history if only in train
|
141 |
+
with torch.set_grad_enabled(phase == 'train'):
|
142 |
+
outputs = model(inputs)
|
143 |
+
_, preds = torch.max(outputs, 1)
|
144 |
+
loss = criterion(outputs, labels)
|
145 |
+
|
146 |
+
# backward + optimize only if in training phase
|
147 |
+
if phase == 'train':
|
148 |
+
loss.backward()
|
149 |
+
optimizer.step()
|
150 |
+
|
151 |
+
# statistics
|
152 |
+
n += inputs.shape[0]
|
153 |
+
running_loss += loss.item() * inputs.size(0)
|
154 |
+
running_corrects += torch.sum(preds == labels.data)
|
155 |
+
|
156 |
+
stream.set_description(f'Batch {i}/{len(dataloaders[phase])} | Loss: {running_loss/n:.4f}, Acc: {running_corrects/n:.4f}')
|
157 |
+
|
158 |
+
if phase == 'train':
|
159 |
+
scheduler.step()
|
160 |
+
|
161 |
+
epoch_loss = running_loss / dataset_sizes[phase]
|
162 |
+
epoch_acc = running_corrects.double() / dataset_sizes[phase]
|
163 |
+
|
164 |
+
print('Epoch {}-{} Loss: {:.4f} Acc: {:.4f}'.format(
|
165 |
+
epoch, phase, epoch_loss, epoch_acc))
|
166 |
+
|
167 |
+
# deep copy the model
|
168 |
+
if phase == 'valid' and epoch_acc >= best_acc:
|
169 |
+
best_acc = epoch_acc
|
170 |
+
best_model_wts = copy.deepcopy(model.state_dict())
|
171 |
+
print('Update best model!')
|
172 |
+
|
173 |
+
time_elapsed = time.time() - since
|
174 |
+
print('Training complete in {:.0f}m {:.0f}s'.format(
|
175 |
+
time_elapsed // 60, time_elapsed % 60))
|
176 |
+
print('Best val Acc: {:4f}'.format(best_acc))
|
177 |
+
|
178 |
+
# load best model weights
|
179 |
+
model.load_state_dict(best_model_wts)
|
180 |
+
torch.save(model, 'logs/resnet18_4class.pth')
|
181 |
+
torch.save(model.state_dict(), 'logs/resnet18_4class.tar')
|
classification/unsup_classification.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# In[1]:
|
5 |
+
|
6 |
+
|
7 |
+
import os
|
8 |
+
import numpy as np
|
9 |
+
import shutil
|
10 |
+
import torch
|
11 |
+
import torch.nn
|
12 |
+
import torchvision.models as models
|
13 |
+
from torch.autograd import Variable
|
14 |
+
import torch.cuda
|
15 |
+
import torchvision.transforms as transforms
|
16 |
+
from PIL import Image
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from sklearn.datasets import make_blobs
|
20 |
+
from sklearn.cluster import KMeans
|
21 |
+
from sklearn.metrics import silhouette_score
|
22 |
+
from sklearn.preprocessing import StandardScaler
|
23 |
+
from sklearn.metrics import pairwise_distances_argmin_min
|
24 |
+
from scipy.spatial.distance import pdist, squareform
|
25 |
+
from skimage import io, segmentation, morphology, exposure
|
26 |
+
from skimage.color import rgb2hsv
|
27 |
+
img_to_tensor = transforms.ToTensor()
|
28 |
+
import random
|
29 |
+
import tifffile as tif
|
30 |
+
path = '/data1/partitionA/CUHKSZ/histopath_2022/grand_competition/Train_Labeled/images/'
|
31 |
+
files = os.listdir(path)
|
32 |
+
binary_path = '0/'
|
33 |
+
gray_path = '1/'
|
34 |
+
colored_path = 'colored/'
|
35 |
+
os.makedirs(binary_path, exist_ok=True)
|
36 |
+
os.makedirs(colored_path, exist_ok=True)
|
37 |
+
os.makedirs(gray_path, exist_ok=True)
|
38 |
+
for img_name in files:
|
39 |
+
img_path = path + str(img_name)
|
40 |
+
if img_name.endswith('.tif') or img_name.endswith('.tiff'):
|
41 |
+
img_data = tif.imread(img_path)
|
42 |
+
else:
|
43 |
+
img_data = io.imread(img_path)
|
44 |
+
if len(img_data.shape) == 2 or (len(img_data.shape) == 3 and img_data.shape[-1] == 1):
|
45 |
+
shutil.copyfile(path + img_name, binary_path + img_name)
|
46 |
+
elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
|
47 |
+
shutil.copyfile(path + img_name, colored_path + img_name)
|
48 |
+
else:
|
49 |
+
hsv_img = rgb2hsv(img_data)
|
50 |
+
s = hsv_img[:,:,1]
|
51 |
+
v = hsv_img[:,:,2]
|
52 |
+
print(img_name,s.mean(),v.mean())
|
53 |
+
if s.mean() > 0.1 or (v.mean()<0.1 or v.mean() > 0.6):
|
54 |
+
shutil.copyfile(path + img_name, colored_path + img_name)
|
55 |
+
else:
|
56 |
+
shutil.copyfile(path + img_name, gray_path + img_name)
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
# In[3]:
|
61 |
+
|
62 |
+
|
63 |
+
####Phrase 2 clustering by cell size
|
64 |
+
from skimage import measure
|
65 |
+
colored_path = 'colored/'
|
66 |
+
label_path = 'allimages/tif/'
|
67 |
+
big_path = '2/'
|
68 |
+
small_path = '3/'
|
69 |
+
files = os.listdir(colored_path)
|
70 |
+
os.makedirs(big_path, exist_ok=True)
|
71 |
+
os.makedirs(small_path, exist_ok=True)
|
72 |
+
for img_name in files:
|
73 |
+
label = tif.imread(label_path + img_name.split('.')[0]+'.tif')
|
74 |
+
props = measure.regionprops(label)
|
75 |
+
num_pix = []
|
76 |
+
for idx in range(len(props)):
|
77 |
+
num_pix.append(props[idx].area)
|
78 |
+
max_area = max(num_pix)
|
79 |
+
print(max_area)
|
80 |
+
if max_area > 30000:
|
81 |
+
shutil.copyfile(path + img_name, big_path + img_name)
|
82 |
+
else:
|
83 |
+
shutil.copyfile(path + img_name, small_path + img_name)
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
compute_metric.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Created on Thu Mar 31 18:10:52 2022
|
3 |
+
adapted form https://github.com/stardist/stardist/blob/master/stardist/matching.py
|
4 |
+
Thanks the authors of Stardist for sharing the great code
|
5 |
+
|
6 |
+
"""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import numpy as np
|
10 |
+
from numba import jit
|
11 |
+
from scipy.optimize import linear_sum_assignment
|
12 |
+
from collections import OrderedDict
|
13 |
+
import pandas as pd
|
14 |
+
from skimage import segmentation
|
15 |
+
import tifffile as tif
|
16 |
+
import os
|
17 |
+
join = os.path.join
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
def _intersection_over_union(masks_true, masks_pred):
|
21 |
+
""" intersection over union of all mask pairs
|
22 |
+
|
23 |
+
Parameters
|
24 |
+
------------
|
25 |
+
|
26 |
+
masks_true: ND-array, int
|
27 |
+
ground truth masks, where 0=NO masks; 1,2... are mask labels
|
28 |
+
masks_pred: ND-array, int
|
29 |
+
predicted masks, where 0=NO masks; 1,2... are mask labels
|
30 |
+
"""
|
31 |
+
overlap = _label_overlap(masks_true, masks_pred)
|
32 |
+
n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
|
33 |
+
n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
|
34 |
+
iou = overlap / (n_pixels_pred + n_pixels_true - overlap)
|
35 |
+
iou[np.isnan(iou)] = 0.0
|
36 |
+
return iou
|
37 |
+
|
38 |
+
@jit(nopython=True)
|
39 |
+
def _label_overlap(x, y):
|
40 |
+
""" fast function to get pixel overlaps between masks in x and y
|
41 |
+
|
42 |
+
Parameters
|
43 |
+
------------
|
44 |
+
|
45 |
+
x: ND-array, int
|
46 |
+
where 0=NO masks; 1,2... are mask labels
|
47 |
+
y: ND-array, int
|
48 |
+
where 0=NO masks; 1,2... are mask labels
|
49 |
+
|
50 |
+
Returns
|
51 |
+
------------
|
52 |
+
|
53 |
+
overlap: ND-array, int
|
54 |
+
matrix of pixel overlaps of size [x.max()+1, y.max()+1]
|
55 |
+
|
56 |
+
"""
|
57 |
+
x = x.ravel()
|
58 |
+
y = y.ravel()
|
59 |
+
|
60 |
+
# preallocate a 'contact map' matrix
|
61 |
+
overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint)
|
62 |
+
|
63 |
+
# loop over the labels in x and add to the corresponding
|
64 |
+
# overlap entry. If label A in x and label B in y share P
|
65 |
+
# pixels, then the resulting overlap is P
|
66 |
+
# len(x)=len(y), the number of pixels in the whole image
|
67 |
+
for i in range(len(x)):
|
68 |
+
overlap[x[i],y[i]] += 1
|
69 |
+
return overlap
|
70 |
+
|
71 |
+
def _true_positive(iou, th):
|
72 |
+
""" true positive at threshold th
|
73 |
+
|
74 |
+
Parameters
|
75 |
+
------------
|
76 |
+
|
77 |
+
iou: float, ND-array
|
78 |
+
array of IOU pairs
|
79 |
+
th: float
|
80 |
+
threshold on IOU for positive label
|
81 |
+
|
82 |
+
Returns
|
83 |
+
------------
|
84 |
+
|
85 |
+
tp: float
|
86 |
+
number of true positives at threshold
|
87 |
+
"""
|
88 |
+
n_min = min(iou.shape[0], iou.shape[1])
|
89 |
+
costs = -(iou >= th).astype(float) - iou / (2*n_min)
|
90 |
+
true_ind, pred_ind = linear_sum_assignment(costs)
|
91 |
+
match_ok = iou[true_ind, pred_ind] >= th
|
92 |
+
tp = match_ok.sum()
|
93 |
+
return tp
|
94 |
+
|
95 |
+
def eval_tp_fp_fn(masks_true, masks_pred, threshold=0.5):
|
96 |
+
num_inst_gt = np.max(masks_true)
|
97 |
+
num_inst_seg = np.max(masks_pred)
|
98 |
+
if num_inst_seg>0:
|
99 |
+
iou = _intersection_over_union(masks_true, masks_pred)[1:, 1:]
|
100 |
+
# for k,th in enumerate(threshold):
|
101 |
+
tp = _true_positive(iou, threshold)
|
102 |
+
fp = num_inst_seg - tp
|
103 |
+
fn = num_inst_gt - tp
|
104 |
+
else:
|
105 |
+
print('No segmentation results!')
|
106 |
+
tp = 0
|
107 |
+
fp = 0
|
108 |
+
fn = 0
|
109 |
+
|
110 |
+
return tp, fp, fn
|
111 |
+
|
112 |
+
def remove_boundary_cells(mask):
|
113 |
+
W, H = mask.shape
|
114 |
+
bd = np.ones((W, H))
|
115 |
+
bd[2:W-2, 2:H-2] = 0
|
116 |
+
bd_cells = np.unique(mask*bd)
|
117 |
+
for i in bd_cells[1:]:
|
118 |
+
mask[mask==i] = 0
|
119 |
+
new_label,_,_ = segmentation.relabel_sequential(mask)
|
120 |
+
return new_label
|
121 |
+
|
122 |
+
def main():
|
123 |
+
parser = argparse.ArgumentParser('Compute F1 score for cell segmentation results', add_help=False)
|
124 |
+
# Dataset parameters
|
125 |
+
parser.add_argument('--gt_path', type=str, help='path to ground truth; file names end with _label.tiff', required=True)
|
126 |
+
parser.add_argument('--seg_path', type=str, help='path to segmentation results; file names are the same as ground truth', required=True)
|
127 |
+
parser.add_argument('--save_path', default='./', help='path where to save metrics')
|
128 |
+
args = parser.parse_args()
|
129 |
+
|
130 |
+
gt_path = args.gt_path
|
131 |
+
seg_path = args.seg_path
|
132 |
+
names = sorted(os.listdir(seg_path))
|
133 |
+
seg_metric = OrderedDict()
|
134 |
+
seg_metric['Names'] = []
|
135 |
+
seg_metric['F1_Score'] = []
|
136 |
+
for name in tqdm(names):
|
137 |
+
assert name.endswith('_label.tiff'), 'The suffix of label name should be _label.tiff'
|
138 |
+
|
139 |
+
# Load the images for this case
|
140 |
+
gt = tif.imread(join(gt_path, name))
|
141 |
+
seg = tif.imread(join(seg_path, name))
|
142 |
+
|
143 |
+
# Score the cases
|
144 |
+
# do not consider cells on the boundaries during evaluation
|
145 |
+
if np.prod(gt.shape)<25000000:
|
146 |
+
gt = remove_boundary_cells(gt.astype(np.int32))
|
147 |
+
seg = remove_boundary_cells(seg.astype(np.int32))
|
148 |
+
tp, fp, fn = eval_tp_fp_fn(gt, seg, threshold=0.5)
|
149 |
+
else: # for large images (>5000x5000), the F1 score is computed by a patch-based way
|
150 |
+
H, W = gt.shape
|
151 |
+
roi_size = 2000
|
152 |
+
|
153 |
+
if H % roi_size != 0:
|
154 |
+
n_H = H // roi_size + 1
|
155 |
+
new_H = roi_size * n_H
|
156 |
+
else:
|
157 |
+
n_H = H // roi_size
|
158 |
+
new_H = H
|
159 |
+
|
160 |
+
if W % roi_size != 0:
|
161 |
+
n_W = W // roi_size + 1
|
162 |
+
new_W = roi_size * n_W
|
163 |
+
else:
|
164 |
+
n_W = W // roi_size
|
165 |
+
new_W = W
|
166 |
+
|
167 |
+
gt_pad = np.zeros((new_H, new_W), dtype=gt.dtype)
|
168 |
+
seg_pad = np.zeros((new_H, new_W), dtype=gt.dtype)
|
169 |
+
gt_pad[:H, :W] = gt
|
170 |
+
seg_pad[:H, :W] = seg
|
171 |
+
|
172 |
+
tp = 0
|
173 |
+
fp = 0
|
174 |
+
fn = 0
|
175 |
+
for i in range(n_H):
|
176 |
+
for j in range(n_W):
|
177 |
+
gt_roi = remove_boundary_cells(gt_pad[roi_size*i:roi_size*(i+1), roi_size*j:roi_size*(j+1)])
|
178 |
+
seg_roi = remove_boundary_cells(seg_pad[roi_size*i:roi_size*(i+1), roi_size*j:roi_size*(j+1)])
|
179 |
+
tp_i, fp_i, fn_i = eval_tp_fp_fn(gt_roi, seg_roi, threshold=0.5)
|
180 |
+
tp += tp_i
|
181 |
+
fp += fp_i
|
182 |
+
fn += fn_i
|
183 |
+
|
184 |
+
if tp == 0:
|
185 |
+
precision = 0
|
186 |
+
recall = 0
|
187 |
+
f1 = 0
|
188 |
+
else:
|
189 |
+
precision = tp / (tp + fp)
|
190 |
+
recall = tp / (tp + fn)
|
191 |
+
f1 = 2*(precision * recall)/ (precision + recall)
|
192 |
+
seg_metric['Names'].append(name)
|
193 |
+
seg_metric['F1_Score'].append(np.round(f1, 4))
|
194 |
+
|
195 |
+
|
196 |
+
seg_metric_df = pd.DataFrame(seg_metric)
|
197 |
+
seg_metric_df.to_csv(join(args.save_path, 'seg_metric.csv'), index=False)
|
198 |
+
print('mean F1 Score:', np.mean(seg_metric['F1_Score']))
|
199 |
+
|
200 |
+
if __name__ == '__main__':
|
201 |
+
main()
|
finetune_convnext_stardist.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Adapted form MONAI Tutorial: https://github.com/Project-MONAI/tutorials/tree/main/2d_segmentation/torch
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
|
10 |
+
join = os.path.join
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
from torch.utils.tensorboard import SummaryWriter
|
17 |
+
from stardist import star_dist,edt_prob
|
18 |
+
from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label
|
19 |
+
from stardist import random_label_cmap,ray_angles
|
20 |
+
import monai
|
21 |
+
from collections import OrderedDict
|
22 |
+
from compute_metric import eval_tp_fp_fn,remove_boundary_cells
|
23 |
+
from monai.data import decollate_batch, PILReader
|
24 |
+
from monai.inferers import sliding_window_inference
|
25 |
+
from monai.metrics import DiceMetric
|
26 |
+
from monai.transforms import (
|
27 |
+
Activations,
|
28 |
+
AsChannelFirstd,
|
29 |
+
AddChanneld,
|
30 |
+
AsDiscrete,
|
31 |
+
Compose,
|
32 |
+
LoadImaged,
|
33 |
+
SpatialPadd,
|
34 |
+
RandSpatialCropd,
|
35 |
+
RandRotate90d,
|
36 |
+
ScaleIntensityd,
|
37 |
+
RandAxisFlipd,
|
38 |
+
RandZoomd,
|
39 |
+
RandGaussianNoised,
|
40 |
+
RandAdjustContrastd,
|
41 |
+
RandGaussianSmoothd,
|
42 |
+
RandHistogramShiftd,
|
43 |
+
EnsureTyped,
|
44 |
+
EnsureType,
|
45 |
+
)
|
46 |
+
from monai.visualize import plot_2d_or_3d_image
|
47 |
+
import matplotlib.pyplot as plt
|
48 |
+
from datetime import datetime
|
49 |
+
import shutil
|
50 |
+
import tqdm
|
51 |
+
from models.unetr2d import UNETR2D
|
52 |
+
from models.swin_unetr import SwinUNETR
|
53 |
+
from models.flexible_unet import FlexibleUNet
|
54 |
+
from models.flexible_unet_convext import FlexibleUNetConvext
|
55 |
+
print("Successfully imported all requirements!")
|
56 |
+
torch.backends.cudnn.enabled =False
|
57 |
+
#os.environ["OMP_NUM_THREADS"] = "1"
|
58 |
+
#os.environ["MKL_NUM_THREADS"] = "1"
|
59 |
+
def main():
|
60 |
+
parser = argparse.ArgumentParser("Baseline for Microscopy image segmentation")
|
61 |
+
# Dataset parameters
|
62 |
+
parser.add_argument(
|
63 |
+
"--data_path",
|
64 |
+
default="",
|
65 |
+
type=str,
|
66 |
+
help="training data path; subfolders: images, labels",
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"--work_dir", default="/mntnfs/med_data5/louwei/nips_comp/stardist_finetune1/", help="path where to save models and logs"
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--model_dir", default="/", help="path where to load pretrained model"
|
73 |
+
)
|
74 |
+
parser.add_argument("--seed", default=2022, type=int)
|
75 |
+
# parser.add_argument("--resume", default=False, help="resume from checkpoint")
|
76 |
+
parser.add_argument("--num_workers", default=4, type=int)
|
77 |
+
#parser.add_argument("--local_rank", type=int)
|
78 |
+
# Model parameters
|
79 |
+
parser.add_argument(
|
80 |
+
"--model_name", default="efficientunet", help="select mode: unet, unetr, swinunetr"
|
81 |
+
)
|
82 |
+
parser.add_argument("--num_class", default=3, type=int, help="segmentation classes")
|
83 |
+
parser.add_argument(
|
84 |
+
"--input_size", default=512, type=int, help="segmentation classes"
|
85 |
+
)
|
86 |
+
# Training parameters
|
87 |
+
parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU")
|
88 |
+
parser.add_argument("--max_epochs", default=2000, type=int)
|
89 |
+
parser.add_argument("--val_interval", default=10, type=int)
|
90 |
+
parser.add_argument("--epoch_tolerance", default=100, type=int)
|
91 |
+
parser.add_argument("--initial_lr", type=float, default=1e-4, help="learning rate")
|
92 |
+
|
93 |
+
args = parser.parse_args()
|
94 |
+
#torch.cuda.set_device(args.local_rank)
|
95 |
+
#torch.distributed.init_process_group(backend='nccl')
|
96 |
+
monai.config.print_config()
|
97 |
+
n_rays = 32
|
98 |
+
pre_trained = True
|
99 |
+
#%% set training/validation split
|
100 |
+
np.random.seed(args.seed)
|
101 |
+
pre_trained_path = args.model_dir
|
102 |
+
model_path = join(args.work_dir, args.model_name + "_3class")
|
103 |
+
os.makedirs(model_path, exist_ok=True)
|
104 |
+
run_id = datetime.now().strftime("%Y%m%d-%H%M")
|
105 |
+
# This must be change every runing time ! ! ! ! ! ! ! ! ! ! !
|
106 |
+
model_file = "models/flexible_unet_convext.py"
|
107 |
+
shutil.copyfile(
|
108 |
+
__file__, join(model_path, os.path.basename(__file__))
|
109 |
+
)
|
110 |
+
shutil.copyfile(
|
111 |
+
model_file, join(model_path, os.path.basename(model_file))
|
112 |
+
)
|
113 |
+
img_path = join(args.data_path, "train/images")
|
114 |
+
gt_path = join(args.data_path, "train/tif")
|
115 |
+
val_img_path = join(args.data_path, "valid/images")
|
116 |
+
val_gt_path = join(args.data_path, "valid/tif")
|
117 |
+
img_names = sorted(os.listdir(img_path))
|
118 |
+
gt_names = [img_name.split(".")[0] + ".tif" for img_name in img_names]
|
119 |
+
img_num = len(img_names)
|
120 |
+
val_frac = 0.1
|
121 |
+
val_img_names = sorted(os.listdir(val_img_path))
|
122 |
+
val_gt_names = [img_name.split(".")[0] + ".tif" for img_name in val_img_names]
|
123 |
+
#indices = np.arange(img_num)
|
124 |
+
#np.random.shuffle(indices)
|
125 |
+
#val_split = int(img_num * val_frac)
|
126 |
+
#train_indices = indices[val_split:]
|
127 |
+
#val_indices = indices[:val_split]
|
128 |
+
|
129 |
+
train_files = [
|
130 |
+
{"img": join(img_path, img_names[i]), "label": join(gt_path, gt_names[i])}
|
131 |
+
for i in range(len(img_names))
|
132 |
+
]
|
133 |
+
val_files = [
|
134 |
+
{"img": join(val_img_path, val_img_names[i]), "label": join(val_gt_path, val_gt_names[i])}
|
135 |
+
for i in range(len(val_img_names))
|
136 |
+
]
|
137 |
+
print(
|
138 |
+
f"training image num: {len(train_files)}, validation image num: {len(val_files)}"
|
139 |
+
)
|
140 |
+
#%% define transforms for image and segmentation
|
141 |
+
train_transforms = Compose(
|
142 |
+
[
|
143 |
+
LoadImaged(
|
144 |
+
keys=["img", "label"], reader=PILReader, dtype=np.float32
|
145 |
+
), # image three channels (H, W, 3); label: (H, W)
|
146 |
+
AddChanneld(keys=["label"], allow_missing_keys=True), # label: (1, H, W)
|
147 |
+
AsChannelFirstd(
|
148 |
+
keys=["img"], channel_dim=-1, allow_missing_keys=True
|
149 |
+
), # image: (3, H, W)
|
150 |
+
#ScaleIntensityd(
|
151 |
+
#keys=["img"], allow_missing_keys=True
|
152 |
+
#), # Do not scale label
|
153 |
+
SpatialPadd(keys=["img", "label"], spatial_size=args.input_size),
|
154 |
+
RandSpatialCropd(
|
155 |
+
keys=["img", "label"], roi_size=args.input_size, random_size=False
|
156 |
+
),
|
157 |
+
RandAxisFlipd(keys=["img", "label"], prob=0.5),
|
158 |
+
RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]),
|
159 |
+
# # intensity transform
|
160 |
+
RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1),
|
161 |
+
RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)),
|
162 |
+
RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)),
|
163 |
+
RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3),
|
164 |
+
RandZoomd(
|
165 |
+
keys=["img", "label"],
|
166 |
+
prob=0.15,
|
167 |
+
min_zoom=0.5,
|
168 |
+
max_zoom=2,
|
169 |
+
mode=["area", "nearest"],
|
170 |
+
),
|
171 |
+
EnsureTyped(keys=["img", "label"]),
|
172 |
+
]
|
173 |
+
)
|
174 |
+
|
175 |
+
val_transforms = Compose(
|
176 |
+
[
|
177 |
+
LoadImaged(keys=["img", "label"], reader=PILReader, dtype=np.float32),
|
178 |
+
AddChanneld(keys=["label"], allow_missing_keys=True),
|
179 |
+
AsChannelFirstd(keys=["img"], channel_dim=-1, allow_missing_keys=True),
|
180 |
+
#ScaleIntensityd(keys=["img"], allow_missing_keys=True),
|
181 |
+
# AsDiscreted(keys=['label'], to_onehot=3),
|
182 |
+
EnsureTyped(keys=["img", "label"]),
|
183 |
+
]
|
184 |
+
)
|
185 |
+
|
186 |
+
#% define dataset, data loader
|
187 |
+
check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
|
188 |
+
check_loader = DataLoader(check_ds, batch_size=1, num_workers=4)
|
189 |
+
check_data = monai.utils.misc.first(check_loader)
|
190 |
+
print(
|
191 |
+
"sanity check:",
|
192 |
+
check_data["img"].shape,
|
193 |
+
torch.max(check_data["img"]),
|
194 |
+
check_data["label"].shape,
|
195 |
+
torch.max(check_data["label"]),
|
196 |
+
)
|
197 |
+
|
198 |
+
#%% create a training data loader
|
199 |
+
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
|
200 |
+
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
|
201 |
+
train_loader = DataLoader(
|
202 |
+
train_ds,
|
203 |
+
batch_size=args.batch_size,
|
204 |
+
shuffle=True,
|
205 |
+
num_workers=args.num_workers,
|
206 |
+
pin_memory =True,
|
207 |
+
)
|
208 |
+
# create a validation data loader
|
209 |
+
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
|
210 |
+
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=1)
|
211 |
+
|
212 |
+
dice_metric = DiceMetric(
|
213 |
+
include_background=False, reduction="mean", get_not_nans=False
|
214 |
+
)
|
215 |
+
|
216 |
+
post_pred = Compose(
|
217 |
+
[EnsureType(), Activations(softmax=True), AsDiscrete(threshold=0.5)]
|
218 |
+
)
|
219 |
+
post_gt = Compose([EnsureType(), AsDiscrete(to_onehot=None)])
|
220 |
+
# create UNet, DiceLoss and Adam optimizer
|
221 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
222 |
+
|
223 |
+
if args.model_name.lower() == "efficientunet":
|
224 |
+
model = FlexibleUNetConvext(
|
225 |
+
in_channels=3,
|
226 |
+
out_channels=n_rays+1,
|
227 |
+
backbone='convnext_small',
|
228 |
+
pretrained=False,
|
229 |
+
).to(device)
|
230 |
+
|
231 |
+
#loss_masked_dice = monai.losses.DiceCELoss(softmax=True)
|
232 |
+
loss_dice = monai.losses.DiceLoss(squared_pred=True,jaccard=True)
|
233 |
+
loss_bce = nn.BCELoss()
|
234 |
+
loss_dist_mae = nn.L1Loss()
|
235 |
+
activatation = nn.ReLU()
|
236 |
+
sigmoid = nn.Sigmoid()
|
237 |
+
#loss_dist_mae = monai.losses.DiceCELoss(softmax=True)
|
238 |
+
initial_lr = args.initial_lr
|
239 |
+
encoder = list(map(id, model.encoder.parameters()))
|
240 |
+
base_params = filter(lambda p: id(p) not in encoder, model.parameters())
|
241 |
+
params = [
|
242 |
+
{"params": base_params, "lr":initial_lr},
|
243 |
+
{"params": model.encoder.parameters(), "lr": initial_lr * 0.1},
|
244 |
+
]
|
245 |
+
optimizer = torch.optim.AdamW(params, initial_lr)
|
246 |
+
if pre_trained == True:
|
247 |
+
|
248 |
+
checkpoint = torch.load(pre_trained_path, map_location=torch.device(device))
|
249 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
250 |
+
print('Load pretrained weights...')
|
251 |
+
max_epochs = args.max_epochs
|
252 |
+
epoch_tolerance = args.epoch_tolerance
|
253 |
+
val_interval = args.val_interval
|
254 |
+
best_metric = -1
|
255 |
+
best_metric_epoch = -1
|
256 |
+
epoch_loss_values = list()
|
257 |
+
metric_values = list()
|
258 |
+
writer = SummaryWriter(model_path)
|
259 |
+
max_f1 = 0
|
260 |
+
for epoch in range(0, max_epochs):
|
261 |
+
model.train()
|
262 |
+
epoch_loss = 0
|
263 |
+
epoch_loss_prob = 0
|
264 |
+
epoch_loss_dist_2 = 0
|
265 |
+
epoch_loss_dist_1 = 0
|
266 |
+
for step, batch_data in enumerate(train_loader, 1):
|
267 |
+
print(step)
|
268 |
+
inputs, labels = batch_data["img"],batch_data["label"]
|
269 |
+
|
270 |
+
processes_labels = []
|
271 |
+
|
272 |
+
for i in range(labels.shape[0]):
|
273 |
+
label = labels[i][0]
|
274 |
+
distances = star_dist(label,n_rays,mode='opencl')
|
275 |
+
distances = np.transpose(distances,(2,0,1))
|
276 |
+
#print(distances.shape)
|
277 |
+
obj_probabilities = edt_prob(label.astype(int))
|
278 |
+
obj_probabilities = np.expand_dims(obj_probabilities,0)
|
279 |
+
#print(obj_probabilities.shape)
|
280 |
+
final_label = np.concatenate((distances,obj_probabilities),axis=0)
|
281 |
+
#print(final_label.shape)
|
282 |
+
processes_labels.append(final_label)
|
283 |
+
|
284 |
+
labels = np.stack(processes_labels)
|
285 |
+
|
286 |
+
#print(inputs.shape,labels.shape)
|
287 |
+
inputs, labels = torch.tensor(inputs).to(device), torch.tensor(labels).to(device)
|
288 |
+
#print(inputs.shape,labels.shape)
|
289 |
+
optimizer.zero_grad()
|
290 |
+
output_dist,output_prob = model(inputs)
|
291 |
+
#print(outputs.shape)
|
292 |
+
dist_output = output_dist
|
293 |
+
prob_output = output_prob
|
294 |
+
dist_label = labels[:,:n_rays,:,:]
|
295 |
+
prob_label = torch.unsqueeze(labels[:,-1,:,:], 1)
|
296 |
+
#print(dist_output.shape,prob_output.shape,dist_label.shape)
|
297 |
+
#labels_onehot = monai.networks.one_hot(
|
298 |
+
#labels, args.num_class
|
299 |
+
#) # (b,cls,256,256)
|
300 |
+
#print(prob_label.max(),prob_label.min())
|
301 |
+
loss_dist_1 = loss_dice(dist_output*prob_label,dist_label*prob_label)
|
302 |
+
#print(loss_dist_1)
|
303 |
+
loss_prob = loss_bce(prob_output,prob_label)
|
304 |
+
#print(prob_label.shape,dist_output.shape)
|
305 |
+
loss_dist_2 = loss_dist_mae(dist_output*prob_label,dist_label*prob_label)
|
306 |
+
#print(loss_dist_2)
|
307 |
+
loss = loss_prob + loss_dist_2*0.3 + loss_dist_1
|
308 |
+
loss.backward()
|
309 |
+
optimizer.step()
|
310 |
+
epoch_loss += loss.item()
|
311 |
+
epoch_loss_prob += loss_prob.item()
|
312 |
+
epoch_loss_dist_2 += loss_dist_2.item()
|
313 |
+
epoch_loss_dist_1 += loss_dist_1.item()
|
314 |
+
epoch_len = len(train_ds) // train_loader.batch_size
|
315 |
+
|
316 |
+
epoch_loss /= step
|
317 |
+
epoch_loss_prob /= step
|
318 |
+
epoch_loss_dist_2 /= step
|
319 |
+
epoch_loss_dist_1 /= step
|
320 |
+
epoch_loss_values.append(epoch_loss)
|
321 |
+
print(f"epoch {epoch} average loss: {epoch_loss:.4f}")
|
322 |
+
writer.add_scalar("train_loss", epoch_loss, epoch)
|
323 |
+
print('dist dice: '+str(epoch_loss_dist_1)+' dist mae: '+str(epoch_loss_dist_2)+' prob bce: '+str(epoch_loss_prob))
|
324 |
+
checkpoint = {
|
325 |
+
"epoch": epoch,
|
326 |
+
"model_state_dict": model.state_dict(),
|
327 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
328 |
+
"loss": epoch_loss_values,
|
329 |
+
}
|
330 |
+
if epoch < 40:
|
331 |
+
continue
|
332 |
+
if epoch > 1 and epoch % val_interval == 0:
|
333 |
+
torch.save(checkpoint, join(model_path, str(epoch) + ".pth"))
|
334 |
+
model.eval()
|
335 |
+
with torch.no_grad():
|
336 |
+
val_images = None
|
337 |
+
val_labels = None
|
338 |
+
val_outputs = None
|
339 |
+
seg_metric = OrderedDict()
|
340 |
+
seg_metric['F1_Score'] = []
|
341 |
+
for val_data in tqdm.tqdm(val_loader):
|
342 |
+
val_images, val_labels = val_data["img"].to(device), val_data[
|
343 |
+
"label"
|
344 |
+
].to(device)
|
345 |
+
roi_size = (512, 512)
|
346 |
+
sw_batch_size = 4
|
347 |
+
output_dist,output_prob = sliding_window_inference(
|
348 |
+
val_images, roi_size, sw_batch_size, model
|
349 |
+
)
|
350 |
+
val_labels = val_labels[0][0].cpu().numpy()
|
351 |
+
prob = output_prob[0][0].cpu().numpy()
|
352 |
+
dist = output_dist[0].cpu().numpy()
|
353 |
+
#print(val_labels.shape,prob.shape,dist.shape)
|
354 |
+
dist = np.transpose(dist,(1,2,0))
|
355 |
+
dist = np.maximum(1e-3, dist)
|
356 |
+
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
|
357 |
+
|
358 |
+
coord = dist_to_coord(disti,points)
|
359 |
+
|
360 |
+
star_label = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
|
361 |
+
gt = remove_boundary_cells(val_labels.astype(np.int32))
|
362 |
+
seg = remove_boundary_cells(star_label.astype(np.int32))
|
363 |
+
tp, fp, fn = eval_tp_fp_fn(gt, seg, threshold=0.5)
|
364 |
+
if tp == 0:
|
365 |
+
precision = 0
|
366 |
+
recall = 0
|
367 |
+
f1 = 0
|
368 |
+
else:
|
369 |
+
precision = tp / (tp + fp)
|
370 |
+
recall = tp / (tp + fn)
|
371 |
+
f1 = 2*(precision * recall)/ (precision + recall)
|
372 |
+
f1 = np.round(f1, 4)
|
373 |
+
seg_metric['F1_Score'].append(np.round(f1, 4))
|
374 |
+
avg_f1 = np.mean(seg_metric['F1_Score'])
|
375 |
+
writer.add_scalar("val_f1score", avg_f1, epoch)
|
376 |
+
if avg_f1 > max_f1:
|
377 |
+
max_f1 = avg_f1
|
378 |
+
print(str(epoch) + 'f1 score: ' + str(max_f1))
|
379 |
+
torch.save(checkpoint, join(model_path, "best_model.pth"))
|
380 |
+
np.savez_compressed(
|
381 |
+
join(model_path, "train_log.npz"),
|
382 |
+
val_dice=metric_values,
|
383 |
+
epoch_loss=epoch_loss_values,
|
384 |
+
)
|
385 |
+
|
386 |
+
|
387 |
+
if __name__ == "__main__":
|
388 |
+
main()
|
models/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Sun Mar 20 14:23:55 2022
|
5 |
+
|
6 |
+
@author: jma
|
7 |
+
"""
|
8 |
+
|
9 |
+
from .unetr2d import UNETR2D
|
10 |
+
from .swin_unetr import SwinUNETR
|
models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (282 Bytes). View file
|
|
models/__pycache__/convnext.cpython-38.pyc
ADDED
Binary file (9.12 kB). View file
|
|
models/__pycache__/flexible_unet.cpython-38.pyc
ADDED
Binary file (10.3 kB). View file
|
|
models/__pycache__/flexible_unet_convext.cpython-38.pyc
ADDED
Binary file (10.4 kB). View file
|
|
models/__pycache__/swin_unetr.cpython-38.pyc
ADDED
Binary file (30 kB). View file
|
|
models/__pycache__/unetr2d.cpython-38.pyc
ADDED
Binary file (3.74 kB). View file
|
|
models/convnext.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
from functools import partial
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from timm.models.layers import trunc_normal_, DropPath
|
13 |
+
from timm.models.registry import register_model
|
14 |
+
from monai.networks.layers.factories import Act, Conv, Pad, Pool
|
15 |
+
from monai.networks.layers.utils import get_norm_layer
|
16 |
+
from monai.utils.module import look_up_option
|
17 |
+
from typing import List, NamedTuple, Optional, Tuple, Type, Union
|
18 |
+
class Block(nn.Module):
|
19 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
20 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
21 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
22 |
+
We use (2) as we find it slightly faster in PyTorch
|
23 |
+
|
24 |
+
Args:
|
25 |
+
dim (int): Number of input channels.
|
26 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
27 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
28 |
+
"""
|
29 |
+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
|
30 |
+
super().__init__()
|
31 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
32 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
33 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
34 |
+
self.act = nn.GELU()
|
35 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
36 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
|
37 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
38 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
input = x
|
42 |
+
x = self.dwconv(x)
|
43 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
44 |
+
x = self.norm(x)
|
45 |
+
x = self.pwconv1(x)
|
46 |
+
x = self.act(x)
|
47 |
+
x = self.pwconv2(x)
|
48 |
+
if self.gamma is not None:
|
49 |
+
x = self.gamma * x
|
50 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
51 |
+
|
52 |
+
x = input + self.drop_path(x)
|
53 |
+
return x
|
54 |
+
|
55 |
+
class ConvNeXt(nn.Module):
|
56 |
+
r""" ConvNeXt
|
57 |
+
A PyTorch impl of : `A ConvNet for the 2020s` -
|
58 |
+
https://arxiv.org/pdf/2201.03545.pdf
|
59 |
+
|
60 |
+
Args:
|
61 |
+
in_chans (int): Number of input image channels. Default: 3
|
62 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
63 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
64 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
65 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
66 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
67 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
68 |
+
"""
|
69 |
+
def __init__(self, in_chans=3, num_classes=21841,
|
70 |
+
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
|
71 |
+
layer_scale_init_value=1e-6, head_init_scale=1., out_indices=[0, 1, 2, 3],
|
72 |
+
):
|
73 |
+
super().__init__()
|
74 |
+
# conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv["conv", 2]
|
75 |
+
# self._conv_stem = conv_type(self.in_channels, self.in_channels, kernel_size=3, stride=stride, bias=False)
|
76 |
+
# self._conv_stem_padding = _make_same_padder(self._conv_stem, current_image_size)
|
77 |
+
|
78 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
79 |
+
stem = nn.Sequential(
|
80 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
81 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
82 |
+
)
|
83 |
+
self.downsample_layers.append(stem)
|
84 |
+
for i in range(3):
|
85 |
+
downsample_layer = nn.Sequential(
|
86 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
87 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
88 |
+
)
|
89 |
+
self.downsample_layers.append(downsample_layer)
|
90 |
+
|
91 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
92 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
93 |
+
cur = 0
|
94 |
+
for i in range(4):
|
95 |
+
stage = nn.Sequential(
|
96 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
|
97 |
+
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
|
98 |
+
)
|
99 |
+
self.stages.append(stage)
|
100 |
+
cur += depths[i]
|
101 |
+
|
102 |
+
|
103 |
+
self.out_indices = out_indices
|
104 |
+
|
105 |
+
norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
|
106 |
+
for i_layer in range(4):
|
107 |
+
layer = norm_layer(dims[i_layer])
|
108 |
+
layer_name = f'norm{i_layer}'
|
109 |
+
self.add_module(layer_name, layer)
|
110 |
+
self.apply(self._init_weights)
|
111 |
+
|
112 |
+
|
113 |
+
def _init_weights(self, m):
|
114 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
115 |
+
trunc_normal_(m.weight, std=.02)
|
116 |
+
nn.init.constant_(m.bias, 0)
|
117 |
+
|
118 |
+
def forward_features(self, x):
|
119 |
+
outs = []
|
120 |
+
|
121 |
+
for i in range(4):
|
122 |
+
x = self.downsample_layers[i](x)
|
123 |
+
x = self.stages[i](x)
|
124 |
+
if i in self.out_indices:
|
125 |
+
norm_layer = getattr(self, f'norm{i}')
|
126 |
+
x_out = norm_layer(x)
|
127 |
+
|
128 |
+
outs.append(x_out)
|
129 |
+
|
130 |
+
return tuple(outs)
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
x = self.forward_features(x)
|
134 |
+
|
135 |
+
return x
|
136 |
+
|
137 |
+
class LayerNorm(nn.Module):
|
138 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
139 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
140 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
141 |
+
with shape (batch_size, channels, height, width).
|
142 |
+
"""
|
143 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
144 |
+
super().__init__()
|
145 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
146 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
147 |
+
self.eps = eps
|
148 |
+
self.data_format = data_format
|
149 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
150 |
+
raise NotImplementedError
|
151 |
+
self.normalized_shape = (normalized_shape, )
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
if self.data_format == "channels_last":
|
155 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
156 |
+
elif self.data_format == "channels_first":
|
157 |
+
u = x.mean(1, keepdim=True)
|
158 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
159 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
160 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
161 |
+
return x
|
162 |
+
|
163 |
+
|
164 |
+
model_urls = {
|
165 |
+
"convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
166 |
+
"convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
|
167 |
+
"convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
|
168 |
+
"convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
|
169 |
+
"convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
|
170 |
+
"convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
|
171 |
+
"convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
|
172 |
+
"convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
|
173 |
+
"convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
|
174 |
+
}
|
175 |
+
|
176 |
+
@register_model
|
177 |
+
def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
|
178 |
+
model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
|
179 |
+
if pretrained:
|
180 |
+
url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
|
181 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
|
182 |
+
model.load_state_dict(checkpoint["model"])
|
183 |
+
return model
|
184 |
+
|
185 |
+
@register_model
|
186 |
+
def convnext_small(pretrained=False,in_22k=False, **kwargs):
|
187 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
|
188 |
+
if pretrained:
|
189 |
+
url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
|
190 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
191 |
+
model.load_state_dict(checkpoint["model"], strict=False)
|
192 |
+
return model
|
193 |
+
|
194 |
+
@register_model
|
195 |
+
def convnext_base(pretrained=False, in_22k=False, **kwargs):
|
196 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
|
197 |
+
if pretrained:
|
198 |
+
url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
|
199 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
200 |
+
model.load_state_dict(checkpoint["model"], strict=False)
|
201 |
+
return model
|
202 |
+
|
203 |
+
@register_model
|
204 |
+
def convnext_large(pretrained=False, in_22k=False, **kwargs):
|
205 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
|
206 |
+
if pretrained:
|
207 |
+
url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
|
208 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
209 |
+
model.load_state_dict(checkpoint["model"])
|
210 |
+
return model
|
211 |
+
|
212 |
+
@register_model
|
213 |
+
def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
|
214 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
|
215 |
+
if pretrained:
|
216 |
+
assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
|
217 |
+
url = model_urls['convnext_xlarge_22k']
|
218 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
219 |
+
model.load_state_dict(checkpoint["model"])
|
220 |
+
return model
|
models/flexible_unet.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
from typing import List, Optional, Sequence, Tuple, Union
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
from monai.networks.blocks import UpSample
|
18 |
+
from monai.networks.layers.factories import Conv
|
19 |
+
from monai.networks.layers.utils import get_act_layer
|
20 |
+
from monai.networks.nets import EfficientNetBNFeatures
|
21 |
+
from monai.networks.nets.basic_unet import UpCat
|
22 |
+
from monai.utils import InterpolateMode
|
23 |
+
|
24 |
+
__all__ = ["FlexibleUNet"]
|
25 |
+
|
26 |
+
encoder_feature_channel = {
|
27 |
+
"efficientnet-b0": (16, 24, 40, 112, 320),
|
28 |
+
"efficientnet-b1": (16, 24, 40, 112, 320),
|
29 |
+
"efficientnet-b2": (16, 24, 48, 120, 352),
|
30 |
+
"efficientnet-b3": (24, 32, 48, 136, 384),
|
31 |
+
"efficientnet-b4": (24, 32, 56, 160, 448),
|
32 |
+
"efficientnet-b5": (24, 40, 64, 176, 512),
|
33 |
+
"efficientnet-b6": (32, 40, 72, 200, 576),
|
34 |
+
"efficientnet-b7": (32, 48, 80, 224, 640),
|
35 |
+
"efficientnet-b8": (32, 56, 88, 248, 704),
|
36 |
+
"efficientnet-l2": (72, 104, 176, 480, 1376),
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
def _get_encoder_channels_by_backbone(backbone: str, in_channels: int = 3) -> tuple:
|
41 |
+
"""
|
42 |
+
Get the encoder output channels by given backbone name.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
backbone: name of backbone to generate features, can be from [efficientnet-b0, ..., efficientnet-b7].
|
46 |
+
in_channels: channel of input tensor, default to 3.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
A tuple of output feature map channels' length .
|
50 |
+
"""
|
51 |
+
encoder_channel_tuple = encoder_feature_channel[backbone]
|
52 |
+
encoder_channel_list = [in_channels] + list(encoder_channel_tuple)
|
53 |
+
encoder_channel = tuple(encoder_channel_list)
|
54 |
+
return encoder_channel
|
55 |
+
|
56 |
+
|
57 |
+
class UNetDecoder(nn.Module):
|
58 |
+
"""
|
59 |
+
UNet Decoder.
|
60 |
+
This class refers to `segmentation_models.pytorch
|
61 |
+
<https://github.com/qubvel/segmentation_models.pytorch>`_.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
spatial_dims: number of spatial dimensions.
|
65 |
+
encoder_channels: number of output channels for all feature maps in encoder.
|
66 |
+
`len(encoder_channels)` should be no less than 2.
|
67 |
+
decoder_channels: number of output channels for all feature maps in decoder.
|
68 |
+
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.
|
69 |
+
act: activation type and arguments.
|
70 |
+
norm: feature normalization type and arguments.
|
71 |
+
dropout: dropout ratio.
|
72 |
+
bias: whether to have a bias term in convolution blocks in this decoder.
|
73 |
+
upsample: upsampling mode, available options are
|
74 |
+
``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
|
75 |
+
pre_conv: a conv block applied before upsampling.
|
76 |
+
Only used in the "nontrainable" or "pixelshuffle" mode.
|
77 |
+
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
|
78 |
+
Only used in the "nontrainable" mode.
|
79 |
+
align_corners: set the align_corners parameter for upsample. Defaults to True.
|
80 |
+
Only used in the "nontrainable" mode.
|
81 |
+
is_pad: whether to pad upsampling features to fit the encoder spatial dims.
|
82 |
+
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
spatial_dims: int,
|
88 |
+
encoder_channels: Sequence[int],
|
89 |
+
decoder_channels: Sequence[int],
|
90 |
+
act: Union[str, tuple],
|
91 |
+
norm: Union[str, tuple],
|
92 |
+
dropout: Union[float, tuple],
|
93 |
+
bias: bool,
|
94 |
+
upsample: str,
|
95 |
+
pre_conv: Optional[str],
|
96 |
+
interp_mode: str,
|
97 |
+
align_corners: Optional[bool],
|
98 |
+
is_pad: bool,
|
99 |
+
):
|
100 |
+
|
101 |
+
super().__init__()
|
102 |
+
if len(encoder_channels) < 2:
|
103 |
+
raise ValueError("the length of `encoder_channels` should be no less than 2.")
|
104 |
+
if len(decoder_channels) != len(encoder_channels) - 1:
|
105 |
+
raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.")
|
106 |
+
|
107 |
+
in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1])
|
108 |
+
skip_channels = list(encoder_channels[1:-1][::-1]) + [0]
|
109 |
+
halves = [True] * (len(skip_channels) - 1)
|
110 |
+
halves.append(False)
|
111 |
+
blocks = []
|
112 |
+
for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves):
|
113 |
+
blocks.append(
|
114 |
+
UpCat(
|
115 |
+
spatial_dims=spatial_dims,
|
116 |
+
in_chns=in_chn,
|
117 |
+
cat_chns=skip_chn,
|
118 |
+
out_chns=out_chn,
|
119 |
+
act=act,
|
120 |
+
norm=norm,
|
121 |
+
dropout=dropout,
|
122 |
+
bias=bias,
|
123 |
+
upsample=upsample,
|
124 |
+
pre_conv=pre_conv,
|
125 |
+
interp_mode=interp_mode,
|
126 |
+
align_corners=align_corners,
|
127 |
+
halves=halve,
|
128 |
+
is_pad=is_pad,
|
129 |
+
)
|
130 |
+
)
|
131 |
+
self.blocks = nn.ModuleList(blocks)
|
132 |
+
|
133 |
+
def forward(self, features: List[torch.Tensor], skip_connect: int = 4):
|
134 |
+
skips = features[:-1][::-1]
|
135 |
+
features = features[1:][::-1]
|
136 |
+
|
137 |
+
x = features[0]
|
138 |
+
for i, block in enumerate(self.blocks):
|
139 |
+
if i < skip_connect:
|
140 |
+
skip = skips[i]
|
141 |
+
else:
|
142 |
+
skip = None
|
143 |
+
x = block(x, skip)
|
144 |
+
|
145 |
+
return x
|
146 |
+
|
147 |
+
|
148 |
+
class SegmentationHead(nn.Sequential):
|
149 |
+
"""
|
150 |
+
Segmentation head.
|
151 |
+
This class refers to `segmentation_models.pytorch
|
152 |
+
<https://github.com/qubvel/segmentation_models.pytorch>`_.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
spatial_dims: number of spatial dimensions.
|
156 |
+
in_channels: number of input channels for the block.
|
157 |
+
out_channels: number of output channels for the block.
|
158 |
+
kernel_size: kernel size for the conv layer.
|
159 |
+
act: activation type and arguments.
|
160 |
+
scale_factor: multiplier for spatial size. Has to match input size if it is a tuple.
|
161 |
+
|
162 |
+
"""
|
163 |
+
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
spatial_dims: int,
|
167 |
+
in_channels: int,
|
168 |
+
out_channels: int,
|
169 |
+
kernel_size: int = 3,
|
170 |
+
act: Optional[Union[Tuple, str]] = None,
|
171 |
+
scale_factor: float = 1.0,
|
172 |
+
):
|
173 |
+
|
174 |
+
conv_layer = Conv[Conv.CONV, spatial_dims](
|
175 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2
|
176 |
+
)
|
177 |
+
up_layer: nn.Module = nn.Identity()
|
178 |
+
if scale_factor > 1.0:
|
179 |
+
up_layer = UpSample(
|
180 |
+
spatial_dims=spatial_dims,
|
181 |
+
scale_factor=scale_factor,
|
182 |
+
mode="nontrainable",
|
183 |
+
pre_conv=None,
|
184 |
+
interp_mode=InterpolateMode.LINEAR,
|
185 |
+
)
|
186 |
+
if act is not None:
|
187 |
+
act_layer = get_act_layer(act)
|
188 |
+
else:
|
189 |
+
act_layer = nn.Identity()
|
190 |
+
super().__init__(conv_layer, up_layer, act_layer)
|
191 |
+
|
192 |
+
|
193 |
+
class FlexibleUNet(nn.Module):
|
194 |
+
"""
|
195 |
+
A flexible implementation of UNet-like encoder-decoder architecture.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(
|
199 |
+
self,
|
200 |
+
in_channels: int,
|
201 |
+
out_channels: int,
|
202 |
+
backbone: str,
|
203 |
+
pretrained: bool = False,
|
204 |
+
decoder_channels: Tuple = (256, 128, 64, 32, 16),
|
205 |
+
spatial_dims: int = 2,
|
206 |
+
norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
|
207 |
+
act: Union[str, tuple] = ("relu", {"inplace": True}),
|
208 |
+
dropout: Union[float, tuple] = 0.0,
|
209 |
+
decoder_bias: bool = False,
|
210 |
+
upsample: str = "nontrainable",
|
211 |
+
interp_mode: str = "nearest",
|
212 |
+
is_pad: bool = True,
|
213 |
+
) -> None:
|
214 |
+
"""
|
215 |
+
A flexible implement of UNet, in which the backbone/encoder can be replaced with
|
216 |
+
any efficient network. Currently the input must have a 2 or 3 spatial dimension
|
217 |
+
and the spatial size of each dimension must be a multiple of 32 if is pad parameter
|
218 |
+
is False
|
219 |
+
|
220 |
+
Args:
|
221 |
+
in_channels: number of input channels.
|
222 |
+
out_channels: number of output channels.
|
223 |
+
backbone: name of backbones to initialize, only support efficientnet right now,
|
224 |
+
can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
|
225 |
+
pretrained: whether to initialize pretrained ImageNet weights, only available
|
226 |
+
for spatial_dims=2 and batch norm is used, default to False.
|
227 |
+
decoder_channels: number of output channels for all feature maps in decoder.
|
228 |
+
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
|
229 |
+
to (256, 128, 64, 32, 16).
|
230 |
+
spatial_dims: number of spatial dimensions, default to 2.
|
231 |
+
norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
|
232 |
+
"momentum": 0.1}).
|
233 |
+
act: activation type and arguments, default to ("relu", {"inplace": True}).
|
234 |
+
dropout: dropout ratio, default to 0.0.
|
235 |
+
decoder_bias: whether to have a bias term in decoder's convolution blocks.
|
236 |
+
upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
|
237 |
+
``"nontrainable"``.
|
238 |
+
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
|
239 |
+
Only used in the "nontrainable" mode.
|
240 |
+
is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
|
241 |
+
If this parameter is set to "True", the spatial dim of network input can be arbitary
|
242 |
+
size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
|
243 |
+
"""
|
244 |
+
super().__init__()
|
245 |
+
|
246 |
+
if backbone not in encoder_feature_channel:
|
247 |
+
raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.")
|
248 |
+
|
249 |
+
if spatial_dims not in (2, 3):
|
250 |
+
raise ValueError("spatial_dims can only be 2 or 3.")
|
251 |
+
|
252 |
+
adv_prop = "ap" in backbone
|
253 |
+
|
254 |
+
self.backbone = backbone
|
255 |
+
self.spatial_dims = spatial_dims
|
256 |
+
model_name = backbone
|
257 |
+
encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels)
|
258 |
+
self.encoder = EfficientNetBNFeatures(
|
259 |
+
model_name=model_name,
|
260 |
+
pretrained=pretrained,
|
261 |
+
in_channels=in_channels,
|
262 |
+
spatial_dims=spatial_dims,
|
263 |
+
norm=norm,
|
264 |
+
adv_prop=adv_prop,
|
265 |
+
)
|
266 |
+
self.decoder = UNetDecoder(
|
267 |
+
spatial_dims=spatial_dims,
|
268 |
+
encoder_channels=encoder_channels,
|
269 |
+
decoder_channels=decoder_channels,
|
270 |
+
act=act,
|
271 |
+
norm=norm,
|
272 |
+
dropout=dropout,
|
273 |
+
bias=decoder_bias,
|
274 |
+
upsample=upsample,
|
275 |
+
interp_mode=interp_mode,
|
276 |
+
pre_conv=None,
|
277 |
+
align_corners=None,
|
278 |
+
is_pad=is_pad,
|
279 |
+
)
|
280 |
+
self.dist_head = SegmentationHead(
|
281 |
+
spatial_dims=spatial_dims,
|
282 |
+
in_channels=decoder_channels[-1],
|
283 |
+
out_channels=32,
|
284 |
+
kernel_size=1,
|
285 |
+
act='relu',
|
286 |
+
)
|
287 |
+
self.prob_head = SegmentationHead(
|
288 |
+
spatial_dims=spatial_dims,
|
289 |
+
in_channels=decoder_channels[-1],
|
290 |
+
out_channels=1,
|
291 |
+
kernel_size=1,
|
292 |
+
act='sigmoid',
|
293 |
+
)
|
294 |
+
|
295 |
+
def forward(self, inputs: torch.Tensor):
|
296 |
+
"""
|
297 |
+
Do a typical encoder-decoder-header inference.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
|
301 |
+
N is defined by `dimensions`.
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
|
305 |
+
|
306 |
+
"""
|
307 |
+
x = inputs
|
308 |
+
enc_out = self.encoder(x)
|
309 |
+
decoder_out = self.decoder(enc_out)
|
310 |
+
dist = self.dist_head(decoder_out)
|
311 |
+
prob = self.prob_head(decoder_out)
|
312 |
+
return dist,prob
|
models/flexible_unet_convext.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
from typing import List, Optional, Sequence, Tuple, Union
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
from . import convnext
|
17 |
+
from monai.networks.blocks import UpSample
|
18 |
+
from monai.networks.layers.factories import Conv
|
19 |
+
from monai.networks.layers.utils import get_act_layer
|
20 |
+
from monai.networks.nets import EfficientNetBNFeatures
|
21 |
+
from monai.networks.nets.basic_unet import UpCat
|
22 |
+
from monai.utils import InterpolateMode
|
23 |
+
|
24 |
+
__all__ = ["FlexibleUNet"]
|
25 |
+
|
26 |
+
encoder_feature_channel = {
|
27 |
+
"efficientnet-b0": (16, 24, 40, 112, 320),
|
28 |
+
"efficientnet-b1": (16, 24, 40, 112, 320),
|
29 |
+
"efficientnet-b2": (16, 24, 48, 120, 352),
|
30 |
+
"efficientnet-b3": (24, 32, 48, 136, 384),
|
31 |
+
"efficientnet-b4": (24, 32, 56, 160, 448),
|
32 |
+
"efficientnet-b5": (24, 40, 64, 176, 512),
|
33 |
+
"efficientnet-b6": (32, 40, 72, 200, 576),
|
34 |
+
"efficientnet-b7": (32, 48, 80, 224, 640),
|
35 |
+
"efficientnet-b8": (32, 56, 88, 248, 704),
|
36 |
+
"efficientnet-l2": (72, 104, 176, 480, 1376),
|
37 |
+
"convnext_small": (96, 192, 384, 768),
|
38 |
+
"convnext_base": (128, 256, 512, 1024),
|
39 |
+
"van_b2": (64, 128, 320, 512),
|
40 |
+
"van_b1": (64, 128, 320, 512),
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
def _get_encoder_channels_by_backbone(backbone: str, in_channels: int = 3) -> tuple:
|
45 |
+
"""
|
46 |
+
Get the encoder output channels by given backbone name.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
backbone: name of backbone to generate features, can be from [efficientnet-b0, ..., efficientnet-b7].
|
50 |
+
in_channels: channel of input tensor, default to 3.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
A tuple of output feature map channels' length .
|
54 |
+
"""
|
55 |
+
encoder_channel_tuple = encoder_feature_channel[backbone]
|
56 |
+
encoder_channel_list = [in_channels] + list(encoder_channel_tuple)
|
57 |
+
encoder_channel = tuple(encoder_channel_list)
|
58 |
+
return encoder_channel
|
59 |
+
|
60 |
+
|
61 |
+
class UNetDecoder(nn.Module):
|
62 |
+
"""
|
63 |
+
UNet Decoder.
|
64 |
+
This class refers to `segmentation_models.pytorch
|
65 |
+
<https://github.com/qubvel/segmentation_models.pytorch>`_.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
spatial_dims: number of spatial dimensions.
|
69 |
+
encoder_channels: number of output channels for all feature maps in encoder.
|
70 |
+
`len(encoder_channels)` should be no less than 2.
|
71 |
+
decoder_channels: number of output channels for all feature maps in decoder.
|
72 |
+
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.
|
73 |
+
act: activation type and arguments.
|
74 |
+
norm: feature normalization type and arguments.
|
75 |
+
dropout: dropout ratio.
|
76 |
+
bias: whether to have a bias term in convolution blocks in this decoder.
|
77 |
+
upsample: upsampling mode, available options are
|
78 |
+
``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
|
79 |
+
pre_conv: a conv block applied before upsampling.
|
80 |
+
Only used in the "nontrainable" or "pixelshuffle" mode.
|
81 |
+
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
|
82 |
+
Only used in the "nontrainable" mode.
|
83 |
+
align_corners: set the align_corners parameter for upsample. Defaults to True.
|
84 |
+
Only used in the "nontrainable" mode.
|
85 |
+
is_pad: whether to pad upsampling features to fit the encoder spatial dims.
|
86 |
+
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
spatial_dims: int,
|
92 |
+
encoder_channels: Sequence[int],
|
93 |
+
decoder_channels: Sequence[int],
|
94 |
+
act: Union[str, tuple],
|
95 |
+
norm: Union[str, tuple],
|
96 |
+
dropout: Union[float, tuple],
|
97 |
+
bias: bool,
|
98 |
+
upsample: str,
|
99 |
+
pre_conv: Optional[str],
|
100 |
+
interp_mode: str,
|
101 |
+
align_corners: Optional[bool],
|
102 |
+
is_pad: bool,
|
103 |
+
):
|
104 |
+
|
105 |
+
super().__init__()
|
106 |
+
if len(encoder_channels) < 2:
|
107 |
+
raise ValueError("the length of `encoder_channels` should be no less than 2.")
|
108 |
+
if len(decoder_channels) != len(encoder_channels) - 1:
|
109 |
+
raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.")
|
110 |
+
|
111 |
+
in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1])
|
112 |
+
skip_channels = list(encoder_channels[1:-1][::-1]) + [0]
|
113 |
+
halves = [True] * (len(skip_channels) - 1)
|
114 |
+
halves.append(False)
|
115 |
+
blocks = []
|
116 |
+
for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves):
|
117 |
+
blocks.append(
|
118 |
+
UpCat(
|
119 |
+
spatial_dims=spatial_dims,
|
120 |
+
in_chns=in_chn,
|
121 |
+
cat_chns=skip_chn,
|
122 |
+
out_chns=out_chn,
|
123 |
+
act=act,
|
124 |
+
norm=norm,
|
125 |
+
dropout=dropout,
|
126 |
+
bias=bias,
|
127 |
+
upsample=upsample,
|
128 |
+
pre_conv=pre_conv,
|
129 |
+
interp_mode=interp_mode,
|
130 |
+
align_corners=align_corners,
|
131 |
+
halves=halve,
|
132 |
+
is_pad=is_pad,
|
133 |
+
)
|
134 |
+
)
|
135 |
+
self.blocks = nn.ModuleList(blocks)
|
136 |
+
|
137 |
+
def forward(self, features: List[torch.Tensor], skip_connect: int = 3):
|
138 |
+
skips = features[:-1][::-1]
|
139 |
+
features = features[1:][::-1]
|
140 |
+
|
141 |
+
x = features[0]
|
142 |
+
for i, block in enumerate(self.blocks):
|
143 |
+
if i < skip_connect:
|
144 |
+
skip = skips[i]
|
145 |
+
else:
|
146 |
+
skip = None
|
147 |
+
x = block(x, skip)
|
148 |
+
|
149 |
+
return x
|
150 |
+
|
151 |
+
|
152 |
+
class SegmentationHead(nn.Sequential):
|
153 |
+
"""
|
154 |
+
Segmentation head.
|
155 |
+
This class refers to `segmentation_models.pytorch
|
156 |
+
<https://github.com/qubvel/segmentation_models.pytorch>`_.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
spatial_dims: number of spatial dimensions.
|
160 |
+
in_channels: number of input channels for the block.
|
161 |
+
out_channels: number of output channels for the block.
|
162 |
+
kernel_size: kernel size for the conv layer.
|
163 |
+
act: activation type and arguments.
|
164 |
+
scale_factor: multiplier for spatial size. Has to match input size if it is a tuple.
|
165 |
+
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
spatial_dims: int,
|
171 |
+
in_channels: int,
|
172 |
+
out_channels: int,
|
173 |
+
kernel_size: int = 3,
|
174 |
+
act: Optional[Union[Tuple, str]] = None,
|
175 |
+
scale_factor: float = 1.0,
|
176 |
+
):
|
177 |
+
|
178 |
+
conv_layer = Conv[Conv.CONV, spatial_dims](
|
179 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2
|
180 |
+
)
|
181 |
+
up_layer: nn.Module = nn.Identity()
|
182 |
+
# if scale_factor > 1.0:
|
183 |
+
# up_layer = UpSample(
|
184 |
+
# in_channels=out_channels,
|
185 |
+
# spatial_dims=spatial_dims,
|
186 |
+
# scale_factor=scale_factor,
|
187 |
+
# mode="deconv",
|
188 |
+
# pre_conv=None,
|
189 |
+
# interp_mode=InterpolateMode.LINEAR,
|
190 |
+
# )
|
191 |
+
if scale_factor > 1.0:
|
192 |
+
up_layer = UpSample(
|
193 |
+
spatial_dims=spatial_dims,
|
194 |
+
scale_factor=scale_factor,
|
195 |
+
mode="nontrainable",
|
196 |
+
pre_conv=None,
|
197 |
+
interp_mode=InterpolateMode.LINEAR,
|
198 |
+
)
|
199 |
+
if act is not None:
|
200 |
+
act_layer = get_act_layer(act)
|
201 |
+
else:
|
202 |
+
act_layer = nn.Identity()
|
203 |
+
super().__init__(conv_layer, up_layer, act_layer)
|
204 |
+
|
205 |
+
|
206 |
+
class FlexibleUNetConvext(nn.Module):
|
207 |
+
"""
|
208 |
+
A flexible implementation of UNet-like encoder-decoder architecture.
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
in_channels: int,
|
214 |
+
out_channels: int,
|
215 |
+
backbone: str,
|
216 |
+
pretrained: bool = False,
|
217 |
+
decoder_channels: Tuple = (1024, 512, 256, 128),
|
218 |
+
spatial_dims: int = 2,
|
219 |
+
norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
|
220 |
+
act: Union[str, tuple] = ("relu", {"inplace": True}),
|
221 |
+
dropout: Union[float, tuple] = 0.0,
|
222 |
+
decoder_bias: bool = False,
|
223 |
+
upsample: str = "nontrainable",
|
224 |
+
interp_mode: str = "nearest",
|
225 |
+
is_pad: bool = True,
|
226 |
+
) -> None:
|
227 |
+
"""
|
228 |
+
A flexible implement of UNet, in which the backbone/encoder can be replaced with
|
229 |
+
any efficient network. Currently the input must have a 2 or 3 spatial dimension
|
230 |
+
and the spatial size of each dimension must be a multiple of 32 if is pad parameter
|
231 |
+
is False
|
232 |
+
|
233 |
+
Args:
|
234 |
+
in_channels: number of input channels.
|
235 |
+
out_channels: number of output channels.
|
236 |
+
backbone: name of backbones to initialize, only support efficientnet right now,
|
237 |
+
can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
|
238 |
+
pretrained: whether to initialize pretrained ImageNet weights, only available
|
239 |
+
for spatial_dims=2 and batch norm is used, default to False.
|
240 |
+
decoder_channels: number of output channels for all feature maps in decoder.
|
241 |
+
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
|
242 |
+
to (256, 128, 64, 32, 16).
|
243 |
+
spatial_dims: number of spatial dimensions, default to 2.
|
244 |
+
norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
|
245 |
+
"momentum": 0.1}).
|
246 |
+
act: activation type and arguments, default to ("relu", {"inplace": True}).
|
247 |
+
dropout: dropout ratio, default to 0.0.
|
248 |
+
decoder_bias: whether to have a bias term in decoder's convolution blocks.
|
249 |
+
upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
|
250 |
+
``"nontrainable"``.
|
251 |
+
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
|
252 |
+
Only used in the "nontrainable" mode.
|
253 |
+
is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
|
254 |
+
If this parameter is set to "True", the spatial dim of network input can be arbitary
|
255 |
+
size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
|
256 |
+
"""
|
257 |
+
super().__init__()
|
258 |
+
|
259 |
+
if backbone not in encoder_feature_channel:
|
260 |
+
raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.")
|
261 |
+
|
262 |
+
if spatial_dims not in (2, 3):
|
263 |
+
raise ValueError("spatial_dims can only be 2 or 3.")
|
264 |
+
|
265 |
+
adv_prop = "ap" in backbone
|
266 |
+
|
267 |
+
self.backbone = backbone
|
268 |
+
self.spatial_dims = spatial_dims
|
269 |
+
model_name = backbone
|
270 |
+
encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels)
|
271 |
+
|
272 |
+
self.encoder = convnext.convnext_small(pretrained=True,in_22k=True)
|
273 |
+
# self.encoder = VAN(embed_dims=[64, 128, 320, 512],
|
274 |
+
# depths=[3, 3, 12, 3],
|
275 |
+
# init_cfg=dict(type='Pretrained', checkpoint='pretrained/van_b2.pth'),
|
276 |
+
# norm_cfg=dict(type='BN', requires_grad=True)
|
277 |
+
# )
|
278 |
+
# self.encoder = VAN(embed_dims=[64, 128, 320, 512],
|
279 |
+
# depths=[2, 2, 4, 2],
|
280 |
+
# init_cfg=dict(type='Pretrained', checkpoint='pretrained/van_b1.pth'),
|
281 |
+
# norm_cfg=dict(type='BN', requires_grad=True)
|
282 |
+
# )
|
283 |
+
# self.encoder.init_weights()
|
284 |
+
self.decoder = UNetDecoder(
|
285 |
+
spatial_dims=spatial_dims,
|
286 |
+
encoder_channels=encoder_channels,
|
287 |
+
decoder_channels=decoder_channels,
|
288 |
+
act=act,
|
289 |
+
norm=norm,
|
290 |
+
dropout=dropout,
|
291 |
+
bias=decoder_bias,
|
292 |
+
upsample=upsample,
|
293 |
+
interp_mode=interp_mode,
|
294 |
+
pre_conv=None,
|
295 |
+
align_corners=None,
|
296 |
+
is_pad=is_pad,
|
297 |
+
)
|
298 |
+
self.dist_head = SegmentationHead(
|
299 |
+
spatial_dims=spatial_dims,
|
300 |
+
in_channels=decoder_channels[-1],
|
301 |
+
out_channels=64,
|
302 |
+
kernel_size=1,
|
303 |
+
act='relu',
|
304 |
+
scale_factor = 2,
|
305 |
+
)
|
306 |
+
self.prob_head = SegmentationHead(
|
307 |
+
spatial_dims=spatial_dims,
|
308 |
+
in_channels=decoder_channels[-1],
|
309 |
+
out_channels=1,
|
310 |
+
kernel_size=1,
|
311 |
+
act='sigmoid',
|
312 |
+
scale_factor = 2,
|
313 |
+
)
|
314 |
+
|
315 |
+
def forward(self, inputs: torch.Tensor):
|
316 |
+
"""
|
317 |
+
Do a typical encoder-decoder-header inference.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
|
321 |
+
N is defined by `dimensions`.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
|
325 |
+
|
326 |
+
"""
|
327 |
+
x = inputs
|
328 |
+
enc_out = self.encoder(x)
|
329 |
+
decoder_out = self.decoder(enc_out)
|
330 |
+
|
331 |
+
dist = self.dist_head(decoder_out)
|
332 |
+
prob = self.prob_head(decoder_out)
|
333 |
+
|
334 |
+
return dist,prob
|
335 |
+
class FlexibleUNet_hv(nn.Module):
|
336 |
+
"""
|
337 |
+
A flexible implementation of UNet-like encoder-decoder architecture.
|
338 |
+
"""
|
339 |
+
|
340 |
+
def __init__(
|
341 |
+
self,
|
342 |
+
in_channels: int,
|
343 |
+
out_channels: int,
|
344 |
+
backbone: str,
|
345 |
+
pretrained: bool = False,
|
346 |
+
decoder_channels: Tuple = (1024, 512, 256, 128),
|
347 |
+
spatial_dims: int = 2,
|
348 |
+
norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
|
349 |
+
act: Union[str, tuple] = ("relu", {"inplace": True}),
|
350 |
+
dropout: Union[float, tuple] = 0.0,
|
351 |
+
decoder_bias: bool = False,
|
352 |
+
upsample: str = "nontrainable",
|
353 |
+
interp_mode: str = "nearest",
|
354 |
+
is_pad: bool = True,
|
355 |
+
n_rays: int = 32,
|
356 |
+
prob_out_channels: int = 1,
|
357 |
+
) -> None:
|
358 |
+
"""
|
359 |
+
A flexible implement of UNet, in which the backbone/encoder can be replaced with
|
360 |
+
any efficient network. Currently the input must have a 2 or 3 spatial dimension
|
361 |
+
and the spatial size of each dimension must be a multiple of 32 if is pad parameter
|
362 |
+
is False
|
363 |
+
|
364 |
+
Args:
|
365 |
+
in_channels: number of input channels.
|
366 |
+
out_channels: number of output channels.
|
367 |
+
backbone: name of backbones to initialize, only support efficientnet right now,
|
368 |
+
can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
|
369 |
+
pretrained: whether to initialize pretrained ImageNet weights, only available
|
370 |
+
for spatial_dims=2 and batch norm is used, default to False.
|
371 |
+
decoder_channels: number of output channels for all feature maps in decoder.
|
372 |
+
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
|
373 |
+
to (256, 128, 64, 32, 16).
|
374 |
+
spatial_dims: number of spatial dimensions, default to 2.
|
375 |
+
norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
|
376 |
+
"momentum": 0.1}).
|
377 |
+
act: activation type and arguments, default to ("relu", {"inplace": True}).
|
378 |
+
dropout: dropout ratio, default to 0.0.
|
379 |
+
decoder_bias: whether to have a bias term in decoder's convolution blocks.
|
380 |
+
upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
|
381 |
+
``"nontrainable"``.
|
382 |
+
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
|
383 |
+
Only used in the "nontrainable" mode.
|
384 |
+
is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
|
385 |
+
If this parameter is set to "True", the spatial dim of network input can be arbitary
|
386 |
+
size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
|
387 |
+
"""
|
388 |
+
super().__init__()
|
389 |
+
|
390 |
+
if backbone not in encoder_feature_channel:
|
391 |
+
raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.")
|
392 |
+
|
393 |
+
if spatial_dims not in (2, 3):
|
394 |
+
raise ValueError("spatial_dims can only be 2 or 3.")
|
395 |
+
|
396 |
+
adv_prop = "ap" in backbone
|
397 |
+
|
398 |
+
self.backbone = backbone
|
399 |
+
self.spatial_dims = spatial_dims
|
400 |
+
model_name = backbone
|
401 |
+
encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels)
|
402 |
+
self.encoder = convnext.convnext_small(pretrained=True,in_22k=True)
|
403 |
+
self.decoder = UNetDecoder(
|
404 |
+
spatial_dims=spatial_dims,
|
405 |
+
encoder_channels=encoder_channels,
|
406 |
+
decoder_channels=decoder_channels,
|
407 |
+
act=act,
|
408 |
+
norm=norm,
|
409 |
+
dropout=dropout,
|
410 |
+
bias=decoder_bias,
|
411 |
+
upsample=upsample,
|
412 |
+
interp_mode=interp_mode,
|
413 |
+
pre_conv=None,
|
414 |
+
align_corners=None,
|
415 |
+
is_pad=is_pad,
|
416 |
+
)
|
417 |
+
self.dist_head = SegmentationHead(
|
418 |
+
spatial_dims=spatial_dims,
|
419 |
+
in_channels=decoder_channels[-1],
|
420 |
+
out_channels=n_rays,
|
421 |
+
kernel_size=1,
|
422 |
+
act=None,
|
423 |
+
scale_factor = 2,
|
424 |
+
)
|
425 |
+
self.prob_head = SegmentationHead(
|
426 |
+
spatial_dims=spatial_dims,
|
427 |
+
in_channels=decoder_channels[-1],
|
428 |
+
out_channels=prob_out_channels,
|
429 |
+
kernel_size=1,
|
430 |
+
act='sigmoid',
|
431 |
+
scale_factor = 2,
|
432 |
+
)
|
433 |
+
|
434 |
+
def forward(self, inputs: torch.Tensor):
|
435 |
+
"""
|
436 |
+
Do a typical encoder-decoder-header inference.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
|
440 |
+
N is defined by `dimensions`.
|
441 |
+
|
442 |
+
Returns:
|
443 |
+
A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
|
444 |
+
|
445 |
+
"""
|
446 |
+
x = inputs
|
447 |
+
enc_out = self.encoder(x)
|
448 |
+
decoder_out = self.decoder(enc_out)
|
449 |
+
dist = self.dist_head(decoder_out)
|
450 |
+
prob = self.prob_head(decoder_out)
|
451 |
+
return dist,prob
|
overlay.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
###overlay
|
5 |
+
import cv2
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
import colorsys
|
9 |
+
import numpy as np
|
10 |
+
import itertools
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
from matplotlib import cm
|
13 |
+
import os
|
14 |
+
import scipy.io as io
|
15 |
+
def get_bounding_box(img):
|
16 |
+
"""Get bounding box coordinate information."""
|
17 |
+
rows = np.any(img, axis=1)
|
18 |
+
cols = np.any(img, axis=0)
|
19 |
+
rmin, rmax = np.where(rows)[0][[0, -1]]
|
20 |
+
cmin, cmax = np.where(cols)[0][[0, -1]]
|
21 |
+
# due to python indexing, need to add 1 to max
|
22 |
+
# else accessing will be 1px in the box, not out
|
23 |
+
rmax += 1
|
24 |
+
cmax += 1
|
25 |
+
return [rmin, rmax, cmin, cmax]
|
26 |
+
####
|
27 |
+
def colorize(ch, vmin, vmax):
|
28 |
+
"""Will clamp value value outside the provided range to vmax and vmin."""
|
29 |
+
cmap = plt.get_cmap("jet")
|
30 |
+
ch = np.squeeze(ch.astype("float32"))
|
31 |
+
vmin = vmin if vmin is not None else ch.min()
|
32 |
+
vmax = vmax if vmax is not None else ch.max()
|
33 |
+
ch[ch > vmax] = vmax # clamp value
|
34 |
+
ch[ch < vmin] = vmin
|
35 |
+
ch = (ch - vmin) / (vmax - vmin + 1.0e-16)
|
36 |
+
# take RGB from RGBA heat map
|
37 |
+
ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8")
|
38 |
+
return ch_cmap
|
39 |
+
|
40 |
+
|
41 |
+
####
|
42 |
+
def random_colors(N, bright=True):
|
43 |
+
"""Generate random colors.
|
44 |
+
|
45 |
+
To get visually distinct colors, generate them in HSV space then
|
46 |
+
convert to RGB.
|
47 |
+
"""
|
48 |
+
brightness = 1.0 if bright else 0.7
|
49 |
+
hsv = [(i / N, 1, brightness) for i in range(N)]
|
50 |
+
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
|
51 |
+
random.shuffle(colors)
|
52 |
+
return colors
|
53 |
+
|
54 |
+
|
55 |
+
####
|
56 |
+
def visualize_instances_map(
|
57 |
+
input_image, inst_map, type_map=None, type_colour=None, line_thickness=2
|
58 |
+
):
|
59 |
+
"""Overlays segmentation results on image as contours.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
input_image: input image
|
63 |
+
inst_map: instance mask with unique value for every object
|
64 |
+
type_map: type mask with unique value for every class
|
65 |
+
type_colour: a dict of {type : colour} , `type` is from 0-N
|
66 |
+
and `colour` is a tuple of (R, G, B)
|
67 |
+
line_thickness: line thickness of contours
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
overlay: output image with segmentation overlay as contours
|
71 |
+
"""
|
72 |
+
overlay = np.copy((input_image).astype(np.uint8))
|
73 |
+
|
74 |
+
inst_list = list(np.unique(inst_map)) # get list of instances
|
75 |
+
inst_list.remove(0) # remove background
|
76 |
+
|
77 |
+
inst_rng_colors = random_colors(len(inst_list))
|
78 |
+
inst_rng_colors = np.array(inst_rng_colors) * 255
|
79 |
+
inst_rng_colors = inst_rng_colors.astype(np.uint8)
|
80 |
+
|
81 |
+
for inst_idx, inst_id in enumerate(inst_list):
|
82 |
+
inst_map_mask = np.array(inst_map == inst_id, np.uint8) # get single object
|
83 |
+
y1, y2, x1, x2 = get_bounding_box(inst_map_mask)
|
84 |
+
y1 = y1 - 2 if y1 - 2 >= 0 else y1
|
85 |
+
x1 = x1 - 2 if x1 - 2 >= 0 else x1
|
86 |
+
x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2
|
87 |
+
y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2
|
88 |
+
inst_map_crop = inst_map_mask[y1:y2, x1:x2]
|
89 |
+
contours_crop = cv2.findContours(
|
90 |
+
inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
91 |
+
)
|
92 |
+
# only has 1 instance per map, no need to check #contour detected by opencv
|
93 |
+
#print(contours_crop)
|
94 |
+
contours_crop = np.squeeze(
|
95 |
+
contours_crop[0][0].astype("int32")
|
96 |
+
) # * opencv protocol format may break
|
97 |
+
|
98 |
+
if len(contours_crop.shape) == 1:
|
99 |
+
contours_crop = contours_crop.reshape(1,-1)
|
100 |
+
#print(contours_crop.shape)
|
101 |
+
contours_crop += np.asarray([[x1, y1]]) # index correction
|
102 |
+
if type_map is not None:
|
103 |
+
type_map_crop = type_map[y1:y2, x1:x2]
|
104 |
+
type_id = np.unique(type_map_crop).max() # non-zero
|
105 |
+
inst_colour = type_colour[type_id]
|
106 |
+
else:
|
107 |
+
inst_colour = (inst_rng_colors[inst_idx]).tolist()
|
108 |
+
cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness)
|
109 |
+
return overlay
|
110 |
+
|
111 |
+
|
112 |
+
# In[ ]:
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
predict.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
join = os.path.join
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from collections import OrderedDict
|
9 |
+
from torchvision import datasets, models, transforms
|
10 |
+
from classifiers import resnet10, resnet18
|
11 |
+
|
12 |
+
from utils_modify import sliding_window_inference,sliding_window_inference_large,__proc_np_hv
|
13 |
+
from PIL import Image
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from skimage import io, segmentation, morphology, measure, exposure
|
16 |
+
import tifffile as tif
|
17 |
+
from models.flexible_unet_convnext import FlexibleUNet_star,FlexibleUNet_hv
|
18 |
+
#from overlay import visualize_instances_map
|
19 |
+
|
20 |
+
def normalize_channel(img, lower=1, upper=99):
|
21 |
+
non_zero_vals = img[np.nonzero(img)]
|
22 |
+
percentiles = np.percentile(non_zero_vals, [lower, upper])
|
23 |
+
if percentiles[1] - percentiles[0] > 0.001:
|
24 |
+
img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
|
25 |
+
else:
|
26 |
+
img_norm = img
|
27 |
+
return img_norm.astype(np.uint8)
|
28 |
+
#torch.cuda.synchronize()
|
29 |
+
parser = argparse.ArgumentParser('Baseline for Microscopy image segmentation', add_help=False)
|
30 |
+
# Dataset parameters
|
31 |
+
parser.add_argument('-i', '--input_path', default='./inputs', type=str, help='training data path; subfolders: images, labels')
|
32 |
+
parser.add_argument("-o", '--output_path', default='./outputs', type=str, help='output path')
|
33 |
+
parser.add_argument('--model_path', default='./models', help='path where to save models and segmentation results')
|
34 |
+
parser.add_argument('--show_overlay', required=False, default=False, action="store_true", help='save segmentation overlay')
|
35 |
+
|
36 |
+
# Model parameters
|
37 |
+
parser.add_argument('--model_name', default='efficientunet', help='select mode: unet, unetr, swinunetr')
|
38 |
+
parser.add_argument('--input_size', default=512, type=int, help='segmentation classes')
|
39 |
+
args = parser.parse_args()
|
40 |
+
input_path = args.input_path
|
41 |
+
output_path = args.output_path
|
42 |
+
model_path = args.model_path
|
43 |
+
os.makedirs(output_path, exist_ok=True)
|
44 |
+
#overlay_path = 'overlays/'
|
45 |
+
#print(input_path)
|
46 |
+
|
47 |
+
img_names = sorted(os.listdir(join(input_path)))
|
48 |
+
#print(img_names)
|
49 |
+
|
50 |
+
|
51 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
52 |
+
|
53 |
+
|
54 |
+
preprocess=transforms.Compose([
|
55 |
+
transforms.Resize(size=256),
|
56 |
+
transforms.CenterCrop(size=224),
|
57 |
+
transforms.ToTensor(),
|
58 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
59 |
+
[0.229, 0.224, 0.225])
|
60 |
+
])
|
61 |
+
roi_size = (512, 512)
|
62 |
+
overlap = 0.5
|
63 |
+
np_thres, ksize, overall_thres, obj_size_thres = 0.6, 15, 0.4, 100
|
64 |
+
n_rays = 32
|
65 |
+
sw_batch_size = 4
|
66 |
+
num_classes= 4
|
67 |
+
block_size = 2048
|
68 |
+
min_overlap = 128
|
69 |
+
context = 128
|
70 |
+
with torch.no_grad():
|
71 |
+
for img_name in img_names:
|
72 |
+
#print(img_name)
|
73 |
+
if img_name.endswith('.tif') or img_name.endswith('.tiff'):
|
74 |
+
img_data = tif.imread(join(input_path, img_name))
|
75 |
+
else:
|
76 |
+
img_data = io.imread(join(input_path, img_name))
|
77 |
+
# normalize image data
|
78 |
+
if len(img_data.shape) == 2:
|
79 |
+
img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
|
80 |
+
elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
|
81 |
+
img_data = img_data[:,:, :3]
|
82 |
+
else:
|
83 |
+
pass
|
84 |
+
pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
|
85 |
+
for i in range(3):
|
86 |
+
img_channel_i = img_data[:,:,i]
|
87 |
+
if len(img_channel_i[np.nonzero(img_channel_i)])>0:
|
88 |
+
pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
|
89 |
+
inputs=preprocess(Image.fromarray(pre_img_data)).unsqueeze(0).to(device)
|
90 |
+
cls_MODEL = model_path + '/cls/resnet18_4class_all_modified.tar'
|
91 |
+
model = resnet18().to(device)
|
92 |
+
model.load_state_dict(torch.load(cls_MODEL))
|
93 |
+
model.eval()
|
94 |
+
outputs = model(inputs)
|
95 |
+
_, preds = torch.max(outputs, 1)
|
96 |
+
label=preds[0].cpu().numpy()
|
97 |
+
#print(label)
|
98 |
+
test_npy01 = pre_img_data
|
99 |
+
if label in [0,1,2] or img_data.shape[0] > 4000:
|
100 |
+
if label == 0:
|
101 |
+
model = FlexibleUNet_star(in_channels=3,out_channels=n_rays+1,backbone='convnext_small',pretrained=False,n_rays=n_rays,prob_out_channels=1,).to(device)
|
102 |
+
checkpoint = torch.load(model_path+'/0/best_model.pth', map_location=torch.device(device))
|
103 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
104 |
+
model.eval()
|
105 |
+
|
106 |
+
output_label = sliding_window_inference_large(test_npy01,block_size,min_overlap,context, roi_size,sw_batch_size,predictor=model,device=device)
|
107 |
+
tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), output_label)
|
108 |
+
|
109 |
+
elif label == 1:
|
110 |
+
model = FlexibleUNet_star(in_channels=3,out_channels=n_rays+1,backbone='convnext_small',pretrained=False,n_rays=n_rays,prob_out_channels=1,).to(device)
|
111 |
+
checkpoint = torch.load(model_path+'/1/best_model.pth', map_location=torch.device(device))
|
112 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
113 |
+
model.eval()
|
114 |
+
|
115 |
+
output_label = sliding_window_inference_large(test_npy01,block_size,min_overlap,context, roi_size,sw_batch_size,predictor=model,device=device)
|
116 |
+
tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), output_label)
|
117 |
+
elif label == 2:
|
118 |
+
model = FlexibleUNet_star(in_channels=3,out_channels=n_rays+1,backbone='convnext_small',pretrained=False,n_rays=n_rays,prob_out_channels=1,).to(device)
|
119 |
+
checkpoint = torch.load(model_path+'/2/best_model.pth', map_location=torch.device(device))
|
120 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
121 |
+
model.eval()
|
122 |
+
|
123 |
+
output_label = sliding_window_inference_large(test_npy01,block_size,min_overlap,context, roi_size,sw_batch_size,predictor=model,device=device)
|
124 |
+
tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), output_label)
|
125 |
+
|
126 |
+
|
127 |
+
else:
|
128 |
+
model = FlexibleUNet_hv(in_channels=3,out_channels=2+2,backbone='convnext_small',pretrained=False,n_rays=2,prob_out_channels=2,).to(device)
|
129 |
+
checkpoint = torch.load(model_path+'/3/best_model_converted.pth', map_location=torch.device(device))
|
130 |
+
#model.load_state_dict(checkpoint['model_state_dict'])
|
131 |
+
#od = OrderedDict()
|
132 |
+
#for k, v in checkpoint['model_state_dict'].items():
|
133 |
+
#od[k.replace('module.', '')] = v
|
134 |
+
model.load_state_dict(checkpoint)
|
135 |
+
model.to(device)
|
136 |
+
model.eval()
|
137 |
+
test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0, 3, 1, 2).type(torch.FloatTensor).to(device)
|
138 |
+
if isinstance(roi_size, tuple):
|
139 |
+
roi = roi_size
|
140 |
+
|
141 |
+
output_hv, output_np = sliding_window_inference(test_tensor, roi, sw_batch_size, model, overlap=overlap)
|
142 |
+
pred_dict = {'np': output_np, 'hv': output_hv}
|
143 |
+
pred_dict = OrderedDict(
|
144 |
+
[[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] # NHWC
|
145 |
+
)
|
146 |
+
pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1:]
|
147 |
+
pred_output = torch.cat(list(pred_dict.values()), -1).cpu().numpy() # NHW3
|
148 |
+
pred_map = np.squeeze(pred_output) # HW3
|
149 |
+
pred_inst = __proc_np_hv(pred_map, np_thres, ksize, overall_thres, obj_size_thres)
|
150 |
+
raw_pred_shape = pred_inst.shape[:2]
|
151 |
+
output_label = pred_inst
|
152 |
+
|
153 |
+
tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), output_label)
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
|
predict_unet_convnext.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
join = os.path.join
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import monai
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from utils import sliding_window_inference
|
11 |
+
#from baseline.models.unetr2d import UNETR2D
|
12 |
+
import time
|
13 |
+
from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label
|
14 |
+
from stardist import random_label_cmap,ray_angles
|
15 |
+
from stardist import star_dist,edt_prob
|
16 |
+
from skimage import io, segmentation, morphology, measure, exposure
|
17 |
+
import tifffile as tif
|
18 |
+
import cv2
|
19 |
+
from overlay import visualize_instances_map
|
20 |
+
from models.flexible_unet import FlexibleUNet
|
21 |
+
from models.flexible_unet_convext import FlexibleUNetConvext
|
22 |
+
def normalize_channel(img, lower=1, upper=99):
|
23 |
+
non_zero_vals = img[np.nonzero(img)]
|
24 |
+
percentiles = np.percentile(non_zero_vals, [lower, upper])
|
25 |
+
if percentiles[1] - percentiles[0] > 0.001:
|
26 |
+
img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
|
27 |
+
else:
|
28 |
+
img_norm = img
|
29 |
+
return img_norm.astype(np.uint8)
|
30 |
+
|
31 |
+
def main():
|
32 |
+
parser = argparse.ArgumentParser('Baseline for Microscopy image segmentation', add_help=False)
|
33 |
+
# Dataset parameters
|
34 |
+
#parser.add_argument('-i', '--input_path', default='./inputs', type=str, help='training data path; subfolders: images, labels')
|
35 |
+
#parser.add_argument("-o", '--output_path', default='./outputs', type=str, help='output path')
|
36 |
+
parser.add_argument('--model_path', default='./work_dir/swinunetr_3class', help='path where to save models and segmentation results')
|
37 |
+
parser.add_argument('--show_overlay', required=False, default=False, action="store_true", help='save segmentation overlay')
|
38 |
+
|
39 |
+
# Model parameters
|
40 |
+
parser.add_argument('--model_name', default='efficientunet', help='select mode: unet, unetr, swinunetr')
|
41 |
+
parser.add_argument('--num_class', default=3, type=int, help='segmentation classes')
|
42 |
+
parser.add_argument('--input_size', default=512, type=int, help='segmentation classes')
|
43 |
+
args = parser.parse_args()
|
44 |
+
|
45 |
+
input_path = '/home/data/TuningSet/'
|
46 |
+
output_path = '/home/data/output/'
|
47 |
+
overlay_path = '/home/data/overlay/'
|
48 |
+
|
49 |
+
|
50 |
+
img_names = sorted(os.listdir(join(input_path)))
|
51 |
+
n_rays = 32
|
52 |
+
|
53 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
if args.model_name.lower() == "efficientunet":
|
58 |
+
model = FlexibleUNetConvext(
|
59 |
+
in_channels=3,
|
60 |
+
out_channels=n_rays+1,
|
61 |
+
backbone='convnext_small',
|
62 |
+
pretrained=True,
|
63 |
+
).to(device)
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
sigmoid = nn.Sigmoid()
|
68 |
+
checkpoint = torch.load('/home/louwei/stardist_convnext/efficientunet_3class/best_model.pth', map_location=torch.device(device))
|
69 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
70 |
+
#%%
|
71 |
+
roi_size = (args.input_size, args.input_size)
|
72 |
+
sw_batch_size = 4
|
73 |
+
model.eval()
|
74 |
+
with torch.no_grad():
|
75 |
+
for img_name in img_names:
|
76 |
+
print(img_name)
|
77 |
+
if img_name.endswith('.tif') or img_name.endswith('.tiff'):
|
78 |
+
img_data = tif.imread(join(input_path, img_name))
|
79 |
+
else:
|
80 |
+
img_data = io.imread(join(input_path, img_name))
|
81 |
+
# normalize image data
|
82 |
+
if len(img_data.shape) == 2:
|
83 |
+
img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
|
84 |
+
elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
|
85 |
+
img_data = img_data[:,:, :3]
|
86 |
+
else:
|
87 |
+
pass
|
88 |
+
pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
|
89 |
+
for i in range(3):
|
90 |
+
img_channel_i = img_data[:,:,i]
|
91 |
+
if len(img_channel_i[np.nonzero(img_channel_i)])>0:
|
92 |
+
pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
|
93 |
+
|
94 |
+
t0 = time.time()
|
95 |
+
#test_npy01 = pre_img_data/np.max(pre_img_data)
|
96 |
+
test_npy01 = pre_img_data
|
97 |
+
test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
|
98 |
+
output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, model)
|
99 |
+
#test_pred_out = torch.nn.functional.softmax(test_pred_out, dim=1) # (B, C, H, W)
|
100 |
+
prob = output_prob[0][0].cpu().numpy()
|
101 |
+
dist = output_dist[0].cpu().numpy()
|
102 |
+
|
103 |
+
|
104 |
+
dist = np.transpose(dist,(1,2,0))
|
105 |
+
dist = np.maximum(1e-3, dist)
|
106 |
+
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
|
107 |
+
|
108 |
+
coord = dist_to_coord(disti,points)
|
109 |
+
|
110 |
+
star_label = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
|
111 |
+
tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), star_label)
|
112 |
+
overlay = visualize_instances_map(pre_img_data,star_label)
|
113 |
+
cv2.imwrite(join(overlay_path, img_name.split('.')[0]+'.png'), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
main()
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gputools==0.2.13
|
2 |
+
h5py==3.7.0
|
3 |
+
huggingface-hub==0.10.1
|
4 |
+
imagecodecs
|
5 |
+
imageio==2.22.2
|
6 |
+
importlib-metadata==5.0.0
|
7 |
+
kiwisolver==1.4.4
|
8 |
+
llvmlite==0.39.1
|
9 |
+
Mako==1.2.3
|
10 |
+
Markdown==3.4.1
|
11 |
+
MarkupSafe==2.1.1
|
12 |
+
matplotlib==3.6.1
|
13 |
+
mkl-fft==1.3.1
|
14 |
+
mkl-service==2.4.0
|
15 |
+
monai==1.0.0
|
16 |
+
networkx==2.8.7
|
17 |
+
numba==0.56.3
|
18 |
+
numexpr
|
19 |
+
numpy
|
20 |
+
oauthlib==3.2.2
|
21 |
+
opencv-python==4.6.0.66
|
22 |
+
packaging
|
23 |
+
pandas==1.4.4
|
24 |
+
Pillow==9.2.0
|
25 |
+
scikit-image==0.19.3
|
26 |
+
scipy==1.9.2
|
27 |
+
stardist==0.8.3
|
28 |
+
tensorboard==2.10.1
|
29 |
+
tensorboard-data-server==0.6.1
|
30 |
+
tensorboard-plugin-wit==1.8.1
|
31 |
+
tifffile==2022.10.10
|
32 |
+
timm==0.6.11
|
33 |
+
torch==1.12.1
|
34 |
+
torchaudio==0.12.1
|
35 |
+
torchvision==0.13.1
|
36 |
+
tqdm==4.64.1
|
37 |
+
|
train_convnext_hover..py
ADDED
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Adapted form MONAI Tutorial: https://github.com/Project-MONAI/tutorials/tree/main/2d_segmentation/torch
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os, sys
|
9 |
+
|
10 |
+
join = os.path.join
|
11 |
+
#sys.path.append('/data2/yuxinyi/stardist_pytorch')
|
12 |
+
|
13 |
+
from tqdm import tqdm
|
14 |
+
import numpy as np
|
15 |
+
import pandas as pd
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch.nn import DataParallel
|
20 |
+
from torch.utils.data import Dataset, DataLoader
|
21 |
+
from torch.utils.tensorboard import SummaryWriter
|
22 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
|
23 |
+
from stardist import star_dist, edt_prob
|
24 |
+
from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label
|
25 |
+
from stardist import random_label_cmap, ray_angles
|
26 |
+
import monai
|
27 |
+
from collections import OrderedDict
|
28 |
+
from compute_metric import eval_tp_fp_fn, remove_boundary_cells
|
29 |
+
from monai.data import decollate_batch, PILReader
|
30 |
+
from monai.inferers import sliding_window_inference
|
31 |
+
from monai.metrics import DiceMetric
|
32 |
+
from monai.transforms import (
|
33 |
+
Activations,
|
34 |
+
AsChannelFirstd,
|
35 |
+
AddChanneld,
|
36 |
+
AsDiscrete,
|
37 |
+
CenterSpatialCropd,
|
38 |
+
Compose,
|
39 |
+
Lambdad,
|
40 |
+
LoadImaged,
|
41 |
+
# LoadImaged_modified,
|
42 |
+
SpatialPadd,
|
43 |
+
RandSpatialCropd,
|
44 |
+
RandRotate90d,
|
45 |
+
ScaleIntensityd,
|
46 |
+
RandAxisFlipd,
|
47 |
+
RandZoomd,
|
48 |
+
RandGaussianNoised,
|
49 |
+
RandAdjustContrastd,
|
50 |
+
RandGaussianSmoothd,
|
51 |
+
RandHistogramShiftd,
|
52 |
+
EnsureTyped,
|
53 |
+
EnsureType,
|
54 |
+
apply_transform,
|
55 |
+
)
|
56 |
+
from monai.visualize import plot_2d_or_3d_image
|
57 |
+
import matplotlib.pyplot as plt
|
58 |
+
from datetime import datetime
|
59 |
+
import shutil
|
60 |
+
from skimage import io
|
61 |
+
from skimage.color import gray2rgb
|
62 |
+
|
63 |
+
from models.unetr2d import UNETR2D
|
64 |
+
from models.swin_unetr import SwinUNETR
|
65 |
+
from models.flexible_unet_convext import FlexibleUNet_hv
|
66 |
+
|
67 |
+
from utils import cropping_center, gen_targets, xentropy_loss, dice_loss, mse_loss, msge_loss
|
68 |
+
|
69 |
+
import warnings
|
70 |
+
warnings.filterwarnings("ignore")
|
71 |
+
|
72 |
+
print("Successfully imported all requirements!")
|
73 |
+
torch.backends.cudnn.enabled = False
|
74 |
+
|
75 |
+
def rm_n_mkdir(dir_path):
|
76 |
+
"""Remove and make directory."""
|
77 |
+
if os.path.isdir(dir_path):
|
78 |
+
shutil.rmtree(dir_path)
|
79 |
+
os.makedirs(dir_path)
|
80 |
+
|
81 |
+
class HoverDataset(Dataset):
|
82 |
+
def __init__(self, data, transform, mask_shape):
|
83 |
+
self.data = data
|
84 |
+
self.transform = transform
|
85 |
+
self.mask_shape = mask_shape
|
86 |
+
|
87 |
+
def __len__(self) -> int:
|
88 |
+
return len(self.data)
|
89 |
+
|
90 |
+
def _transform(self, index):
|
91 |
+
data_i = self.data[index]
|
92 |
+
return apply_transform(self.transform, data_i) if self.transform is not None else data_i
|
93 |
+
|
94 |
+
def __getitem__(self, index):
|
95 |
+
ret = self._transform(index)
|
96 |
+
# print(target_dict['img'].dtype, target_dict['label'].dtype)
|
97 |
+
# gen targets
|
98 |
+
inst_map = np.squeeze(ret['label'].numpy()).astype('int32') # 1HW -> HW
|
99 |
+
target_dict = gen_targets(inst_map, inst_map.shape[:2]) # original code: self.mask_shape -> current code: aug_size
|
100 |
+
np_map, hv_map = target_dict['np_map'], target_dict['hv_map']
|
101 |
+
np_map = cropping_center(np_map, self.mask_shape) # HW
|
102 |
+
hv_map = cropping_center(hv_map, self.mask_shape) # HW2
|
103 |
+
target_dict['np_map'] = torch.tensor(np_map)
|
104 |
+
target_dict['hv_map'] = torch.tensor(hv_map)
|
105 |
+
# centercrop img
|
106 |
+
img = cropping_center(ret['img'].permute(1,2,0), self.mask_shape).permute(2,0,1) # CHW -> HWC -> CHW
|
107 |
+
ret['img'] = img
|
108 |
+
ret.update(target_dict)
|
109 |
+
return ret
|
110 |
+
|
111 |
+
def valid_step(model, batch_data):
|
112 |
+
|
113 |
+
model.eval() # infer mode
|
114 |
+
|
115 |
+
####
|
116 |
+
imgs = batch_data["img"]
|
117 |
+
true_np = batch_data["np_map"]
|
118 |
+
true_hv = batch_data["hv_map"]
|
119 |
+
|
120 |
+
imgs_gpu = imgs.to("cuda").type(torch.float32) # NCHW
|
121 |
+
|
122 |
+
# HWC
|
123 |
+
true_np = torch.squeeze(true_np).type(torch.int64)
|
124 |
+
true_hv = torch.squeeze(true_hv).type(torch.float32)
|
125 |
+
|
126 |
+
true_dict = {
|
127 |
+
"np": true_np,
|
128 |
+
"hv": true_hv,
|
129 |
+
}
|
130 |
+
|
131 |
+
# --------------------------------------------------------------
|
132 |
+
with torch.no_grad(): # dont compute gradient
|
133 |
+
preds = model(imgs_gpu)
|
134 |
+
pred_dict = {'np': preds[1], 'hv': preds[0]}
|
135 |
+
pred_dict = OrderedDict(
|
136 |
+
[[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()]
|
137 |
+
)
|
138 |
+
pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1]
|
139 |
+
|
140 |
+
# * Its up to user to define the protocol to process the raw output per step!
|
141 |
+
result_dict = { # protocol for contents exchange within `raw`
|
142 |
+
"raw": {
|
143 |
+
"imgs": imgs.numpy(),
|
144 |
+
"true_np": true_dict["np"].numpy(),
|
145 |
+
"true_hv": true_dict["hv"].numpy(),
|
146 |
+
"prob_np": pred_dict["np"].cpu().numpy(),
|
147 |
+
"pred_hv": pred_dict["hv"].cpu().numpy(),
|
148 |
+
}
|
149 |
+
}
|
150 |
+
|
151 |
+
return result_dict
|
152 |
+
|
153 |
+
def proc_valid_step_output(raw_data, nr_types=None):
|
154 |
+
|
155 |
+
track_dict = {}
|
156 |
+
|
157 |
+
def _dice_info(true, pred, label):
|
158 |
+
true = np.array(true == label, np.int32)
|
159 |
+
pred = np.array(pred == label, np.int32)
|
160 |
+
inter = (pred * true).sum()
|
161 |
+
total = (pred + true).sum()
|
162 |
+
return inter, total
|
163 |
+
|
164 |
+
over_inter = 0
|
165 |
+
over_total = 0
|
166 |
+
over_correct = 0
|
167 |
+
prob_np = raw_data["prob_np"]
|
168 |
+
true_np = raw_data["true_np"]
|
169 |
+
for idx in range(len(raw_data["true_np"])):
|
170 |
+
patch_prob_np = prob_np[idx]
|
171 |
+
patch_true_np = true_np[idx]
|
172 |
+
patch_pred_np = np.array(patch_prob_np > 0.5, dtype=np.int32)
|
173 |
+
inter, total = _dice_info(patch_true_np, patch_pred_np, 1)
|
174 |
+
correct = (patch_pred_np == patch_true_np).sum()
|
175 |
+
over_inter += inter
|
176 |
+
over_total += total
|
177 |
+
over_correct += correct
|
178 |
+
nr_pixels = len(true_np) * np.size(true_np[0])
|
179 |
+
acc_np = over_correct / nr_pixels
|
180 |
+
dice_np = 2 * over_inter / (over_total + 1.0e-8)
|
181 |
+
track_dict['np_acc'] = acc_np
|
182 |
+
track_dict['np_dice'] = dice_np
|
183 |
+
|
184 |
+
# * HV regression statistic
|
185 |
+
pred_hv = raw_data["pred_hv"]
|
186 |
+
true_hv = raw_data["true_hv"]
|
187 |
+
|
188 |
+
over_squared_error = 0
|
189 |
+
for idx in range(len(raw_data["true_np"])):
|
190 |
+
patch_pred_hv = pred_hv[idx]
|
191 |
+
patch_true_hv = true_hv[idx]
|
192 |
+
squared_error = patch_pred_hv - patch_true_hv
|
193 |
+
squared_error = squared_error * squared_error
|
194 |
+
over_squared_error += squared_error.sum()
|
195 |
+
mse = over_squared_error / nr_pixels
|
196 |
+
track_dict['hv_mse'] = mse
|
197 |
+
|
198 |
+
return track_dict
|
199 |
+
|
200 |
+
def main():
|
201 |
+
|
202 |
+
# class Args:
|
203 |
+
# def __init__(self, data_path, seed, num_workers, model_name, input_size, mask_size, batch_size, max_epochs,
|
204 |
+
# val_interval, save_interval, initial_lr, gpu_id, n_rays):
|
205 |
+
# self.data_path = data_path
|
206 |
+
# self.seed = seed
|
207 |
+
# self.num_workers = num_workers
|
208 |
+
# self.model_name = model_name
|
209 |
+
# self.input_size = input_size
|
210 |
+
# self.mask_size = mask_size
|
211 |
+
# self.batch_size = batch_size
|
212 |
+
# self.max_epochs = max_epochs
|
213 |
+
# self.val_interval = val_interval
|
214 |
+
# self.save_interval = save_interval
|
215 |
+
# self.initial_lr = initial_lr
|
216 |
+
# self.gpu_id = gpu_id
|
217 |
+
# self.n_rays = n_rays
|
218 |
+
|
219 |
+
# args = Args('/data2/yuxinyi/stardist_pytorch/dataset/class3_seed2', 2022, 4, 'efficientunet', 512, 256, 16, 600,
|
220 |
+
# 1, 10, 1e-4, '4', 32)
|
221 |
+
modelname = 'star-hover'
|
222 |
+
strategy = 'aug256_out256'
|
223 |
+
parser = argparse.ArgumentParser("Baseline for Microscopy image segmentation")
|
224 |
+
# Dataset parameters
|
225 |
+
parser.add_argument(
|
226 |
+
"--data_path",
|
227 |
+
default=f"/mntnfs/med_data5/louwei/consep/",
|
228 |
+
type=str,
|
229 |
+
help="training data path; subfolders: images, labels",
|
230 |
+
)
|
231 |
+
parser.add_argument("--seed", default=10, type=int)
|
232 |
+
# parser.add_argument("--resume", default=False, help="resume from checkpoint")
|
233 |
+
parser.add_argument("--num_workers", default=4, type=int)
|
234 |
+
|
235 |
+
# Model parameters
|
236 |
+
parser.add_argument(
|
237 |
+
"--model_name", default="efficientunet", help="select mode: unet, unetr, swinunetr"
|
238 |
+
)
|
239 |
+
parser.add_argument("--input_size", default=512, type=int, help="after rand crop")
|
240 |
+
parser.add_argument("--mask_size", default=256, type=int, help="after gen target")
|
241 |
+
# Training parameters
|
242 |
+
parser.add_argument("--batch_size", default=12, type=int, help="Batch size per GPU")
|
243 |
+
parser.add_argument("--max_epochs", default=800, type=int)
|
244 |
+
parser.add_argument("--val_interval", default=1, type=int)
|
245 |
+
parser.add_argument("--save_interval", default=10, type=int)
|
246 |
+
parser.add_argument("--initial_lr", type=float, default=1e-4, help="learning rate")
|
247 |
+
parser.add_argument('--gpu_id', type=str, default='0', help='gpu id')
|
248 |
+
|
249 |
+
args = parser.parse_args()
|
250 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
|
251 |
+
|
252 |
+
work_dir = f'/mntnfs/med_data5/louwei/hover_stardist/class_{modelname}_{strategy}'
|
253 |
+
|
254 |
+
# monai.config.print_config()
|
255 |
+
pre_trained = False
|
256 |
+
# %% set training/validation split
|
257 |
+
np.random.seed(args.seed)
|
258 |
+
model_path = join(work_dir)
|
259 |
+
rm_n_mkdir(model_path)
|
260 |
+
run_id = datetime.now().strftime("%Y%m%d-%H%M")
|
261 |
+
shutil.copyfile(
|
262 |
+
__file__, join(model_path, run_id + "_" + os.path.basename(__file__))
|
263 |
+
)
|
264 |
+
img_path = join(args.data_path, "Train/Images_3channels")
|
265 |
+
gt_path = join(args.data_path, "Train/tif")
|
266 |
+
val_img_path = join(args.data_path, "Test/Images_3channels")
|
267 |
+
val_gt_path = join(args.data_path, "Test/tif")
|
268 |
+
img_names = sorted(os.listdir(img_path))
|
269 |
+
gt_names = [img_name.replace('.png', '.tif') for img_name in img_names]
|
270 |
+
img_num = len(img_names)
|
271 |
+
val_frac = 0.1
|
272 |
+
val_img_names = sorted(os.listdir(val_img_path))
|
273 |
+
val_gt_names = [img_name.replace('.png', '.tif') for img_name in val_img_names]
|
274 |
+
|
275 |
+
train_files = [
|
276 |
+
{"img": join(img_path, img_names[i]), "label": join(gt_path, gt_names[i]), 'name': img_names[i]}
|
277 |
+
for i in range(len(img_names))
|
278 |
+
]
|
279 |
+
val_files = [
|
280 |
+
{"img": join(val_img_path, val_img_names[i]), "label": join(val_gt_path, val_gt_names[i]),
|
281 |
+
'name': val_img_names[i]}
|
282 |
+
for i in range(len(val_img_names))
|
283 |
+
]
|
284 |
+
print(
|
285 |
+
f"training image num: {len(train_files)}, validation image num: {len(val_files)}"
|
286 |
+
)
|
287 |
+
|
288 |
+
def load_img(img):
|
289 |
+
ret = io.imread(img)
|
290 |
+
if len(ret.shape) == 2:
|
291 |
+
ret = gray2rgb(ret)
|
292 |
+
return ret.astype('float32')
|
293 |
+
|
294 |
+
def load_ann(ann):
|
295 |
+
ret = np.squeeze(io.imread(ann)).astype('float32')
|
296 |
+
return ret
|
297 |
+
|
298 |
+
# %% define transforms for image and segmentation
|
299 |
+
train_transforms = Compose(
|
300 |
+
[
|
301 |
+
Lambdad(('img',), load_img),
|
302 |
+
Lambdad(('label',), load_ann),
|
303 |
+
# LoadImaged(
|
304 |
+
# keys=["img", "label"], reader=PILReader, dtype=np.float32
|
305 |
+
# ), # image three channels (H, W, 3); label: (H, W)
|
306 |
+
AddChanneld(keys=["label"], allow_missing_keys=True), # label: (1, H, W)
|
307 |
+
AsChannelFirstd(
|
308 |
+
keys=["img"], channel_dim=-1, allow_missing_keys=True
|
309 |
+
), # image: (3, H, W)
|
310 |
+
# ScaleIntensityd(
|
311 |
+
# keys=["img"], allow_missing_keys=True
|
312 |
+
# ), # Do not scale label
|
313 |
+
# SpatialPadd(keys=["img", "label"], spatial_size=args.input_size),
|
314 |
+
# RandSpatialCropd(
|
315 |
+
# keys=["img", "label"], roi_size=args.input_size, random_size=False
|
316 |
+
# ),
|
317 |
+
RandAxisFlipd(keys=["img", "label"], prob=0.5),
|
318 |
+
RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]),
|
319 |
+
# # intensity transform
|
320 |
+
RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1),
|
321 |
+
RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)),
|
322 |
+
RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)),
|
323 |
+
RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3),
|
324 |
+
RandZoomd(
|
325 |
+
keys=["img", "label"],
|
326 |
+
prob=0.15,
|
327 |
+
min_zoom=0.5,
|
328 |
+
max_zoom=2.0,
|
329 |
+
mode=["area", "nearest"],
|
330 |
+
),
|
331 |
+
EnsureTyped(keys=["img", "label"]),
|
332 |
+
]
|
333 |
+
)
|
334 |
+
|
335 |
+
val_transforms = Compose(
|
336 |
+
[
|
337 |
+
Lambdad(('img',), load_img),
|
338 |
+
Lambdad(('label',), load_ann),
|
339 |
+
# LoadImaged(keys=["img", "label"], reader=PILReader, dtype=np.float32),
|
340 |
+
AddChanneld(keys=["label"], allow_missing_keys=True),
|
341 |
+
AsChannelFirstd(keys=["img"], channel_dim=-1, allow_missing_keys=True),
|
342 |
+
# ScaleIntensityd(keys=["img"], allow_missing_keys=True),
|
343 |
+
# AsDiscreted(keys=['label'], to_onehot=3),
|
344 |
+
# CenterSpatialCropd(
|
345 |
+
# keys=["img", "label"], roi_size=args.input_size
|
346 |
+
# ),
|
347 |
+
EnsureTyped(keys=["img", "label"]),
|
348 |
+
]
|
349 |
+
)
|
350 |
+
|
351 |
+
# % define dataset, data loader
|
352 |
+
# check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
|
353 |
+
check_ds = HoverDataset(data=train_files, transform=train_transforms, mask_shape=(args.mask_size, args.mask_size))
|
354 |
+
print(len(check_ds))
|
355 |
+
tmp = check_ds[0]
|
356 |
+
print(tmp['img'].shape, tmp['label'].shape, tmp['hv_map'].shape, tmp['np_map'].shape)
|
357 |
+
check_loader = DataLoader(check_ds, batch_size=1, num_workers=4)
|
358 |
+
check_data = monai.utils.misc.first(check_loader)
|
359 |
+
print(
|
360 |
+
"sanity check:",
|
361 |
+
check_data["img"].shape,
|
362 |
+
torch.max(check_data["img"]),
|
363 |
+
check_data["label"].shape,
|
364 |
+
torch.max(check_data["label"]),
|
365 |
+
check_data["hv_map"].shape,
|
366 |
+
torch.max(check_data["hv_map"]),
|
367 |
+
check_data["np_map"].shape,
|
368 |
+
torch.max(check_data["np_map"]),
|
369 |
+
)
|
370 |
+
|
371 |
+
# %% create a training data loader
|
372 |
+
# train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
|
373 |
+
train_ds = HoverDataset(data=train_files, transform=train_transforms, mask_shape=(args.mask_size, args.mask_size))
|
374 |
+
print(len(train_ds))
|
375 |
+
# example = train_ds[0]
|
376 |
+
# plt.imshow(np.array(example['img']).transpose(1,2,0).astype('uint8'))
|
377 |
+
# plt.imshow(np.squeeze(example['np_map'].numpy()).astype('uint8'), 'gray')
|
378 |
+
# plt.imshow(example['hv_map'].numpy()[...,0])
|
379 |
+
# plt.imshow(example['hv_map'].numpy()[..., 1])
|
380 |
+
# plt.show()
|
381 |
+
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
|
382 |
+
train_loader = DataLoader(
|
383 |
+
train_ds,
|
384 |
+
batch_size=args.batch_size,
|
385 |
+
shuffle=True,
|
386 |
+
num_workers=args.num_workers,
|
387 |
+
pin_memory=torch.cuda.is_available(),
|
388 |
+
)
|
389 |
+
# create a validation data loader
|
390 |
+
# val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
|
391 |
+
val_ds = HoverDataset(data=val_files, transform=val_transforms, mask_shape=(args.mask_size, args.mask_size))
|
392 |
+
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4)
|
393 |
+
|
394 |
+
model = FlexibleUNet_hv(
|
395 |
+
in_channels=3,
|
396 |
+
out_channels=2+2,
|
397 |
+
backbone='convnext_small',
|
398 |
+
pretrained=True,
|
399 |
+
n_rays=2,
|
400 |
+
prob_out_channels=2,
|
401 |
+
)
|
402 |
+
|
403 |
+
activatation = nn.ReLU()
|
404 |
+
sigmoid = nn.Sigmoid()
|
405 |
+
initial_lr = args.initial_lr
|
406 |
+
optimizer = torch.optim.AdamW(model.parameters(), initial_lr)
|
407 |
+
scheduler = StepLR(optimizer, 100, 0.1)
|
408 |
+
#if pre_trained == True:
|
409 |
+
#print('Load pretrained weights...')
|
410 |
+
#checkpoint = torch.load('/data2/yuxinyi/stardist_pytorch/pretrained/overall/330.pth')
|
411 |
+
#model.load_state_dict(checkpoint['model_state_dict'])
|
412 |
+
# model = DataParallel(model)
|
413 |
+
model = model.to('cuda')
|
414 |
+
# start a typical PyTorch training
|
415 |
+
max_epochs = args.max_epochs
|
416 |
+
val_interval = args.val_interval
|
417 |
+
save_interval = args.save_interval
|
418 |
+
epoch_loss_values = []
|
419 |
+
writer = SummaryWriter(model_path)
|
420 |
+
|
421 |
+
#*# record loss and f1
|
422 |
+
loss_file = f'{work_dir}/train_loss.txt'
|
423 |
+
f1_file = f'{work_dir}/train_loss.txt'
|
424 |
+
if os.path.exists(loss_file):
|
425 |
+
os.remove(loss_file)
|
426 |
+
if os.path.exists(f1_file):
|
427 |
+
os.remove(f1_file)
|
428 |
+
#*#
|
429 |
+
|
430 |
+
for epoch in range(1, args.max_epochs):
|
431 |
+
model.train()
|
432 |
+
epoch_loss = 0
|
433 |
+
running_np_1, running_np_2, running_hv_1, running_hv_2 = 0.0, 0.0, 0.0, 0.0
|
434 |
+
stream = tqdm(train_loader)
|
435 |
+
for step, batch_data in enumerate(stream, start=1):
|
436 |
+
|
437 |
+
#*# hv map
|
438 |
+
inputs, true_np, true_hv = batch_data["img"], batch_data["np_map"], batch_data['hv_map']
|
439 |
+
true_np = true_np.to("cuda").type(torch.int64) # NHW
|
440 |
+
true_hv = true_hv.to("cuda").type(torch.float32) # NHWC
|
441 |
+
true_np_onehot = (F.one_hot(true_np, num_classes=2)).type(torch.float32) # NHWC
|
442 |
+
inputs = torch.tensor(inputs).to('cuda')
|
443 |
+
# print(inputs.shape, true_np.shape, true_hv.shape)
|
444 |
+
|
445 |
+
optimizer.zero_grad()
|
446 |
+
pred_hv, pred_np = model(inputs) # NCHW
|
447 |
+
pred_hv = pred_hv.permute(0, 2, 3, 1).contiguous() # NHWC
|
448 |
+
pred_np = pred_np.permute(0, 2, 3, 1).contiguous() # NHWC
|
449 |
+
pred_np = F.softmax(pred_np, dim=-1)
|
450 |
+
|
451 |
+
# losses
|
452 |
+
loss_np_1 = xentropy_loss(true_np_onehot, pred_np) # bce
|
453 |
+
loss_np_2 = dice_loss(true_np_onehot, pred_np) # dice
|
454 |
+
loss_hv_1 = mse_loss(true_hv, pred_hv) # mse
|
455 |
+
loss_hv_2 = msge_loss(true_hv, pred_hv, true_np_onehot[...,1]) # msge
|
456 |
+
loss = loss_np_1 + loss_np_2 + loss_hv_1 + loss_hv_2
|
457 |
+
loss.backward()
|
458 |
+
optimizer.step()
|
459 |
+
epoch_loss += loss.item()
|
460 |
+
epoch_len = len(train_ds) // train_loader.batch_size
|
461 |
+
|
462 |
+
running_np_1 += loss_np_1.item()
|
463 |
+
running_np_2 += loss_np_2.item()
|
464 |
+
running_hv_1 += loss_hv_1.item()
|
465 |
+
running_hv_2 += loss_hv_2.item()
|
466 |
+
#*#
|
467 |
+
|
468 |
+
stream.set_description(
|
469 |
+
f'Epoch {epoch} | np bce: {running_np_1 / step:.4f}, np dice: {running_np_2 / step:.4f}, hv mse: {running_hv_1 / step:.4f}, hv msge: {running_hv_2 / step:.4f}')
|
470 |
+
|
471 |
+
epoch_loss /= step
|
472 |
+
epoch_loss_values.append(epoch_loss)
|
473 |
+
writer.add_scalar("train_loss", epoch_loss, epoch)
|
474 |
+
writer.add_scalar("np_bce", running_np_1 / step, epoch)
|
475 |
+
writer.add_scalar("np_dice", running_np_2 / step, epoch)
|
476 |
+
writer.add_scalar("hv_mse", running_hv_1 / step, epoch)
|
477 |
+
writer.add_scalar("hv_msge", running_hv_2 / step, epoch)
|
478 |
+
print(f"epoch {epoch} average loss: {epoch_loss:.4f}, lr: {optimizer.param_groups[0]['lr']}")
|
479 |
+
|
480 |
+
#*# record
|
481 |
+
with open(loss_file, 'a') as f:
|
482 |
+
f.write(f'Epoch{epoch}\tloss:{epoch_loss:.4f}\tnp_bce:{running_np_1/step:.4f}\tnp_dice:{running_np_2/step:.4f}\thv_mse:{running_hv_1/step:.4f}\thv_msge:{running_hv_2/step:.4f}\n')
|
483 |
+
#*#
|
484 |
+
|
485 |
+
checkpoint = {
|
486 |
+
"epoch": epoch,
|
487 |
+
"model_state_dict": model.state_dict(),
|
488 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
489 |
+
"loss": epoch_loss_values,
|
490 |
+
}
|
491 |
+
if epoch % save_interval == 0:
|
492 |
+
torch.save(checkpoint, join(model_path, str(epoch) + ".pth"))
|
493 |
+
|
494 |
+
running_np_acc, running_np_dice, running_hv_mse = 0.0, 0.0, 0.0
|
495 |
+
stream_val = tqdm(val_loader)
|
496 |
+
for step, batch_data in enumerate(stream_val, start=1):
|
497 |
+
raw_data = valid_step(model, batch_data)['raw']
|
498 |
+
track_dict = proc_valid_step_output(raw_data)
|
499 |
+
running_np_acc += track_dict['np_acc']
|
500 |
+
running_np_dice += track_dict['np_dice']
|
501 |
+
running_hv_mse += track_dict['hv_mse']
|
502 |
+
stream.set_description(f'Epoch {epoch} | np acc: {running_np_acc / step:.4f}, np dice: {running_np_dice / step:.4f}, hv mse: {running_hv_mse / step:.4f}')
|
503 |
+
writer.add_scalar("np_acc", running_np_acc / step, epoch)
|
504 |
+
writer.add_scalar("np_dice", running_np_dice / step, epoch)
|
505 |
+
writer.add_scalar("hv_mse", running_hv_mse / step, epoch)
|
506 |
+
print(f'Epoch {epoch} | np acc: {running_np_acc / step:.4f}, np dice: {running_np_dice / step:.4f}, hv mse: {running_hv_mse / step:.4f}')
|
507 |
+
|
508 |
+
#*# record
|
509 |
+
with open(loss_file, 'a') as f:
|
510 |
+
f.write(f'Validation | Epoch{epoch}\tloss:{epoch_loss:.4f}\tnp_acc:{running_np_acc/step:.4f}\tnp_dice:{running_np_dice/step:.4f}\thv_mse:{running_hv_mse/step:.4f}\n')
|
511 |
+
#*#
|
512 |
+
|
513 |
+
scheduler.step()
|
514 |
+
|
515 |
+
if __name__ == "__main__":
|
516 |
+
main()
|
train_convnext_stardist.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Adapted form MONAI Tutorial: https://github.com/Project-MONAI/tutorials/tree/main/2d_segmentation/torch
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
|
10 |
+
join = os.path.join
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
from torch.utils.tensorboard import SummaryWriter
|
17 |
+
from stardist import star_dist,edt_prob
|
18 |
+
from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label
|
19 |
+
from stardist import random_label_cmap,ray_angles
|
20 |
+
import monai
|
21 |
+
from collections import OrderedDict
|
22 |
+
from compute_metric import eval_tp_fp_fn,remove_boundary_cells
|
23 |
+
from monai.data import decollate_batch, PILReader
|
24 |
+
from monai.inferers import sliding_window_inference
|
25 |
+
from monai.metrics import DiceMetric
|
26 |
+
from monai.transforms import (
|
27 |
+
Activations,
|
28 |
+
AsChannelFirstd,
|
29 |
+
AddChanneld,
|
30 |
+
AsDiscrete,
|
31 |
+
Compose,
|
32 |
+
LoadImaged,
|
33 |
+
SpatialPadd,
|
34 |
+
RandSpatialCropd,
|
35 |
+
RandRotate90d,
|
36 |
+
ScaleIntensityd,
|
37 |
+
RandAxisFlipd,
|
38 |
+
RandZoomd,
|
39 |
+
RandGaussianNoised,
|
40 |
+
RandAdjustContrastd,
|
41 |
+
RandGaussianSmoothd,
|
42 |
+
RandHistogramShiftd,
|
43 |
+
EnsureTyped,
|
44 |
+
EnsureType,
|
45 |
+
)
|
46 |
+
from monai.visualize import plot_2d_or_3d_image
|
47 |
+
import matplotlib.pyplot as plt
|
48 |
+
from datetime import datetime
|
49 |
+
import shutil
|
50 |
+
import tqdm
|
51 |
+
from models.unetr2d import UNETR2D
|
52 |
+
from models.swin_unetr import SwinUNETR
|
53 |
+
from models.flexible_unet import FlexibleUNet
|
54 |
+
from models.flexible_unet_convext import FlexibleUNetConvext
|
55 |
+
print("Successfully imported all requirements!")
|
56 |
+
torch.backends.cudnn.enabled =False
|
57 |
+
|
58 |
+
def main():
|
59 |
+
parser = argparse.ArgumentParser("Baseline for Microscopy image segmentation")
|
60 |
+
# Dataset parameters
|
61 |
+
parser.add_argument(
|
62 |
+
"--data_path",
|
63 |
+
default="/data2/liuchenyu/external_processed/split",
|
64 |
+
type=str,
|
65 |
+
help="training data path; subfolders: images, labels",
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--work_dir", default="/data/louwei/nips_comp/convnext_fold0", help="path where to save models and logs"
|
69 |
+
)
|
70 |
+
parser.add_argument("--seed", default=2022, type=int)
|
71 |
+
# parser.add_argument("--resume", default=False, help="resume from checkpoint")
|
72 |
+
parser.add_argument("--num_workers", default=8, type=int)
|
73 |
+
parser.add_argument("--local_rank", type=int)
|
74 |
+
# Model parameters
|
75 |
+
parser.add_argument(
|
76 |
+
"--model_name", default="efficientunet", help="select mode: unet, unetr, swinunetr"
|
77 |
+
)
|
78 |
+
parser.add_argument("--num_class", default=3, type=int, help="segmentation classes")
|
79 |
+
parser.add_argument(
|
80 |
+
"--input_size", default=512, type=int, help="segmentation classes"
|
81 |
+
)
|
82 |
+
# Training parameters
|
83 |
+
parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU")
|
84 |
+
parser.add_argument("--max_epochs", default=2000, type=int)
|
85 |
+
parser.add_argument("--val_interval", default=5, type=int)
|
86 |
+
parser.add_argument("--epoch_tolerance", default=100, type=int)
|
87 |
+
parser.add_argument("--initial_lr", type=float, default=1e-4, help="learning rate")
|
88 |
+
|
89 |
+
args = parser.parse_args()
|
90 |
+
torch.cuda.set_device(args.local_rank)
|
91 |
+
torch.distributed.init_process_group(backend='nccl')
|
92 |
+
monai.config.print_config()
|
93 |
+
n_rays = 32
|
94 |
+
pre_trained = True
|
95 |
+
#%% set training/validation split
|
96 |
+
np.random.seed(args.seed)
|
97 |
+
model_path = join(args.work_dir, args.model_name + "_3class")
|
98 |
+
os.makedirs(model_path, exist_ok=True)
|
99 |
+
run_id = datetime.now().strftime("%Y%m%d-%H%M")
|
100 |
+
# This must be change every runing time ! ! ! ! ! ! ! ! ! ! !
|
101 |
+
model_file = "models/flexible_unet_convext.py"
|
102 |
+
shutil.copyfile(
|
103 |
+
__file__, join(model_path, os.path.basename(__file__))
|
104 |
+
)
|
105 |
+
shutil.copyfile(
|
106 |
+
model_file, join(model_path, os.path.basename(model_file))
|
107 |
+
)
|
108 |
+
all_image_path = '/data/louwei/nips_comp/train_cellpose_multi0/'
|
109 |
+
all_img_path = join(all_image_path, "train/images")
|
110 |
+
all_gt_path = join(all_image_path, "train/tif")
|
111 |
+
|
112 |
+
all_img_names = sorted(os.listdir(all_img_path))
|
113 |
+
all_gt_names = [img_name.split(".")[0] + ".tif" for img_name in all_img_names]
|
114 |
+
all_img_files = [join(all_img_path, all_img_names[i]) for i in range(len(all_img_names))]
|
115 |
+
all_gt_files = [join(all_gt_path, all_gt_names[i]) for i in range(len(all_img_names))]
|
116 |
+
img_path = join(args.data_path, "train/images")
|
117 |
+
gt_path = join(args.data_path, "train/tif")
|
118 |
+
val_img_path = join(args.data_path, "test/images")
|
119 |
+
val_gt_path = join(args.data_path, "test/tif")
|
120 |
+
img_names = sorted(os.listdir(img_path))
|
121 |
+
gt_names = [img_name.split(".")[0] + ".tif" for img_name in img_names]
|
122 |
+
train_img_files = [join(img_path, img_names[i]) for i in range(len(img_names))]
|
123 |
+
train_gt_files = [join(gt_path, gt_names[i]) for i in range(len(img_names))]
|
124 |
+
cat_img_files = train_img_files + all_img_files
|
125 |
+
cat_gt_files = train_gt_files + all_gt_files
|
126 |
+
img_num = len(img_names)
|
127 |
+
val_frac = 0.1
|
128 |
+
val_img_names = sorted(os.listdir(val_img_path))
|
129 |
+
val_gt_names = [img_name.split(".")[0] + ".tif" for img_name in val_img_names]
|
130 |
+
#indices = np.arange(img_num)
|
131 |
+
#np.random.shuffle(indices)
|
132 |
+
#val_split = int(img_num * val_frac)
|
133 |
+
#train_indices = indices[val_split:]
|
134 |
+
#val_indices = indices[:val_split]
|
135 |
+
|
136 |
+
train_files = [
|
137 |
+
{"img": cat_img_files[i], "label": cat_gt_files[i]}
|
138 |
+
for i in range(len(cat_img_files))
|
139 |
+
]
|
140 |
+
val_files = [
|
141 |
+
{"img": join(val_img_path, val_img_names[i]), "label": join(val_gt_path, val_gt_names[i])}
|
142 |
+
for i in range(len(val_img_names))
|
143 |
+
]
|
144 |
+
print(
|
145 |
+
f"training image num: {len(train_files)}, validation image num: {len(val_files)}"
|
146 |
+
)
|
147 |
+
#%% define transforms for image and segmentation
|
148 |
+
train_transforms = Compose(
|
149 |
+
[
|
150 |
+
LoadImaged(
|
151 |
+
keys=["img", "label"], reader=PILReader, dtype=np.float32
|
152 |
+
), # image three channels (H, W, 3); label: (H, W)
|
153 |
+
AddChanneld(keys=["label"], allow_missing_keys=True), # label: (1, H, W)
|
154 |
+
AsChannelFirstd(
|
155 |
+
keys=["img"], channel_dim=-1, allow_missing_keys=True
|
156 |
+
), # image: (3, H, W)
|
157 |
+
#ScaleIntensityd(
|
158 |
+
#keys=["img"], allow_missing_keys=True
|
159 |
+
#), # Do not scale label
|
160 |
+
SpatialPadd(keys=["img", "label"], spatial_size=args.input_size),
|
161 |
+
RandSpatialCropd(
|
162 |
+
keys=["img", "label"], roi_size=args.input_size, random_size=False
|
163 |
+
),
|
164 |
+
RandAxisFlipd(keys=["img", "label"], prob=0.5),
|
165 |
+
RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]),
|
166 |
+
# # intensity transform
|
167 |
+
RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1),
|
168 |
+
RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)),
|
169 |
+
RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)),
|
170 |
+
RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3),
|
171 |
+
RandZoomd(
|
172 |
+
keys=["img", "label"],
|
173 |
+
prob=0.15,
|
174 |
+
min_zoom=0.5,
|
175 |
+
max_zoom=2,
|
176 |
+
mode=["area", "nearest"],
|
177 |
+
),
|
178 |
+
EnsureTyped(keys=["img", "label"]),
|
179 |
+
]
|
180 |
+
)
|
181 |
+
|
182 |
+
val_transforms = Compose(
|
183 |
+
[
|
184 |
+
LoadImaged(keys=["img", "label"], reader=PILReader, dtype=np.float32),
|
185 |
+
AddChanneld(keys=["label"], allow_missing_keys=True),
|
186 |
+
AsChannelFirstd(keys=["img"], channel_dim=-1, allow_missing_keys=True),
|
187 |
+
#ScaleIntensityd(keys=["img"], allow_missing_keys=True),
|
188 |
+
# AsDiscreted(keys=['label'], to_onehot=3),
|
189 |
+
EnsureTyped(keys=["img", "label"]),
|
190 |
+
]
|
191 |
+
)
|
192 |
+
|
193 |
+
#% define dataset, data loader
|
194 |
+
check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
|
195 |
+
check_loader = DataLoader(check_ds, batch_size=1, num_workers=4)
|
196 |
+
check_data = monai.utils.misc.first(check_loader)
|
197 |
+
print(
|
198 |
+
"sanity check:",
|
199 |
+
check_data["img"].shape,
|
200 |
+
torch.max(check_data["img"]),
|
201 |
+
check_data["label"].shape,
|
202 |
+
torch.max(check_data["label"]),
|
203 |
+
)
|
204 |
+
|
205 |
+
#%% create a training data loader
|
206 |
+
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
|
207 |
+
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
|
208 |
+
train_loader = DataLoader(
|
209 |
+
train_ds,
|
210 |
+
batch_size=args.batch_size,
|
211 |
+
shuffle=True,
|
212 |
+
num_workers=args.num_workers,
|
213 |
+
pin_memory=torch.cuda.is_available(),
|
214 |
+
)
|
215 |
+
# create a validation data loader
|
216 |
+
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
|
217 |
+
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=1)
|
218 |
+
|
219 |
+
dice_metric = DiceMetric(
|
220 |
+
include_background=False, reduction="mean", get_not_nans=False
|
221 |
+
)
|
222 |
+
|
223 |
+
post_pred = Compose(
|
224 |
+
[EnsureType(), Activations(softmax=True), AsDiscrete(threshold=0.5)]
|
225 |
+
)
|
226 |
+
post_gt = Compose([EnsureType(), AsDiscrete(to_onehot=None)])
|
227 |
+
# create UNet, DiceLoss and Adam optimizer
|
228 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
229 |
+
if args.model_name.lower() == "unet":
|
230 |
+
model = monai.networks.nets.UNet(
|
231 |
+
spatial_dims=2,
|
232 |
+
in_channels=3,
|
233 |
+
out_channels=args.num_class,
|
234 |
+
channels=(16, 32, 64, 128, 256),
|
235 |
+
strides=(2, 2, 2, 2),
|
236 |
+
num_res_units=2,
|
237 |
+
).to(device)
|
238 |
+
|
239 |
+
if args.model_name.lower() == "efficientunet":
|
240 |
+
model = FlexibleUNetConvext(
|
241 |
+
in_channels=3,
|
242 |
+
out_channels=n_rays+1,
|
243 |
+
backbone='convnext_small',
|
244 |
+
pretrained=True,
|
245 |
+
).to(device)
|
246 |
+
|
247 |
+
if args.model_name.lower() == "swinunetr":
|
248 |
+
model = SwinUNETR(
|
249 |
+
img_size=(args.input_size, args.input_size),
|
250 |
+
in_channels=3,
|
251 |
+
out_channels=n_rays+1,
|
252 |
+
feature_size=24, # should be divisible by 12
|
253 |
+
spatial_dims=2,
|
254 |
+
).to(device)
|
255 |
+
|
256 |
+
#loss_masked_dice = monai.losses.DiceCELoss(softmax=True)
|
257 |
+
loss_dice = monai.losses.DiceLoss(squared_pred=True,jaccard=True)
|
258 |
+
loss_bce = nn.BCELoss()
|
259 |
+
loss_dist_mae = nn.L1Loss()
|
260 |
+
activatation = nn.ReLU()
|
261 |
+
sigmoid = nn.Sigmoid()
|
262 |
+
#loss_dist_mae = monai.losses.DiceCELoss(softmax=True)
|
263 |
+
initial_lr = args.initial_lr
|
264 |
+
encoder = list(map(id, model.encoder.parameters()))
|
265 |
+
base_params = filter(lambda p: id(p) not in encoder, model.parameters())
|
266 |
+
params = [
|
267 |
+
{"params": base_params, "lr":initial_lr},
|
268 |
+
{"params": model.encoder.parameters(), "lr": initial_lr * 0.1},
|
269 |
+
]
|
270 |
+
optimizer = torch.optim.AdamW(params, initial_lr)
|
271 |
+
#if pre_trained == True:
|
272 |
+
#print('Load pretrained weights...')
|
273 |
+
#checkpoint = torch.load('/mntnfs/med_data5/louwei/nips_comp/swin_stardist/swinunetr_3class/40.pth', map_location=torch.device(device))
|
274 |
+
#model.load_state_dict(checkpoint['model_state_dict'])
|
275 |
+
# start a typical PyTorch training
|
276 |
+
#checkpoint = torch.load("/data2/liuchenyu/log/convnextsmall/efficientunet_3class/510.pth", map_location=torch.device(device))
|
277 |
+
#model.load_state_dict(checkpoint['model_state_dict'])
|
278 |
+
print('distributed model')
|
279 |
+
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
|
280 |
+
print('successful model')
|
281 |
+
max_epochs = args.max_epochs
|
282 |
+
epoch_tolerance = args.epoch_tolerance
|
283 |
+
val_interval = args.val_interval
|
284 |
+
best_metric = -1
|
285 |
+
best_metric_epoch = -1
|
286 |
+
epoch_loss_values = list()
|
287 |
+
metric_values = list()
|
288 |
+
writer = SummaryWriter(model_path)
|
289 |
+
max_f1 = 0
|
290 |
+
for epoch in range(0, max_epochs):
|
291 |
+
model.train()
|
292 |
+
epoch_loss = 0
|
293 |
+
epoch_loss_prob = 0
|
294 |
+
epoch_loss_dist_2 = 0
|
295 |
+
epoch_loss_dist_1 = 0
|
296 |
+
for step, batch_data in enumerate(tqdm.tqdm(train_loader), 1):
|
297 |
+
inputs, labels = batch_data["img"],batch_data["label"]
|
298 |
+
print(step)
|
299 |
+
processes_labels = []
|
300 |
+
|
301 |
+
for i in range(labels.shape[0]):
|
302 |
+
label = labels[i][0]
|
303 |
+
distances = star_dist(label,n_rays)
|
304 |
+
distances = np.transpose(distances,(2,0,1))
|
305 |
+
#print(distances.shape)
|
306 |
+
obj_probabilities = edt_prob(label.astype(int))
|
307 |
+
obj_probabilities = np.expand_dims(obj_probabilities,0)
|
308 |
+
#print(obj_probabilities.shape)
|
309 |
+
final_label = np.concatenate((distances,obj_probabilities),axis=0)
|
310 |
+
#print(final_label.shape)
|
311 |
+
processes_labels.append(final_label)
|
312 |
+
|
313 |
+
labels = np.stack(processes_labels)
|
314 |
+
|
315 |
+
#print(inputs.shape,labels.shape)
|
316 |
+
inputs, labels = torch.tensor(inputs).to(device), torch.tensor(labels).to(device)
|
317 |
+
#print(inputs.shape,labels.shape)
|
318 |
+
optimizer.zero_grad()
|
319 |
+
output_dist,output_prob = model(inputs)
|
320 |
+
#print(outputs.shape)
|
321 |
+
dist_output = output_dist
|
322 |
+
prob_output = output_prob
|
323 |
+
dist_label = labels[:,:n_rays,:,:]
|
324 |
+
prob_label = torch.unsqueeze(labels[:,-1,:,:], 1)
|
325 |
+
#print(dist_output.shape,prob_output.shape,dist_label.shape)
|
326 |
+
#labels_onehot = monai.networks.one_hot(
|
327 |
+
#labels, args.num_class
|
328 |
+
#) # (b,cls,256,256)
|
329 |
+
#print(prob_label.max(),prob_label.min())
|
330 |
+
loss_dist_1 = loss_dice(dist_output*prob_label,dist_label*prob_label)
|
331 |
+
#print(loss_dist_1)
|
332 |
+
loss_prob = loss_bce(prob_output,prob_label)
|
333 |
+
#print(prob_label.shape,dist_output.shape)
|
334 |
+
loss_dist_2 = loss_dist_mae(dist_output*prob_label,dist_label*prob_label)
|
335 |
+
#print(loss_dist_2)
|
336 |
+
loss = loss_prob + loss_dist_2*0.3 + loss_dist_1
|
337 |
+
loss.backward()
|
338 |
+
optimizer.step()
|
339 |
+
epoch_loss += loss.item()
|
340 |
+
epoch_loss_prob += loss_prob.item()
|
341 |
+
epoch_loss_dist_2 += loss_dist_2.item()
|
342 |
+
epoch_loss_dist_1 += loss_dist_1.item()
|
343 |
+
epoch_len = len(train_ds) // train_loader.batch_size
|
344 |
+
|
345 |
+
epoch_loss /= step
|
346 |
+
epoch_loss_prob /= step
|
347 |
+
epoch_loss_dist_2 /= step
|
348 |
+
epoch_loss_dist_1 /= step
|
349 |
+
epoch_loss_values.append(epoch_loss)
|
350 |
+
print(f"epoch {epoch} average loss: {epoch_loss:.4f}")
|
351 |
+
writer.add_scalar("train_loss", epoch_loss, epoch)
|
352 |
+
print('dist dice: '+str(epoch_loss_dist_1)+' dist mae: '+str(epoch_loss_dist_2)+' prob bce: '+str(epoch_loss_prob))
|
353 |
+
checkpoint = {
|
354 |
+
"epoch": epoch,
|
355 |
+
"model_state_dict": model.module.state_dict(),
|
356 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
357 |
+
"loss": epoch_loss_values,
|
358 |
+
}
|
359 |
+
if epoch < 8:
|
360 |
+
continue
|
361 |
+
if epoch > 1 and epoch % val_interval == 0:
|
362 |
+
torch.save(checkpoint, join(model_path, str(epoch) + ".pth"))
|
363 |
+
model.eval()
|
364 |
+
with torch.no_grad():
|
365 |
+
val_images = None
|
366 |
+
val_labels = None
|
367 |
+
val_outputs = None
|
368 |
+
seg_metric = OrderedDict()
|
369 |
+
seg_metric['F1_Score'] = []
|
370 |
+
for val_data in tqdm.tqdm(val_loader):
|
371 |
+
val_images, val_labels = val_data["img"].to(device), val_data[
|
372 |
+
"label"
|
373 |
+
].to(device)
|
374 |
+
roi_size = (512, 512)
|
375 |
+
sw_batch_size = 4
|
376 |
+
output_dist,output_prob = sliding_window_inference(
|
377 |
+
val_images, roi_size, sw_batch_size, model
|
378 |
+
)
|
379 |
+
val_labels = val_labels[0][0].cpu().numpy()
|
380 |
+
prob = output_prob[0][0].cpu().numpy()
|
381 |
+
dist = output_dist[0].cpu().numpy()
|
382 |
+
#print(val_labels.shape,prob.shape,dist.shape)
|
383 |
+
dist = np.transpose(dist,(1,2,0))
|
384 |
+
dist = np.maximum(1e-3, dist)
|
385 |
+
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
|
386 |
+
|
387 |
+
coord = dist_to_coord(disti,points)
|
388 |
+
|
389 |
+
star_label = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
|
390 |
+
gt = remove_boundary_cells(val_labels.astype(np.int32))
|
391 |
+
seg = remove_boundary_cells(star_label.astype(np.int32))
|
392 |
+
tp, fp, fn = eval_tp_fp_fn(gt, seg, threshold=0.5)
|
393 |
+
if tp == 0:
|
394 |
+
precision = 0
|
395 |
+
recall = 0
|
396 |
+
f1 = 0
|
397 |
+
else:
|
398 |
+
precision = tp / (tp + fp)
|
399 |
+
recall = tp / (tp + fn)
|
400 |
+
f1 = 2*(precision * recall)/ (precision + recall)
|
401 |
+
f1 = np.round(f1, 4)
|
402 |
+
seg_metric['F1_Score'].append(np.round(f1, 4))
|
403 |
+
avg_f1 = np.mean(seg_metric['F1_Score'])
|
404 |
+
writer.add_scalar("val_f1score", avg_f1, epoch)
|
405 |
+
if avg_f1 > max_f1:
|
406 |
+
max_f1 = avg_f1
|
407 |
+
print(str(epoch) + 'f1 score: ' + str(max_f1))
|
408 |
+
torch.save(checkpoint, join(model_path, "best_model.pth"))
|
409 |
+
np.savez_compressed(
|
410 |
+
join(model_path, "train_log.npz"),
|
411 |
+
val_dice=metric_values,
|
412 |
+
epoch_loss=epoch_loss_values,
|
413 |
+
)
|
414 |
+
|
415 |
+
|
416 |
+
if __name__ == "__main__":
|
417 |
+
main()
|
utils.py
ADDED
@@ -0,0 +1,868 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
import warnings
|
13 |
+
from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
|
14 |
+
|
15 |
+
import cv2
|
16 |
+
import math
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import colorsys
|
21 |
+
import itertools
|
22 |
+
import matplotlib.pyplot as plt
|
23 |
+
from matplotlib import cm
|
24 |
+
|
25 |
+
from monai.data.meta_tensor import MetaTensor
|
26 |
+
from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
|
27 |
+
from monai.transforms import Resize
|
28 |
+
from monai.utils import (
|
29 |
+
BlendMode,
|
30 |
+
PytorchPadMode,
|
31 |
+
convert_data_type,
|
32 |
+
convert_to_dst_type,
|
33 |
+
ensure_tuple,
|
34 |
+
fall_back_tuple,
|
35 |
+
look_up_option,
|
36 |
+
optional_import,
|
37 |
+
)
|
38 |
+
|
39 |
+
from scipy import ndimage
|
40 |
+
from scipy.ndimage.filters import gaussian_filter
|
41 |
+
from scipy.ndimage.interpolation import affine_transform, map_coordinates
|
42 |
+
|
43 |
+
from skimage import morphology as morph
|
44 |
+
from scipy.ndimage import filters, measurements
|
45 |
+
from scipy.ndimage.morphology import (
|
46 |
+
binary_dilation,
|
47 |
+
binary_fill_holes,
|
48 |
+
distance_transform_cdt,
|
49 |
+
distance_transform_edt,
|
50 |
+
)
|
51 |
+
|
52 |
+
from skimage.segmentation import watershed
|
53 |
+
from skimage.exposure import rescale_intensity
|
54 |
+
from skimage.filters import sobel_h, sobel_v, gaussian
|
55 |
+
from skimage.morphology import disk, binary_opening
|
56 |
+
|
57 |
+
tqdm, _ = optional_import("tqdm", name="tqdm")
|
58 |
+
|
59 |
+
__all__ = ["sliding_window_inference"]
|
60 |
+
|
61 |
+
####
|
62 |
+
def normalize(mask, dtype=np.uint8):
|
63 |
+
return (255 * mask / np.amax(mask)).astype(dtype)
|
64 |
+
|
65 |
+
def fix_mirror_padding(ann):
|
66 |
+
"""Deal with duplicated instances due to mirroring in interpolation
|
67 |
+
during shape augmentation (scale, rotation etc.).
|
68 |
+
|
69 |
+
"""
|
70 |
+
current_max_id = np.amax(ann)
|
71 |
+
inst_list = list(np.unique(ann))
|
72 |
+
if 0 in inst_list:
|
73 |
+
inst_list.remove(0) # 0 is background
|
74 |
+
for inst_id in inst_list:
|
75 |
+
inst_map = np.array(ann == inst_id, np.uint8)
|
76 |
+
remapped_ids = measurements.label(inst_map)[0]
|
77 |
+
remapped_ids[remapped_ids > 1] += current_max_id
|
78 |
+
ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
|
79 |
+
current_max_id = np.amax(ann)
|
80 |
+
return ann
|
81 |
+
|
82 |
+
####
|
83 |
+
def get_bounding_box(img):
|
84 |
+
"""Get bounding box coordinate information."""
|
85 |
+
rows = np.any(img, axis=1)
|
86 |
+
cols = np.any(img, axis=0)
|
87 |
+
rmin, rmax = np.where(rows)[0][[0, -1]]
|
88 |
+
cmin, cmax = np.where(cols)[0][[0, -1]]
|
89 |
+
# due to python indexing, need to add 1 to max
|
90 |
+
# else accessing will be 1px in the box, not out
|
91 |
+
rmax += 1
|
92 |
+
cmax += 1
|
93 |
+
return [rmin, rmax, cmin, cmax]
|
94 |
+
|
95 |
+
|
96 |
+
####
|
97 |
+
def cropping_center(x, crop_shape, batch=False):
|
98 |
+
"""Crop an input image at the centre.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
x: input array
|
102 |
+
crop_shape: dimensions of cropped array
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
x: cropped array
|
106 |
+
|
107 |
+
"""
|
108 |
+
orig_shape = x.shape
|
109 |
+
if not batch:
|
110 |
+
h0 = int((orig_shape[0] - crop_shape[0]) * 0.5)
|
111 |
+
w0 = int((orig_shape[1] - crop_shape[1]) * 0.5)
|
112 |
+
x = x[h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]]
|
113 |
+
else:
|
114 |
+
h0 = int((orig_shape[1] - crop_shape[0]) * 0.5)
|
115 |
+
w0 = int((orig_shape[2] - crop_shape[1]) * 0.5)
|
116 |
+
x = x[:, h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]]
|
117 |
+
return x
|
118 |
+
|
119 |
+
def gen_instance_hv_map(ann, crop_shape):
|
120 |
+
"""Input annotation must be of original shape.
|
121 |
+
|
122 |
+
The map is calculated only for instances within the crop portion
|
123 |
+
but based on the original shape in original image.
|
124 |
+
|
125 |
+
Perform following operation:
|
126 |
+
Obtain the horizontal and vertical distance maps for each
|
127 |
+
nuclear instance.
|
128 |
+
|
129 |
+
"""
|
130 |
+
orig_ann = ann.copy() # instance ID map
|
131 |
+
fixed_ann = fix_mirror_padding(orig_ann)
|
132 |
+
# re-cropping with fixed instance id map
|
133 |
+
crop_ann = cropping_center(fixed_ann, crop_shape)
|
134 |
+
# TODO: deal with 1 label warning
|
135 |
+
crop_ann = morph.remove_small_objects(crop_ann, min_size=30)
|
136 |
+
|
137 |
+
x_map = np.zeros(orig_ann.shape[:2], dtype=np.float32)
|
138 |
+
y_map = np.zeros(orig_ann.shape[:2], dtype=np.float32)
|
139 |
+
|
140 |
+
inst_list = list(np.unique(crop_ann))
|
141 |
+
if 0 in inst_list:
|
142 |
+
inst_list.remove(0) # 0 is background
|
143 |
+
for inst_id in inst_list:
|
144 |
+
inst_map = np.array(fixed_ann == inst_id, np.uint8)
|
145 |
+
inst_box = get_bounding_box(inst_map) # rmin, rmax, cmin, cmax
|
146 |
+
|
147 |
+
# expand the box by 2px
|
148 |
+
# Because we first pad the ann at line 207, the bboxes
|
149 |
+
# will remain valid after expansion
|
150 |
+
inst_box[0] -= 2
|
151 |
+
inst_box[2] -= 2
|
152 |
+
inst_box[1] += 2
|
153 |
+
inst_box[3] += 2
|
154 |
+
|
155 |
+
# fix inst_box
|
156 |
+
inst_box[0] = max(inst_box[0], 0)
|
157 |
+
inst_box[2] = max(inst_box[2], 0)
|
158 |
+
# inst_box[1] = min(inst_box[1], fixed_ann.shape[0])
|
159 |
+
# inst_box[3] = min(inst_box[3], fixed_ann.shape[1])
|
160 |
+
|
161 |
+
inst_map = inst_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
|
162 |
+
|
163 |
+
if inst_map.shape[0] < 2 or inst_map.shape[1] < 2:
|
164 |
+
print(f'inst_map.shape < 2: {inst_map.shape}, {inst_box}, {get_bounding_box(np.array(fixed_ann == inst_id, np.uint8))}')
|
165 |
+
continue
|
166 |
+
|
167 |
+
# instance center of mass, rounded to nearest pixel
|
168 |
+
inst_com = list(measurements.center_of_mass(inst_map))
|
169 |
+
if np.isnan(measurements.center_of_mass(inst_map)).any():
|
170 |
+
print(inst_id, fixed_ann.shape, np.array(fixed_ann == inst_id, np.uint8).shape)
|
171 |
+
print(get_bounding_box(np.array(fixed_ann == inst_id, np.uint8)))
|
172 |
+
print(inst_map)
|
173 |
+
print(inst_list)
|
174 |
+
print(inst_box)
|
175 |
+
print(np.count_nonzero(np.array(fixed_ann == inst_id, np.uint8)))
|
176 |
+
|
177 |
+
inst_com[0] = int(inst_com[0] + 0.5)
|
178 |
+
inst_com[1] = int(inst_com[1] + 0.5)
|
179 |
+
|
180 |
+
inst_x_range = np.arange(1, inst_map.shape[1] + 1)
|
181 |
+
inst_y_range = np.arange(1, inst_map.shape[0] + 1)
|
182 |
+
# shifting center of pixels grid to instance center of mass
|
183 |
+
inst_x_range -= inst_com[1]
|
184 |
+
inst_y_range -= inst_com[0]
|
185 |
+
|
186 |
+
inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range)
|
187 |
+
|
188 |
+
# remove coord outside of instance
|
189 |
+
inst_x[inst_map == 0] = 0
|
190 |
+
inst_y[inst_map == 0] = 0
|
191 |
+
inst_x = inst_x.astype("float32")
|
192 |
+
inst_y = inst_y.astype("float32")
|
193 |
+
|
194 |
+
# normalize min into -1 scale
|
195 |
+
if np.min(inst_x) < 0:
|
196 |
+
inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0])
|
197 |
+
if np.min(inst_y) < 0:
|
198 |
+
inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0])
|
199 |
+
# normalize max into +1 scale
|
200 |
+
if np.max(inst_x) > 0:
|
201 |
+
inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0])
|
202 |
+
if np.max(inst_y) > 0:
|
203 |
+
inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0])
|
204 |
+
|
205 |
+
####
|
206 |
+
x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
|
207 |
+
x_map_box[inst_map > 0] = inst_x[inst_map > 0]
|
208 |
+
|
209 |
+
y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
|
210 |
+
y_map_box[inst_map > 0] = inst_y[inst_map > 0]
|
211 |
+
|
212 |
+
hv_map = np.dstack([x_map, y_map])
|
213 |
+
return hv_map
|
214 |
+
|
215 |
+
def remove_small_objects(pred, min_size=64, connectivity=1):
|
216 |
+
"""Remove connected components smaller than the specified size.
|
217 |
+
|
218 |
+
This function is taken from skimage.morphology.remove_small_objects, but the warning
|
219 |
+
is removed when a single label is provided.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
pred: input labelled array
|
223 |
+
min_size: minimum size of instance in output array
|
224 |
+
connectivity: The connectivity defining the neighborhood of a pixel.
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
out: output array with instances removed under min_size
|
228 |
+
|
229 |
+
"""
|
230 |
+
out = pred
|
231 |
+
|
232 |
+
if min_size == 0: # shortcut for efficiency
|
233 |
+
return out
|
234 |
+
|
235 |
+
if out.dtype == bool:
|
236 |
+
selem = ndimage.generate_binary_structure(pred.ndim, connectivity)
|
237 |
+
ccs = np.zeros_like(pred, dtype=np.int32)
|
238 |
+
ndimage.label(pred, selem, output=ccs)
|
239 |
+
else:
|
240 |
+
ccs = out
|
241 |
+
|
242 |
+
try:
|
243 |
+
component_sizes = np.bincount(ccs.ravel())
|
244 |
+
except ValueError:
|
245 |
+
raise ValueError(
|
246 |
+
"Negative value labels are not supported. Try "
|
247 |
+
"relabeling the input with `scipy.ndimage.label` or "
|
248 |
+
"`skimage.morphology.label`."
|
249 |
+
)
|
250 |
+
|
251 |
+
too_small = component_sizes < min_size
|
252 |
+
too_small_mask = too_small[ccs]
|
253 |
+
out[too_small_mask] = 0
|
254 |
+
|
255 |
+
return out
|
256 |
+
|
257 |
+
####
|
258 |
+
def gen_targets(ann, crop_shape, **kwargs):
|
259 |
+
"""Generate the targets for the network."""
|
260 |
+
hv_map = gen_instance_hv_map(ann, crop_shape)
|
261 |
+
np_map = ann.copy()
|
262 |
+
np_map[np_map > 0] = 1
|
263 |
+
|
264 |
+
hv_map = cropping_center(hv_map, crop_shape)
|
265 |
+
np_map = cropping_center(np_map, crop_shape)
|
266 |
+
|
267 |
+
target_dict = {
|
268 |
+
"hv_map": hv_map,
|
269 |
+
"np_map": np_map,
|
270 |
+
}
|
271 |
+
|
272 |
+
return target_dict
|
273 |
+
|
274 |
+
####
|
275 |
+
def xentropy_loss(true, pred, reduction="mean"):
|
276 |
+
"""Cross entropy loss. Assumes NHWC!
|
277 |
+
|
278 |
+
Args:
|
279 |
+
pred: prediction array
|
280 |
+
true: ground truth array
|
281 |
+
|
282 |
+
Returns:
|
283 |
+
cross entropy loss
|
284 |
+
|
285 |
+
"""
|
286 |
+
epsilon = 10e-8
|
287 |
+
# scale preds so that the class probs of each sample sum to 1
|
288 |
+
pred = pred / torch.sum(pred, -1, keepdim=True)
|
289 |
+
# manual computation of crossentropy
|
290 |
+
pred = torch.clamp(pred, epsilon, 1.0 - epsilon)
|
291 |
+
loss = -torch.sum((true * torch.log(pred)), -1, keepdim=True)
|
292 |
+
loss = loss.mean() if reduction == "mean" else loss.sum()
|
293 |
+
return loss
|
294 |
+
|
295 |
+
|
296 |
+
####
|
297 |
+
def dice_loss(true, pred, smooth=1e-3):
|
298 |
+
"""`pred` and `true` must be of torch.float32. Assuming of shape NxHxWxC."""
|
299 |
+
inse = torch.sum(pred * true, (0, 1, 2))
|
300 |
+
l = torch.sum(pred, (0, 1, 2))
|
301 |
+
r = torch.sum(true, (0, 1, 2))
|
302 |
+
loss = 1.0 - (2.0 * inse + smooth) / (l + r + smooth)
|
303 |
+
loss = torch.sum(loss)
|
304 |
+
return loss
|
305 |
+
|
306 |
+
|
307 |
+
####
|
308 |
+
def mse_loss(true, pred):
|
309 |
+
"""Calculate mean squared error loss.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
true: ground truth of combined horizontal
|
313 |
+
and vertical maps
|
314 |
+
pred: prediction of combined horizontal
|
315 |
+
and vertical maps
|
316 |
+
|
317 |
+
Returns:
|
318 |
+
loss: mean squared error
|
319 |
+
|
320 |
+
"""
|
321 |
+
loss = pred - true
|
322 |
+
loss = (loss * loss).mean()
|
323 |
+
return loss
|
324 |
+
|
325 |
+
|
326 |
+
####
|
327 |
+
def msge_loss(true, pred, focus):
|
328 |
+
"""Calculate the mean squared error of the gradients of
|
329 |
+
horizontal and vertical map predictions. Assumes
|
330 |
+
channel 0 is Vertical and channel 1 is Horizontal.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
true: ground truth of combined horizontal
|
334 |
+
and vertical maps
|
335 |
+
pred: prediction of combined horizontal
|
336 |
+
and vertical maps
|
337 |
+
focus: area where to apply loss (we only calculate
|
338 |
+
the loss within the nuclei)
|
339 |
+
|
340 |
+
Returns:
|
341 |
+
loss: mean squared error of gradients
|
342 |
+
|
343 |
+
"""
|
344 |
+
|
345 |
+
def get_sobel_kernel(size):
|
346 |
+
"""Get sobel kernel with a given size."""
|
347 |
+
assert size % 2 == 1, "Must be odd, get size=%d" % size
|
348 |
+
|
349 |
+
h_range = torch.arange(
|
350 |
+
-size // 2 + 1,
|
351 |
+
size // 2 + 1,
|
352 |
+
dtype=torch.float32,
|
353 |
+
device="cuda",
|
354 |
+
requires_grad=False,
|
355 |
+
)
|
356 |
+
v_range = torch.arange(
|
357 |
+
-size // 2 + 1,
|
358 |
+
size // 2 + 1,
|
359 |
+
dtype=torch.float32,
|
360 |
+
device="cuda",
|
361 |
+
requires_grad=False,
|
362 |
+
)
|
363 |
+
h, v = torch.meshgrid(h_range, v_range)
|
364 |
+
kernel_h = h / (h * h + v * v + 1.0e-15)
|
365 |
+
kernel_v = v / (h * h + v * v + 1.0e-15)
|
366 |
+
return kernel_h, kernel_v
|
367 |
+
|
368 |
+
####
|
369 |
+
def get_gradient_hv(hv):
|
370 |
+
"""For calculating gradient."""
|
371 |
+
kernel_h, kernel_v = get_sobel_kernel(5)
|
372 |
+
kernel_h = kernel_h.view(1, 1, 5, 5) # constant
|
373 |
+
kernel_v = kernel_v.view(1, 1, 5, 5) # constant
|
374 |
+
|
375 |
+
h_ch = hv[..., 0].unsqueeze(1) # Nx1xHxW
|
376 |
+
v_ch = hv[..., 1].unsqueeze(1) # Nx1xHxW
|
377 |
+
|
378 |
+
# can only apply in NCHW mode
|
379 |
+
h_dh_ch = F.conv2d(h_ch, kernel_h, padding=2)
|
380 |
+
v_dv_ch = F.conv2d(v_ch, kernel_v, padding=2)
|
381 |
+
dhv = torch.cat([h_dh_ch, v_dv_ch], dim=1)
|
382 |
+
dhv = dhv.permute(0, 2, 3, 1).contiguous() # to NHWC
|
383 |
+
return dhv
|
384 |
+
|
385 |
+
focus = (focus[..., None]).float() # assume input NHW
|
386 |
+
focus = torch.cat([focus, focus], axis=-1)
|
387 |
+
true_grad = get_gradient_hv(true)
|
388 |
+
pred_grad = get_gradient_hv(pred)
|
389 |
+
loss = pred_grad - true_grad
|
390 |
+
loss = focus * (loss * loss)
|
391 |
+
# artificial reduce_mean with focused region
|
392 |
+
loss = loss.sum() / (focus.sum() + 1.0e-8)
|
393 |
+
return loss
|
394 |
+
|
395 |
+
|
396 |
+
def __proc_np_hv(pred, np_thres, ksize, overall_thres, obj_size_thres):
|
397 |
+
"""Process Nuclei Prediction with XY Coordinate Map.
|
398 |
+
|
399 |
+
Args:
|
400 |
+
pred: prediction output, assuming
|
401 |
+
channel 0 contain probability map of nuclei
|
402 |
+
channel 1 containing the regressed X-map
|
403 |
+
channel 2 containing the regressed Y-map
|
404 |
+
|
405 |
+
"""
|
406 |
+
pred = np.array(pred, dtype=np.float32)
|
407 |
+
|
408 |
+
blb_raw = pred[..., 0]
|
409 |
+
h_dir_raw = pred[..., 1]
|
410 |
+
v_dir_raw = pred[..., 2]
|
411 |
+
|
412 |
+
# processing
|
413 |
+
blb = np.array(blb_raw >= np_thres, dtype=np.int32)
|
414 |
+
|
415 |
+
blb = measurements.label(blb)[0]
|
416 |
+
blb = remove_small_objects(blb, min_size=10)
|
417 |
+
blb[blb > 0] = 1 # background is 0 already
|
418 |
+
|
419 |
+
h_dir = cv2.normalize(
|
420 |
+
h_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
|
421 |
+
)
|
422 |
+
v_dir = cv2.normalize(
|
423 |
+
v_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
|
424 |
+
)
|
425 |
+
|
426 |
+
sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize)
|
427 |
+
sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize)
|
428 |
+
|
429 |
+
sobelh = 1 - (
|
430 |
+
cv2.normalize(
|
431 |
+
sobelh, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
|
432 |
+
)
|
433 |
+
)
|
434 |
+
sobelv = 1 - (
|
435 |
+
cv2.normalize(
|
436 |
+
sobelv, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
|
437 |
+
)
|
438 |
+
)
|
439 |
+
|
440 |
+
overall = np.maximum(sobelh, sobelv)
|
441 |
+
overall = overall - (1 - blb)
|
442 |
+
overall[overall < 0] = 0
|
443 |
+
|
444 |
+
dist = (1.0 - overall) * blb
|
445 |
+
## nuclei values form mountains so inverse to get basins
|
446 |
+
dist = -cv2.GaussianBlur(dist, (3, 3), 0)
|
447 |
+
|
448 |
+
overall = np.array(overall >= overall_thres, dtype=np.int32)
|
449 |
+
|
450 |
+
marker = blb - overall
|
451 |
+
marker[marker < 0] = 0
|
452 |
+
marker = binary_fill_holes(marker).astype("uint8")
|
453 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
454 |
+
marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel)
|
455 |
+
marker = measurements.label(marker)[0]
|
456 |
+
marker = remove_small_objects(marker, min_size=obj_size_thres)
|
457 |
+
|
458 |
+
proced_pred = watershed(dist, markers=marker, mask=blb)
|
459 |
+
|
460 |
+
return proced_pred
|
461 |
+
|
462 |
+
def __proc_np_hv_2(pred, np_thres=0.5, ksize=21, overall_thres=0.4, obj_size_thres=10):
|
463 |
+
"""Process Nuclei Prediction with XY Coordinate Map.
|
464 |
+
|
465 |
+
Args:
|
466 |
+
pred: prediction output, assuming
|
467 |
+
channel 0 contain probability map of nuclei
|
468 |
+
channel 1 containing the regressed X-map
|
469 |
+
channel 2 containing the regressed Y-map
|
470 |
+
|
471 |
+
"""
|
472 |
+
pred = np.array(pred, dtype=np.float32)
|
473 |
+
|
474 |
+
blb_raw = pred[..., 0]
|
475 |
+
h_dir_raw = pred[..., 1]
|
476 |
+
v_dir_raw = pred[..., 2]
|
477 |
+
|
478 |
+
# processing
|
479 |
+
blb = np.array(blb_raw >= np_thres, dtype=np.int32)
|
480 |
+
|
481 |
+
blb = measurements.label(blb)[0]
|
482 |
+
blb = remove_small_objects(blb, min_size=10)
|
483 |
+
blb[blb > 0] = 1 # background is 0 already
|
484 |
+
|
485 |
+
h_dir = rescale_intensity(h_dir_raw, out_range=(0, 1)).astype('float32')
|
486 |
+
v_dir = rescale_intensity(v_dir_raw, out_range=(0, 1)).astype('float32')
|
487 |
+
|
488 |
+
sobelh = sobel_v(h_dir).astype('float64')
|
489 |
+
sobelv = sobel_h(v_dir).astype('float64')
|
490 |
+
|
491 |
+
sobelh = 1 - rescale_intensity(sobelh, out_range=(0, 1)).astype('float32')
|
492 |
+
sobelv = 1 - rescale_intensity(sobelv, out_range=(0, 1)).astype('float32')
|
493 |
+
|
494 |
+
overall = np.maximum(sobelh, sobelv)
|
495 |
+
overall = overall - (1 - blb)
|
496 |
+
overall[overall < 0] = 0
|
497 |
+
|
498 |
+
dist = (1.0 - overall) * blb
|
499 |
+
## nuclei values form mountains so inverse to get basins
|
500 |
+
dist = - gaussian(dist, sigma=0.8)
|
501 |
+
|
502 |
+
overall = np.array(overall >= overall_thres, dtype=np.int32)
|
503 |
+
|
504 |
+
marker = blb - overall
|
505 |
+
marker[marker < 0] = 0
|
506 |
+
marker = binary_fill_holes(marker).astype("uint8")
|
507 |
+
kernel = disk(2)
|
508 |
+
marker = binary_opening(marker, kernel)
|
509 |
+
marker = measurements.label(marker)[0]
|
510 |
+
marker = remove_small_objects(marker, min_size=obj_size_thres)
|
511 |
+
|
512 |
+
proced_pred = watershed(dist, markers=marker, mask=blb)
|
513 |
+
|
514 |
+
return proced_pred
|
515 |
+
|
516 |
+
|
517 |
+
####
|
518 |
+
def colorize(ch, vmin, vmax):
|
519 |
+
"""Will clamp value value outside the provided range to vmax and vmin."""
|
520 |
+
cmap = plt.get_cmap("jet")
|
521 |
+
ch = np.squeeze(ch.astype("float32"))
|
522 |
+
vmin = vmin if vmin is not None else ch.min()
|
523 |
+
vmax = vmax if vmax is not None else ch.max()
|
524 |
+
ch[ch > vmax] = vmax # clamp value
|
525 |
+
ch[ch < vmin] = vmin
|
526 |
+
ch = (ch - vmin) / (vmax - vmin + 1.0e-16)
|
527 |
+
# take RGB from RGBA heat map
|
528 |
+
ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8")
|
529 |
+
return ch_cmap
|
530 |
+
|
531 |
+
|
532 |
+
####
|
533 |
+
def random_colors(N, bright=True):
|
534 |
+
"""Generate random colors.
|
535 |
+
|
536 |
+
To get visually distinct colors, generate them in HSV space then
|
537 |
+
convert to RGB.
|
538 |
+
"""
|
539 |
+
brightness = 1.0 if bright else 0.7
|
540 |
+
hsv = [(i / N, 1, brightness) for i in range(N)]
|
541 |
+
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
|
542 |
+
random.shuffle(colors)
|
543 |
+
return colors
|
544 |
+
|
545 |
+
|
546 |
+
####
|
547 |
+
def visualize_instances_map(
|
548 |
+
input_image, inst_map, type_map=None, type_colour=None, line_thickness=2
|
549 |
+
):
|
550 |
+
"""Overlays segmentation results on image as contours.
|
551 |
+
|
552 |
+
Args:
|
553 |
+
input_image: input image
|
554 |
+
inst_map: instance mask with unique value for every object
|
555 |
+
type_map: type mask with unique value for every class
|
556 |
+
type_colour: a dict of {type : colour} , `type` is from 0-N
|
557 |
+
and `colour` is a tuple of (R, G, B)
|
558 |
+
line_thickness: line thickness of contours
|
559 |
+
|
560 |
+
Returns:
|
561 |
+
overlay: output image with segmentation overlay as contours
|
562 |
+
"""
|
563 |
+
overlay = np.copy((input_image).astype(np.uint8))
|
564 |
+
|
565 |
+
inst_list = list(np.unique(inst_map)) # get list of instances
|
566 |
+
inst_list.remove(0) # remove background
|
567 |
+
|
568 |
+
inst_rng_colors = random_colors(len(inst_list))
|
569 |
+
inst_rng_colors = np.array(inst_rng_colors) * 255
|
570 |
+
inst_rng_colors = inst_rng_colors.astype(np.uint8)
|
571 |
+
|
572 |
+
for inst_idx, inst_id in enumerate(inst_list):
|
573 |
+
inst_map_mask = np.array(inst_map == inst_id, np.uint8) # get single object
|
574 |
+
y1, y2, x1, x2 = get_bounding_box(inst_map_mask)
|
575 |
+
y1 = y1 - 2 if y1 - 2 >= 0 else y1
|
576 |
+
x1 = x1 - 2 if x1 - 2 >= 0 else x1
|
577 |
+
x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2
|
578 |
+
y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2
|
579 |
+
inst_map_crop = inst_map_mask[y1:y2, x1:x2]
|
580 |
+
contours_crop = cv2.findContours(
|
581 |
+
inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
582 |
+
)
|
583 |
+
# only has 1 instance per map, no need to check #contour detected by opencv
|
584 |
+
contours_crop = np.squeeze(
|
585 |
+
contours_crop[0][0].astype("int32")
|
586 |
+
) # * opencv protocol format may break
|
587 |
+
contours_crop += np.asarray([[x1, y1]]) # index correction
|
588 |
+
if type_map is not None:
|
589 |
+
type_map_crop = type_map[y1:y2, x1:x2]
|
590 |
+
type_id = np.unique(type_map_crop).max() # non-zero
|
591 |
+
inst_colour = type_colour[type_id]
|
592 |
+
else:
|
593 |
+
inst_colour = (inst_rng_colors[inst_idx]).tolist()
|
594 |
+
cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness)
|
595 |
+
return overlay
|
596 |
+
|
597 |
+
|
598 |
+
def sliding_window_inference(
|
599 |
+
inputs: torch.Tensor,
|
600 |
+
roi_size: Union[Sequence[int], int],
|
601 |
+
sw_batch_size: int,
|
602 |
+
predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]],
|
603 |
+
overlap: float = 0.25,
|
604 |
+
mode: Union[BlendMode, str] = BlendMode.CONSTANT,
|
605 |
+
sigma_scale: Union[Sequence[float], float] = 0.125,
|
606 |
+
padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
|
607 |
+
cval: float = 0.0,
|
608 |
+
sw_device: Union[torch.device, str, None] = None,
|
609 |
+
device: Union[torch.device, str, None] = None,
|
610 |
+
progress: bool = False,
|
611 |
+
roi_weight_map: Union[torch.Tensor, None] = None,
|
612 |
+
*args: Any,
|
613 |
+
**kwargs: Any,
|
614 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
|
615 |
+
"""
|
616 |
+
Sliding window inference on `inputs` with `predictor`.
|
617 |
+
|
618 |
+
The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
|
619 |
+
Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
|
620 |
+
e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
|
621 |
+
could be ([128,64,256], [64,32,128]).
|
622 |
+
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
|
623 |
+
an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
|
624 |
+
so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).
|
625 |
+
|
626 |
+
When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
|
627 |
+
To maintain the same spatial sizes, the output image will be cropped to the original input size.
|
628 |
+
|
629 |
+
Args:
|
630 |
+
inputs: input image to be processed (assuming NCHW[D])
|
631 |
+
roi_size: the spatial window size for inferences.
|
632 |
+
When its components have None or non-positives, the corresponding inputs dimension will be used.
|
633 |
+
if the components of the `roi_size` are non-positive values, the transform will use the
|
634 |
+
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
|
635 |
+
to `(32, 64)` if the second spatial dimension size of img is `64`.
|
636 |
+
sw_batch_size: the batch size to run window slices.
|
637 |
+
predictor: given input tensor ``patch_data`` in shape NCHW[D],
|
638 |
+
The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
|
639 |
+
with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
|
640 |
+
where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
|
641 |
+
N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
|
642 |
+
the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
|
643 |
+
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
|
644 |
+
to ensure the scaled output ROI sizes are still integers.
|
645 |
+
If the `predictor`'s input and output spatial sizes are different,
|
646 |
+
we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
|
647 |
+
overlap: Amount of overlap between scans.
|
648 |
+
mode: {``"constant"``, ``"gaussian"``}
|
649 |
+
How to blend output of overlapping windows. Defaults to ``"constant"``.
|
650 |
+
|
651 |
+
- ``"constant``": gives equal weight to all predictions.
|
652 |
+
- ``"gaussian``": gives less weight to predictions on edges of windows.
|
653 |
+
|
654 |
+
sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
|
655 |
+
Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
|
656 |
+
When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
|
657 |
+
spatial dimensions.
|
658 |
+
padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
|
659 |
+
Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
|
660 |
+
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
661 |
+
cval: fill value for 'constant' padding mode. Default: 0
|
662 |
+
sw_device: device for the window data.
|
663 |
+
By default the device (and accordingly the memory) of the `inputs` is used.
|
664 |
+
Normally `sw_device` should be consistent with the device where `predictor` is defined.
|
665 |
+
device: device for the stitched output prediction.
|
666 |
+
By default the device (and accordingly the memory) of the `inputs` is used. If for example
|
667 |
+
set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
|
668 |
+
`inputs` and `roi_size`. Output is on the `device`.
|
669 |
+
progress: whether to print a `tqdm` progress bar.
|
670 |
+
roi_weight_map: pre-computed (non-negative) weight map for each ROI.
|
671 |
+
If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
|
672 |
+
args: optional args to be passed to ``predictor``.
|
673 |
+
kwargs: optional keyword args to be passed to ``predictor``.
|
674 |
+
|
675 |
+
Note:
|
676 |
+
- input must be channel-first and have a batch dim, supports N-D sliding window.
|
677 |
+
|
678 |
+
"""
|
679 |
+
compute_dtype = inputs.dtype
|
680 |
+
num_spatial_dims = len(inputs.shape) - 2
|
681 |
+
if overlap < 0 or overlap >= 1:
|
682 |
+
raise ValueError("overlap must be >= 0 and < 1.")
|
683 |
+
|
684 |
+
# determine image spatial size and batch size
|
685 |
+
# Note: all input images must have the same image size and batch size
|
686 |
+
batch_size, _, *image_size_ = inputs.shape
|
687 |
+
|
688 |
+
if device is None:
|
689 |
+
device = inputs.device
|
690 |
+
if sw_device is None:
|
691 |
+
sw_device = inputs.device
|
692 |
+
|
693 |
+
roi_size = fall_back_tuple(roi_size, image_size_)
|
694 |
+
# in case that image size is smaller than roi size
|
695 |
+
image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
|
696 |
+
pad_size = []
|
697 |
+
for k in range(len(inputs.shape) - 1, 1, -1):
|
698 |
+
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
|
699 |
+
half = diff // 2
|
700 |
+
pad_size.extend([half, diff - half])
|
701 |
+
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
|
702 |
+
|
703 |
+
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
|
704 |
+
|
705 |
+
# Store all slices in list
|
706 |
+
slices = dense_patch_slices(image_size, roi_size, scan_interval)
|
707 |
+
num_win = len(slices) # number of windows per image
|
708 |
+
total_slices = num_win * batch_size # total number of windows
|
709 |
+
|
710 |
+
# Create window-level importance map
|
711 |
+
valid_patch_size = get_valid_patch_size(image_size, roi_size)
|
712 |
+
if valid_patch_size == roi_size and (roi_weight_map is not None):
|
713 |
+
importance_map = roi_weight_map
|
714 |
+
else:
|
715 |
+
try:
|
716 |
+
importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device)
|
717 |
+
except BaseException as e:
|
718 |
+
raise RuntimeError(
|
719 |
+
"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
|
720 |
+
) from e
|
721 |
+
importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
|
722 |
+
# handle non-positive weights
|
723 |
+
min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
|
724 |
+
importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype)
|
725 |
+
|
726 |
+
# Perform predictions
|
727 |
+
dict_key, output_image_list, count_map_list = None, [], []
|
728 |
+
_initialized_ss = -1
|
729 |
+
is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple)
|
730 |
+
|
731 |
+
# for each patch
|
732 |
+
for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size):
|
733 |
+
slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
|
734 |
+
unravel_slice = [
|
735 |
+
[slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
|
736 |
+
for idx in slice_range
|
737 |
+
]
|
738 |
+
window_data = torch.cat(
|
739 |
+
[convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice]
|
740 |
+
).to(sw_device)
|
741 |
+
seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation
|
742 |
+
|
743 |
+
# convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
|
744 |
+
seg_prob_tuple: Tuple[torch.Tensor, ...]
|
745 |
+
if isinstance(seg_prob_out, torch.Tensor):
|
746 |
+
seg_prob_tuple = (seg_prob_out,)
|
747 |
+
elif isinstance(seg_prob_out, Mapping):
|
748 |
+
if dict_key is None:
|
749 |
+
dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys
|
750 |
+
seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
|
751 |
+
is_tensor_output = False
|
752 |
+
else:
|
753 |
+
seg_prob_tuple = ensure_tuple(seg_prob_out)
|
754 |
+
is_tensor_output = False
|
755 |
+
|
756 |
+
# for each output in multi-output list
|
757 |
+
for ss, seg_prob in enumerate(seg_prob_tuple):
|
758 |
+
seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN
|
759 |
+
|
760 |
+
# compute zoom scale: out_roi_size/in_roi_size
|
761 |
+
zoom_scale = []
|
762 |
+
for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
|
763 |
+
zip(image_size, seg_prob.shape[2:], window_data.shape[2:])
|
764 |
+
):
|
765 |
+
_scale = out_w_i / float(in_w_i)
|
766 |
+
if not (img_s_i * _scale).is_integer():
|
767 |
+
warnings.warn(
|
768 |
+
f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
|
769 |
+
f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."
|
770 |
+
)
|
771 |
+
zoom_scale.append(_scale)
|
772 |
+
|
773 |
+
if _initialized_ss < ss: # init. the ss-th buffer at the first iteration
|
774 |
+
# construct multi-resolution outputs
|
775 |
+
output_classes = seg_prob.shape[1]
|
776 |
+
output_shape = [batch_size, output_classes] + [
|
777 |
+
int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale)
|
778 |
+
]
|
779 |
+
# allocate memory to store the full output and the count for overlapping parts
|
780 |
+
output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device='cpu'))
|
781 |
+
count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device='cpu'))
|
782 |
+
_initialized_ss += 1
|
783 |
+
|
784 |
+
# resizing the importance_map
|
785 |
+
resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False)
|
786 |
+
|
787 |
+
# store the result in the proper location of the full output. Apply weights from importance map.
|
788 |
+
for idx, original_idx in zip(slice_range, unravel_slice):
|
789 |
+
# zoom roi
|
790 |
+
original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image
|
791 |
+
for axis in range(2, len(original_idx_zoom)):
|
792 |
+
zoomed_start = original_idx[axis].start * zoom_scale[axis - 2]
|
793 |
+
zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
|
794 |
+
if not zoomed_start.is_integer() or (not zoomed_end.is_integer()):
|
795 |
+
warnings.warn(
|
796 |
+
f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
|
797 |
+
f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
|
798 |
+
f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
|
799 |
+
f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
|
800 |
+
f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
|
801 |
+
"Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
|
802 |
+
)
|
803 |
+
original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None)
|
804 |
+
importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype)
|
805 |
+
# store results and weights
|
806 |
+
#print(output_image_list[ss][original_idx_zoom].device,importance_map_zoom.cpu().device,seg_prob.cpu().device)
|
807 |
+
output_image_list[ss][original_idx_zoom] += importance_map_zoom.cpu() * seg_prob[idx - slice_g].cpu()
|
808 |
+
count_map_list[ss][original_idx_zoom] += (
|
809 |
+
importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape).cpu()
|
810 |
+
)
|
811 |
+
|
812 |
+
# account for any overlapping sections
|
813 |
+
for ss in range(len(output_image_list)):
|
814 |
+
output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype)
|
815 |
+
|
816 |
+
# remove padding if image_size smaller than roi_size
|
817 |
+
for ss, output_i in enumerate(output_image_list):
|
818 |
+
if torch.isnan(output_i).any() or torch.isinf(output_i).any():
|
819 |
+
warnings.warn("Sliding window inference results contain NaN or Inf.")
|
820 |
+
|
821 |
+
zoom_scale = [
|
822 |
+
seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size)
|
823 |
+
]
|
824 |
+
|
825 |
+
final_slicing: List[slice] = []
|
826 |
+
for sp in range(num_spatial_dims):
|
827 |
+
slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
|
828 |
+
slice_dim = slice(
|
829 |
+
int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])),
|
830 |
+
int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])),
|
831 |
+
)
|
832 |
+
final_slicing.insert(0, slice_dim)
|
833 |
+
while len(final_slicing) < len(output_i.shape):
|
834 |
+
final_slicing.insert(0, slice(None))
|
835 |
+
output_image_list[ss] = output_i[final_slicing]
|
836 |
+
|
837 |
+
if dict_key is not None: # if output of predictor is a dict
|
838 |
+
final_output = dict(zip(dict_key, output_image_list))
|
839 |
+
else:
|
840 |
+
final_output = tuple(output_image_list) # type: ignore
|
841 |
+
final_output = final_output[0] if is_tensor_output else final_output # type: ignore
|
842 |
+
if isinstance(inputs, MetaTensor):
|
843 |
+
final_output = convert_to_dst_type(final_output, inputs)[0] # type: ignore
|
844 |
+
return final_output
|
845 |
+
|
846 |
+
|
847 |
+
def _get_scan_interval(
|
848 |
+
image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float
|
849 |
+
) -> Tuple[int, ...]:
|
850 |
+
"""
|
851 |
+
Compute scan interval according to the image size, roi size and overlap.
|
852 |
+
Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
|
853 |
+
use 1 instead to make sure sliding window works.
|
854 |
+
|
855 |
+
"""
|
856 |
+
if len(image_size) != num_spatial_dims:
|
857 |
+
raise ValueError("image coord different from spatial dims.")
|
858 |
+
if len(roi_size) != num_spatial_dims:
|
859 |
+
raise ValueError("roi coord different from spatial dims.")
|
860 |
+
|
861 |
+
scan_interval = []
|
862 |
+
for i in range(num_spatial_dims):
|
863 |
+
if roi_size[i] == image_size[i]:
|
864 |
+
scan_interval.append(int(roi_size[i]))
|
865 |
+
else:
|
866 |
+
interval = int(roi_size[i] * (1 - overlap))
|
867 |
+
scan_interval.append(interval if interval > 0 else 1)
|
868 |
+
return tuple(scan_interval)
|
utils_modify.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
import warnings
|
13 |
+
from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from stardist.big import _grid_divisible, BlockND, OBJECT_KEYS#, repaint_labels
|
18 |
+
from stardist.matching import relabel_sequential
|
19 |
+
from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label
|
20 |
+
from stardist import random_label_cmap,ray_angles
|
21 |
+
from stardist import star_dist,edt_prob
|
22 |
+
from monai.data.meta_tensor import MetaTensor
|
23 |
+
from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
|
24 |
+
from monai.transforms import Resize
|
25 |
+
from monai.utils import (
|
26 |
+
BlendMode,
|
27 |
+
PytorchPadMode,
|
28 |
+
convert_data_type,
|
29 |
+
convert_to_dst_type,
|
30 |
+
ensure_tuple,
|
31 |
+
fall_back_tuple,
|
32 |
+
look_up_option,
|
33 |
+
optional_import,
|
34 |
+
)
|
35 |
+
|
36 |
+
tqdm, _ = optional_import("tqdm", name="tqdm")
|
37 |
+
|
38 |
+
__all__ = ["sliding_window_inference"]
|
39 |
+
|
40 |
+
|
41 |
+
def sliding_window_inference_large(inputs,block_size,min_overlap,context,roi_size,sw_batch_size,predictor,device):
|
42 |
+
|
43 |
+
h,w = inputs.shape[0],inputs.shape[1]
|
44 |
+
if h < 5000 or w < 5000:
|
45 |
+
test_tensor = torch.from_numpy(np.expand_dims(inputs, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
|
46 |
+
output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor)
|
47 |
+
prob = output_prob[0][0].cpu().numpy()
|
48 |
+
dist = output_dist[0].cpu().numpy()
|
49 |
+
dist = np.transpose(dist,(1,2,0))
|
50 |
+
dist = np.maximum(1e-3, dist)
|
51 |
+
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
|
52 |
+
|
53 |
+
coord = dist_to_coord(disti,points)
|
54 |
+
|
55 |
+
labels_out = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
|
56 |
+
else:
|
57 |
+
n = inputs.ndim
|
58 |
+
axes = 'YXC'
|
59 |
+
grid = (1,1,1)
|
60 |
+
if np.isscalar(block_size): block_size = n*[block_size]
|
61 |
+
if np.isscalar(min_overlap): min_overlap = n*[min_overlap]
|
62 |
+
if np.isscalar(context): context = n*[context]
|
63 |
+
shape_out = (inputs.shape[0],inputs.shape[1])
|
64 |
+
labels_out = np.zeros(shape_out, dtype=np.uint64)
|
65 |
+
#print(inputs.dtype)
|
66 |
+
block_size[2] = inputs.shape[2]
|
67 |
+
min_overlap[2] = context[2] = 0
|
68 |
+
block_size = tuple(_grid_divisible(g, v, name='block_size', verbose=False) for v,g,a in zip(block_size, grid,axes))
|
69 |
+
min_overlap = tuple(_grid_divisible(g, v, name='min_overlap', verbose=False) for v,g,a in zip(min_overlap,grid,axes))
|
70 |
+
context = tuple(_grid_divisible(g, v, name='context', verbose=False) for v,g,a in zip(context, grid,axes))
|
71 |
+
print(f'effective: block_size={block_size}, min_overlap={min_overlap}, context={context}', flush=True)
|
72 |
+
blocks = BlockND.cover(inputs.shape, axes, block_size, min_overlap, context)
|
73 |
+
label_offset = 1
|
74 |
+
blocks = tqdm(blocks)
|
75 |
+
for block in blocks:
|
76 |
+
image = block.read(inputs, axes=axes)
|
77 |
+
test_tensor = torch.from_numpy(np.expand_dims(image, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
|
78 |
+
output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor)
|
79 |
+
prob = output_prob[0][0].cpu().numpy()
|
80 |
+
dist = output_dist[0].cpu().numpy()
|
81 |
+
dist = np.transpose(dist,(1,2,0))
|
82 |
+
dist = np.maximum(1e-3, dist)
|
83 |
+
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
|
84 |
+
|
85 |
+
coord = dist_to_coord(disti,points)
|
86 |
+
polys = dict(coord=coord, points=points, prob=probi)
|
87 |
+
labels = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
|
88 |
+
labels = block.crop_context(labels, axes='YX')
|
89 |
+
labels, polys = block.filter_objects(labels, polys, axes='YX')
|
90 |
+
labels = relabel_sequential(labels, label_offset)[0]
|
91 |
+
if labels_out is not None:
|
92 |
+
block.write(labels_out, labels, axes='YX')
|
93 |
+
#for k,v in polys.items():
|
94 |
+
#polys_all.setdefault(k,[]).append(v)
|
95 |
+
label_offset += len(polys['prob'])
|
96 |
+
del labels
|
97 |
+
#polys_all = {k: (np.concatenate(v) if k in OBJECT_KEYS else v[0]) for k,v in polys_all.items()}
|
98 |
+
return labels_out
|
99 |
+
def sliding_window_inference(
|
100 |
+
inputs: torch.Tensor,
|
101 |
+
roi_size: Union[Sequence[int], int],
|
102 |
+
sw_batch_size: int,
|
103 |
+
predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]],
|
104 |
+
overlap: float = 0.25,
|
105 |
+
mode: Union[BlendMode, str] = BlendMode.CONSTANT,
|
106 |
+
sigma_scale: Union[Sequence[float], float] = 0.125,
|
107 |
+
padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
|
108 |
+
cval: float = 0.0,
|
109 |
+
sw_device: Union[torch.device, str, None] = None,
|
110 |
+
device: Union[torch.device, str, None] = None,
|
111 |
+
progress: bool = False,
|
112 |
+
roi_weight_map: Union[torch.Tensor, None] = None,
|
113 |
+
*args: Any,
|
114 |
+
**kwargs: Any,
|
115 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
|
116 |
+
"""
|
117 |
+
Sliding window inference on `inputs` with `predictor`.
|
118 |
+
|
119 |
+
The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
|
120 |
+
Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
|
121 |
+
e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
|
122 |
+
could be ([128,64,256], [64,32,128]).
|
123 |
+
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
|
124 |
+
an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
|
125 |
+
so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).
|
126 |
+
|
127 |
+
When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
|
128 |
+
To maintain the same spatial sizes, the output image will be cropped to the original input size.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
inputs: input image to be processed (assuming NCHW[D])
|
132 |
+
roi_size: the spatial window size for inferences.
|
133 |
+
When its components have None or non-positives, the corresponding inputs dimension will be used.
|
134 |
+
if the components of the `roi_size` are non-positive values, the transform will use the
|
135 |
+
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
|
136 |
+
to `(32, 64)` if the second spatial dimension size of img is `64`.
|
137 |
+
sw_batch_size: the batch size to run window slices.
|
138 |
+
predictor: given input tensor ``patch_data`` in shape NCHW[D],
|
139 |
+
The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
|
140 |
+
with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
|
141 |
+
where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
|
142 |
+
N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
|
143 |
+
the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
|
144 |
+
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
|
145 |
+
to ensure the scaled output ROI sizes are still integers.
|
146 |
+
If the `predictor`'s input and output spatial sizes are different,
|
147 |
+
we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
|
148 |
+
overlap: Amount of overlap between scans.
|
149 |
+
mode: {``"constant"``, ``"gaussian"``}
|
150 |
+
How to blend output of overlapping windows. Defaults to ``"constant"``.
|
151 |
+
|
152 |
+
- ``"constant``": gives equal weight to all predictions.
|
153 |
+
- ``"gaussian``": gives less weight to predictions on edges of windows.
|
154 |
+
|
155 |
+
sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
|
156 |
+
Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
|
157 |
+
When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
|
158 |
+
spatial dimensions.
|
159 |
+
padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
|
160 |
+
Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
|
161 |
+
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
162 |
+
cval: fill value for 'constant' padding mode. Default: 0
|
163 |
+
sw_device: device for the window data.
|
164 |
+
By default the device (and accordingly the memory) of the `inputs` is used.
|
165 |
+
Normally `sw_device` should be consistent with the device where `predictor` is defined.
|
166 |
+
device: device for the stitched output prediction.
|
167 |
+
By default the device (and accordingly the memory) of the `inputs` is used. If for example
|
168 |
+
set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
|
169 |
+
`inputs` and `roi_size`. Output is on the `device`.
|
170 |
+
progress: whether to print a `tqdm` progress bar.
|
171 |
+
roi_weight_map: pre-computed (non-negative) weight map for each ROI.
|
172 |
+
If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
|
173 |
+
args: optional args to be passed to ``predictor``.
|
174 |
+
kwargs: optional keyword args to be passed to ``predictor``.
|
175 |
+
|
176 |
+
Note:
|
177 |
+
- input must be channel-first and have a batch dim, supports N-D sliding window.
|
178 |
+
|
179 |
+
"""
|
180 |
+
compute_dtype = inputs.dtype
|
181 |
+
num_spatial_dims = len(inputs.shape) - 2
|
182 |
+
if overlap < 0 or overlap >= 1:
|
183 |
+
raise ValueError("overlap must be >= 0 and < 1.")
|
184 |
+
|
185 |
+
# determine image spatial size and batch size
|
186 |
+
# Note: all input images must have the same image size and batch size
|
187 |
+
batch_size, _, *image_size_ = inputs.shape
|
188 |
+
|
189 |
+
if device is None:
|
190 |
+
device = inputs.device
|
191 |
+
if sw_device is None:
|
192 |
+
sw_device = inputs.device
|
193 |
+
|
194 |
+
roi_size = fall_back_tuple(roi_size, image_size_)
|
195 |
+
# in case that image size is smaller than roi size
|
196 |
+
image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
|
197 |
+
pad_size = []
|
198 |
+
for k in range(len(inputs.shape) - 1, 1, -1):
|
199 |
+
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
|
200 |
+
half = diff // 2
|
201 |
+
pad_size.extend([half, diff - half])
|
202 |
+
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
|
203 |
+
|
204 |
+
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
|
205 |
+
|
206 |
+
# Store all slices in list
|
207 |
+
slices = dense_patch_slices(image_size, roi_size, scan_interval)
|
208 |
+
num_win = len(slices) # number of windows per image
|
209 |
+
total_slices = num_win * batch_size # total number of windows
|
210 |
+
|
211 |
+
# Create window-level importance map
|
212 |
+
valid_patch_size = get_valid_patch_size(image_size, roi_size)
|
213 |
+
if valid_patch_size == roi_size and (roi_weight_map is not None):
|
214 |
+
importance_map = roi_weight_map
|
215 |
+
else:
|
216 |
+
try:
|
217 |
+
importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device)
|
218 |
+
except BaseException as e:
|
219 |
+
raise RuntimeError(
|
220 |
+
"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
|
221 |
+
) from e
|
222 |
+
importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
|
223 |
+
# handle non-positive weights
|
224 |
+
min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
|
225 |
+
importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype)
|
226 |
+
|
227 |
+
# Perform predictions
|
228 |
+
dict_key, output_image_list, count_map_list = None, [], []
|
229 |
+
_initialized_ss = -1
|
230 |
+
is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple)
|
231 |
+
|
232 |
+
# for each patch
|
233 |
+
for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size):
|
234 |
+
slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
|
235 |
+
unravel_slice = [
|
236 |
+
[slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
|
237 |
+
for idx in slice_range
|
238 |
+
]
|
239 |
+
window_data = torch.cat(
|
240 |
+
[convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice]
|
241 |
+
).to(sw_device)
|
242 |
+
seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation
|
243 |
+
|
244 |
+
# convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
|
245 |
+
seg_prob_tuple: Tuple[torch.Tensor, ...]
|
246 |
+
if isinstance(seg_prob_out, torch.Tensor):
|
247 |
+
seg_prob_tuple = (seg_prob_out,)
|
248 |
+
elif isinstance(seg_prob_out, Mapping):
|
249 |
+
if dict_key is None:
|
250 |
+
dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys
|
251 |
+
seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
|
252 |
+
is_tensor_output = False
|
253 |
+
else:
|
254 |
+
seg_prob_tuple = ensure_tuple(seg_prob_out)
|
255 |
+
is_tensor_output = False
|
256 |
+
|
257 |
+
# for each output in multi-output list
|
258 |
+
for ss, seg_prob in enumerate(seg_prob_tuple):
|
259 |
+
seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN
|
260 |
+
|
261 |
+
# compute zoom scale: out_roi_size/in_roi_size
|
262 |
+
zoom_scale = []
|
263 |
+
for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
|
264 |
+
zip(image_size, seg_prob.shape[2:], window_data.shape[2:])
|
265 |
+
):
|
266 |
+
_scale = out_w_i / float(in_w_i)
|
267 |
+
if not (img_s_i * _scale).is_integer():
|
268 |
+
warnings.warn(
|
269 |
+
f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
|
270 |
+
f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."
|
271 |
+
)
|
272 |
+
zoom_scale.append(_scale)
|
273 |
+
|
274 |
+
if _initialized_ss < ss: # init. the ss-th buffer at the first iteration
|
275 |
+
# construct multi-resolution outputs
|
276 |
+
output_classes = seg_prob.shape[1]
|
277 |
+
output_shape = [batch_size, output_classes] + [
|
278 |
+
int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale)
|
279 |
+
]
|
280 |
+
# allocate memory to store the full output and the count for overlapping parts
|
281 |
+
output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device))
|
282 |
+
count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))
|
283 |
+
_initialized_ss += 1
|
284 |
+
|
285 |
+
# resizing the importance_map
|
286 |
+
resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False)
|
287 |
+
|
288 |
+
# store the result in the proper location of the full output. Apply weights from importance map.
|
289 |
+
for idx, original_idx in zip(slice_range, unravel_slice):
|
290 |
+
# zoom roi
|
291 |
+
original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image
|
292 |
+
for axis in range(2, len(original_idx_zoom)):
|
293 |
+
zoomed_start = original_idx[axis].start * zoom_scale[axis - 2]
|
294 |
+
zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
|
295 |
+
if not zoomed_start.is_integer() or (not zoomed_end.is_integer()):
|
296 |
+
warnings.warn(
|
297 |
+
f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
|
298 |
+
f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
|
299 |
+
f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
|
300 |
+
f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
|
301 |
+
f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
|
302 |
+
"Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
|
303 |
+
)
|
304 |
+
original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None)
|
305 |
+
importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype)
|
306 |
+
# store results and weights
|
307 |
+
output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g]
|
308 |
+
count_map_list[ss][original_idx_zoom] += (
|
309 |
+
importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape)
|
310 |
+
)
|
311 |
+
|
312 |
+
# account for any overlapping sections
|
313 |
+
for ss in range(len(output_image_list)):
|
314 |
+
output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype)
|
315 |
+
|
316 |
+
# remove padding if image_size smaller than roi_size
|
317 |
+
for ss, output_i in enumerate(output_image_list):
|
318 |
+
if torch.isnan(output_i).any() or torch.isinf(output_i).any():
|
319 |
+
warnings.warn("Sliding window inference results contain NaN or Inf.")
|
320 |
+
|
321 |
+
zoom_scale = [
|
322 |
+
seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size)
|
323 |
+
]
|
324 |
+
|
325 |
+
final_slicing: List[slice] = []
|
326 |
+
for sp in range(num_spatial_dims):
|
327 |
+
slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
|
328 |
+
slice_dim = slice(
|
329 |
+
int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])),
|
330 |
+
int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])),
|
331 |
+
)
|
332 |
+
final_slicing.insert(0, slice_dim)
|
333 |
+
while len(final_slicing) < len(output_i.shape):
|
334 |
+
final_slicing.insert(0, slice(None))
|
335 |
+
output_image_list[ss] = output_i[final_slicing]
|
336 |
+
|
337 |
+
if dict_key is not None: # if output of predictor is a dict
|
338 |
+
final_output = dict(zip(dict_key, output_image_list))
|
339 |
+
else:
|
340 |
+
final_output = tuple(output_image_list) # type: ignore
|
341 |
+
final_output = final_output[0] if is_tensor_output else final_output
|
342 |
+
|
343 |
+
if isinstance(inputs, MetaTensor):
|
344 |
+
final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore
|
345 |
+
return final_output
|
346 |
+
|
347 |
+
|
348 |
+
def _get_scan_interval(
|
349 |
+
image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float
|
350 |
+
) -> Tuple[int, ...]:
|
351 |
+
"""
|
352 |
+
Compute scan interval according to the image size, roi size and overlap.
|
353 |
+
Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
|
354 |
+
use 1 instead to make sure sliding window works.
|
355 |
+
|
356 |
+
"""
|
357 |
+
if len(image_size) != num_spatial_dims:
|
358 |
+
raise ValueError("image coord different from spatial dims.")
|
359 |
+
if len(roi_size) != num_spatial_dims:
|
360 |
+
raise ValueError("roi coord different from spatial dims.")
|
361 |
+
|
362 |
+
scan_interval = []
|
363 |
+
for i in range(num_spatial_dims):
|
364 |
+
if roi_size[i] == image_size[i]:
|
365 |
+
scan_interval.append(int(roi_size[i]))
|
366 |
+
else:
|
367 |
+
interval = int(roi_size[i] * (1 - overlap))
|
368 |
+
scan_interval.append(interval if interval > 0 else 1)
|
369 |
+
return tuple(scan_interval)
|