SerdarHelli's picture
Upload 6 files (#1)
6f4e6ed
raw
history blame
2.86 kB
from transformers import SegformerForSemanticSegmentation
from transformers import SegformerImageProcessor
from PIL import Image
import gradio as gr
import numpy as np
import random
import cv2
import torch
image_list = [
"data/1.png",
"data/2.png",
"data/3.png",
"data/4.png",
]
model_path = ['deprem-ml/deprem_satellite_semantic_whu']
def visualize_instance_seg_mask(mask):
# Initialize image with zeros with the image resolution
# of the segmentation mask and 3 channels
image = np.zeros((mask.shape[0], mask.shape[1], 3))
# Create labels
labels = np.unique(mask)
label2color = {
label: (
random.randint(0, 255),
random.randint(0, 255),
random.randint(0, 255),
)
for label in labels
}
for height in range(image.shape[0]):
for width in range(image.shape[1]):
image[height, width, :] = label2color[mask[height, width]]
image = image / 255
return image
def Segformer_Segmentation(image_path, model_id):
output_save = "output.png"
test_image = Image.open(image_path)
model = SegformerForSemanticSegmentation.from_pretrained(model_id)
proccessor = SegformerImageProcessor(model_id)
inputs = proccessor(images=test_image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
result = proccessor.post_process_semantic_segmentation(outputs)[0]
result = np.array(result)
result = visualize_instance_seg_mask(result)
cv2.imwrite(output_save, result*255)
return image_path, output_save
examples = [[image_list[0], "deprem-ml/deprem_satellite_semantic_whu"],
[image_list[1], "deprem-ml/deprem_satellite_semantic_whu"],
[image_list[2], "deprem-ml/deprem_satellite_semantic_whu"],
[image_list[3], "deprem-ml/deprem_satellite_semantic_whu"]]
title = "Deprem ML - Segformer Semantic Segmentation"
app = gr.Blocks()
with app:
gr.HTML("<h1 style='text-align: center'>{}</h1>".format(title))
with gr.Row():
with gr.Column():
gr.Markdown("Video")
input_video = gr.Image(type='filepath')
model_id = gr.Dropdown(value=model_path[0], choices=model_path)
input_video_button = gr.Button(value="Predict")
with gr.Column():
output_orijinal_image = gr.Image(type='filepath')
with gr.Column():
output_mask_image = gr.Image(type='filepath')
gr.Examples(examples, inputs=[input_video, model_id], outputs=[output_orijinal_image, output_mask_image], fn=Segformer_Segmentation, cache_examples=True)
input_video_button.click(Segformer_Segmentation, inputs=[input_video, model_id], outputs=[output_orijinal_image, output_mask_image])
app.launch()