JinHyeong99
1
f9b645e
raw
history blame
1.26 kB
import gradio as gr
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from PIL import Image
import numpy as np
# 모델과 특징 추출기 불러오기
feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b5-finetuned-cityscapes-1024-1024")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b5-finetuned-cityscapes-1024-1024")
def segment_image(image):
# 이미지를 처리하고 모델에 전달하기
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# 결과 처리 및 이미지로 변환
result = logits.argmax(dim=1)[0]
result = result.cpu().detach().numpy()
result_image = Image.fromarray(result.astype(np.uint8), mode="P")
# 결과 이미지 반환
return result_image
# Gradio 인터페이스 정의
iface = gr.Interface(
fn=segment_image,
inputs=gr.inputs.Image(type="pil"),
examples = ['image1.jpg', 'image2.jpg', 'image3.jpg'],
outputs=['plot'],
title="SegFormer Image Segmentation",
description="Upload an image to segment it using the SegFormer model trained on Cityscapes dataset."
)
# 인터페이스 실행
iface.launch()