hwonheo commited on
Commit
84d5807
·
verified ·
1 Parent(s): 7d5920a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -108
app.py CHANGED
@@ -1,35 +1,34 @@
1
- import streamlit as st
2
- import numpy as np
3
- import nibabel as nib
4
- import torch
5
- from scipy.ndimage import zoom
6
- from network.generator import ResnetGenerator
7
  import time
8
  from pathlib import Path
 
9
  import ants
10
- import tempfile
11
- import os
12
  import matplotlib.pyplot as plt
13
- import shutil
14
- from pathlib import Path
 
 
 
 
15
 
16
- # Class for handling MRI inference
17
  class MRIInference:
18
  def __init__(self, model, device, input_shape, output_shape):
 
19
  self.model = model
20
  self.device = device
21
  self.input_shape = input_shape
22
  self.output_shape = output_shape
23
 
24
  def load_image(self, file_path):
25
- # Load and preprocess the MRI image
26
  image = nib.load(file_path).get_fdata()
27
  rotated_image = np.rot90(image, k=1, axes=(1, 2))
28
- mean = np.mean(rotated_image)
29
- std = np.std(rotated_image)
30
  normalized_image = (rotated_image - mean) / std
31
- min_val = np.min(normalized_image)
32
- max_val = np.max(normalized_image)
33
  scale = 255 / (max_val - min_val)
34
  normalized_image = scale * (normalized_image - min_val)
35
  scale_factors = (
@@ -42,7 +41,7 @@ class MRIInference:
42
  resampled_image[np.newaxis, np.newaxis, ...], dtype=torch.float32)
43
 
44
  def save_image(self, image, file_name):
45
- # Save the processed image
46
  image = image.squeeze().cpu().numpy()
