File size: 2,970 Bytes
8a33342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import streamlit as st
from PIL import Image
import cv2
import numpy as np
import time
import models
import torch

from torchvision import transforms
from torchvision import transforms

def load_model(path, model):
    model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
    return model

def predict(img):
    model = models.unet(3, 1)
    model = load_model('model.pth',model)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    img = cv2.resize(img, (512, 512))
    convert_tensor = transforms.ToTensor()
    img =  convert_tensor(img).float()
    img = normalize(img)
    img = torch.unsqueeze(img, dim=0)

    output = model(img)
    result = torch.sigmoid(output)

    threshold = 0.5
    result = (result >= threshold).float()
    prediction = result[0].cpu()  # Move tensor to CPU if it's on GPU
    # Convert tensor to a numpy array
    prediction_array = prediction.numpy()
    # Rescale values to the range [0, 255]
    prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0)
    cv2.imwrite("test.png",prediction_array)
    return prediction_array

def predicjt(img):
    model1 = models.SAunet(3, 1)
    model1 = load_model('saunet.pth',model1)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    img = cv2.resize(img, (512, 512))
    convert_tensor = transforms.ToTensor()
    img =  convert_tensor(img).float()
    img = normalize(img)
    img = torch.unsqueeze(img, dim=0)

    output = model1(img)
    result = torch.sigmoid(output)

    threshold = 0.5
    result = (result >= threshold).float()
    prediction = result[0].cpu()  # Move tensor to CPU if it's on GPU
    # Convert tensor to a numpy array
    prediction_array = prediction.numpy()
    # Rescale values to the range [0, 255]
    prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0)
    cv2.imwrite("test1.png",prediction_array)
    return prediction_array
def main():
    st.title("Image Segmentation Demo")

    # Predefined list of image names
    image_names = ["01_test.tif", "02_test.tif", "03_test.tif"]

    # Create a selection box for the images
    selected_image_name = st.selectbox("Select an Image", image_names)

    # Load the selected image
    selected_image = cv2.imread(selected_image_name)

    # Display the selected image
    st.image(selected_image, channels="RGB")

    # Create a button for segmentation
    if st.button("Segment"):
        # Perform segmentation on the selected image
        segmented_image = predict(selected_image)
        segmented_image1 = predicjt(selected_image)


        # Display the segmented image
        st.image(segmented_image, channels="RGB",caption='U-Net segmentation')
        st.image(segmented_image1, channels="RGB",caption='Spatial Attention U-Net segmentation ')

# Function to perform segmentation on the selected image


if __name__ == "__main__":
    main()