tommonopoli's picture
removed assert that input and output extensions must match
aa485da
raw
history blame
4.8 kB
import gradio as gr
import torch
import numpy as np
from transformers import SegformerForSemanticSegmentation
from loguru import logger
import rasterio as rio
from lib.utils import segment, compute_vndvi, compute_vdi
import os
import os.path as osp
# set temp dir for gradio
#os.environ["GRADIO_TEMP_DIR"] = "temp"
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f'Using device: {device}')
# load model architecture
logger.info('Loading model architecture...')
model = SegformerForSemanticSegmentation.from_pretrained(
'nvidia/mit-b5',
num_labels = 1, # binary segmentation
num_channels = 3, # RGB
id2label = {1: 'vine'},
label2id = {'vine': 1},
)
# load model weights
logger.info('Loading model weights...')
device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
model.load_state_dict(torch.load(f"model/growseg.pt", map_location=device))
model = model.to(device)
model.eval()
def main(input, output, patch_size=512, stride=256, scaling_factor=1., rotate=False, batch_size=16, verbose=False, return_vndvi=True, return_vdi=True, window_size=360):
assert osp.splitext(output)[1].lower() in ['.tif', '.tiff', '.png', '.jpg', '.jpeg'], 'Output file format not supported'
# read image
logger.info(f'Reading image {input}...')
with rio.open(input, 'r') as src:
image = src.read()
profile = src.profile
shape = src.shape
if profile['driver'] == 'GTiff':
gsd = profile['transform'][0] # ground sampling distance (NB: valid only if image is a GeoTIFF)
else:
gsd = None
# Growseg works best on orthoimages with gsd in [1, 1.7] cm/px. You may want to
# specify a scaling factor different from 1 if your image has a different gsd.
# E.g.: SCALING_FACTOR = gsd / 1.5
# segment
logger.info('Segmenting image...')
mask = segment(
image,
model,
patch_size=patch_size,
stride=stride,
scaling_factor=scaling_factor,
rotate=rotate,
device=device,
batch_size=batch_size,
verbose=verbose
) # mask is a HxW float32 array in [0, 1]
# apply threshold on confidence scores
alpha = (mask == -1)
mask = (mask > 0.5)
# convert to uint8
mask = (mask * 255).astype(np.uint8)
# set nodata pixels to 1
mask[alpha] = 1
# if requested, compute additional
if return_vndvi:
logger.info('Computing VNDVI...')
vndvi_rows_fig, vndvi_interrows_fig = compute_vndvi(image, mask, window_size)
else:
vndvi_rows_fig = vndvi_interrows_fig = None
if return_vdi:
logger.info('Computing VDI...')
vdi_fig = compute_vdi(image, mask, window_size)
else:
vdi_fig = None
# save mask
"""
logger.info('Saving mask...')
profile.update(
dtype=rio.uint8,
count=1,
compress='lzw',
)
with rio.open(output, 'w', **profile) as dst:
dst.write(mask[None, ...])
logger.info(f'Mask saved to {output}')
"""
# return mask and eventually additional outputs
return image.transpose(1,2,0), mask, vndvi_rows_fig, vndvi_interrows_fig, vdi_fig
demo = gr.Interface(
fn=main,
inputs=[
gr.File(type='filepath', file_types=['image','.tif','.tiff'], label='Input image path'),
gr.Textbox(value='mask.tif', label='Output mask path (TODO)'),
gr.Slider(minimum=128, maximum=512, value=512, step=128, label='Patch size'),
gr.Slider(minimum=0, maximum=256, value=256, step=64, label='Stride'),
gr.Slider(minimum=0.1, maximum=10, value=1, step=0.05, label='Scaling factor'),
gr.Checkbox(value=False, label='Rotate patches'),
gr.Slider(minimum=4, maximum=128, value=16, step=4, label='Batch size'),
gr.Checkbox(value=False, label='Verbose'),
gr.Checkbox(value=True, label='Return VNDVI map'),
gr.Checkbox(value=True, label='Return VDI map'),
gr.Slider(minimum=10, maximum=600, value=360, step=1, label='Moving window size for computing vNDVI/VDI (suggestion: inversely proportional to the GSD [px/m])'),
],
outputs=[
gr.Image(type='numpy', format='png', label='Input image'),
gr.Image(type='numpy', format='png', label='Predicted mask'),
gr.Plot(format='png', label='VNDVI rows (dilated for visibility)'),
gr.Plot(format='png', label='VNDVI interrows'),
gr.Plot(format='png', label='VDI'),
], # NB: if one of the outputs is None, it will not be displayed in the interface (https://github.com/gradio-app/gradio/issues/500#issuecomment-1046877766)
title='Growseg',
description='Segment vineyards in orthoimages',
delete_cache=[3600,3600],
)
demo.launch(share=True)