hwonheo commited on
Commit
9cbcfa0
·
1 Parent(s): 89562ee

Upload 8 files

Browse files
app.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import nibabel as nib
4
+ import torch
5
+ from scipy.ndimage import zoom
6
+ from network.generator import ResnetGenerator
7
+ import ants
8
+ import tempfile
9
+ import os
10
+ import matplotlib.pyplot as plt
11
+ import shutil
12
+
13
+ # Class for handling MRI inference
14
+ class MRIInference:
15
+ def __init__(self, model, device, input_shape, output_shape):
16
+ self.model = model
17
+ self.device = device
18
+ self.input_shape = input_shape
19
+ self.output_shape = output_shape
20
+
21
+ def load_image(self, file_path):
22
+ # Load and preprocess the MRI image
23
+ image = nib.load(file_path).get_fdata()
24
+ rotated_image = np.rot90(image, k=1, axes=(1, 2))
25
+ mean = np.mean(rotated_image)
26
+ std = np.std(rotated_image)
27
+ normalized_image = (rotated_image - mean) / std
28
+ min_val = np.min(normalized_image)
29
+ max_val = np.max(normalized_image)
30
+ scale = 255 / (max_val - min_val)
31
+ normalized_image = scale * (normalized_image - min_val)
32
+ scale_factors = (
33
+ self.input_shape[0] / normalized_image.shape[0],
34
+ self.input_shape[1] / normalized_image.shape[1],
35
+ self.input_shape[2] / normalized_image.shape[2]
36
+ )
37
+ resampled_image = zoom(normalized_image, scale_factors, order=1)
38
+ return torch.tensor(
39
+ resampled_image[np.newaxis, np.newaxis, ...], dtype=torch.float32)
40
+
41
+ def save_image(self, image, file_name):
42
+ # Save the processed image
43
+ image = image.squeeze().cpu().numpy()
44
+ scale_factors = (
45
+ self.output_shape[0] / image.shape[0],
46
+ self.output_shape[1] / image.shape[1],
47
+ self.output_shape[2] / image.shape[2]
48
+ )
49
+ resampled_image = zoom(image, scale_factors, order=1)
50
+ resampled_image = np.rot90(resampled_image, k=-1, axes=(1, 2))
51
+ nib.save(nib.Nifti1Image(resampled_image, np.eye(4)), file_name)
52
+
53
+ def match_sform_affine(self, orig_path, gen_path):
54
+ # Match the affine of original and generated images
55
+ orig_img = nib.load(orig_path)
56
+ orig_affine = orig_img.affine
57
+ gen_img = nib.load(gen_path)
58
+ gen_data = gen_img.get_fdata()
59
+ matched_gen_img = nib.Nifti1Image(gen_data, orig_affine)
60
+ nib.save(matched_gen_img, gen_path)
61
+
62
+ def infer(self, input_tensor, original_file_path, output_path):
63
+ # Inference process
64
+ with torch.no_grad():
65
+ self.model.eval()
66
+ output = self.model(input_tensor.to(self.device))
67
+ scale_factor = (1, 1, self.output_shape[2] / output.shape[4])
68
+ resampled_output = zoom(
69
+ output.squeeze().cpu().numpy(), scale_factor, order=1)
70
+ generated_image = torch.tensor(resampled_output[np.newaxis, ...])
71
+ temp_orig_path = os.path.join(output_path, 'temp_orig.nii.gz')
72
+ resampled_file_path = resample_to_isotropic(
73
+ original_file_path, temp_orig_path)
74
+ temp_generated_path = os.path.join(output_path, 'temp_generated.nii.gz')
75
+ self.save_image(generated_image, temp_generated_path)
76
+ self.match_sform_affine(resampled_file_path, temp_generated_path)
77
+ resampled_generated_path = os.path.join(output_path, 'resampled_generated.nii.gz')
78
+ resample_to_isotropic(temp_generated_path, resampled_generated_path)
79
+ base_name = os.path.basename(original_file_path)
80
+ gen_file_name = base_name.replace(".nii", "_gen.nii")
81
+ warped_file_path = os.path.join(output_path, gen_file_name)
82
+ affine_registration(
83
+ resampled_file_path, resampled_generated_path, warped_file_path)
84
+ for temp_file in [temp_orig_path, temp_generated_path, resampled_generated_path]:
85
+ os.remove(temp_file)
86
+ return warped_file_path
87
+
88
+ # Functions for image processing and Streamlit UI handling
89
+ def resample_to_isotropic(image_path, output_path):
90
+ image = ants.image_read(image_path)
91
+ resampled_image = ants.resample_image(
92
+ image, (0.15, 0.15, 0.15), use_voxels=False, interp_type=4)
93
+ ants.image_write(resampled_image, output_path)
94
+ return output_path
95
+
96
+
97
+ def affine_registration(fixed_image_path, moving_image_path, output_path):
98
+ # Register affine of fixed and moving images
99
+ fixed_image = ants.image_read(fixed_image_path)
100
+ moving_image = ants.image_read(moving_image_path)
101
+ registration = ants.registration(
102
+ fixed=fixed_image, moving=moving_image,
103
+ type_of_transform='Elastic')
104
+ ants.image_write(registration['warpedmovout'], output_path)
105
+
106
+ @st.cache_data
107
+ def load_model():
108
+ # Load the trained model
109
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
110
+ generator = ResnetGenerator().to(device)
111
+ checkpoint_path = 'ckpt/ckpt_final/G_latest.pth'
112
+ checkpoint = torch.load(checkpoint_path, map_location=device)
113
+ generator.load_state_dict(checkpoint)
114
+ return generator, device
115
+
116
+ generator, device = load_model()
117
+ inference_engine = MRIInference(generator, device, (128, 32, 128), (128, 192, 128))
118
+
119
+ def save_middle_slice(image, file_path):
120
+ # Save middle slice of the MRI image
121
+ middle_slice = image[image.shape[0] // 2]
122
+ fig, ax = plt.subplots(figsize=(5, 5))
123
+ ax.imshow(middle_slice, cmap='gray', aspect='auto')
124
+ ax.axis('off')
125
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
126
+ plt.savefig(file_path, format='jpg', bbox_inches='tight', pad_inches=0, dpi=500)
127
+ plt.close()
128
+
129
+ def clear_output_folder(folder_path):
130
+ # Clear contents of the specified folder
131
+ for filename in os.listdir(folder_path):
132
+ file_path = os.path.join(folder_path, filename)
133
+ if os.path.isfile(file_path) or os.path.islink(file_path):
134
+ os.unlink(file_path)
135
+ elif os.path.isdir(file_path):
136
+ shutil.rmtree(file_path)
137
+
138
+ def clear_session():
139
+ # Clear the session state
140
+ for key in list(st.session_state.keys()):
141
+ del st.session_state[key]
142
+
143
+ def main():
144
+ # Sidebar - How to Use Guide
145
+ st.sidebar.title("How to Use EasySR")
146
+ st.sidebar.markdown(
147
+ "**Step-by-Step Guide:**\n\n"
148
+ "1. **Prepare Your Data**: \n\n\tGet your rat brain T2 MRI data. "
149
+ "Ensure it's in NIFTI format. Convert if necessary.\n\n"
150
+ "2. **Upload Your MRI**: \n\n\tDrag and drop your NIFTI file or use "
151
+ "the upload button.\n\n"
152
+ "3. **Start the EasySR**: \n\n\tPress 'EasySR' and we'll handle the rest. "
153
+ "The process is quick, typically taking just a few seconds to complete!\n\n"
154
+ "4. **Sit Back and Relax**: \n\n\tNo long waits here - your data will be "
155
+ "processed in under a minute.\n\n"
156
+ "5. **View and Download**: \n\n\tAfter processing, view the results and "
157
+ "use the download button to save the MRI.\n\n"
158
+ "6. **Use as Needed**: \n\n\tDownload and use your enhanced MRI as you see fit. "
159
+ "Get your data more!\n\n #"
160
+ "#\n\n "
161
+ "#\n\n\n "
162
+ "#\n\n\n "
163
+ "GitHub: EasySR"
164
+ "\n\n "
165
+ "[github.com/hwonheo/easysr](https://github.com/hwonheo/easysr)"
166
+ "\n\n "
167
+ "Huggingface (space): EasySR"
168
+ "\n\n "
169
+ "[huggingface.co/spaces/hwonheo/easysr]"
170
+ "(https://huggingface.co/spaces/hwonheo/easysr)"
171
+ )
172
+
173
+ # Main function for Streamlit UI
174
+ st.markdown("<h1 style='text-align: center;'>EasySR:</h1>", unsafe_allow_html=True)
175
+ st.markdown("<h2 style='text-align: center;'>Rat Brain T2 MRI SR-Reconstruction</h2>", unsafe_allow_html=True)
176
+ st.title("\n")
177
+ col1, col2 = st.columns([0.5, 0.5])
178
+
179
+ original_slice_path = None
180
+ inferred_slice_path = None
181
+ download_file_path = None
182
+
183
+ output_path = "infer/generate"
184
+ if not os.path.exists(output_path):
185
+ os.makedirs(output_path)
186
+
187
+ with col1:
188
+ st.markdown("<h3 style='text-align: center;'>MRI File Upload (NIFTI)</h3>",
189
+ unsafe_allow_html=True)
190
+ uploaded_file = st.file_uploader("", type=["nii", "nii.gz"], key='file_uploader')
191
+
192
+ if uploaded_file is not None:
193
+ st.session_state['uploaded_file'] = uploaded_file
194
+ file_name = uploaded_file.name
195
+ infer_button = st.button("EasySR (start inference)", type="primary")
196
+
197
+ if infer_button:
198
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".nii.gz") as tmp_file:
199
+ tmp_file.write(uploaded_file.getvalue())
200
+ file_path = tmp_file.name
201
+
202
+ try:
203
+ input_tensor = inference_engine.load_image(file_path)
204
+ warped_image_path = inference_engine.infer(
205
+ input_tensor, file_path, output_path)
206
+
207
+ gen_file_name = file_name.replace(".nii", "_gen.nii")
208
+ download_file_path = os.path.join(output_path, gen_file_name)
209
+ shutil.copy(warped_image_path, download_file_path)
210
+
211
+ original_img = nib.load(file_path).get_fdata()
212
+ inferred_img = nib.load(warped_image_path).get_fdata()
213
+
214
+ original_slice_path = os.path.join(output_path, "original_slice.jpg")
215
+ inferred_slice_path = os.path.join(output_path, "inferred_slice.jpg")
216
+ save_middle_slice(original_img, original_slice_path)
217
+ save_middle_slice(inferred_img, inferred_slice_path)
218
+
219
+ except Exception as e:
220
+ st.error(f"Error during inference: {e}")
221
+
222
+ if file_path and os.path.exists(file_path):
223
+ os.remove(file_path)
224
+
225
+ with col2:
226
+ st.header("\n")
227
+ st.header("\n")
228
+ st.header("\n")
229
+ st.header("\n")
230
+ st.header("\n")
231
+ st.header("\n")
232
+ st.header("\n")
233
+ st.subheader("\n")
234
+ if download_file_path and os.path.exists(download_file_path):
235
+ with open(download_file_path, "rb") as file:
236
+ st.download_button(
237
+ label="Download (EasySR inferred-MRI)",
238
+ data=file,
239
+ file_name=gen_file_name,
240
+ mime="application/gzip",
241
+ type="primary"
242
+ )
243
+
244
+ if st.button('Clear All',
245
+ help='Caution: Pressing the Clear All button will delete the contents of the generate folder.'):
246
+ clear_output_folder('infer/generate')
247
+ clear_session()
248
+ st.experimental_rerun()
249
+
250
+ st.subheader("\n")
251
+ st.subheader("\n")
252
+ st.subheader("\n")
253
+ if original_slice_path and os.path.exists(original_slice_path):
254
+ st.subheader("Comparison of Inferred slice")
255
+ col3, col4 = st.columns([0.5, 0.5])
256
+ with col3:
257
+ if original_slice_path and os.path.exists(original_slice_path):
258
+ st.markdown("**Original**")
259
+ st.image(original_slice_path, caption="Original MRI", width=300)
260
+
261
+ with col4:
262
+ if inferred_slice_path and os.path.exists(inferred_slice_path):
263
+ st.markdown("**EasySR**")
264
+ st.image(inferred_slice_path, caption="Inferred MRI", width=300)
265
+
266
+ if __name__ == '__main__':
267
+ main()
ckpt/ckpt_final/D_latest.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:364c07a84b3e8af7519aed34c7a66f56a0ee5d8faa0f8812ff9f20423058f417
3
+ size 2806810
ckpt/ckpt_final/G_latest.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b454e05aeef19abaed090814723266eb17ff18b4e2b6cde276c377173af8d577
3
+ size 285648
infer/generate/easysr_generated_data_will_be_here_in.txt ADDED
File without changes
infer/input/input_your_infer_data_here.txt ADDED
File without changes
network/discriminator.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ def conv_block(ndf, in_channels, out_channels, kernel_size, stride, padding):
5
+ """Defines a convolutional block with convolution, batch normalization, and LeakyReLU activation."""
6
+ return nn.Sequential(
7
+ nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
8
+ nn.BatchNorm3d(out_channels),
9
+ nn.LeakyReLU(0.2, inplace=True)
10
+ )
11
+
12
+ class PatchDiscriminator(nn.Module):
13
+ def __init__(self, input_nc=1, ndf=16):
14
+ """Initializes the Patch Discriminator model.
15
+
16
+ Args:
17
+ input_nc (int): Number of input channels. Default is 1 (e.g., for grayscale images).
18
+ ndf (int): Number of filters in the first convolution layer. Default is 16.
19
+ """
20
+ super(PatchDiscriminator, self).__init__()
21
+
22
+ # Define convolutional blocks
23
+ self.conv1 = conv_block(ndf, input_nc, ndf, kernel_size=4, stride=2, padding=1)
24
+ self.conv2 = conv_block(ndf, ndf, ndf * 2, kernel_size=4, stride=2, padding=1)
25
+ self.conv3 = conv_block(ndf * 2, ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)
26
+ self.conv4 = conv_block(ndf * 4, ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)
27
+
28
+ # Final convolution layer to reduce to a single channel output
29
+ self.conv5 = nn.Conv3d(ndf * 8, 1, kernel_size=4, padding=1)
30
+
31
+ # Flatten layer
32
+ self.flatten = nn.Flatten()
33
+
34
+ # Fully connected layer to adjust output size
35
+ self.fc = nn.Linear(539, 1) # Adjust '539' based on the flattened output size
36
+
37
+ # Sigmoid activation to obtain a probability
38
+ self.sigmoid = nn.Sigmoid()
39
+
40
+ def forward(self, x):
41
+ """Defines the forward pass of the discriminator."""
42
+ x = self.conv1(x)
43
+ x = self.conv2(x)
44
+ x = self.conv3(x)
45
+ x = self.conv4(x)
46
+ x = self.conv5(x)
47
+ x = self.flatten(x)
48
+ x = self.fc(x)
49
+ x = self.sigmoid(x)
50
+ return x
51
+
network/generator.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ # Resnet Block
5
+ class ResnetBlock(nn.Module):
6
+ def __init__(self, inf, onf):
7
+ super(ResnetBlock, self).__init__()
8
+ self.conv_block = self.build_conv_block(inf, onf)
9
+
10
+ def build_conv_block(self, inf, onf):
11
+ conv_block = [
12
+ nn.Conv3d(inf, onf, kernel_size=3, stride=1, padding=1),
13
+ nn.BatchNorm3d(onf),
14
+ nn.LeakyReLU(0.2)
15
+ ]
16
+ conv_block += [
17
+ nn.Conv3d(onf, onf, kernel_size=3, stride=1, padding=1),
18
+ nn.BatchNorm3d(onf)
19
+ ]
20
+ return nn.Sequential(*conv_block)
21
+
22
+ def forward(self, x):
23
+ out = x + self.conv_block(x)
24
+ return out
25
+
26
+ # DeUpBlock for upsampling in the width dimension
27
+ class DeUpBlock(nn.Module):
28
+ def __init__(self, inf, onf):
29
+ super(DeUpBlock, self).__init__()
30
+ # Upsampling only in the width dimension
31
+ self.deupblock = nn.Sequential(
32
+ nn.ConvTranspose3d(inf, onf, kernel_size=(1, 6, 1), stride=(1, 6, 1), padding=(0, 0, 0)),
33
+ nn.LeakyReLU(0.2)
34
+ )
35
+
36
+ def forward(self, x):
37
+ return self.deupblock(x)
38
+
39
+ # Resnet Generator
40
+ class ResnetGenerator(nn.Module):
41
+ def __init__(self, input_nc=1, output_nc=1, ngf=16, n_residual_blocks=4):
42
+ super(ResnetGenerator, self).__init__()
43
+ self.n_residual_blocks = n_residual_blocks
44
+
45
+ self.conv_block1 = nn.Sequential(
46
+ nn.Conv3d(input_nc, ngf, kernel_size=3, padding=1),
47
+ nn.LeakyReLU(0.2)
48
+ )
49
+
50
+ for i in range(n_residual_blocks):
51
+ self.add_module(f'residual_block{i+1}', ResnetBlock(ngf, ngf))
52
+
53
+ self.conv_block2 = nn.Sequential(
54
+ nn.Conv3d(ngf, ngf, kernel_size=3, padding=1),
55
+ nn.BatchNorm3d(ngf)
56
+ )
57
+
58
+ self.deup = DeUpBlock(ngf, ngf)
59
+ self.conv3 = nn.Conv3d(ngf, output_nc, kernel_size=3, padding=1)
60
+
61
+ def forward(self, x):
62
+ x = self.conv_block1(x)
63
+ y = x.clone()
64
+ for i in range(self.n_residual_blocks):
65
+ y = self.__getattr__(f'residual_block{i+1}')(y)
66
+ x = self.conv_block2(y) + x
67
+ x = self.deup(x)
68
+ return self.conv3(x)
69
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ nibabel
2
+ numpy
3
+ tqdm
4
+ scipy
5
+ matplotlib
6
+ SimpleITK
7
+ torchio
8
+ antspyx
9
+ streamlit