hwonheo commited on
Commit
942a94f
·
verified ·
1 Parent(s): 604568a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +244 -142
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  import tempfile
3
  import threading
4
  import time
@@ -23,22 +25,24 @@ class MRIInference:
23
  self.output_shape = output_shape
24
 
25
  def load_image(self, file_path):
26
- # Load and preprocess MRI image
27
- image = nib.load(file_path).get_fdata()
28
- rotated_image = np.rot90(image, k=1, axes=(1, 2))
29
- mean, std = np.mean(rotated_image), np.std(rotated_image)
30
- normalized_image = (rotated_image - mean) / std
31
- min_val, max_val = np.min(normalized_image), np.max(normalized_image)
 
 
32
  scale = 255 / (max_val - min_val)
33
- normalized_image = scale * (normalized_image - min_val)
 
34
  scale_factors = (
35
  self.input_shape[0] / normalized_image.shape[0],
36
  self.input_shape[1] / normalized_image.shape[1],
37
  self.input_shape[2] / normalized_image.shape[2]
38
  )
39
- resampled_image = zoom(normalized_image, scale_factors, order=1)
40
- return torch.tensor(
41
- resampled_image[np.newaxis, np.newaxis, ...], dtype=torch.float32)
42
 
43
  def save_image(self, image, file_name):
44
  # Save processed image to file
@@ -48,8 +52,7 @@ class MRIInference:
48
  self.output_shape[1] / image.shape[1],
49
  self.output_shape[2] / image.shape[2]
50
  )
51
- resampled_image = zoom(image, scale_factors, order=1)
52
- resampled_image = np.rot90(resampled_image, k=-1, axes=(1, 2))
53
  nib.save(nib.Nifti1Image(resampled_image, np.eye(4)), file_name)
54
 
55
  def match_sform_affine(self, orig_path, gen_path):
@@ -61,62 +64,77 @@ class MRIInference:
61
  matched_gen_img = nib.Nifti1Image(gen_data, orig_affine)
62
  nib.save(matched_gen_img, gen_path)
63
 
64
- def infer(self, input_tensor, original_file_path, output_path):
 
 
 
65
  # Perform inference on input tensor
66
  with torch.no_grad():
67
  self.model.eval()
68
  output = self.model(input_tensor.to(self.device))
69
- scale_factor = (1, 1, self.output_shape[2] / output.shape[4])
 
 
 
 
 
 
70
  resampled_output = zoom(
71
- output.squeeze().cpu().numpy(), scale_factor, order=1)
72
  generated_image = torch.tensor(resampled_output[np.newaxis, ...])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  temp_orig_path = os.path.join(output_path, 'temp_orig.nii.gz')
74
  resampled_file_path = resample_to_isotropic(
75
  original_file_path, temp_orig_path)
76
- temp_generated_path = os.path.join(output_path, 'temp_generated.nii.gz')
77
- self.save_image(generated_image, temp_generated_path)
78
  self.match_sform_affine(resampled_file_path, temp_generated_path)
 
79
  resampled_generated_path = os.path.join(output_path, 'resampled_generated.nii.gz')
80
  resample_to_isotropic(temp_generated_path, resampled_generated_path)
 
81
  base_name = os.path.basename(original_file_path)
82
  gen_file_name = f"{Path(base_name).stem}_{int(time.time())}_gen.nii.gz"
83
  warped_file_path = os.path.join(output_path, gen_file_name)
84
  affine_registration(
85
- resampled_file_path, resampled_generated_path, warped_file_path)
 
 
86
  for temp_file in [temp_orig_path, temp_generated_path, resampled_generated_path]:
87
  os.remove(temp_file)
 
88
  return warped_file_path
