Spaces:
Running
Running
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import SegformerForSemanticSegmentation | |
from loguru import logger | |
import rasterio as rio | |
from lib.utils import segment, compute_vndvi, compute_vdi | |
import os | |
import os.path as osp | |
# set temp dir for gradio | |
#os.environ["GRADIO_TEMP_DIR"] = "temp" | |
# set device | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
logger.info(f'Using device: {device}') | |
# load model architecture | |
logger.info('Loading model architecture...') | |
model = SegformerForSemanticSegmentation.from_pretrained( | |
'nvidia/mit-b5', | |
num_labels = 1, # binary segmentation | |
num_channels = 3, # RGB | |
id2label = {1: 'vine'}, | |
label2id = {'vine': 1}, | |
) | |
# load model weights | |
logger.info('Loading model weights...') | |
device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu') | |
model.load_state_dict(torch.load(f"model/growseg.pt", map_location=device)) | |
model = model.to(device) | |
model.eval() | |
def main(input, output, patch_size=512, stride=256, scaling_factor=1., rotate=False, batch_size=16, verbose=False, return_vndvi=True, return_vdi=True, window_size=360): | |
assert osp.splitext(output)[1].lower() in ['.tif', '.tiff', '.png', '.jpg', '.jpeg'], 'Output file format not supported' | |
# read image | |
logger.info(f'Reading image {input}...') | |
with rio.open(input, 'r') as src: | |
image = src.read() | |
profile = src.profile | |
shape = src.shape | |
if profile['driver'] == 'GTiff': | |
gsd = profile['transform'][0] # ground sampling distance (NB: valid only if image is a GeoTIFF) | |
else: | |
gsd = None | |
# Growseg works best on orthoimages with gsd in [1, 1.7] cm/px. You may want to | |
# specify a scaling factor different from 1 if your image has a different gsd. | |
# E.g.: SCALING_FACTOR = gsd / 1.5 | |
# segment | |
logger.info('Segmenting image...') | |
mask = segment( | |
image, | |
model, | |
patch_size=patch_size, | |
stride=stride, | |
scaling_factor=scaling_factor, | |
rotate=rotate, | |
device=device, | |
batch_size=batch_size, | |
verbose=verbose | |
) # mask is a HxW float32 array in [0, 1] | |
# apply threshold on confidence scores | |
alpha = (mask == -1) | |
mask = (mask > 0.5) | |
# convert to uint8 | |
mask = (mask * 255).astype(np.uint8) | |
# set nodata pixels to 1 | |
mask[alpha] = 1 | |
# if requested, compute additional | |
if return_vndvi: | |
logger.info('Computing VNDVI...') | |
vndvi_rows_fig, vndvi_interrows_fig = compute_vndvi(image, mask, window_size) | |
else: | |
vndvi_rows_fig = vndvi_interrows_fig = None | |
if return_vdi: | |
logger.info('Computing VDI...') | |
vdi_fig = compute_vdi(image, mask, window_size) | |
else: | |
vdi_fig = None | |
# save mask | |
""" | |
logger.info('Saving mask...') | |
profile.update( | |
dtype=rio.uint8, | |
count=1, | |
compress='lzw', | |
) | |
with rio.open(output, 'w', **profile) as dst: | |
dst.write(mask[None, ...]) | |
logger.info(f'Mask saved to {output}') | |
""" | |
# return mask and eventually additional outputs | |
return image.transpose(1,2,0), mask, vndvi_rows_fig, vndvi_interrows_fig, vdi_fig | |
demo = gr.Interface( | |
fn=main, | |
inputs=[ | |
gr.File(type='filepath', file_types=['image','.tif','.tiff'], label='Input image path'), | |
gr.Textbox(value='mask.tif', label='Output mask path (TODO)'), | |
gr.Slider(minimum=128, maximum=512, value=512, step=128, label='Patch size'), | |
gr.Slider(minimum=0, maximum=256, value=256, step=64, label='Stride'), | |
gr.Slider(minimum=0.1, maximum=10, value=1, step=0.05, label='Scaling factor'), | |
gr.Checkbox(value=False, label='Rotate patches'), | |
gr.Slider(minimum=4, maximum=128, value=16, step=4, label='Batch size'), | |
gr.Checkbox(value=False, label='Verbose'), | |
gr.Checkbox(value=True, label='Return VNDVI map'), | |
gr.Checkbox(value=True, label='Return VDI map'), | |
gr.Slider(minimum=10, maximum=600, value=360, step=1, label='Moving window size for computing vNDVI/VDI (suggestion: inversely proportional to the GSD [px/m])'), | |
], | |
outputs=[ | |
gr.Image(type='numpy', format='png', label='Input image'), | |
gr.Image(type='numpy', format='png', label='Predicted mask'), | |
gr.Plot(format='png', label='VNDVI rows (dilated for visibility)'), | |
gr.Plot(format='png', label='VNDVI interrows'), | |
gr.Plot(format='png', label='VDI'), | |
], # NB: if one of the outputs is None, it will not be displayed in the interface (https://github.com/gradio-app/gradio/issues/500#issuecomment-1046877766) | |
title='Growseg', | |
description='Segment vineyards in orthoimages', | |
delete_cache=[3600,3600], | |
) | |
demo.launch(share=True) |