|
|
|
|
|
""" |
|
Adapted form MONAI Tutorial: https://github.com/Project-MONAI/tutorials/tree/main/2d_segmentation/torch |
|
""" |
|
|
|
import argparse |
|
import os, sys |
|
|
|
join = os.path.join |
|
|
|
|
|
from tqdm import tqdm |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn import DataParallel |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.utils.tensorboard import SummaryWriter |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR |
|
from stardist import star_dist, edt_prob |
|
from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label |
|
from stardist import random_label_cmap, ray_angles |
|
import monai |
|
from collections import OrderedDict |
|
from compute_metric import eval_tp_fp_fn, remove_boundary_cells |
|
from monai.data import decollate_batch, PILReader |
|
from monai.inferers import sliding_window_inference |
|
from monai.metrics import DiceMetric |
|
from monai.transforms import ( |
|
Activations, |
|
AsChannelFirstd, |
|
AddChanneld, |
|
AsDiscrete, |
|
CenterSpatialCropd, |
|
Compose, |
|
Lambdad, |
|
LoadImaged, |
|
|
|
SpatialPadd, |
|
RandSpatialCropd, |
|
RandRotate90d, |
|
ScaleIntensityd, |
|
RandAxisFlipd, |
|
RandZoomd, |
|
RandGaussianNoised, |
|
RandAdjustContrastd, |
|
RandGaussianSmoothd, |
|
RandHistogramShiftd, |
|
EnsureTyped, |
|
EnsureType, |
|
apply_transform, |
|
) |
|
from monai.visualize import plot_2d_or_3d_image |
|
import matplotlib.pyplot as plt |
|
from datetime import datetime |
|
import shutil |
|
from skimage import io |
|
from skimage.color import gray2rgb |
|
|
|
from models.unetr2d import UNETR2D |
|
from models.swin_unetr import SwinUNETR |
|
from models.flexible_unet_convext import FlexibleUNet_hv |
|
|
|
from utils import cropping_center, gen_targets, xentropy_loss, dice_loss, mse_loss, msge_loss |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
print("Successfully imported all requirements!") |
|
torch.backends.cudnn.enabled = False |
|
|
|
def rm_n_mkdir(dir_path): |
|
"""Remove and make directory.""" |
|
if os.path.isdir(dir_path): |
|
shutil.rmtree(dir_path) |
|
os.makedirs(dir_path) |
|
|
|
class HoverDataset(Dataset): |
|
def __init__(self, data, transform, mask_shape): |
|
self.data = data |
|
self.transform = transform |
|
self.mask_shape = mask_shape |
|
|
|
def __len__(self) -> int: |
|
return len(self.data) |
|
|
|
def _transform(self, index): |
|
data_i = self.data[index] |
|
return apply_transform(self.transform, data_i) if self.transform is not None else data_i |
|
|
|
def __getitem__(self, index): |
|
ret = self._transform(index) |
|
|
|
|
|
inst_map = np.squeeze(ret['label'].numpy()).astype('int32') |
|
target_dict = gen_targets(inst_map, inst_map.shape[:2]) |
|
np_map, hv_map = target_dict['np_map'], target_dict['hv_map'] |
|
np_map = cropping_center(np_map, self.mask_shape) |
|
hv_map = cropping_center(hv_map, self.mask_shape) |
|
target_dict['np_map'] = torch.tensor(np_map) |
|
target_dict['hv_map'] = torch.tensor(hv_map) |
|
|
|
img = cropping_center(ret['img'].permute(1,2,0), self.mask_shape).permute(2,0,1) |
|
ret['img'] = img |
|
ret.update(target_dict) |
|
return ret |
|
|
|
def valid_step(model, batch_data): |
|
|
|
model.eval() |
|
|
|
|
|
imgs = batch_data["img"] |
|
true_np = batch_data["np_map"] |
|
true_hv = batch_data["hv_map"] |
|
|
|
imgs_gpu = imgs.to("cuda").type(torch.float32) |
|
|
|
|
|
true_np = torch.squeeze(true_np).type(torch.int64) |
|
true_hv = torch.squeeze(true_hv).type(torch.float32) |
|
|
|
true_dict = { |
|
"np": true_np, |
|
"hv": true_hv, |
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
preds = model(imgs_gpu) |
|
pred_dict = {'np': preds[1], 'hv': preds[0]} |
|
pred_dict = OrderedDict( |
|
[[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] |
|
) |
|
pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1] |
|
|
|
|
|
result_dict = { |
|
"raw": { |
|
"imgs": imgs.numpy(), |
|
"true_np": true_dict["np"].numpy(), |
|
"true_hv": true_dict["hv"].numpy(), |
|
"prob_np": pred_dict["np"].cpu().numpy(), |
|
"pred_hv": pred_dict["hv"].cpu().numpy(), |
|
} |
|
} |
|
|
|
return result_dict |
|
|
|
def proc_valid_step_output(raw_data, nr_types=None): |
|
|
|
track_dict = {} |
|
|
|
def _dice_info(true, pred, label): |
|
true = np.array(true == label, np.int32) |
|
pred = np.array(pred == label, np.int32) |
|
inter = (pred * true).sum() |
|
total = (pred + true).sum() |
|
return inter, total |
|
|
|
over_inter = 0 |
|
over_total = 0 |
|
over_correct = 0 |
|
prob_np = raw_data["prob_np"] |
|
true_np = raw_data["true_np"] |
|
for idx in range(len(raw_data["true_np"])): |
|
patch_prob_np = prob_np[idx] |
|
patch_true_np = true_np[idx] |
|
patch_pred_np = np.array(patch_prob_np > 0.5, dtype=np.int32) |
|
inter, total = _dice_info(patch_true_np, patch_pred_np, 1) |
|
correct = (patch_pred_np == patch_true_np).sum() |
|
over_inter += inter |
|
over_total += total |
|
over_correct += correct |
|
nr_pixels = len(true_np) * np.size(true_np[0]) |
|
acc_np = over_correct / nr_pixels |
|
dice_np = 2 * over_inter / (over_total + 1.0e-8) |
|
track_dict['np_acc'] = acc_np |
|
track_dict['np_dice'] = dice_np |
|
|
|
|
|
pred_hv = raw_data["pred_hv"] |
|
true_hv = raw_data["true_hv"] |
|
|
|
over_squared_error = 0 |
|
for idx in range(len(raw_data["true_np"])): |
|
patch_pred_hv = pred_hv[idx] |
|
patch_true_hv = true_hv[idx] |
|
squared_error = patch_pred_hv - patch_true_hv |
|
squared_error = squared_error * squared_error |
|
over_squared_error += squared_error.sum() |
|
mse = over_squared_error / nr_pixels |
|
track_dict['hv_mse'] = mse |
|
|
|
return track_dict |
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modelname = 'star-hover' |
|
strategy = 'aug256_out256' |
|
parser = argparse.ArgumentParser("Baseline for Microscopy image segmentation") |
|
|
|
parser.add_argument( |
|
"--data_path", |
|
default=f"/mntnfs/med_data5/louwei/consep/", |
|
type=str, |
|
help="training data path; subfolders: images, labels", |
|
) |
|
parser.add_argument("--seed", default=10, type=int) |
|
|
|
parser.add_argument("--num_workers", default=4, type=int) |
|
|
|
|
|
parser.add_argument( |
|
"--model_name", default="efficientunet", help="select mode: unet, unetr, swinunetr" |
|
) |
|
parser.add_argument("--input_size", default=512, type=int, help="after rand crop") |
|
parser.add_argument("--mask_size", default=256, type=int, help="after gen target") |
|
|
|
parser.add_argument("--batch_size", default=12, type=int, help="Batch size per GPU") |
|
parser.add_argument("--max_epochs", default=800, type=int) |
|
parser.add_argument("--val_interval", default=1, type=int) |
|
parser.add_argument("--save_interval", default=10, type=int) |
|
parser.add_argument("--initial_lr", type=float, default=1e-4, help="learning rate") |
|
parser.add_argument('--gpu_id', type=str, default='0', help='gpu id') |
|
|
|
args = parser.parse_args() |
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) |
|
|
|
work_dir = f'/mntnfs/med_data5/louwei/hover_stardist/class_{modelname}_{strategy}' |
|
|
|
|
|
pre_trained = False |
|
|
|
np.random.seed(args.seed) |
|
model_path = join(work_dir) |
|
rm_n_mkdir(model_path) |
|
run_id = datetime.now().strftime("%Y%m%d-%H%M") |
|
shutil.copyfile( |
|
__file__, join(model_path, run_id + "_" + os.path.basename(__file__)) |
|
) |
|
img_path = join(args.data_path, "Train/Images_3channels") |
|
gt_path = join(args.data_path, "Train/tif") |
|
val_img_path = join(args.data_path, "Test/Images_3channels") |
|
val_gt_path = join(args.data_path, "Test/tif") |
|
img_names = sorted(os.listdir(img_path)) |
|
gt_names = [img_name.replace('.png', '.tif') for img_name in img_names] |
|
img_num = len(img_names) |
|
val_frac = 0.1 |
|
val_img_names = sorted(os.listdir(val_img_path)) |
|
val_gt_names = [img_name.replace('.png', '.tif') for img_name in val_img_names] |
|
|
|
train_files = [ |
|
{"img": join(img_path, img_names[i]), "label": join(gt_path, gt_names[i]), 'name': img_names[i]} |
|
for i in range(len(img_names)) |
|
] |
|
val_files = [ |
|
{"img": join(val_img_path, val_img_names[i]), "label": join(val_gt_path, val_gt_names[i]), |
|
'name': val_img_names[i]} |
|
for i in range(len(val_img_names)) |
|
] |
|
print( |
|
f"training image num: {len(train_files)}, validation image num: {len(val_files)}" |
|
) |
|
|
|
def load_img(img): |
|
ret = io.imread(img) |
|
if len(ret.shape) == 2: |
|
ret = gray2rgb(ret) |
|
return ret.astype('float32') |
|
|
|
def load_ann(ann): |
|
ret = np.squeeze(io.imread(ann)).astype('float32') |
|
return ret |
|
|
|
|
|
train_transforms = Compose( |
|
[ |
|
Lambdad(('img',), load_img), |
|
Lambdad(('label',), load_ann), |
|
|
|
|
|
|
|
AddChanneld(keys=["label"], allow_missing_keys=True), |
|
AsChannelFirstd( |
|
keys=["img"], channel_dim=-1, allow_missing_keys=True |
|
), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RandAxisFlipd(keys=["img", "label"], prob=0.5), |
|
RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]), |
|
|
|
RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1), |
|
RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)), |
|
RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)), |
|
RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3), |
|
RandZoomd( |
|
keys=["img", "label"], |
|
prob=0.15, |
|
min_zoom=0.5, |
|
max_zoom=2.0, |
|
mode=["area", "nearest"], |
|
), |
|
EnsureTyped(keys=["img", "label"]), |
|
] |
|
) |
|
|
|
val_transforms = Compose( |
|
[ |
|
Lambdad(('img',), load_img), |
|
Lambdad(('label',), load_ann), |
|
|
|
AddChanneld(keys=["label"], allow_missing_keys=True), |
|
AsChannelFirstd(keys=["img"], channel_dim=-1, allow_missing_keys=True), |
|
|
|
|
|
|
|
|
|
|
|
EnsureTyped(keys=["img", "label"]), |
|
] |
|
) |
|
|
|
|
|
|
|
check_ds = HoverDataset(data=train_files, transform=train_transforms, mask_shape=(args.mask_size, args.mask_size)) |
|
print(len(check_ds)) |
|
tmp = check_ds[0] |
|
print(tmp['img'].shape, tmp['label'].shape, tmp['hv_map'].shape, tmp['np_map'].shape) |
|
check_loader = DataLoader(check_ds, batch_size=1, num_workers=4) |
|
check_data = monai.utils.misc.first(check_loader) |
|
print( |
|
"sanity check:", |
|
check_data["img"].shape, |
|
torch.max(check_data["img"]), |
|
check_data["label"].shape, |
|
torch.max(check_data["label"]), |
|
check_data["hv_map"].shape, |
|
torch.max(check_data["hv_map"]), |
|
check_data["np_map"].shape, |
|
torch.max(check_data["np_map"]), |
|
) |
|
|
|
|
|
|
|
train_ds = HoverDataset(data=train_files, transform=train_transforms, mask_shape=(args.mask_size, args.mask_size)) |
|
print(len(train_ds)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
train_ds, |
|
batch_size=args.batch_size, |
|
shuffle=True, |
|
num_workers=args.num_workers, |
|
pin_memory=torch.cuda.is_available(), |
|
) |
|
|
|
|
|
val_ds = HoverDataset(data=val_files, transform=val_transforms, mask_shape=(args.mask_size, args.mask_size)) |
|
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4) |
|
|
|
model = FlexibleUNet_hv( |
|
in_channels=3, |
|
out_channels=2+2, |
|
backbone='convnext_small', |
|
pretrained=True, |
|
n_rays=2, |
|
prob_out_channels=2, |
|
) |
|
|
|
activatation = nn.ReLU() |
|
sigmoid = nn.Sigmoid() |
|
initial_lr = args.initial_lr |
|
optimizer = torch.optim.AdamW(model.parameters(), initial_lr) |
|
scheduler = StepLR(optimizer, 100, 0.1) |
|
|
|
|
|
|
|
|
|
|
|
model = model.to('cuda') |
|
|
|
max_epochs = args.max_epochs |
|
val_interval = args.val_interval |
|
save_interval = args.save_interval |
|
epoch_loss_values = [] |
|
writer = SummaryWriter(model_path) |
|
|
|
|
|
loss_file = f'{work_dir}/train_loss.txt' |
|
f1_file = f'{work_dir}/train_loss.txt' |
|
if os.path.exists(loss_file): |
|
os.remove(loss_file) |
|
if os.path.exists(f1_file): |
|
os.remove(f1_file) |
|
|
|
|
|
for epoch in range(1, args.max_epochs): |
|
model.train() |
|
epoch_loss = 0 |
|
running_np_1, running_np_2, running_hv_1, running_hv_2 = 0.0, 0.0, 0.0, 0.0 |
|
stream = tqdm(train_loader) |
|
for step, batch_data in enumerate(stream, start=1): |
|
|
|
|
|
inputs, true_np, true_hv = batch_data["img"], batch_data["np_map"], batch_data['hv_map'] |
|
true_np = true_np.to("cuda").type(torch.int64) |
|
true_hv = true_hv.to("cuda").type(torch.float32) |
|
true_np_onehot = (F.one_hot(true_np, num_classes=2)).type(torch.float32) |
|
inputs = torch.tensor(inputs).to('cuda') |
|
|
|
|
|
optimizer.zero_grad() |
|
pred_hv, pred_np = model(inputs) |
|
pred_hv = pred_hv.permute(0, 2, 3, 1).contiguous() |
|
pred_np = pred_np.permute(0, 2, 3, 1).contiguous() |
|
pred_np = F.softmax(pred_np, dim=-1) |
|
|
|
|
|
loss_np_1 = xentropy_loss(true_np_onehot, pred_np) |
|
loss_np_2 = dice_loss(true_np_onehot, pred_np) |
|
loss_hv_1 = mse_loss(true_hv, pred_hv) |
|
loss_hv_2 = msge_loss(true_hv, pred_hv, true_np_onehot[...,1]) |
|
loss = loss_np_1 + loss_np_2 + loss_hv_1 + loss_hv_2 |
|
loss.backward() |
|
optimizer.step() |
|
epoch_loss += loss.item() |
|
epoch_len = len(train_ds) // train_loader.batch_size |
|
|
|
running_np_1 += loss_np_1.item() |
|
running_np_2 += loss_np_2.item() |
|
running_hv_1 += loss_hv_1.item() |
|
running_hv_2 += loss_hv_2.item() |
|
|
|
|
|
stream.set_description( |
|
f'Epoch {epoch} | np bce: {running_np_1 / step:.4f}, np dice: {running_np_2 / step:.4f}, hv mse: {running_hv_1 / step:.4f}, hv msge: {running_hv_2 / step:.4f}') |
|
|
|
epoch_loss /= step |
|
epoch_loss_values.append(epoch_loss) |
|
writer.add_scalar("train_loss", epoch_loss, epoch) |
|
writer.add_scalar("np_bce", running_np_1 / step, epoch) |
|
writer.add_scalar("np_dice", running_np_2 / step, epoch) |
|
writer.add_scalar("hv_mse", running_hv_1 / step, epoch) |
|
writer.add_scalar("hv_msge", running_hv_2 / step, epoch) |
|
print(f"epoch {epoch} average loss: {epoch_loss:.4f}, lr: {optimizer.param_groups[0]['lr']}") |
|
|
|
|
|
with open(loss_file, 'a') as f: |
|
f.write(f'Epoch{epoch}\tloss:{epoch_loss:.4f}\tnp_bce:{running_np_1/step:.4f}\tnp_dice:{running_np_2/step:.4f}\thv_mse:{running_hv_1/step:.4f}\thv_msge:{running_hv_2/step:.4f}\n') |
|
|
|
|
|
checkpoint = { |
|
"epoch": epoch, |
|
"model_state_dict": model.state_dict(), |
|
"optimizer_state_dict": optimizer.state_dict(), |
|
"loss": epoch_loss_values, |
|
} |
|
if epoch % save_interval == 0: |
|
torch.save(checkpoint, join(model_path, str(epoch) + ".pth")) |
|
|
|
running_np_acc, running_np_dice, running_hv_mse = 0.0, 0.0, 0.0 |
|
stream_val = tqdm(val_loader) |
|
for step, batch_data in enumerate(stream_val, start=1): |
|
raw_data = valid_step(model, batch_data)['raw'] |
|
track_dict = proc_valid_step_output(raw_data) |
|
running_np_acc += track_dict['np_acc'] |
|
running_np_dice += track_dict['np_dice'] |
|
running_hv_mse += track_dict['hv_mse'] |
|
stream.set_description(f'Epoch {epoch} | np acc: {running_np_acc / step:.4f}, np dice: {running_np_dice / step:.4f}, hv mse: {running_hv_mse / step:.4f}') |
|
writer.add_scalar("np_acc", running_np_acc / step, epoch) |
|
writer.add_scalar("np_dice", running_np_dice / step, epoch) |
|
writer.add_scalar("hv_mse", running_hv_mse / step, epoch) |
|
print(f'Epoch {epoch} | np acc: {running_np_acc / step:.4f}, np dice: {running_np_dice / step:.4f}, hv mse: {running_hv_mse / step:.4f}') |
|
|
|
|
|
with open(loss_file, 'a') as f: |
|
f.write(f'Validation | Epoch{epoch}\tloss:{epoch_loss:.4f}\tnp_acc:{running_np_acc/step:.4f}\tnp_dice:{running_np_dice/step:.4f}\thv_mse:{running_hv_mse/step:.4f}\n') |
|
|
|
|
|
scheduler.step() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|