89
-
90
- # Perform inference and handle images
91
- def run_inference(input_tensor, temp_file_path, output_path):
92
- try:
93
- warped_image_path = inference_engine.infer(
94
- input_tensor, temp_file_path, output_path)
95
-
96
- gen_file_name = temp_file_path.replace(".nii", "_gen.nii")
97
- download_file_path = os.path.join(output_path, gen_file_name)
98
- shutil.copy(warped_image_path, download_file_path)
99
-
100
- original_img = nib.load(temp_file_path).get_fdata()
101
- inferred_img = nib.load(warped_image_path).get_fdata()
102
-
103
- original_slice_path = os.path.join(output_path, "original_slice.jpg")
104
- inferred_slice_path = os.path.join(output_path, "inferred_slice.jpg")
105
- save_middle_slice(original_img, original_slice_path)
106
- save_middle_slice(inferred_img, inferred_slice_path)
107
-
108
- return (original_slice_path, inferred_slice_path,
109
- download_file_path, gen_file_name)
110
- except Exception as e:
111
- st.error(f"Error during inference: {e}")
112
- return None, None, None, None
113
 
114
  # Image processing functions
115
  def resample_to_isotropic(image_path, output_path):
116
  # Resample image to isotropic resolution
117
  image = ants.image_read(image_path)
118
  resampled_image = ants.resample_image(
119
- image, (0.15, 0.15, 0.15), use_voxels=False, interp_type=4)
120
  ants.image_write(resampled_image, output_path)
121
  return output_path
122
 
@@ -126,33 +144,123 @@ def affine_registration(fixed_image_path, moving_image_path, output_path):
126
  moving_image = ants.image_read(moving_image_path)
127
  registration = ants.registration(
128
  fixed=fixed_image, moving=moving_image,
129
- type_of_transform='Elastic')
130
  ants.image_write(registration['warpedmovout'], output_path)
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  @st.cache_data
133
- def load_model():
134
- # Load pre-trained model
135
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
136
  generator = ResnetGenerator().to(device)
137
- checkpoint_path = 'ckpt/ckpt_final/G_latest.pth'
 
 
 
 
 
138
  checkpoint = torch.load(checkpoint_path, map_location=device)
139
  generator.load_state_dict(checkpoint)
140
  return generator, device
141
 
142
- # Initialize model and inference engine
143
- generator, device = load_model()
144
- inference_engine = MRIInference(generator, device, (128, 32, 128), (128, 192, 128))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  def save_middle_slice(image, file_path):
147
  # Save the middle slice of the MRI image
