File size: 2,858 Bytes
6f4e6ed
 
 
3a25e28
6f4e6ed
 
 
 
3a25e28
 
6f4e6ed
 
 
 
 
 
3a25e28
6f4e6ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from transformers import SegformerForSemanticSegmentation
from transformers import SegformerImageProcessor
from PIL import Image
import gradio as gr  
import numpy as np
import random 
import cv2
import torch


image_list = [
    "data/1.png",
    "data/2.png",
    "data/3.png",
    "data/4.png",
]

model_path = ['deprem-ml/deprem_satellite_semantic_whu']

def visualize_instance_seg_mask(mask):
    # Initialize image with zeros with the image resolution
    # of the segmentation mask and 3 channels
    image = np.zeros((mask.shape[0], mask.shape[1], 3))

    # Create labels
    labels = np.unique(mask)
    label2color = {
        label: (
            random.randint(0, 255),
            random.randint(0, 255),
            random.randint(0, 255),
        )
        for label in labels
    }

    for height in range(image.shape[0]):
        for width in range(image.shape[1]):
            image[height, width, :] = label2color[mask[height, width]]

    image = image / 255
    return image


def Segformer_Segmentation(image_path, model_id): 
    output_save = "output.png"
    
    test_image = Image.open(image_path)

    model = SegformerForSemanticSegmentation.from_pretrained(model_id)
    proccessor = SegformerImageProcessor(model_id)
    
    inputs = proccessor(images=test_image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        
    result = proccessor.post_process_semantic_segmentation(outputs)[0]
    result = np.array(result)
    result = visualize_instance_seg_mask(result)
    cv2.imwrite(output_save, result*255)

    return image_path, output_save

examples = [[image_list[0], "deprem-ml/deprem_satellite_semantic_whu"],
            [image_list[1], "deprem-ml/deprem_satellite_semantic_whu"],
            [image_list[2], "deprem-ml/deprem_satellite_semantic_whu"],
            [image_list[3], "deprem-ml/deprem_satellite_semantic_whu"]]

title = "Deprem ML - Segformer Semantic Segmentation"

app = gr.Blocks()
with app:
    gr.HTML("<h1 style='text-align: center'>{}</h1>".format(title))    
    with gr.Row():
        with gr.Column():
            gr.Markdown("Video")
            input_video = gr.Image(type='filepath')
            model_id =  gr.Dropdown(value=model_path[0], choices=model_path)
            input_video_button = gr.Button(value="Predict")    
                    
        with gr.Column():
            output_orijinal_image = gr.Image(type='filepath')
            
        with gr.Column():
            output_mask_image =  gr.Image(type='filepath')

        
    gr.Examples(examples, inputs=[input_video, model_id], outputs=[output_orijinal_image, output_mask_image], fn=Segformer_Segmentation, cache_examples=True)
    input_video_button.click(Segformer_Segmentation, inputs=[input_video, model_id], outputs=[output_orijinal_image, output_mask_image])

app.launch()