digimap-web-app / app.py
joshangngoching's picture
Update app.py
7e7266c verified
raw
history blame
1.36 kB
import streamlit as st
from PIL import Image
from transformers import pipeline
import numpy as np
import cv2
import matplotlib.cm as cm
semantic_segmentation = pipeline("image-segmentation", "nvidia/segformer-b1-finetuned-cityscapes-1024-1024")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"])
def draw_masks_fromDict(image, results):
masked_image = image.copy()
colormap = cm.get_cmap('nipy_spectral')
for i, result in enumerate(results):
mask = np.array(result['mask'])
mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
color = colormap(i / len(results))[:3]
color = tuple(int(c * 255) for c in color)
masked_image = np.where(mask, color, masked_image)
masked_image = masked_image.astype(np.uint8)
return cv2.addWeighted(image, 0.3, masked_image, 0.7, 0)
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image.', use_column_width=True)
st.write("")
segmentation_results = semantic_segmentation(image)
st.json(segmentation_results)
image_with_masks = draw_masks_fromDict(np.array(image), segmentation_results)
image_with_masks_pil = Image.fromarray(image_with_masks, 'RGB')
st.image(image_with_masks_pil, caption='Segmented Image', use_column_width=True)