148
  middle_slice = image[image.shape[0] // 2]
 
 
 
 
149
  fig, ax = plt.subplots(figsize=(5, 5))
150
- ax.imshow(middle_slice, cmap='gray', aspect='auto')
151
  ax.axis('off')
152
  plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
153
  plt.savefig(file_path, format='jpg', bbox_inches='tight', pad_inches=0, dpi=500)
154
  plt.close()
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def clear_output_folder(folder_path):
157
  # Clear contents of a specified folder
158
  for filename in os.listdir(folder_path):
@@ -169,27 +277,46 @@ def clear_session():
169
 
170
  # Main function for Streamlit UI
171
  def main():
172
- global original_slice_path, inferred_slice_path, download_file_path, gen_file_name
173
 
174
- # Setup sidebar with instructions
175
- st.sidebar.subheader("_How to Use EasySR_", divider='red')
176
  st.sidebar.markdown(
177
- "**Step-by-Step Guide:**\n\n"
178
- "1. **Prepare Your Data**: Make sure your rat brain MRI data "
179
- "is in NIFTI format. Convert if needed.\n\n"
180
- "2. **Upload Your MRI**: Drag and drop your NIFTI file "
181
- "or use the upload button.\n\n"
182
- "3. **Start the EasySR**: Click 'EasySR' to begin processing. "
183
- "It usually takes a few minutes.\n\n"
184
- "4. **Sit Back and Relax**: Wait while your data is processed quickly.\n\n"
185
- "5. **View and Download**: After processing, view the results and "
186
- "use the download button to save the enhanced MRI data.\n\n"
187
- "6. **Use as Needed**: Download and utilize your enhanced MRI. "
188
- "Continue using EasySR for more enhancements.\n\n"
189
- ":rocket: :red[*EasySR*] \t [Github](https://github.com/hwonheo/easysr)\n\n"
190
- ":hugging_face: :orange[*EasySR*] \t [Huggingface](https://huggingface.co/spaces/hwonheo/easysr)"
191
  )
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  # Main interface layout
194
  st.markdown("<h1 style='text-align: center;'>EasySR</h1>", unsafe_allow_html=True)
195
  st.subheader("_Easy Web UI for Generative 3D Inference of Rat Brain MRI_", divider='red')
@@ -206,88 +333,63 @@ def main():
206
  # File uploader for MRI files
207
  uploaded_file = st.file_uploader("_MRI File Upload (NIFTI)_",
208
  type=["nii", "nii.gz"], key='file_uploader')
 
 
 
 
 
 
 
 
 
209
 
210
  if uploaded_file is not None:
211
  # Store uploaded file in session state
212
  st.session_state['uploaded_file'] = uploaded_file
213
  file_name = uploaded_file.name
214
 
 
 
 
 
 
 
 
 
215
  # Inference start button
216
- infer_button = st.button("EasySR (start inference)", type="primary")
217
-
218
- if infer_button:
219
- # Temporary directory for file processing
220
- temp_dir = tempfile.gettempdir()
221
- temp_file_path = os.path.join(temp_dir, file_name)
222
-
223
- # Write uploaded file to temp directory
224
- with open(temp_file_path, "wb") as tmp_file:
225
- tmp_file.write(uploaded_file.getvalue())
226
-
227
- # Load image and start inference in a thread
228
- input_tensor = inference_engine.load_image(temp_file_path)
229
-
230
- def inference_wrapper():
231
- # Running inference and processing results
232
- global original_slice_path, inferred_slice_path, download_file_path, gen_file_name
233
- try:
234
- warped_image_path = inference_engine.infer(
235
- input_tensor, temp_file_path, output_path)
236
- gen_file_name = file_name.replace(".nii", "_gen.nii")
237
- download_file_path = os.path.join(output_path, gen_file_name)
238
- shutil.copy(warped_image_path, download_file_path)
239
-
240
- # Load original and inferred images for display
241
- original_img = nib.load(temp_file_path).get_fdata()
242
- inferred_img = nib.load(warped_image_path).get_fdata()
243
- original_slice_path = os.path.join(output_path, "original_slice.jpg")
244
- inferred_slice_path = os.path.join(output_path, "inferred_slice.jpg")
245
-
246
- # Save middle slice of both images for comparison
247
- save_middle_slice(original_img, original_slice_path)
248
- save_middle_slice(inferred_img, inferred_slice_path)
249
- except Exception as e:
250
- st.error(f"Error during inference: {e}")
251
- finally:
252
- if temp_file_path and os.path.exists(temp_file_path):
253
- os.remove(temp_file_path)
254
-
255
- # Start thread for inference
256
- inference_thread = threading.Thread(target=inference_wrapper)
257
- inference_thread.start()
258
-
259
- # Display spinner while processing
260
- with st.spinner("Processing your MRI image..."):
261
- inference_thread.join()
262
-
263
- # Display comparison images and download button after processing
264
- if original_slice_path and os.path.exists(original_slice_path) \
265
- and inferred_slice_path and os.path.exists(inferred_slice_path):
266
- st.subheader("Comparison of Original and EasySR Inferred Slice")
267
- col1, col2 = st.columns([0.5, 0.5])
268
- with col1:
269
- st.markdown("**Original**")
270
- st.image(original_slice_path, caption="Original MRI", width=300)
271
- with col2:
272
- st.markdown("**EasySR**")
273
- st.image(inferred_slice_path, caption="Inferred MRI", width=300)
274
-
275
- if download_file_path and os.path.exists(download_file_path):
276
- with open(download_file_path, "rb") as file:
277
- st.download_button(
278
- label="Download (EasySR Inferred-MRI)",
279
- data=file,
280
- file_name=gen_file_name,
281
- mime="application/gzip",
282
- type="primary"
283
- )
284
-
285
- # Button to clear generated content
286
- if st.button('Clear Generated All',
287
  help='Pressing this will delete the contents of the generate folder.'):
288
- clear_output_folder('infer/generate')
289
- clear_session()
290
- st.rerun()
291
 
292
  # Entry point for the Streamlit application
293
  if __name__ == '__main__':
 
1
  import os
2
+ import sys
3
+ import subprocess
4
  import tempfile
5
  import threading
6
  import time
 
25
  self.output_shape = output_shape
26
 
27
  def load_image(self, file_path):
28
+ # Load the image using nibabel
29
+ nib_image = nib.load(file_path)
30
+
31
+ image_data = nib_image.get_fdata()
32
+ rotated_image = np.rot90(image_data, k=1, axes=(1, 2))
33
+
34
+ # Standard normalization to 0-255
35
+ min_val, max_val = np.min(rotated_image), np.max(rotated_image)
36
  scale = 255 / (max_val - min_val)
37
+ normalized_image = scale * (rotated_image - min_val)
38
+
39
  scale_factors = (
40
  self.input_shape[0] / normalized_image.shape[0],
41
  self.input_shape[1] / normalized_image.shape[1],
42
  self.input_shape[2] / normalized_image.shape[2]
43
  )
44
+ resampled_image = zoom(normalized_image, scale_factors, order=3)
45
+ return torch.tensor(resampled_image[np.newaxis, np.newaxis, ...], dtype=torch.float32)
 
46
 
47
  def save_image(self, image, file_name):
48
  # Save processed image to file
 
52
  self.output_shape[1] / image.shape[1],
53
  self.output_shape[2] / image.shape[2]
54
  )
