File size: 2,970 Bytes
8a33342 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import streamlit as st
from PIL import Image
import cv2
import numpy as np
import time
import models
import torch
from torchvision import transforms
from torchvision import transforms
def load_model(path, model):
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
return model
def predict(img):
model = models.unet(3, 1)
model = load_model('model.pth',model)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
img = cv2.resize(img, (512, 512))
convert_tensor = transforms.ToTensor()
img = convert_tensor(img).float()
img = normalize(img)
img = torch.unsqueeze(img, dim=0)
output = model(img)
result = torch.sigmoid(output)
threshold = 0.5
result = (result >= threshold).float()
prediction = result[0].cpu() # Move tensor to CPU if it's on GPU
# Convert tensor to a numpy array
prediction_array = prediction.numpy()
# Rescale values to the range [0, 255]
prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0)
cv2.imwrite("test.png",prediction_array)
return prediction_array
def predicjt(img):
model1 = models.SAunet(3, 1)
model1 = load_model('saunet.pth',model1)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
img = cv2.resize(img, (512, 512))
convert_tensor = transforms.ToTensor()
img = convert_tensor(img).float()
img = normalize(img)
img = torch.unsqueeze(img, dim=0)
output = model1(img)
result = torch.sigmoid(output)
threshold = 0.5
result = (result >= threshold).float()
prediction = result[0].cpu() # Move tensor to CPU if it's on GPU
# Convert tensor to a numpy array
prediction_array = prediction.numpy()
# Rescale values to the range [0, 255]
prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0)
cv2.imwrite("test1.png",prediction_array)
return prediction_array
def main():
st.title("Image Segmentation Demo")
# Predefined list of image names
image_names = ["01_test.tif", "02_test.tif", "03_test.tif"]
# Create a selection box for the images
selected_image_name = st.selectbox("Select an Image", image_names)
# Load the selected image
selected_image = cv2.imread(selected_image_name)
# Display the selected image
st.image(selected_image, channels="RGB")
# Create a button for segmentation
if st.button("Segment"):
# Perform segmentation on the selected image
segmented_image = predict(selected_image)
segmented_image1 = predicjt(selected_image)
# Display the segmented image
st.image(segmented_image, channels="RGB",caption='U-Net segmentation')
st.image(segmented_image1, channels="RGB",caption='Spatial Attention U-Net segmentation ')
# Function to perform segmentation on the selected image
if __name__ == "__main__":
main()
|