Pawntoqueen commited on
Commit
846083b
1 Parent(s): 8e10b87
Files changed (1) hide show
  1. app.py +1 -137
app.py CHANGED
@@ -1,143 +1,7 @@
1
  import gradio as gr
2
- from monai.utils import first, set_determinism
3
- from monai.transforms import(
4
- Compose,
5
- AddChanneld,
6
- LoadImaged,
7
- Resized,
8
- ToTensord,
9
- Spacingd,
10
- Orientationd,
11
- ScaleIntensityRanged,
12
- CropForegroundd,
13
- Activations,
14
- )
15
 
16
- from monai.networks.nets import UNet
17
- from monai.networks.layers import Norm
18
- from monai.data import CacheDataset, DataLoader, Dataset
19
-
20
- import torch
21
- import matplotlib.pyplot as plt
22
-
23
- import os
24
- from glob import glob
25
- import numpy as np
26
-
27
- from monai.inferers import sliding_window_inference
28
-
29
- in_dir = 'Data_Train_Test'
30
- model_dir = 'results/results'
31
-
32
- train_loss = np.load(os.path.join(model_dir, 'loss_train.npy'))
33
- train_metric = np.load(os.path.join(model_dir, 'metric_train.npy'))
34
- test_loss = np.load(os.path.join(model_dir, 'loss_test.npy'))
35
- test_metric = np.load(os.path.join(model_dir, 'metric_test.npy'))
36
-
37
- def plot_result():
38
- plt.figure("Results 25 june", (12, 6))
39
- plt.subplot(2, 2, 1)
40
- plt.title("Train dice loss")
41
- x = [i + 1 for i in range(len(train_loss))]
42
- y = train_loss
43
- plt.xlabel("epoch")
44
- plt.plot(x, y)
45
-
46
- plt.subplot(2, 2, 2)
47
- plt.title("Train metric DICE")
48
- x = [i + 1 for i in range(len(train_metric))]
49
- y = train_metric
50
- plt.xlabel("epoch")
51
- plt.plot(x, y)
52
-
53
- plt.subplot(2, 2, 3)
54
- plt.title("Test dice loss")
55
- x = [i + 1 for i in range(len(test_loss))]
56
- y = test_loss
57
- plt.xlabel("epoch")
58
- plt.plot(x, y)
59
-
60
- plt.subplot(2, 2, 4)
61
- plt.title("Test metric DICE")
62
- x = [i + 1 for i in range(len(test_metric))]
63
- y = test_metric
64
- plt.xlabel("epoch")
65
- plt.plot(x, y)
66
-
67
- plt.show()
68
-
69
-
70
-
71
- path_train_volumes = sorted(glob(os.path.join(in_dir, "TrainVolumes", "*.nii.gz")))
72
- path_train_segmentation = sorted(glob(os.path.join(in_dir, "TrainSegmentation", "*.nii.gz")))
73
-
74
- path_test_volumes = sorted(glob(os.path.join(in_dir, "TestVolumes", "*.nii.gz")))
75
- path_test_segmentation = sorted(glob(os.path.join(in_dir, "TestSegmentation", "*.nii.gz")))
76
-
77
- train_files = [{"vol": image_name, "seg": label_name} for image_name, label_name in zip(path_train_volumes, path_train_segmentation)]
78
- test_files = [{"vol": image_name, "seg": label_name} for image_name, label_name in zip(path_test_volumes, path_test_segmentation)]
79
- test_files = test_files[0:9]
80
-
81
-
82
- test_transforms = Compose(
83
- [
84
- LoadImaged(keys=["vol", "seg"]),
85
- AddChanneld(keys=["vol", "seg"]),
86
- Spacingd(keys=["vol", "seg"], pixdim=(1.5,1.5,1.0), mode=("bilinear", "nearest")),
87
- Orientationd(keys=["vol", "seg"], axcodes="RAS"),
88
- ScaleIntensityRanged(keys=["vol"], a_min=-200, a_max=200,b_min=0.0, b_max=1.0, clip=True),
89
- CropForegroundd(keys=['vol', 'seg'], source_key='vol'),
90
- Resized(keys=["vol", "seg"], spatial_size=[128,128,64]),
91
- ToTensord(keys=["vol", "seg"]),
92
- ]
93
- )
94
-
95
- test_ds = Dataset(data=test_files, transform=test_transforms)
96
- test_loader = DataLoader(test_ds, batch_size=1)
97
-
98
- device = torch.device("cuda:0")
99
- model = UNet(
100
- dimensions=3,
101
- in_channels=1,
102
- out_channels=2,
103
- channels=(16, 32, 64, 128, 256),
104
- strides=(2, 2, 2, 2),
105
- num_res_units=2,
106
- norm=Norm.BATCH,
107
- ).to(device)
108
-
109
- model.load_state_dict(torch.load(
110
- os.path.join(model_dir, "best_metric_model.pth")))
111
- model.eval()
112
-
113
-
114
- sw_batch_size = 4
115
- roi_size = (128, 128, 64)
116
- with torch.no_grad():
117
- test_patient = first(test_loader)
118
- t_volume = test_patient['vol']
119
- #t_segmentation = test_patient['seg']
120
-
121
- test_outputs = sliding_window_inference(t_volume.to(device), roi_size, sw_batch_size, model)
122
- sigmoid_activation = Activations(sigmoid=True)
123
- test_outputs = sigmoid_activation(test_outputs)
124
- test_outputs = test_outputs > 0.53
125
-
126
- # plot the slice [:, :, 80]
127
  def predict():
128
- i =55
129
- plt.figure("check", (18, 6))
130
- plt.subplot(1, 3, 1)
131
- plt.title(f"image {i}")
132
- plt.imshow(test_patient["vol"][0, 0, :, :, i], cmap="gray")
133
- plt.subplot(1, 3, 2)
134
- plt.title(f"label {i}")
135
- plt.imshow(test_patient["seg"][0, 0, :, :, i] != 0)
136
- plt.subplot(1, 3, 3)
137
- plt.title(f"output {i}")
138
- plt.imshow(test_outputs.detach().cpu()[0, 1, :, :, i])
139
- plt.show()
140
-
141
 
142
 
143
  gr.Interface(fn=predict,
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  def predict():
4
+ break
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  gr.Interface(fn=predict,