55
+ resampled_image = zoom(image, scale_factors, order=3)
 
56
  nib.save(nib.Nifti1Image(resampled_image, np.eye(4)), file_name)
57
 
58
  def match_sform_affine(self, orig_path, gen_path):
 
64
  matched_gen_img = nib.Nifti1Image(gen_data, orig_affine)
65
  nib.save(matched_gen_img, gen_path)
66
 
67
+ def infer(self, aligned_image_path, original_file_path, output_path):
68
+ # Load and preprocess the image from aligned_image_path
69
+ input_tensor = self.load_image(aligned_image_path)
70
+
71
  # Perform inference on input tensor
72
  with torch.no_grad():
73
  self.model.eval()
74
  output = self.model(input_tensor.to(self.device))
75
+
76
+ # Resample output to target shape
77
+ scale_factor = (
78
+ self.output_shape[0] / output.shape[2],
79
+ self.output_shape[1] / output.shape[3],
80
+ self.output_shape[2] / output.shape[4]
81
+ )
82
  resampled_output = zoom(
83
+ output.squeeze().cpu().numpy(), scale_factor, order=3)
84
  generated_image = torch.tensor(resampled_output[np.newaxis, ...])
85
+
86
+ # Save the generated image
87
+ temp_generated_path = os.path.join(output_path, 'temp_generated.nii.gz')
88
+ self.save_image(generated_image, temp_generated_path)
89
+
90
+ # Get and print orientation code of the original image
91
+ orig_img = ants.image_read(original_file_path)
92
+ orig_orientation = ants.get_orientation(orig_img)
93
+
94
+ # Reorient the generated image based on original orientation
95
+ gen_img = nib.load(temp_generated_path)
96
+ gen_data = gen_img.get_fdata()
97
+ reoriented_image = ants.from_numpy(gen_data)
98
+ #print(f"Orientation of the original image: {orig_orientation}") ## print Orientation ##
99
+
100
+ if orig_orientation == 'LSP':
101
+ reoriented_image = ants.reorient_image2(reoriented_image, 'RAI')
102
+ elif orig_orientation == 'LPI':
103
+ reoriented_image = ants.reorient_image2(reoriented_image, 'RIP')
104
+ elif orig_orientation == 'RAS':
105
+ reoriented_image = ants.reorient_image2(reoriented_image, 'LSA')
106
+ # No reorientation for other cases
107
+
108
+ # Save the reoriented image
109
+ nib.save(nib.Nifti1Image(reoriented_image.numpy(), np.eye(4)), temp_generated_path)
110
+
111
+ # Match affine and resample
112
  temp_orig_path = os.path.join(output_path, 'temp_orig.nii.gz')
113
  resampled_file_path = resample_to_isotropic(
114
  original_file_path, temp_orig_path)
 
 
115
  self.match_sform_affine(resampled_file_path, temp_generated_path)
116
+
117
  resampled_generated_path = os.path.join(output_path, 'resampled_generated.nii.gz')
