rajsecrets0 commited on
Commit
1f2b3fa
·
verified ·
1 Parent(s): c9fa382

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ from torchvision import transforms
5
+ import numpy as np
6
+ import os
7
+ from osgeo import gdal
8
+
9
+ # Load the pretrained model
10
+ @st.cache(allow_output_mutation=True)
11
+ def load_model():
12
+ model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
13
+ pretrained=True, progress=True)
14
+ model.eval()
15
+ return model
16
+
17
+ # Function to load large TIFF images
18
+ def load_tiff_image(tiff_path):
19
+ try:
20
+ dataset = gdal.Open(tiff_path)
21
+ if dataset is None:
22
+ st.error("Failed to load the TIFF image. Please check the file format.")
23
+ return None
24
+ band = dataset.GetRasterBand(1) # Assuming grayscale or single band
25
+ image = band.ReadAsArray()
26
+ return image
27
+ except Exception as e:
28
+ st.error(f"Error loading image: {e}")
29
+ return None
30
+
31
+ # Preprocess image
32
+ def preprocess_image(image):
33
+ transform = transforms.Compose([
34
+ transforms.ToTensor(),
35
+ transforms.Resize((256, 256)), # Resize image for model input
36
+ transforms.Normalize(mean=[0.485], std=[0.229]) # Normalize
37
+ ])
38
+ image_tensor = transform(image).unsqueeze(0) # Add batch dimension
39
+ return image_tensor
40
+
41
+ # Post-process prediction to display
42
+ def postprocess_prediction(pred):
43
+ pred = torch.sigmoid(pred)
44
+ pred = pred.squeeze().detach().numpy() # Remove batch dimension
45
+ pred = (pred > 0.5).astype(np.uint8) # Binary mask thresholding
46
+ return pred
47
+
48
+ # Streamlit app
49
+ st.title("TIFF Image Upload and Model Prediction")
50
+
51
+ # Upload image
52
+ uploaded_file = st.file_uploader("Upload a large TIFF image (up to 5GB)", type=["tiff"])
53
+
54
+ if uploaded_file is not None:
55
+ with open("temp_image.tiff", "wb") as f:
56
+ f.write(uploaded_file.getbuffer())
57
+
58
+ tiff_image = load_tiff_image("temp_image.tiff")
59
+
60
+ if tiff_image is not None:
61
+ st.write("Original Image")
62
+ st.image(tiff_image, caption="Uploaded Image", use_column_width=True)
63
+
64
+ model = load_model()
65
+
66
+ image = Image.fromarray(tiff_image)
67
+ image_tensor = preprocess_image(image)
68
+
69
+ with torch.no_grad():
70
+ prediction = model(image_tensor)
71
+
72
+ pred_image = postprocess_prediction(prediction)
73
+
74
+ st.write("Model Prediction")
75
+ st.image(pred_image, caption="Predicted Image", use_column_width=True)
76
+
77
+ os.remove("temp_image.tiff")