Kukulauren commited on
Commit
02443c1
·
verified ·
1 Parent(s): 1d8797e

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ from apps.data_preprocessing import preprocess_input
4
+ import os
5
+ import tempfile
6
+ import nibabel as nib
7
+
8
+ st.title("NIfTI File Uploader and Video Generator")
9
+ st.write("This tool accepts NIfTI `.nii` or `.nii.gz` files smaller than a specific size.")
10
+ uploaded_file = st.file_uploader("Upload a NIfTI file (.nii or .nii.gz)", type=["nii", "nii.gz"])
11
+
12
+ if uploaded_file is not None:
13
+ with tempfile.TemporaryDirectory() as temp_dir:
14
+ try:
15
+ # Save uploaded file to a temporary location
16
+ nifti_path = os.path.join(temp_dir, "uploaded_file.nii.gz")
17
+ with open(nifti_path, "wb") as f:
18
+ f.write(uploaded_file.read())
19
+
20
+ # Preprocess the uploaded NIfTI file
21
+ output_tensor, dataset = preprocess_input(nifti_path)
22
+ imgs = dataset[0]["CT"]["data"]
23
+ pred = output_tensor.argmax(0)
24
+ numpy_array = pred.numpy().astype(np.int16)
25
+ affine = np.eye(4)
26
+
27
+ # Save finalized mask
28
+ mask_path = os.path.join(temp_dir, "finalized_mask.nii")
29
+ nifti_image_mask = nib.Nifti1Image(numpy_array, affine)
30
+ nib.save(nifti_image_mask, mask_path)
31
+
32
+ # Save finalized image
33
+ nifti_array = imgs.squeeze(0).numpy()
34
+ image_path = os.path.join(temp_dir, "finalized_image.nii.gz")
35
+ nifti_image_img = nib.Nifti1Image(nifti_array, affine)
36
+ nib.save(nifti_image_img, image_path)
37
+
38
+ # Provide download buttons for the finalized files
39
+ with open(mask_path, "rb") as f:
40
+ st.download_button(
41
+ label="Download Finalized Mask",
42
+ data=f,
43
+ file_name="finalized_mask.nii",
44
+ mime="application/octet-stream"
45
+ )
46
+
47
+ with open(image_path, "rb") as f:
48
+ st.download_button(
49
+ label="Download Finalized Image",
50
+ data=f,
51
+ file_name="finalized_image.nii.gz",
52
+ mime="application/octet-stream"
53
+ )
54
+
55
+ except Exception as e:
56
+ st.error(f"An error occurred: {e}")
57
+ else:
58
+ st.info("Please upload a NIfTI file to generate the video.")
apps/__pycache__/data_preprocessing.cpython-311.pyc ADDED
Binary file (2.73 kB). View file
 
apps/__pycache__/model.cpython-311.pyc ADDED
Binary file (1.14 kB). View file
 
apps/__pycache__/project_model2.cpython-311.pyc ADDED
Binary file (4.52 kB). View file
 