118
  resample_to_isotropic(temp_generated_path, resampled_generated_path)
119
+
120
  base_name = os.path.basename(original_file_path)
121
  gen_file_name = f"{Path(base_name).stem}_{int(time.time())}_gen.nii.gz"
122
  warped_file_path = os.path.join(output_path, gen_file_name)
123
  affine_registration(
124
+ resampled_file_path, temp_generated_path, warped_file_path)
125
+
126
+ # Remove temporary files
127
  for temp_file in [temp_orig_path, temp_generated_path, resampled_generated_path]:
128
  os.remove(temp_file)
129
+
130
  return warped_file_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  # Image processing functions
133
  def resample_to_isotropic(image_path, output_path):
134
  # Resample image to isotropic resolution
135
  image = ants.image_read(image_path)
136
  resampled_image = ants.resample_image(
137
+ image, (0.15, 0.15, 0.15), use_voxels=False, interp_type=3)
138
  ants.image_write(resampled_image, output_path)
139
  return output_path
140
 
 
144
  moving_image = ants.image_read(moving_image_path)
145
  registration = ants.registration(
146
  fixed=fixed_image, moving=moving_image,
147
+ type_of_transform='Rigid')
148
  ants.image_write(registration['warpedmovout'], output_path)
149
 
150
+ def align_to_template(resampled_image_path, template_path, output_path):
151
+ # Align the resampled image to the template
152
+ moving_image = ants.image_read(resampled_image_path)
153
+ fixed_image = ants.image_read(template_path)
154
+ registration = ants.registration(
155
+ fixed=fixed_image, moving=moving_image,
156
+ type_of_transform='Rigid')
157
+ aligned_image = registration['warpedmovout']
158
+ ants.image_write(aligned_image, output_path)
159
+ return output_path
160
+
161
+ def download_model_if_needed(templates_folder):
162
+ """Downloads model from Hugging Face if template folder is empty or doesn't exist."""
163
+ if not os.path.exists(templates_folder) or not os.listdir(templates_folder):
164
+ print("Downloading model from Hugging Face...")
165
+ os.makedirs(templates_folder, exist_ok=True)
166
+ subprocess.run(["huggingface-cli", "download", "hwonheo/easysr_templates",
167
+ "--local-dir", "templates", "--local-dir-use-symlinks", "False"], check=True)
168
+
169
  @st.cache_data
170
+ def load_model(model_choice):
171
+ # Load pre-trained model based on user selection
172
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
173
  generator = ResnetGenerator().to(device)
174
+
175
+ if model_choice == "T1-Model":
176
+ checkpoint_path = 'ckpt/ckpt_final/G_latest_T1.pth'
177
+ else: # "Mixed-Model"
178
+ checkpoint_path = 'ckpt/ckpt_final/G_latest_Mixed.pth'
179
+
180
  checkpoint = torch.load(checkpoint_path, map_location=device)
181
  generator.load_state_dict(checkpoint)
182
  return generator, device
183
 
184
+ def run_bias_field_correction(file_path, output_path, correction_type):
185
+ """Bias field correction script and return corrected file path"""
186
+ corrected_file_name = os.path.basename(file_path).replace('.nii', '_corrected.nii')
187
+ corrected_file_path = os.path.join(output_path, corrected_file_name)
188
+
189
+ subprocess.run([
190
+ sys.executable, "utils/BiasFieldCorrection.py",
191
+ "--input", file_path,
192
+ "--output", output_path,
193
+ "--type", correction_type
194
+ ])
195
+
196
+ # Rename the processed file if necessary
197
+ original_corrected_file_path = os.path.join(output_path, os.path.basename(file_path))
198
+ if os.path.exists(original_corrected_file_path) and original_corrected_file_path != corrected_file_path:
199
+ shutil.move(original_corrected_file_path, corrected_file_path)
200
+
201
+ return corrected_file_path
202
+
203
+ # Perform inference and handle images
204
+ def run_inference(inference_engine, aligned_image_path, original_file_path, output_path):
205
+ try:
206
+ # Perform inference using the aligned image and original file path
207
+ warped_image_path = inference_engine.infer(aligned_image_path, original_file_path, output_path)
208
+
209
+ # Generate file name for output
210
+ gen_file_name = os.path.basename(original_file_path).replace(".nii", "_gen.nii")
211
+ download_file_path = os.path.join(output_path, gen_file_name)
212
+
213
+ # Copy the processed file to the download path
214
+ shutil.copy(warped_image_path, download_file_path)
215
+
216
+ # Load original and inferred images for display
217
+ original_img = nib.load(original_file_path).get_fdata()
218
+ inferred_img = nib.load(warped_image_path).get_fdata()
219
+
220
+ # Save middle slices of both images for comparison
221
+ original_slice_path = os.path.join(output_path, "original_slice.jpg")
222
+ inferred_slice_path = os.path.join(output_path, "inferred_slice.jpg")
223
+ save_middle_slice(original_img, original_slice_path)
224
+ save_middle_slice(inferred_img, inferred_slice_path)
225
+
226
+ # Return paths for UI display
227
+ return (original_slice_path, inferred_slice_path, download_file_path, gen_file_name)
228
+ except Exception as e:
229
+ st.error(f"Error during inference: {e}")
230
+ return None, None, None, None
231
 
