Spaces:
Sleeping
Sleeping
initial commit
Browse files- app.py +58 -0
- apps/__pycache__/data_preprocessing.cpython-311.pyc +0 -0
- apps/__pycache__/model.cpython-311.pyc +0 -0
- apps/__pycache__/project_model2.cpython-311.pyc +0 -0
- apps/best_model_36.ckpt +3 -0
- apps/data_preprocessing.py +30 -0
- apps/model.py +11 -0
- apps/project_model2.py +91 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
from apps.data_preprocessing import preprocess_input
|
4 |
+
import os
|
5 |
+
import tempfile
|
6 |
+
import nibabel as nib
|
7 |
+
|
8 |
+
st.title("NIfTI File Uploader and Video Generator")
|
9 |
+
st.write("This tool accepts NIfTI `.nii` or `.nii.gz` files smaller than a specific size.")
|
10 |
+
uploaded_file = st.file_uploader("Upload a NIfTI file (.nii or .nii.gz)", type=["nii", "nii.gz"])
|
11 |
+
|
12 |
+
if uploaded_file is not None:
|
13 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
14 |
+
try:
|
15 |
+
# Save uploaded file to a temporary location
|
16 |
+
nifti_path = os.path.join(temp_dir, "uploaded_file.nii.gz")
|
17 |
+
with open(nifti_path, "wb") as f:
|
18 |
+
f.write(uploaded_file.read())
|
19 |
+
|
20 |
+
# Preprocess the uploaded NIfTI file
|
21 |
+
output_tensor, dataset = preprocess_input(nifti_path)
|
22 |
+
imgs = dataset[0]["CT"]["data"]
|
23 |
+
pred = output_tensor.argmax(0)
|
24 |
+
numpy_array = pred.numpy().astype(np.int16)
|
25 |
+
affine = np.eye(4)
|
26 |
+
|
27 |
+
# Save finalized mask
|
28 |
+
mask_path = os.path.join(temp_dir, "finalized_mask.nii")
|
29 |
+
nifti_image_mask = nib.Nifti1Image(numpy_array, affine)
|
30 |
+
nib.save(nifti_image_mask, mask_path)
|
31 |
+
|
32 |
+
# Save finalized image
|
33 |
+
nifti_array = imgs.squeeze(0).numpy()
|
34 |
+
image_path = os.path.join(temp_dir, "finalized_image.nii.gz")
|
35 |
+
nifti_image_img = nib.Nifti1Image(nifti_array, affine)
|
36 |
+
nib.save(nifti_image_img, image_path)
|
37 |
+
|
38 |
+
# Provide download buttons for the finalized files
|
39 |
+
with open(mask_path, "rb") as f:
|
40 |
+
st.download_button(
|
41 |
+
label="Download Finalized Mask",
|
42 |
+
data=f,
|
43 |
+
file_name="finalized_mask.nii",
|
44 |
+
mime="application/octet-stream"
|
45 |
+
)
|
46 |
+
|
47 |
+
with open(image_path, "rb") as f:
|
48 |
+
st.download_button(
|
49 |
+
label="Download Finalized Image",
|
50 |
+
data=f,
|
51 |
+
file_name="finalized_image.nii.gz",
|
52 |
+
mime="application/octet-stream"
|
53 |
+
)
|
54 |
+
|
55 |
+
except Exception as e:
|
56 |
+
st.error(f"An error occurred: {e}")
|
57 |
+
else:
|
58 |
+
st.info("Please upload a NIfTI file to generate the video.")
|
apps/__pycache__/data_preprocessing.cpython-311.pyc
ADDED
Binary file (2.73 kB). View file
|
|
apps/__pycache__/model.cpython-311.pyc
ADDED
Binary file (1.14 kB). View file
|
|
apps/__pycache__/project_model2.cpython-311.pyc
ADDED
Binary file (4.52 kB). View file
|
|
apps/best_model_36.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cb29eaf7de0a7d1554b01b728b6fe5acbf4312c6cb8a8f43d1c6742534538969
|
3 |
+
size 70073984
|
apps/data_preprocessing.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchio as tio
|
2 |
+
import torch
|
3 |
+
from apps.model import model
|
4 |
+
def preprocess_input(uploaded_file):
|
5 |
+
subject = tio.Subject({"CT": tio.ScalarImage(uploaded_file)})
|
6 |
+
normalize_orientation = tio.ToCanonical()
|
7 |
+
preprocess_spatial = tio.Compose([
|
8 |
+
normalize_orientation,
|
9 |
+
tio.RescaleIntensity((0, 1)),
|
10 |
+
tio.Resize((300, 300, 400))
|
11 |
+
])
|
12 |
+
transform = preprocess_spatial
|
13 |
+
dataset = tio.SubjectsDataset([subject], transform=transform)
|
14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
ckpt_path = 'apps/best_model_36.ckpt'
|
16 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
17 |
+
model.load_state_dict(checkpoint['state_dict'])
|
18 |
+
model.to(device)
|
19 |
+
model.eval()
|
20 |
+
grid_sampler = tio.inference.GridSampler(dataset[0], 96, (8, 8, 8))
|
21 |
+
aggregator = tio.inference.GridAggregator(grid_sampler)
|
22 |
+
patch_loader = tio.data.SubjectsLoader(grid_sampler, batch_size=4)
|
23 |
+
with torch.no_grad():
|
24 |
+
for patches_batch in patch_loader:
|
25 |
+
input_tensor = patches_batch['CT']["data"].to(device) # Get batch of patches
|
26 |
+
locations = patches_batch[tio.LOCATION] # Get locations of patches
|
27 |
+
pred = model(input_tensor) # Compute prediction
|
28 |
+
aggregator.add_batch(pred, locations)
|
29 |
+
output_tensor = aggregator.get_output_tensor()
|
30 |
+
return output_tensor, dataset
|
apps/model.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
from apps.project_model2 import UNet
|
3 |
+
|
4 |
+
class Segmenter(pl.LightningModule):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.model = UNet()
|
8 |
+
def forward(self, data):
|
9 |
+
pred = self.model(data)
|
10 |
+
return pred
|
11 |
+
model=Segmenter()
|
apps/project_model2.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class DoubleConv(torch.nn.Module):
|
4 |
+
"""
|
5 |
+
Helper Class which implements the intermediate Convolutions
|
6 |
+
"""
|
7 |
+
def __init__(self, in_channels, out_channels):
|
8 |
+
|
9 |
+
super().__init__()
|
10 |
+
self.step = torch.nn.Sequential(torch.nn.Conv3d(in_channels, out_channels, 3, padding=1),
|
11 |
+
torch.nn.ReLU(),
|
12 |
+
torch.nn.Conv3d(out_channels, out_channels, 3, padding=1),
|
13 |
+
torch.nn.ReLU())
|
14 |
+
|
15 |
+
def forward(self, X):
|
16 |
+
return self.step(X)
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
class UNet(torch.nn.Module):
|
22 |
+
"""
|
23 |
+
This class implements a UNet for the Segmentation
|
24 |
+
We use 3 down- and 3 UpConvolutions and two Convolutions in each step
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self):
|
28 |
+
"""Sets up the U-Net Structure
|
29 |
+
"""
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
|
33 |
+
############# DOWN SAMPLING #####################
|
34 |
+
self.layer1 = DoubleConv(1, 32)
|
35 |
+
self.layer2 = DoubleConv(32, 64)
|
36 |
+
self.layer3 = DoubleConv(64, 128)
|
37 |
+
self.layer4 = DoubleConv(128, 256)
|
38 |
+
|
39 |
+
#########################################
|
40 |
+
|
41 |
+
############## UP SAMPLING #######################
|
42 |
+
self.layer5 = DoubleConv(256 + 128, 128)
|
43 |
+
self.layer6 = DoubleConv(128+64, 64)
|
44 |
+
self.layer7 = DoubleConv(64+32, 32)
|
45 |
+
self.layer8 = torch.nn.Conv3d(32, 6, 1) # Output: 5 values -> background, upper jaw, lower jaw,upper teeth, lower teeth, artery
|
46 |
+
#########################################
|
47 |
+
|
48 |
+
self.maxpool = torch.nn.MaxPool3d(2)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
|
52 |
+
####### DownConv 1#########
|
53 |
+
x1 = self.layer1(x)
|
54 |
+
x1m = self.maxpool(x1)
|
55 |
+
###########################
|
56 |
+
|
57 |
+
####### DownConv 2#########
|
58 |
+
x2 = self.layer2(x1m)
|
59 |
+
x2m = self.maxpool(x2)
|
60 |
+
###########################
|
61 |
+
|
62 |
+
####### DownConv 3#########
|
63 |
+
x3 = self.layer3(x2m)
|
64 |
+
x3m = self.maxpool(x3)
|
65 |
+
###########################
|
66 |
+
|
67 |
+
##### Intermediate Layer ##
|
68 |
+
x4 = self.layer4(x3m)
|
69 |
+
###########################
|
70 |
+
|
71 |
+
####### UpCONV 1#########
|
72 |
+
x5 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x4) # Upsample with a factor of 2
|
73 |
+
x5 = torch.cat([x5, x3], dim=1) # Skip-Connection
|
74 |
+
x5 = self.layer5(x5)
|
75 |
+
###########################
|
76 |
+
|
77 |
+
####### UpCONV 2#########
|
78 |
+
x6 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x5)
|
79 |
+
x6 = torch.cat([x6, x2], dim=1) # Skip-Connection AKA downsampling
|
80 |
+
x6 = self.layer6(x6)
|
81 |
+
###########################
|
82 |
+
|
83 |
+
####### UpCONV 3#########
|
84 |
+
x7 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x6)
|
85 |
+
x7 = torch.cat([x7, x1], dim=1)
|
86 |
+
x7 = self.layer7(x7)
|
87 |
+
###########################
|
88 |
+
|
89 |
+
####### Predicted segmentation#########
|
90 |
+
ret = self.layer8(x7)
|
91 |
+
return ret
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.41.1
|
2 |
+
nibabel==5.3.2
|
3 |
+
torch==2.5.1
|
4 |
+
torchio==0.20.3
|
5 |
+
pytorch-lightning==2.5.0.post0
|