Spaces:
Runtime error
Runtime error
from transformers import SegformerForSemanticSegmentation | |
from transformers import SegformerImageProcessor | |
from PIL import Image | |
import gradio as gr | |
import numpy as np | |
import random | |
import cv2 | |
import torch | |
from imutils import perspective | |
def midpoint(ptA, ptB): | |
return ((ptA[0] + ptB[0]) * 0.5, (ptA[1] + ptB[1]) * 0.5) | |
# Load in image, convert to gray scale, and Otsu's threshold | |
kernel1 =( np.ones((5,5), dtype=np.float32)) | |
blur_radius=0.5 | |
kernel_sharpening = np.array([[-1,-1,-1], | |
[-1,9,-1], | |
[-1,-1,-1]])*(1/9) | |
def cca_analysis(image,predicted_mask): | |
image2=np.asarray(image) | |
print(image.shape) | |
image = cv2.resize(predicted_mask, (image2.shape[1],image2.shape[1]), interpolation = cv2.INTER_AREA) | |
image=cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel1,iterations=1 ) | |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] | |
labels=cv2.connectedComponents(thresh,connectivity=8)[1] | |
a=np.unique(labels) | |
count2=0 | |
for label in a: | |
if label == 0: | |
continue | |
# Create a mask | |
mask = np.zeros(thresh.shape, dtype="uint8") | |
mask[labels == label] = 255 | |
# Find contours and determine contour area | |
cnts,hieararch = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
cnts = cnts[0] | |
c_area = cv2.contourArea(cnts) | |
# threshhold for tooth count | |
if c_area>100: | |
count2+=1 | |
rect = cv2.minAreaRect(cnts) | |
box = cv2.boxPoints(rect) | |
box = np.array(box, dtype="int") | |
box = perspective.order_points(box) | |
color1 = (list(np.random.choice(range(150), size=3))) | |
color =[int(color1[0]), int(color1[1]), int(color1[2])] | |
cv2.drawContours(image2,[box.astype("int")],0,color,2) | |
(tl,tr,br,bl)=box | |
(tltrX,tltrY)=midpoint(tl,tr) | |
(blbrX,blbrY)=midpoint(bl,br) | |
# compute the midpoint between the top-left and top-right points, | |
# followed by the midpoint between the top-righ and bottom-right | |
(tlblX,tlblY)=midpoint(tl,bl) | |
(trbrX,trbrY)=midpoint(tr,br) | |
# draw the midpoints on the image | |
cv2.circle(image2, (int(tltrX), int(tltrY)), 5, (255, 0, 0), -1) | |
cv2.circle(image2, (int(blbrX), int(blbrY)), 5, (255, 0, 0), -1) | |
cv2.circle(image2, (int(tlblX), int(tlblY)), 5, (255, 0, 0), -1) | |
cv2.circle(image2, (int(trbrX), int(trbrY)), 5, (255, 0, 0), -1) | |
cv2.line(image2, (int(tltrX), int(tltrY)), (int(blbrX), int(blbrY)),color, 2) | |
cv2.line(image2, (int(tlblX), int(tlblY)), (int(trbrX), int(trbrY)),color, 2) | |
return image2 | |
def to_rgb(img): | |
result_new=np.zeros((img.shape[1],img.shape[0],3)) | |
result_new[:,:,0]=img | |
result_new[:,:,1]=img | |
result_new[:,:,2]=img | |
result_new=np.uint8(result_new*255) | |
return result_new | |
image_list = [ | |
"data/1.png", | |
"data/2.png", | |
"data/3.png", | |
"data/4.png", | |
] | |
model_path = ['deprem-ml/deprem_satellite_semantic_whu'] | |
def visualize_instance_seg_mask(mask): | |
# Initialize image with zeros with the image resolution | |
# of the segmentation mask and 3 channels | |
image = np.zeros((mask.shape[0], mask.shape[1], 3)) | |
# Create labels | |
labels = np.unique(mask) | |
label2color = { | |
label: ( | |
random.randint(0, 255), | |
random.randint(0, 255), | |
random.randint(0, 255), | |
) | |
for label in labels | |
} | |
for height in range(image.shape[0]): | |
for width in range(image.shape[1]): | |
image[height, width, :] = label2color[mask[height, width]] | |
image = image / 255 | |
return image | |
def Segformer_Segmentation(image_path, model_id,postpro): | |
output_save = "output.png" | |
test_image = cv2.imread(image_path) | |
model = SegformerForSemanticSegmentation.from_pretrained(model_id) | |
proccessor = SegformerImageProcessor(model_id) | |
inputs = proccessor(images=test_image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
result = proccessor.post_process_semantic_segmentation(outputs)[0] | |
result = np.array(result) | |
if postpro=="Connected Components Labelling": | |
result=to_rgb(result) | |
result=cca_analysis(test_image,result) | |
else: | |
result = visualize_instance_seg_mask(result) | |
result=result*255 | |
cv2.imwrite(output_save, result) | |
return image_path, output_save | |
examples = [[image_list[0], "deprem-ml/deprem_satellite_semantic_whu"], | |
[image_list[1], "deprem-ml/deprem_satellite_semantic_whu"], | |
[image_list[2], "deprem-ml/deprem_satellite_semantic_whu"], | |
[image_list[3], "deprem-ml/deprem_satellite_semantic_whu"]] | |
title = "Deprem ML - Segformer Semantic Segmentation" | |
app = gr.Blocks() | |
with app: | |
gr.HTML("<h1 style='text-align: center'>{}</h1>".format(title)) | |
with gr.Row(): | |
with gr.Column(): | |
input_video = gr.Image(type='filepath') | |
model_id = gr.Dropdown(value=model_path[0], choices=model_path,label="Model Name") | |
cca = gr.Dropdown(value="Connected Components Labelling", choices=["Connected Components Labelling","No Post Process"],label="Post Process") | |
input_video_button = gr.Button(value="Predict") | |
with gr.Column(): | |
output_orijinal_image = gr.Image(type='filepath') | |
with gr.Column(): | |
output_mask_image = gr.Image(type='filepath') | |
gr.Examples(examples, inputs=[input_video, model_id,cca], outputs=[output_orijinal_image, output_mask_image], fn=Segformer_Segmentation, cache_examples=True) | |
input_video_button.click(Segformer_Segmentation, inputs=[input_video, model_id,cca], outputs=[output_orijinal_image, output_mask_image]) | |
app.launch(debug=True) | |