232
  def save_middle_slice(image, file_path):
233
  # Save the middle slice of the MRI image
234
  middle_slice = image[image.shape[0] // 2]
235
+
236
+ # Rotate the image 90 degrees counterclockwise
237
+ rotated_slice = np.rot90(middle_slice)
238
+
239
  fig, ax = plt.subplots(figsize=(5, 5))
240
+ ax.imshow(rotated_slice, cmap='gray', aspect='auto')
241
  ax.axis('off')
242
  plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
243
  plt.savefig(file_path, format='jpg', bbox_inches='tight', pad_inches=0, dpi=500)
244
  plt.close()
245
 
246
+ def display_results(original_slice_path, inferred_slice_path, download_file_path, gen_file_name):
247
+ st.subheader("Comparison of Original and EasySR Inferred Slice")
248
+ col1, col2 = st.columns([0.5, 0.5])
249
+ with col1:
250
+ st.image(original_slice_path, caption="Original MRI", width=300)
251
+ with col2:
252
+ st.image(inferred_slice_path, caption="Inferred MRI", width=300)
253
+
254
+ if os.path.exists(download_file_path):
255
+ with open(download_file_path, "rb") as file:
256
+ st.download_button(
257
+ label="Download (EasySR Inferred-MRI)",
258
+ data=file,
259
+ file_name=gen_file_name,
260
+ mime="application/gzip",
261
+ type="primary"
262
+ )
263
+
264
  def clear_output_folder(folder_path):
265
  # Clear contents of a specified folder
266
  for filename in os.listdir(folder_path):
 
277
 
278
  # Main function for Streamlit UI
279
  def main():
280
+ global original_slice_path, inferred_slice_path, download_file_path, gen_file_name, intensity_adjust
281
 
282
+ st.sidebar.markdown("# ")
 
283
  st.sidebar.markdown(
284
+ "[![git](https://img.icons8.com/material-outlined/48/000000/github.png)]"
285
+ "(https://github.com/hwonheo/easysr)"
286
+ )
287
+ st.sidebar.markdown("# ")
288
+
289
+ # Setup sidebar with instructions and model selection
290
+ st.sidebar.subheader("*Model Selection*", divider='red')
291
+ model_choice = st.sidebar.selectbox(
292
+ "Choose the model type:",
293
+ ("Mixed-Model", "T1-Model"),
294
+ index=1 # Default is Combined-Model
 
 
 
295
  )
296
 
297
+ st.sidebar.header("\n")
298
+
299
+ # Setup sidebar with instructions
300
+ st.sidebar.subheader("_How to Use EasySR_", divider='red')
301
+ with st.sidebar.expander("Step-by-Step Guide:"):
302
+ st.markdown(
303
+ "1. **Prepare Your Data**: Make sure your rat brain MRI data "
304
+ "is in NIFTI format. Convert if needed.\n\n"
305
+ "2. **Upload Your MRI**: Drag and drop your NIFTI file "
306
+ "or use the upload button.\n\n"
307
+ "3. **Start the EasySR**: Click 'EasySR' to begin processing. "
308
+ "It usually takes a few minutes.\n\n"
309
+ "4. **Sit Back and Relax**: Wait while your data is processed quickly.\n\n"
310
+ "5. **View and Download**: After processing, view the results and "
311
+ "use the download button to save the enhanced MRI data.\n\n"
312
+ "6. **Use as Needed**: Download and utilize your enhanced MRI. "
313
+ "Continue using EasySR for more enhancements.\n\n"
314
+ )
315
+
316
+ # Initialize model and inference engine with the selected model
317
+ generator, device = load_model(model_choice)
318
+ inference_engine = MRIInference(generator, device, (128, 128, 64), (128, 128, 192))
319
+
320
  # Main interface layout
321
  st.markdown("<h1 style='text-align: center;'>EasySR</h1>", unsafe_allow_html=True)
322
  st.subheader("_Easy Web UI for Generative 3D Inference of Rat Brain MRI_", divider='red')
 
333
  # File uploader for MRI files
334
  uploaded_file = st.file_uploader("_MRI File Upload (NIFTI)_",
335
  type=["nii", "nii.gz"], key='file_uploader')
336
+
337
+ # Checkbox for intensity adjustment
338
+ intensity_adjust = st.checkbox("Bias Field Correction (enhance signal intensity)",
339
+ help="Apply intensity truncation and bias correction to an image: "
340
+ "Check this option if the input image exhibits low signal intensity "
341
+ "(common in T2RARE, TOF, etc.) or if the output from the inference "
342
+ "process appears weakly signaled. This will enhance the signals by "
343
+ "N4-bias correction and very low- or high-signal intensity truncation, "
344
+ "yielding clearer and more defined results.")
345
 
346
  if uploaded_file is not None:
347
  # Store uploaded file in session state
348
  st.session_state['uploaded_file'] = uploaded_file
349
  file_name = uploaded_file.name
350
 
351
+ # Temporary directory for file processing
352
+ temp_dir = tempfile.gettempdir()
353
+ temp_file_path = os.path.join(temp_dir, file_name)
354
+
355
+ # Write uploaded file to temp directory
356
+ with open(temp_file_path, "wb") as tmp_file:
357
+ tmp_file.write(uploaded_file.getvalue())
358
+
359
  # Inference start button
360
+ if st.button("EasySR (start inference)", type="primary"):
361
+ try:
362
+ # Bias Field Correction
363
+ corrected_file_path = run_bias_field_correction(
364
+ temp_file_path, temp_dir, "abp") if intensity_adjust else temp_file_path
365
+
366
+ # Ensure template files are available
367
+ templates_folder = "templates"
368
+ download_model_if_needed(templates_folder)
369
+ template_path = os.path.join(templates_folder, "bmc_t2_rat.nii.gz")
370
+
371
+ # Resample and align the image
372
+ resampled_path = resample_to_isotropic(
373
+ corrected_file_path, os.path.join(temp_dir, "resampled.nii.gz"))
374
+ aligned_path = align_to_template(
375
+ resampled_path, template_path, os.path.join(temp_dir, "aligned.nii.gz"))
376
+
377
+ # Perform inference and process results
378
+ original_slice_path, inferred_slice_path, download_file_path, gen_file_name = run_inference(
379
+ inference_engine, aligned_path, corrected_file_path, output_path)
380
+
381
+ # Display results
382
+ display_results(original_slice_path, inferred_slice_path, download_file_path, gen_file_name)
383
+
384
+ except Exception as e:
385
+ st.error(f"Error during inference: {e}")
386
+
387
+ # Button to clear generated content
388
+ if st.button('Clear Generated All',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  help='Pressing this will delete the contents of the generate folder.'):
390
+ clear_output_folder('infer/generate')
391
+ clear_session()
392
+ st.rerun()
393
 
394
  # Entry point for the Streamlit application
395
  if __name__ == '__main__':
requirements.txt CHANGED
@@ -6,4 +6,5 @@ matplotlib
6
  SimpleITK
7
  torchio
8
  antspyx
9
- streamlit
 
 
6
  SimpleITK
7
  torchio
8
  antspyx
9
+ streamlit
10
+ scikit-image