File size: 4,804 Bytes
b53fda4
 
 
 
 
 
 
 
 
 
 
84dbd25
b53fda4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84dbd25
b53fda4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
import torch
import numpy as np
from transformers import SegformerForSemanticSegmentation
from loguru import logger
import rasterio as rio
from lib.utils import segment, compute_vndvi, compute_vdi
import os
import os.path as osp

# set temp dir for gradio
#os.environ["GRADIO_TEMP_DIR"] = "temp"

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f'Using device: {device}')

# load model architecture
logger.info('Loading model architecture...')
model = SegformerForSemanticSegmentation.from_pretrained(
    'nvidia/mit-b5',
    num_labels = 1, # binary segmentation
    num_channels = 3,   # RGB
    id2label = {1: 'vine'},
    label2id = {'vine': 1},
)

# load model weights
logger.info('Loading model weights...')
device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
model.load_state_dict(torch.load(f"model/growseg.pt", map_location=device))
model = model.to(device)
model.eval()


def main(input, output, patch_size=512, stride=256, scaling_factor=1., rotate=False, batch_size=16, verbose=False, return_vndvi=True, return_vdi=True, window_size=360):
    assert osp.splitext(output)[1].lower() in ['.tif', '.tiff', '.png', '.jpg', '.jpeg'], 'Output file format not supported'

    # read image
    logger.info(f'Reading image {input}...')
    with rio.open(input, 'r') as src:
        image = src.read()
        profile = src.profile
        shape = src.shape

    if profile['driver'] == 'GTiff':
        gsd = profile['transform'][0]  # ground sampling distance (NB: valid only if image is a GeoTIFF)
    else:
        gsd = None

    # Growseg works best on orthoimages with gsd in [1, 1.7] cm/px. You may want to
    # specify a scaling factor different from 1 if your image has a different gsd.
    # E.g.: SCALING_FACTOR = gsd / 1.5

    # segment
    logger.info('Segmenting image...')
    mask = segment(
        image,
        model,
        patch_size=patch_size,
        stride=stride,
        scaling_factor=scaling_factor,
        rotate=rotate,
        device=device,
        batch_size=batch_size,
        verbose=verbose
        )   # mask is a HxW float32 array in [0, 1]
    
    # apply threshold on confidence scores
    alpha = (mask == -1)
    mask = (mask > 0.5)

    # convert to uint8
    mask = (mask * 255).astype(np.uint8)

    # set nodata pixels to 1
    mask[alpha] = 1

    # if requested, compute additional
    if return_vndvi:
        logger.info('Computing VNDVI...')
        vndvi_rows_fig, vndvi_interrows_fig = compute_vndvi(image, mask, window_size)
    else:
        vndvi_rows_fig = vndvi_interrows_fig = None

    if return_vdi:
        logger.info('Computing VDI...')
        vdi_fig = compute_vdi(image, mask, window_size)
    else:
        vdi_fig = None
    
    # save mask
    """
    logger.info('Saving mask...')
    profile.update(
        dtype=rio.uint8,
        count=1,
        compress='lzw',
    )

    with rio.open(output, 'w', **profile) as dst:
        dst.write(mask[None, ...])

    logger.info(f'Mask saved to {output}')
    """

    # return mask and eventually additional outputs
    return image.transpose(1,2,0), mask, vndvi_rows_fig, vndvi_interrows_fig, vdi_fig

demo = gr.Interface(
    fn=main,
    inputs=[
        gr.File(type='filepath', file_types=['image','.tif','.tiff'], label='Input image path'),
        gr.Textbox(value='mask.tif', label='Output mask path (TODO)'),
        gr.Slider(minimum=128, maximum=512, value=512, step=128, label='Patch size'),
        gr.Slider(minimum=0, maximum=256, value=256, step=64, label='Stride'),
        gr.Slider(minimum=0.1, maximum=10, value=1, step=0.05, label='Scaling factor'),
        gr.Checkbox(value=False, label='Rotate patches'),
        gr.Slider(minimum=4, maximum=128, value=16, step=4, label='Batch size'),
        gr.Checkbox(value=False, label='Verbose'),
        gr.Checkbox(value=True, label='Return VNDVI map'),
        gr.Checkbox(value=True, label='Return VDI map'),
        gr.Slider(minimum=10, maximum=600, value=360, step=1, label='Moving window size for computing vNDVI/VDI (suggestion: inversely proportional to the GSD [px/m])'),
    ],
    outputs=[
        gr.Image(type='numpy', format='png', label='Input image'),
        gr.Image(type='numpy', format='png', label='Predicted mask'),
        gr.Plot(format='png', label='VNDVI rows (dilated for visibility)'),
        gr.Plot(format='png', label='VNDVI interrows'),
        gr.Plot(format='png', label='VDI'),
    ], # NB: if one of the outputs is None, it will not be displayed in the interface (https://github.com/gradio-app/gradio/issues/500#issuecomment-1046877766)
    title='Growseg',
    description='Segment vineyards in orthoimages',
    delete_cache=[3600,3600],
)

demo.launch(share=True)