import gradio as gr import torch import numpy as np from transformers import SegformerForSemanticSegmentation from loguru import logger import rasterio as rio from lib.utils import segment, compute_vndvi, compute_vdi import os import os.path as osp # set temp dir for gradio #os.environ["GRADIO_TEMP_DIR"] = "temp" # set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f'Using device: {device}') # load model architecture logger.info('Loading model architecture...') model = SegformerForSemanticSegmentation.from_pretrained( 'nvidia/mit-b5', num_labels = 1, # binary segmentation num_channels = 3, # RGB id2label = {1: 'vine'}, label2id = {'vine': 1}, ) # load model weights logger.info('Loading model weights...') device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu') model.load_state_dict(torch.load(f"model/growseg.pt", map_location=device)) model = model.to(device) model.eval() def main(input, output, patch_size=512, stride=256, scaling_factor=1., rotate=False, batch_size=16, verbose=False, return_vndvi=True, return_vdi=True, window_size=360): assert osp.splitext(output)[1].lower() in ['.tif', '.tiff', '.png', '.jpg', '.jpeg'], 'Output file format not supported' # read image logger.info(f'Reading image {input}...') with rio.open(input, 'r') as src: image = src.read() profile = src.profile shape = src.shape if profile['driver'] == 'GTiff': gsd = profile['transform'][0] # ground sampling distance (NB: valid only if image is a GeoTIFF) else: gsd = None # Growseg works best on orthoimages with gsd in [1, 1.7] cm/px. You may want to # specify a scaling factor different from 1 if your image has a different gsd. # E.g.: SCALING_FACTOR = gsd / 1.5 # segment logger.info('Segmenting image...') mask = segment( image, model, patch_size=patch_size, stride=stride, scaling_factor=scaling_factor, rotate=rotate, device=device, batch_size=batch_size, verbose=verbose ) # mask is a HxW float32 array in [0, 1] # apply threshold on confidence scores alpha = (mask == -1) mask = (mask > 0.5) # convert to uint8 mask = (mask * 255).astype(np.uint8) # set nodata pixels to 1 mask[alpha] = 1 # if requested, compute additional if return_vndvi: logger.info('Computing VNDVI...') vndvi_rows_fig, vndvi_interrows_fig = compute_vndvi(image, mask, window_size) else: vndvi_rows_fig = vndvi_interrows_fig = None if return_vdi: logger.info('Computing VDI...') vdi_fig = compute_vdi(image, mask, window_size) else: vdi_fig = None # save mask """ logger.info('Saving mask...') profile.update( dtype=rio.uint8, count=1, compress='lzw', ) with rio.open(output, 'w', **profile) as dst: dst.write(mask[None, ...]) logger.info(f'Mask saved to {output}') """ # return mask and eventually additional outputs return image.transpose(1,2,0), mask, vndvi_rows_fig, vndvi_interrows_fig, vdi_fig demo = gr.Interface( fn=main, inputs=[ gr.File(type='filepath', file_types=['image','.tif','.tiff'], label='Input image path'), gr.Textbox(value='mask.tif', label='Output mask path (TODO)'), gr.Slider(minimum=128, maximum=512, value=512, step=128, label='Patch size'), gr.Slider(minimum=0, maximum=256, value=256, step=64, label='Stride'), gr.Slider(minimum=0.1, maximum=10, value=1, step=0.05, label='Scaling factor'), gr.Checkbox(value=False, label='Rotate patches'), gr.Slider(minimum=4, maximum=128, value=16, step=4, label='Batch size'), gr.Checkbox(value=False, label='Verbose'), gr.Checkbox(value=True, label='Return VNDVI map'), gr.Checkbox(value=True, label='Return VDI map'), gr.Slider(minimum=10, maximum=600, value=360, step=1, label='Moving window size for computing vNDVI/VDI (suggestion: inversely proportional to the GSD [px/m])'), ], outputs=[ gr.Image(type='numpy', format='png', label='Input image'), gr.Image(type='numpy', format='png', label='Predicted mask'), gr.Plot(format='png', label='VNDVI rows (dilated for visibility)'), gr.Plot(format='png', label='VNDVI interrows'), gr.Plot(format='png', label='VDI'), ], # NB: if one of the outputs is None, it will not be displayed in the interface (https://github.com/gradio-app/gradio/issues/500#issuecomment-1046877766) title='Growseg', description='Segment vineyards in orthoimages', delete_cache=[3600,3600], ) demo.launch(share=True)