Spaces:
Runtime error
Runtime error
try: | |
import detectron2 | |
except ImportError: | |
import os | |
os.system('pip install git+https://github.com/facebookresearch/detectron2.git') | |
import streamlit as st | |
import cv2 | |
from PIL import Image | |
import numpy as np | |
import torch | |
from detectron2.config import get_cfg | |
from detectron2.engine import DefaultPredictor | |
from detectron2.utils.visualizer import Visualizer | |
from detectron2.data import MetadataCatalog | |
cfg = get_cfg() | |
cfg.merge_from_file('./faster_rcnn_R_50_FPN_3x.yaml') | |
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 13 | |
cfg.MODEL.WEIGHTS = './fewshot_all_30.pth' | |
my_metadata = MetadataCatalog.get("fewshot_all_30") | |
my_metadata.thing_classes = ["clean water network", "communication network", "flood", "garbage", "gutter cover", | |
"illegal parking", "layout and building", "park", "road", "sidewalk", "tree", "vandalism", "waterway"] | |
if not torch.cuda.is_available(): | |
cfg.MODEL.DEVICE = "cpu" | |
def inference(image): | |
im = np.array(image) | |
#min_thresh = st.slider("Minimum score", 0.0, 1.0, 0.5) | |
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = min_thresh | |
predictor = DefaultPredictor(cfg) | |
outputs = predictor(im) | |
v = Visualizer(im, | |
metadata=my_metadata, | |
scale=0.5 | |
) | |
instances = outputs["instances"].to("cpu") | |
# Customize the font_size | |
v._default_font_size = 30 | |
out = v.draw_instance_predictions(instances) | |
return out.get_image() | |
st.set_page_config(layout="wide", initial_sidebar_state="collapsed") # Set app layout to wide and hide sidebar by default on smaller screens | |
st.title('Urban Problem Detection Model Demo') | |
st.markdown('This demo introduces an interactive playground for our trained Detectron2 model for urban problem/report detection.') | |
st.markdown('The model was trained on manually annotated image from image report in JAKI (Jakarta Kini).') | |
st.markdown('The object categories it can detect are:') | |
cl1, cl2, cl3 = st.columns(3) | |
cl1.markdown( | |
""" | |
- Clean Water Network | |
- Communication Network | |
- Flood | |
- Garbage | |
- Gutter Cover | |
""" | |
) | |
cl2.markdown( | |
""" | |
- Illegal Parking | |
- Layout and building | |
- Park | |
- Road | |
- Sidewalk | |
""" | |
) | |
cl3.markdown( | |
""" | |
- Tree | |
- Vandalism | |
- Waterway | |
""" | |
) | |
uploaded_file = st.file_uploader('Upload Image', type=['jpg', 'jpeg', 'png']) | |
min_thresh = st.slider("Minimum score", 0.0, 1.0, 0.5) | |
col1, col2 = st.columns((1, 1)) | |
with col1: | |
#uploaded_file = st.file_uploader('Upload Image', type=['jpg', 'jpeg', 'png']) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
with col2: | |
if uploaded_file is not None: | |
result = inference(image) | |
st.image(result, caption='Output', use_column_width=True) |