apps/best_model_36.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb29eaf7de0a7d1554b01b728b6fe5acbf4312c6cb8a8f43d1c6742534538969
3
+ size 70073984
apps/data_preprocessing.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchio as tio
2
+ import torch
3
+ from apps.model import model
4
+ def preprocess_input(uploaded_file):
5
+ subject = tio.Subject({"CT": tio.ScalarImage(uploaded_file)})
6
+ normalize_orientation = tio.ToCanonical()
7
+ preprocess_spatial = tio.Compose([
8
+ normalize_orientation,
9
+ tio.RescaleIntensity((0, 1)),
10
+ tio.Resize((300, 300, 400))
11
+ ])
12
+ transform = preprocess_spatial
13
+ dataset = tio.SubjectsDataset([subject], transform=transform)
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ ckpt_path = 'apps/best_model_36.ckpt'
16
+ checkpoint = torch.load(ckpt_path, map_location=device)
17
+ model.load_state_dict(checkpoint['state_dict'])
18
+ model.to(device)
19
+ model.eval()
20
+ grid_sampler = tio.inference.GridSampler(dataset[0], 96, (8, 8, 8))
21
+ aggregator = tio.inference.GridAggregator(grid_sampler)
22
+ patch_loader = tio.data.SubjectsLoader(grid_sampler, batch_size=4)
23
+ with torch.no_grad():
24
+ for patches_batch in patch_loader:
25
+ input_tensor = patches_batch['CT']["data"].to(device) # Get batch of patches
26
+ locations = patches_batch[tio.LOCATION] # Get locations of patches
27
+ pred = model(input_tensor) # Compute prediction
28
+ aggregator.add_batch(pred, locations)
29
+ output_tensor = aggregator.get_output_tensor()
30
+ return output_tensor, dataset
apps/model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ from apps.project_model2 import UNet
3
+
4
+ class Segmenter(pl.LightningModule):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.model = UNet()
8
+ def forward(self, data):
9
+ pred = self.model(data)
10
+ return pred
11
+ model=Segmenter()
apps/project_model2.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class DoubleConv(torch.nn.Module):
4
+ """
5
+ Helper Class which implements the intermediate Convolutions
6
+ """
7
+ def __init__(self, in_channels, out_channels):
8
+
9
+ super().__init__()
10
+ self.step = torch.nn.Sequential(torch.nn.Conv3d(in_channels, out_channels, 3, padding=1),
11
+ torch.nn.ReLU(),
12
+ torch.nn.Conv3d(out_channels, out_channels, 3, padding=1),
13
+ torch.nn.ReLU())
14
+
15
+ def forward(self, X):
16
+ return self.step(X)
17
+
18
+
19
+
20
+
21
+ class UNet(torch.nn.Module):
22
+ """
23
+ This class implements a UNet for the Segmentation
24
+ We use 3 down- and 3 UpConvolutions and two Convolutions in each step
25
+ """
26
+
27
+ def __init__(self):
28
+ """Sets up the U-Net Structure
29
+ """
30
+ super().__init__()
31
+
32
+
33
+ ############# DOWN SAMPLING #####################
34
+ self.layer1 = DoubleConv(1, 32)
35
+ self.layer2 = DoubleConv(32, 64)
36
+ self.layer3 = DoubleConv(64, 128)
37
+ self.layer4 = DoubleConv(128, 256)
38
+
39
+ #########################################
40
+
41
+ ############## UP SAMPLING #######################
42
+ self.layer5 = DoubleConv(256 + 128, 128)
43
+ self.layer6 = DoubleConv(128+64, 64)
44
+ self.layer7 = DoubleConv(64+32, 32)
45
+ self.layer8 = torch.nn.Conv3d(32, 6, 1) # Output: 5 values -> background, upper jaw, lower jaw,upper teeth, lower teeth, artery
46
+ #########################################
47
+
48
+ self.maxpool = torch.nn.MaxPool3d(2)
49
+
50
+ def forward(self, x):
51
+
52
+ ####### DownConv 1#########
53
+ x1 = self.layer1(x)
54
+ x1m = self.maxpool(x1)
55
+ ###########################
56
+
57
+ ####### DownConv 2#########
58
+ x2 = self.layer2(x1m)
59
+ x2m = self.maxpool(x2)
60
+ ###########################
61
+
62
+ ####### DownConv 3#########
63
+ x3 = self.layer3(x2m)
64
+ x3m = self.maxpool(x3)
65
+ ###########################
66
+
67
+ ##### Intermediate Layer ##
68
+ x4 = self.layer4(x3m)
69
+ ###########################
70
+
71
+ ####### UpCONV 1#########
72
+ x5 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x4) # Upsample with a factor of 2
73
+ x5 = torch.cat([x5, x3], dim=1) # Skip-Connection
74
+ x5 = self.layer5(x5)
75
+ ###########################
76
+
77
+ ####### UpCONV 2#########
78
+ x6 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x5)
79
+ x6 = torch.cat([x6, x2], dim=1) # Skip-Connection AKA downsampling
80
+ x6 = self.layer6(x6)
81
+ ###########################
82
+
83
+ ####### UpCONV 3#########
84
+ x7 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x6)
85
+ x7 = torch.cat([x7, x1], dim=1)
86
+ x7 = self.layer7(x7)
87
+ ###########################
88
+
89
+ ####### Predicted segmentation#########
90
+ ret = self.layer8(x7)
91
+ return ret
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit==1.41.1
2
+ nibabel==5.3.2
3
+ torch==2.5.1
4
+ torchio==0.20.3
5
+ pytorch-lightning==2.5.0.post0