47
  scale_factors = (
48
  self.output_shape[0] / image.shape[0],
@@ -52,9 +51,9 @@ class MRIInference:
52
  resampled_image = zoom(image, scale_factors, order=1)
53
  resampled_image = np.rot90(resampled_image, k=-1, axes=(1, 2))
54
  nib.save(nib.Nifti1Image(resampled_image, np.eye(4)), file_name)
55
-
56
  def match_sform_affine(self, orig_path, gen_path):
57
- # Match the affine of original and generated images
58
  orig_img = nib.load(orig_path)
59
  orig_affine = orig_img.affine
60
  gen_img = nib.load(gen_path)
@@ -63,7 +62,7 @@ class MRIInference:
63
  nib.save(matched_gen_img, gen_path)
64
 
65
  def infer(self, input_tensor, original_file_path, output_path):
66
- # Inference process
67
  with torch.no_grad():
68
  self.model.eval()
69
  output = self.model(input_tensor.to(self.device))
@@ -87,18 +86,42 @@ class MRIInference:
87
  for temp_file in [temp_orig_path, temp_generated_path, resampled_generated_path]:
88
  os.remove(temp_file)
89
  return warped_file_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- # Functions for image processing and Streamlit UI handling
 
 
 
 
 
 
92
  def resample_to_isotropic(image_path, output_path):
 
93
  image = ants.image_read(image_path)
94
  resampled_image = ants.resample_image(
95
  image, (0.15, 0.15, 0.15), use_voxels=False, interp_type=4)
96
  ants.image_write(resampled_image, output_path)
97
  return output_path
98
 
99
-
100
  def affine_registration(fixed_image_path, moving_image_path, output_path):
101
- # Register affine of fixed and moving images
102
  fixed_image = ants.image_read(fixed_image_path)
103
  moving_image = ants.image_read(moving_image_path)
104
  registration = ants.registration(
@@ -108,7 +131,7 @@ def affine_registration(fixed_image_path, moving_image_path, output_path):
108
 
109
  @st.cache_data
110
  def load_model():
111
- # Load the trained model
112
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
113
  generator = ResnetGenerator().to(device)
114
  checkpoint_path = 'ckpt/ckpt_final/G_latest.pth'
@@ -116,11 +139,12 @@ def load_model():
116
  generator.load_state_dict(checkpoint)
117
  return generator, device
118
 
 
119
  generator, device = load_model()
120
  inference_engine = MRIInference(generator, device, (128, 32, 128), (128, 192, 128))
121
 
122
  def save_middle_slice(image, file_path):
123
- # Save middle slice of the MRI image
124
  middle_slice = image[image.shape[0] // 2]
125
  fig, ax = plt.subplots(figsize=(5, 5))
126
  ax.imshow(middle_slice, cmap='gray', aspect='auto')
@@ -130,7 +154,7 @@ def save_middle_slice(image, file_path):
130
  plt.close()
131
 
132
  def clear_output_folder(folder_path):
133
- # Clear contents of the specified folder
134
  for filename in os.listdir(folder_path):
135
  file_path = os.path.join(folder_path, filename)
136
  if os.path.isfile(file_path) or os.path.islink(file_path):
@@ -139,46 +163,38 @@ def clear_output_folder(folder_path):
139
  shutil.rmtree(file_path)
140
 
141
  def clear_session():
142
- # Clear the session state
143
  for key in list(st.session_state.keys()):
144
  del st.session_state[key]
145
 
 
146
  def main():
147
- # Sidebar - How to Use Guide
 
 
148
  st.sidebar.subheader("_How to Use EasySR_", divider='red')
149
-
150
  st.sidebar.markdown(
151
  "**Step-by-Step Guide:**\n\n"
152
- "1. **Prepare Your Data**: \n\n\tGet your rat brain T2 MRI data. "
153
- "Ensure it's in NIFTI format. Convert if necessary.\n\n"
154
- "2. **Upload Your MRI**: \n\n\tDrag and drop your NIFTI file or use "
155
- "the upload button.\n\n"
156
- "3. **Start the EasySR**: \n\n\tPress 'EasySR' and we'll handle the rest. "
157
- "The process is quick, typically taking just a few seconds to complete!\n\n"
158
- "4. **Sit Back and Relax**: \n\n\tNo long waits here - your data will be "
159
- "processed in under a minute.\n\n"
160
- "5. **View and Download**: \n\n\tAfter processing, view the results and "
161
- "use the download button to save the MRI.\n\n"
162
- "6. **Use as Needed**: \n\n\tDownload and use your enhanced MRI as you see fit. "
163
- "Get your data more!\n\n #"
164
- "#\n\n "
165
- "#\n\n\n "
166
- "#\n\n\n "
167
- "GitHub: EasySR"
168
- "\n\n "
169
- "[github.com/hwonheo/easysr](https://github.com/hwonheo/easysr)"
170
- "\n\n "
171
- "Huggingface (space): EasySR"
172
- "\n\n "
173
- "[huggingface.co/spaces/hwonheo/easysr]"
174
- "(https://huggingface.co/spaces/hwonheo/easysr)"
175
  )
176
 
177
- # Main function for Streamlit UI
178
  st.markdown("<h1 style='text-align: center;'>EasySR</h1>", unsafe_allow_html=True)
179
- st.subheader("_Easy WebUI App for Rat Brain MRI 3D SR-Recon DL Inference_", divider='red')
180
-
181
 
 
182
  original_slice_path = None
183
  inferred_slice_path = None
184
  download_file_path = None
@@ -187,76 +203,93 @@ def main():
187
  if not os.path.exists(output_path):
188
  os.makedirs(output_path)
189
 
190
-
191
  uploaded_file = st.file_uploader("_MRI File Upload (NIFTI)_",
192
- type=["nii", "nii.gz"], key='file_uploader')
193
 
194
  if uploaded_file is not None:
 
195
  st.session_state['uploaded_file'] = uploaded_file
196
  file_name = uploaded_file.name
 
 
197
  infer_button = st.button("EasySR (start inference)", type="primary")
198
 
199
  if infer_button:
 
200
  temp_dir = tempfile.gettempdir()
201
  temp_file_path = os.path.join(temp_dir, file_name)
202
 
 
203
  with open(temp_file_path, "wb") as tmp_file:
204
  tmp_file.write(uploaded_file.getvalue())
205
 
206
- try:
207
- input_tensor = inference_engine.load_image(temp_file_path)
208
- warped_image_path = inference_engine.infer(
209
- input_tensor, temp_file_path, output_path)
210
-
211
- gen_file_name = file_name.replace(".nii", "_gen.nii")
212
- download_file_path = os.path.join(output_path, gen_file_name)
213
- shutil.copy(warped_image_path, download_file_path)
214
-
215
- original_img = nib.load(temp_file_path).get_fdata()
216
- inferred_img = nib.load(warped_image_path).get_fdata()
217
-
218
- original_slice_path = os.path.join(output_path, "original_slice.jpg")
219
- inferred_slice_path = os.path.join(output_path, "inferred_slice.jpg")
220
- save_middle_slice(original_img, original_slice_path)
221
- save_middle_slice(inferred_img, inferred_slice_path)
222
-
223
- except Exception as e:
224
- st.error(f"Error during inference: {e}")
225
-
226
- if temp_file_path and os.path.exists(temp_file_path):
227
- os.remove(temp_file_path)
228
-
229
- if download_file_path and os.path.exists(download_file_path):
230
- with open(download_file_path, "rb") as file:
231
- st.download_button(
232
- label="Download (EasySR inferred-MRI)",
233
- data=file,
234
- file_name=gen_file_name,
235
- mime="application/gzip",
236
- type="primary"
237
- )
238
-
239
- if st.button('Clear Generated All',
240
- help='Caution: Pressing the Clear All button will delete the contents of the generate folder.'):
241
- clear_output_folder('infer/generate')
242
- clear_session()
243
- st.rerun()
244
-
245
- st.subheader("\n")
246
- st.subheader("\n")
247
- st.subheader("\n")
248
- if original_slice_path and os.path.exists(original_slice_path):
249
- st.subheader("Comparison of Inferred slice")
250
- col1, col2 = st.columns([0.5, 0.5])
251
- with col1:
252
- if original_slice_path and os.path.exists(original_slice_path):
253
  st.markdown("**Original**")
254
  st.image(original_slice_path, caption="Original MRI", width=300)
255
-
256
- with col2:
257
- if inferred_slice_path and os.path.exists(inferred_slice_path):
258
  st.markdown("**EasySR**")
259
  st.image(inferred_slice_path, caption="Inferred MRI", width=300)
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  if __name__ == '__main__':
262
- main()
 
 
1
+ import os
2
+ import tempfile
3
+ import threading
 
 
 
4
  import time
5
  from pathlib import Path
6
+ import shutil
7
  import ants
 
 
8
  import matplotlib.pyplot as plt
9
+ import nibabel as nib
10
+ import numpy as np
11
+ import streamlit as st
12
+ import torch
13
+ from network.generator import ResnetGenerator
14
+ from scipy.ndimage import zoom
15
 
16
+ # Class to handle MRI image inference
17
  class MRIInference:
18
  def __init__(self, model, device, input_shape, output_shape):
19
+ # Initialize with model, device, and shapes
20
  self.model = model
21
  self.device = device
22
  self.input_shape = input_shape
23
  self.output_shape = output_shape
24
 
25
  def load_image(self, file_path):
26
+ # Load and preprocess MRI image
27
  image = nib.load(file_path).get_fdata()
28
  rotated_image = np.rot90(image, k=1, axes=(1, 2))
29
+ mean, std = np.mean(rotated_image), np.std(rotated_image)
 
30
  normalized_image = (rotated_image - mean) / std
31
+ min_val, max_val = np.min(normalized_image), np.max(normalized_image)
 
32
  scale = 255 / (max_val - min_val)
33
  normalized_image = scale * (normalized_image - min_val)
34
  scale_factors = (
 
41
  resampled_image[np.newaxis, np.newaxis, ...], dtype=torch.float32)
42
 
43
  def save_image(self, image, file_name):
44
+ # Save processed image to file
45
  image = image.squeeze().cpu().numpy()
46
  scale_factors = (
47
  self.output_shape[0] / image.shape[0],
 
51
  resampled_image = zoom(image, scale_factors, order=1)
52
  resampled_image = np.rot90(resampled_image, k=-1, axes=(1, 2))
53
  nib.save(nib.Nifti1Image(resampled_image, np.eye(4)), file_name)
54
+
55
  def match_sform_affine(self, orig_path, gen_path):
56
+ # Match affine transformation of original and generated images
57
  orig_img = nib.load(orig_path)
58
  orig_affine = orig_img.affine
59
  gen_img = nib.load(gen_path)
 
62
  nib.save(matched_gen_img, gen_path)
63
 
64
  def infer(self, input_tensor, original_file_path, output_path):
65
+ # Perform inference on input tensor
66
  with torch.no_grad():
67
  self.model.eval()
68
  output = self.model(input_tensor.to(self.device))
 
86
  for temp_file in [temp_orig_path, temp_generated_path, resampled_generated_path]:
87
  os.remove(temp_file)
88
  return warped_file_path
89
+
90
+ # Perform inference and handle images
91
+ def run_inference(input_tensor, temp_file_path, output_path):
92
+ try:
93
+ warped_image_path = inference_engine.infer(
94
+ input_tensor, temp_file_path, output_path)
95
+
96
+ gen_file_name = temp_file_path.replace(".nii", "_gen.nii")
97
+ download_file_path = os.path.join(output_path, gen_file_name)
98
+ shutil.copy(warped_image_path, download_file_path)
99
+
100
+ original_img = nib.load(temp_file_path).get_fdata()
101
+ inferred_img = nib.load(warped_image_path).get_fdata()
102
+
103
+ original_slice_path = os.path.join(output_path, "original_slice.jpg")
104
+ inferred_slice_path = os.path.join(output_path, "inferred_slice.jpg")
105
+ save_middle_slice(original_img, original_slice_path)
106
+ save_middle_slice(inferred_img, inferred_slice_path)
107
 
108
+ return (original_slice_path, inferred_slice_path,
109
+ download_file_path, gen_file_name)
110
+ except Exception as e:
111
+ st.error(f"Error during inference: {e}")
112
+ return None, None, None, None
113
+
114
+ # Image processing functions
115
  def resample_to_isotropic(image_path, output_path):
116
+ # Resample image to isotropic resolution
117
  image = ants.image_read(image_path)
118
  resampled_image = ants.resample_image(
119
  image, (0.15, 0.15, 0.15), use_voxels=False, interp_type=4)
120
  ants.image_write(resampled_image, output_path)
121
  return output_path
122
 
 
123
  def affine_registration(fixed_image_path, moving_image_path, output_path):
124
+ # Perform affine registration between two images
125
  fixed_image = ants.image_read(fixed_image_path)
126
  moving_image = ants.image_read(moving_image_path)
127
  registration = ants.registration(
 
131
 
132
  @st.cache_data
133
  def load_model():
134
+ # Load pre-trained model
135
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
136
  generator = ResnetGenerator().to(device)
137
  checkpoint_path = 'ckpt/ckpt_final/G_latest.pth'
 
139
  generator.load_state_dict(checkpoint)
140
  return generator, device
141
 
142
+ # Initialize model and inference engine
143
  generator, device = load_model()
144
  inference_engine = MRIInference(generator, device, (128, 32, 128), (128, 192, 128))
145
 
146
  def save_middle_slice(image, file_path):
147
+ # Save the middle slice of the MRI image
148
  middle_slice = image[image.shape[0] // 2]
149
  fig, ax = plt.subplots(figsize=(5, 5))
150
  ax.imshow(middle_slice, cmap='gray', aspect='auto')
 
154
  plt.close()
155
 
156
  def clear_output_folder(folder_path):
157
+ # Clear contents of a specified folder
158
  for filename in os.listdir(folder_path):
159
  file_path = os.path.join(folder_path, filename)
160
  if os.path.isfile(file_path) or os.path.islink(file_path):
 
163
  shutil.rmtree(file_path)
164
 
165
  def clear_session():
166
+ # Clear Streamlit session state
167
  for key in list(st.session_state.keys()):
168
  del st.session_state[key]
169
 
170
+ # Main function for Streamlit UI
171
  def main():
172
+ global original_slice_path, inferred_slice_path, download_file_path, gen_file_name
173
+
174
+ # Setup sidebar with instructions
175
  st.sidebar.subheader("_How to Use EasySR_", divider='red')
 
176
  st.sidebar.markdown(
177
  "**Step-by-Step Guide:**\n\n"
178
+ "1. **Prepare Your Data**: Make sure your rat brain MRI data "
179
+ "is in NIFTI format. Convert if needed.\n\n"
180
+ "2. **Upload Your MRI**: Drag and drop your NIFTI file "
181
+ "or use the upload button.\n\n"
182
+ "3. **Start the EasySR**: Click 'EasySR' to begin processing. "
183
+ "It usually takes a few minutes.\n\n"
184
+ "4. **Sit Back and Relax**: Wait while your data is processed quickly.\n\n"
185
+ "5. **View and Download**: After processing, view the results and "
186
+ "use the download button to save the enhanced MRI data.\n\n"
187
+ "6. **Use as Needed**: Download and utilize your enhanced MRI. "
188
+ "Continue using EasySR for more enhancements.\n\n"
189
+ ":rocket: :red[*EasySR*] \t [Github](https://github.com/hwonheo/easysr)\n\n"
190
+ ":hugging_face: :orange[*EasySR*] \t [Huggingface](https://huggingface.co/spaces/hwonheo/easysr)"
 
 
 
 
 
 
 
 
 
 
191
  )
192
 
193
+ # Main interface layout
194
  st.markdown("<h1 style='text-align: center;'>EasySR</h1>", unsafe_allow_html=True)
195
+ st.subheader("_Easy Web UI for Generative 3D Inference of Rat Brain MRI_", divider='red')
 
196
 
197
+ # Initialize paths for processing results
198
  original_slice_path = None
199
  inferred_slice_path = None
200
  download_file_path = None
 
203
  if not os.path.exists(output_path):
204
  os.makedirs(output_path)
205
 
206
+ # File uploader for MRI files
207
  uploaded_file = st.file_uploader("_MRI File Upload (NIFTI)_",
208
+ type=["nii", "nii.gz"], key='file_uploader')
209
 
210
  if uploaded_file is not None:
211
+ # Store uploaded file in session state
212
  st.session_state['uploaded_file'] = uploaded_file
213
  file_name = uploaded_file.name
214
+
215
+ # Inference start button
216
  infer_button = st.button("EasySR (start inference)", type="primary")
217
 
218
  if infer_button:
219
+ # Temporary directory for file processing
220
  temp_dir = tempfile.gettempdir()
221
  temp_file_path = os.path.join(temp_dir, file_name)
222
 
223
+ # Write uploaded file to temp directory
224
  with open(temp_file_path, "wb") as tmp_file:
225
  tmp_file.write(uploaded_file.getvalue())
226
 
227
+ # Load image and start inference in a thread
228
+ input_tensor = inference_engine.load_image(temp_file_path)
229
+
230
+ def inference_wrapper():
231
+ # Running inference and processing results
232
+ global original_slice_path, inferred_slice_path, download_file_path, gen_file_name
233
+ try:
234
+ warped_image_path = inference_engine.infer(
235
+ input_tensor, temp_file_path, output_path)
236
+ gen_file_name = file_name.replace(".nii", "_gen.nii")
237
+ download_file_path = os.path.join(output_path, gen_file_name)
238
+ shutil.copy(warped_image_path, download_file_path)
239
+
240
+ # Load original and inferred images for display
241
+ original_img = nib.load(temp_file_path).get_fdata()
242
+ inferred_img = nib.load(warped_image_path).get_fdata()
243
+ original_slice_path = os.path.join(output_path, "original_slice.jpg")
244
+ inferred_slice_path = os.path.join(output_path, "inferred_slice.jpg")
245
+
246
+ # Save middle slice of both images for comparison
247
+ save_middle_slice(original_img, original_slice_path)
248
+ save_middle_slice(inferred_img, inferred_slice_path)
249
+ except Exception as e:
250
+ st.error(f"Error during inference: {e}")
251
+ finally:
252
+ if temp_file_path and os.path.exists(temp_file_path):
253
+ os.remove(temp_file_path)
254
+
255
+ # Start thread for inference
256
+ inference_thread = threading.Thread(target=inference_wrapper)
257
+ inference_thread.start()
258
+
259
+ # Display spinner while processing
260
+ with st.spinner("Processing your MRI image..."):
261
+ inference_thread.join()
262
+
263
+ # Display comparison images and download button after processing
264
+ if original_slice_path and os.path.exists(original_slice_path) \
265
+ and inferred_slice_path and os.path.exists(inferred_slice_path):
266
+ st.subheader("Comparison of Original and EasySR Inferred Slice")
267
+ col1, col2 = st.columns([0.5, 0.5])
268
+ with col1:
 
 
 
 
 
269
  st.markdown("**Original**")
270
  st.image(original_slice_path, caption="Original MRI", width=300)
271
+ with col2:
 
 
272
  st.markdown("**EasySR**")
273
  st.image(inferred_slice_path, caption="Inferred MRI", width=300)
274
 
275
+ if download_file_path and os.path.exists(download_file_path):
276
+ with open(download_file_path, "rb") as file:
277
+ st.download_button(
278
+ label="Download (EasySR Inferred-MRI)",
279
+ data=file,
280
+ file_name=gen_file_name,
281
+ mime="application/gzip",
282
+ type="primary"
283
+ )
284
+
285
+ # Button to clear generated content
286
+ if st.button('Clear Generated All',
287
+ help='Pressing this will delete the contents of the generate folder.'):
288
+ clear_output_folder('infer/generate')
289
+ clear_session()
290
+ st.rerun()
291
+
292
+ # Entry point for the Streamlit application
293
  if __name__ == '__main__':
294
+ main()
295
+