JinHyeong99
1
09dbfd1
raw
history blame
2.32 kB
import gradio as gr
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from PIL import Image
import numpy as np
# ๋ชจ๋ธ๊ณผ ํŠน์ง• ์ถ”์ถœ๊ธฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-512-1024")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-512-1024")
colors = np.array([
[255, 0, 0], # ๋นจ๊ฐ•
[255, 228, 0], # ๋…ธ๋ž‘
[171, 242, 0], # ์—ฐ๋‘
[0, 216, 255], # ํ•˜๋Š˜
[0, 0, 255], # ํŒŒ๋ž‘
[255, 0, 221], # ํ•‘ํฌ
[116, 116, 116], # ํšŒ์ƒ‰
[95, 0, 255], # ๋ณด๋ผ
[255, 94, 0], # ์ฃผํ™ฉ
[71, 200, 62], # ์ดˆ๋ก
[153, 0, 76], # ๋งˆ์  ํƒ€
[67, 116, 217], # ์• ๋งคํ•œํ•˜๋Š˜+ํŒŒ๋ž‘
[153, 112, 0], # ๊ฒจ์ž
[87, 129, 0], # ๋…น์ƒ‰
[255, 169, 169], # ๋ถ„ํ™๋ถ„ํ™
[35, 30, 183], # ์–ด๋‘์šด ํŒŒ๋ž‘
[225, 186, 133], # ์‚ด์ƒ‰
[206, 251, 201], # ์—ฐํ•œ์ดˆ๋ก
[165, 102, 255] # ์• ๋งคํ•œ๋ณด๋ผ
], dtype=np.uint8)
def segment_image(image):
# ์ด๋ฏธ์ง€๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ณ  ๋ชจ๋ธ์— ์ „๋‹ฌํ•˜๊ธฐ
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ ๋ฐ NumPy ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜
result = logits.argmax(dim=1)[0]
result = result.cpu().detach().numpy()
# ์ƒ‰์ƒ ํŒ”๋ ˆํŠธ ์ ์šฉ
result_color = colors[result]
# NumPy ๋ฐฐ์—ด์„ PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
result_image = Image.fromarray(result_color.astype(np.uint8))
# ์›๋ณธ ์ด๋ฏธ์ง€์™€ ์ถ”๋ก  ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€์˜ ํฌ๊ธฐ ์ผ์น˜์‹œํ‚ค๊ธฐ
result_image = result_image.resize(image.size, Image.NEAREST)
# ์›๋ณธ ์ด๋ฏธ์ง€์™€ ์ถ”๋ก  ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ๊ฒฐํ•ฉ
combined_image = Image.blend(image.convert("RGBA"), result_image.convert("RGBA"), alpha=0.5)
return combined_image
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜
iface = gr.Interface(
fn=segment_image,
inputs=gr.inputs.Image(type='pil'),
examples = ['image1.jpg', 'image2.jpg', 'image3.jpg'],
outputs= 'image',
title="SegFormer Image Segmentation",
description="Upload an image to segment it using the SegFormer model trained on Cityscapes dataset."
)
# ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰
iface.launch()