Pawntoqueen commited on
Commit
2b30d70
1 Parent(s): 4a6fb9d
Files changed (1) hide show
  1. app.py +140 -0
app.py CHANGED
@@ -1,4 +1,144 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  gr.Interface(fn=predict,
4
  inputs=gr.inputs.Image(type="pil"),
 
1
  import gradio as gr
2
+ from monai.utils import first, set_determinism
3
+ from monai.transforms import(
4
+ Compose,
5
+ AddChanneld,
6
+ LoadImaged,
7
+ Resized,
8
+ ToTensord,
9
+ Spacingd,
10
+ Orientationd,
11
+ ScaleIntensityRanged,
12
+ CropForegroundd,
13
+ Activations,
14
+ )
15
+
16
+ from monai.networks.nets import UNet
17
+ from monai.networks.layers import Norm
18
+ from monai.data import CacheDataset, DataLoader, Dataset
19
+
20
+ import torch
21
+ import matplotlib.pyplot as plt
22
+
23
+ import os
24
+ from glob import glob
25
+ import numpy as np
26
+
27
+ from monai.inferers import sliding_window_inference
28
+
29
+ in_dir = 'Data_Train_Test'
30
+ model_dir = 'results/results'
31
+
32
+ train_loss = np.load(os.path.join(model_dir, 'loss_train.npy'))
33
+ train_metric = np.load(os.path.join(model_dir, 'metric_train.npy'))
34
+ test_loss = np.load(os.path.join(model_dir, 'loss_test.npy'))
35
+ test_metric = np.load(os.path.join(model_dir, 'metric_test.npy'))
36
+
37
+ def plot_result:
38
+ plt.figure("Results 25 june", (12, 6))
39
+ plt.subplot(2, 2, 1)
40
+ plt.title("Train dice loss")
41
+ x = [i + 1 for i in range(len(train_loss))]
42
+ y = train_loss
43
+ plt.xlabel("epoch")
44
+ plt.plot(x, y)
45
+
46
+ plt.subplot(2, 2, 2)
47
+ plt.title("Train metric DICE")
48
+ x = [i + 1 for i in range(len(train_metric))]
49
+ y = train_metric
50
+ plt.xlabel("epoch")
51
+ plt.plot(x, y)
52
+
53
+ plt.subplot(2, 2, 3)
54
+ plt.title("Test dice loss")
55
+ x = [i + 1 for i in range(len(test_loss))]
56
+ y = test_loss
57
+ plt.xlabel("epoch")
58
+ plt.plot(x, y)
59
+
60
+ plt.subplot(2, 2, 4)
61
+ plt.title("Test metric DICE")
62
+ x = [i + 1 for i in range(len(test_metric))]
63
+ y = test_metric
64
+ plt.xlabel("epoch")
65
+ plt.plot(x, y)
66
+
67
+ plt.show()
68
+
69
+
70
+
71
+ path_train_volumes = sorted(glob(os.path.join(in_dir, "TrainVolumes", "*.nii.gz")))
72
+ path_train_segmentation = sorted(glob(os.path.join(in_dir, "TrainSegmentation", "*.nii.gz")))
73
+
74
+ path_test_volumes = sorted(glob(os.path.join(in_dir, "TestVolumes", "*.nii.gz")))
75
+ path_test_segmentation = sorted(glob(os.path.join(in_dir, "TestSegmentation", "*.nii.gz")))
76
+
77
+ train_files = [{"vol": image_name, "seg": label_name} for image_name, label_name in zip(path_train_volumes, path_train_segmentation)]
78
+ test_files = [{"vol": image_name, "seg": label_name} for image_name, label_name in zip(path_test_volumes, path_test_segmentation)]
79
+ test_files = test_files[0:9]
80
+
81
+
82
+ test_transforms = Compose(
83
+ [
84
+ LoadImaged(keys=["vol", "seg"]),
85
+ AddChanneld(keys=["vol", "seg"]),
86
+ Spacingd(keys=["vol", "seg"], pixdim=(1.5,1.5,1.0), mode=("bilinear", "nearest")),
87
+ Orientationd(keys=["vol", "seg"], axcodes="RAS"),
88
+ ScaleIntensityRanged(keys=["vol"], a_min=-200, a_max=200,b_min=0.0, b_max=1.0, clip=True),
89
+ CropForegroundd(keys=['vol', 'seg'], source_key='vol'),
90
+ Resized(keys=["vol", "seg"], spatial_size=[128,128,64]),
91
+ ToTensord(keys=["vol", "seg"]),
92
+ ]
93
+ )
94
+
95
+ test_ds = Dataset(data=test_files, transform=test_transforms)
96
+ test_loader = DataLoader(test_ds, batch_size=1)
97
+
98
+ device = torch.device("cuda:0")
99
+ model = UNet(
100
+ dimensions=3,
101
+ in_channels=1,
102
+ out_channels=2,
103
+ channels=(16, 32, 64, 128, 256),
104
+ strides=(2, 2, 2, 2),
105
+ num_res_units=2,
106
+ norm=Norm.BATCH,
107
+ ).to(device)
108
+
109
+ model.load_state_dict(torch.load(
110
+ os.path.join(model_dir, "best_metric_model.pth")))
111
+ model.eval()
112
+
113
+
114
+ sw_batch_size = 4
115
+ roi_size = (128, 128, 64)
116
+ with torch.no_grad():
117
+ test_patient = first(test_loader)
118
+ t_volume = test_patient['vol']
119
+ #t_segmentation = test_patient['seg']
120
+
121
+ test_outputs = sliding_window_inference(t_volume.to(device), roi_size, sw_batch_size, model)
122
+ sigmoid_activation = Activations(sigmoid=True)
123
+ test_outputs = sigmoid_activation(test_outputs)
124
+ test_outputs = test_outputs > 0.53
125
+
126
+ # plot the slice [:, :, 80]
127
+ def predict:
128
+ i =55
129
+ plt.figure("check", (18, 6))
130
+ plt.subplot(1, 3, 1)
131
+ plt.title(f"image {i}")
132
+ plt.imshow(test_patient["vol"][0, 0, :, :, i], cmap="gray")
133
+ plt.subplot(1, 3, 2)
134
+ plt.title(f"label {i}")
135
+ plt.imshow(test_patient["seg"][0, 0, :, :, i] != 0)
136
+ plt.subplot(1, 3, 3)
137
+ plt.title(f"output {i}")
138
+ plt.imshow(test_outputs.detach().cpu()[0, 1, :, :, i])
139
+ plt.show()
140
+
141
+
142
 
143
  gr.Interface(fn=predict,
144
  inputs=gr.inputs.Image(type="pil"),