tommonopoli commited on
Commit
b53fda4
Β·
1 Parent(s): e7e9077

app first commit

Browse files
Files changed (6) hide show
  1. README.md +2 -2
  2. app.py +138 -0
  3. lib/utils.py +521 -0
  4. lib/viz_utils.py +125 -0
  5. model/growseg.pt +3 -0
  6. requirements.txt +11 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Gaia Growseg Demo
3
- emoji: πŸƒ
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
 
1
  ---
2
+ title: Growseg Demo
3
+ emoji: πŸ‡
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import SegformerForSemanticSegmentation
5
+ from loguru import logger
6
+ import rasterio as rio
7
+ from lib.utils import segment, compute_vndvi, compute_vdi
8
+ import os
9
+ import os.path as osp
10
+
11
+ # set temp dir for gradio
12
+ os.environ["GRADIO_TEMP_DIR"] = "/nfs/home/monopoli/VITIGEOSS/gaia-growseg-demo/temp"
13
+
14
+ # set device
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+ logger.info(f'Using device: {device}')
17
+
18
+ # load model architecture
19
+ logger.info('Loading model architecture...')
20
+ model = SegformerForSemanticSegmentation.from_pretrained(
21
+ 'nvidia/mit-b5',
22
+ num_labels = 1, # binary segmentation
23
+ num_channels = 3, # RGB
24
+ id2label = {1: 'vine'},
25
+ label2id = {'vine': 1},
26
+ )
27
+
28
+ # load model weights
29
+ logger.info('Loading model weights...')
30
+ device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
31
+ model.load_state_dict(torch.load(f"model/growseg.pt", map_location=device))
32
+ model = model.to(device)
33
+ model.eval()
34
+
35
+
36
+ def main(input, output, patch_size=512, stride=256, scaling_factor=1., rotate=False, batch_size=16, verbose=False, return_vndvi=True, return_vdi=True, window_size=360):
37
+ assert osp.splitext(output)[1].lower() in ['.tif', '.tiff', '.png', '.jpg', '.jpeg'], 'Output file format not supported'
38
+ assert osp.splitext(input)[1].lower() == osp.splitext(output)[1].lower(), f'Input and output file formats must match. Got {osp.splitext(input)[1].lower()} and {osp.splitext(output)[1].lower()} respectively.'
39
+
40
+ # read image
41
+ logger.info(f'Reading image {input}...')
42
+ with rio.open(input, 'r') as src:
43
+ image = src.read()
44
+ profile = src.profile
45
+ shape = src.shape
46
+
47
+ if profile['driver'] == 'GTiff':
48
+ gsd = profile['transform'][0] # ground sampling distance (NB: valid only if image is a GeoTIFF)
49
+ else:
50
+ gsd = None
51
+
52
+ # Growseg works best on orthoimages with gsd in [1, 1.7] cm/px. You may want to
53
+ # specify a scaling factor different from 1 if your image has a different gsd.
54
+ # E.g.: SCALING_FACTOR = gsd / 1.5
55
+
56
+ # segment
57
+ logger.info('Segmenting image...')
58
+ mask = segment(
59
+ image,
60
+ model,
61
+ patch_size=patch_size,
62
+ stride=stride,
63
+ scaling_factor=scaling_factor,
64
+ rotate=rotate,
65
+ device=device,
66
+ batch_size=batch_size,
67
+ verbose=verbose
68
+ ) # mask is a HxW float32 array in [0, 1]
69
+
70
+ # apply threshold on confidence scores
71
+ alpha = (mask == -1)
72
+ mask = (mask > 0.5)
73
+
74
+ # convert to uint8
75
+ mask = (mask * 255).astype(np.uint8)
76
+
77
+ # set nodata pixels to 1
78
+ mask[alpha] = 1
79
+
80
+ # if requested, compute additional
81
+ if return_vndvi:
82
+ logger.info('Computing VNDVI...')
83
+ vndvi_rows_fig, vndvi_interrows_fig = compute_vndvi(image, mask, window_size)
84
+ else:
85
+ vndvi_rows_fig = vndvi_interrows_fig = None
86
+
87
+ if return_vdi:
88
+ logger.info('Computing VDI...')
89
+ vdi_fig = compute_vdi(image, mask, window_size)
90
+ else:
91
+ vdi_fig = None
92
+
93
+ # save mask
94
+ """
95
+ logger.info('Saving mask...')
96
+ profile.update(
97
+ dtype=rio.uint8,
98
+ count=1,
99
+ compress='lzw',
100
+ )
101
+
102
+ with rio.open(output, 'w', **profile) as dst:
103
+ dst.write(mask[None, ...])
104
+
105
+ logger.info(f'Mask saved to {output}')
106
+ """
107
+
108
+ # return mask and eventually additional outputs
109
+ return image.transpose(1,2,0), mask, vndvi_rows_fig, vndvi_interrows_fig, vdi_fig
110
+
111
+ demo = gr.Interface(
112
+ fn=main,
113
+ inputs=[
114
+ gr.File(type='filepath', file_types=['image','.tif','.tiff'], label='Input image path'),
115
+ gr.Textbox(value='mask.tif', label='Output mask path'),
116
+ gr.Slider(minimum=128, maximum=512, value=512, step=128, label='Patch size'),
117
+ gr.Slider(minimum=0, maximum=256, value=256, step=64, label='Stride'),
118
+ gr.Slider(minimum=0.1, maximum=10, value=1, step=0.05, label='Scaling factor'),
119
+ gr.Checkbox(value=False, label='Rotate patches'),
120
+ gr.Slider(minimum=4, maximum=128, value=16, step=4, label='Batch size'),
121
+ gr.Checkbox(value=False, label='Verbose'),
122
+ gr.Checkbox(value=True, label='Return VNDVI map'),
123
+ gr.Checkbox(value=True, label='Return VDI map'),
124
+ gr.Slider(minimum=10, maximum=600, value=360, step=1, label='Moving window size for computing vNDVI/VDI (suggestion: inversely proportional to the GSD [px/m])'),
125
+ ],
126
+ outputs=[
127
+ gr.Image(type='numpy', format='png', label='Input image'),
128
+ gr.Image(type='numpy', format='png', label='Predicted mask'),
129
+ gr.Plot(format='png', label='VNDVI rows (dilated for visibility)'),
130
+ gr.Plot(format='png', label='VNDVI interrows'),
131
+ gr.Plot(format='png', label='VDI'),
132
+ ], # NB: if one of the outputs is None, it will not be displayed in the interface (https://github.com/gradio-app/gradio/issues/500#issuecomment-1046877766)
133
+ title='Growseg',
134
+ description='Segment vineyards in orthoimages',
135
+ delete_cache=[3600,3600],
136
+ )
137
+
138
+ demo.launch(share=True)
lib/utils.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import rasterio
4
+ import cv2
5
+ from transformers import SegformerForSemanticSegmentation
6
+ from tqdm import tqdm
7
+ from PIL import Image
8
+ from scipy.ndimage import grey_dilation
9
+ import matplotlib as mpl
10
+ import matplotlib.pyplot as plt
11
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
12
+ from .viz_utils import alpha_composite
13
+
14
+
15
+ def read_raster(path, order='CHW'):
16
+ """Read a raster file and return a numpy array"""
17
+ assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
18
+
19
+ with rasterio.open(path) as src:
20
+ img = src.read()
21
+
22
+ if order == 'HWC':
23
+ img = np.moveaxis(img, 0, -1)
24
+
25
+ return img
26
+
27
+ def write_raster(path, img, profile, order='CHW'):
28
+ """Write a numpy array to a raster file"""
29
+ assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
30
+
31
+ if order == 'HWC':
32
+ img = np.moveaxis(img, -1, 0)
33
+
34
+ with rasterio.open(path, 'w', **profile) as dst:
35
+ dst.write(img)
36
+
37
+
38
+ def resize(img, shape=None, scaling_factor=1., order='CHW'):
39
+ """Resize an image by a given scaling factor"""
40
+ assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
41
+ assert shape is None or scaling_factor == 1., "Got both shape and scaling_factor. Please provide only one of them"
42
+
43
+ # resize image
44
+ if order == 'CHW':
45
+ img = np.moveaxis(img, 0, -1) # CHW -> HWC
46
+
47
+ if shape is not None:
48
+ img = cv2.resize(img, shape[::-1], interpolation=cv2.INTER_LINEAR)
49
+ else:
50
+ img = cv2.resize(img, None, fx=scaling_factor, fy=scaling_factor, interpolation=cv2.INTER_LINEAR)
51
+
52
+ # NB: cv2.resize returns a HW image if the input image is HW1: restore the C dimension
53
+ if len(img.shape) == 2:
54
+ img = img[..., None]
55
+
56
+ if order == 'CHW':
57
+ img = np.moveaxis(img, -1, 0) # HWC -> CHW
58
+
59
+ return img
60
+
61
+
62
+ def minimum_needed_padding(img_size, patch_size: int, stride: int):
63
+ """
64
+ Compute the minimum padding needed to make an image divisible by a patch size with a given stride.
65
+ Args:
66
+ image_shape (tuple): the shape (H,W) of the image tensor
67
+ patch_size (int): the size of the patches to extract
68
+ stride (int): the stride to use when extracting patches
69
+ Returns:
70
+ tuple: the padding needed to make the image tensor divisible by the patch size with the given stride
71
+ """
72
+
73
+ img_size = np.array(img_size)
74
+ pad = np.where(
75
+ img_size <= patch_size,
76
+ (patch_size - img_size) % patch_size, # the % patch_size is to handle the case img_size = (0,0)
77
+ (stride - (img_size - patch_size)) % stride
78
+ )
79
+ pad_t, pad_l = pad // 2
80
+ pad_b, pad_r = pad[0] - pad_t, pad[1] - pad_l
81
+
82
+ return pad_t, pad_b, pad_l, pad_r
83
+
84
+
85
+ def pad(img, pad, order='CHW'):
86
+ """Pad an image by the given pad values, in the format (pad_t, pad_b, pad_l, pad_r)"""
87
+ assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
88
+
89
+ pad_t, pad_b, pad_l, pad_r = pad
90
+
91
+ # pad image
92
+ if order == 'HWC':
93
+ padded_img = np.pad(img, ((pad_t,pad_b), (pad_l,pad_r), (0,0)), mode='constant', constant_values=0) # can also try mode='reflect'
94
+ else:
95
+ padded_img = np.pad(img, ((0,0), (pad_t,pad_b), (pad_l,pad_r)), mode='constant', constant_values=0) # can also try mode='reflect'
96
+
97
+ if isinstance(img, torch.Tensor):
98
+ padded_img = torch.tensor(padded_img)
99
+
100
+ return padded_img
101
+
102
+
103
+ def extract_patches(img, patch_size=512, stride=256, order='CHW', only_return_idx=True):
104
+ """Extract patches from an image, in the format (h_start, h_end, w_start, w_end)"""
105
+ assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
106
+
107
+ if order == 'HWC':
108
+ H, W = img.shape[:2]
109
+ else:
110
+ H, W = img.shape[1:]
111
+
112
+ # compute the number of patches
113
+ n_patches = ((H - patch_size) // stride + 1) * ((W - patch_size) // stride + 1)
114
+
115
+ # extract patches
116
+ patches = []
117
+ patches_idx = []
118
+ for i in range(0, H-patch_size+1, stride):
119
+ for j in range(0, W-patch_size+1, stride):
120
+
121
+ patches_idx.append((i, i+patch_size, j, j+patch_size))
122
+
123
+ if not only_return_idx:
124
+ if order == 'HWC':
125
+ patch = img[i:i+patch_size, j:j+patch_size, :]
126
+ else:
127
+ patch = img[:, i:i+patch_size, j:j+patch_size]
128
+ patches.append(patch)
129
+
130
+ if only_return_idx:
131
+ return patches_idx
132
+ return patches, patches_idx
133
+
134
+
135
+ def segment_batch(batch, model):
136
+
137
+ # perform prediction
138
+ with torch.no_grad():
139
+ out = model(batch) # (n_patches, 1, H, W) logits
140
+ if isinstance(model, SegformerForSemanticSegmentation):
141
+ out = upsample(out.logits, size=batch.shape[-2:])
142
+
143
+ # apply sigmoid
144
+ out = torch.sigmoid(out) # logits -> confidence scores
145
+
146
+ return out
147
+
148
+
149
+ def upsample(x, size):
150
+ """Upsample a 3D/4D/5D tensor"""
151
+ return torch.nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=False)
152
+
153
+
154
+ def merge_patches(patches, patches_idx, rotate=False, canvas_shape=None, order='CHW'): # TODO
155
+ """Merge patches into a single image"""
156
+ assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
157
+ if rotate:
158
+ axes_to_rotate = (0,1) if order == 'HWC' else (1,2)
159
+ patches = [np.rot90(p, -i, axes=axes_to_rotate) for i,p in enumerate(patches)]
160
+ else:
161
+ assert len(patches) == len(patches_idx), f"Got {len(patches)} patches and {len(patches_idx)} indexes"
162
+
163
+ # if canvas_shape is None, infer it from patches_idx
164
+ if canvas_shape is None:
165
+ patches_idx_zipped = list(zip(*patches_idx))
166
+ canvas_H = max(patches_idx_zipped[1])
167
+ canvas_W = max(patches_idx_zipped[3])
168
+ else:
169
+ canvas_H, canvas_W = canvas_shape
170
+
171
+ # initialize canvas
172
+ dtype = patches[0].dtype
173
+ if order == 'HWC':
174
+ canvas_C = patches[0].shape[-1]
175
+ canvas = np.zeros((canvas_H, canvas_W, canvas_C), dtype=dtype) # HWC
176
+ n_overlapping_patches = np.zeros((canvas_H, canvas_W, 1))
177
+ else:
178
+ canvas_C = patches[0].shape[0]
179
+ canvas = np.zeros((canvas_C, canvas_H, canvas_W, ), dtype=dtype) # CHW
180
+ n_overlapping_patches = np.zeros((1, canvas_H, canvas_W))
181
+
182
+ # merge patches
183
+ for p, (t,b,l,r) in zip(patches, patches_idx):
184
+ if order == 'HWC':
185
+ canvas[t:b, l:r, :] += p
186
+ n_overlapping_patches[t:b, l:r, 0] += 1
187
+ else:
188
+ canvas[:, t:b, l:r] += p
189
+ n_overlapping_patches[0, t:b, l:r] += 1
190
+
191
+
192
+ # compute average
193
+ canvas = np.divide(canvas, n_overlapping_patches, where=(n_overlapping_patches != 0))
194
+
195
+ return canvas
196
+
197
+
198
+ def segment(img, model, patch_size=512, stride=256, scaling_factor=1., rotate=False, device=None, batch_size=16, verbose=False):
199
+ """Segment an RGB image by using a segmentation model. Returns a probability
200
+ map (and performance metrics, if requested)"""
201
+
202
+ # some checks
203
+ assert isinstance(img, np.ndarray), f"Input must be a numpy array. Got {type(img)}"
204
+ assert img.shape[0] in [3,4], f"Input image must be formatted as CHW, with C = 3,4. Got a shape of {img.shape}"
205
+ assert img.dtype == np.uint8, f"Input image must be a numpy array with dtype np.uint8. Got {img.dtype}"
206
+
207
+ # prepare model for evaluation
208
+ model = model.to(device)
209
+ model.eval()
210
+
211
+ # prepare alpha channel
212
+ original_shape = img.shape
213
+ if img.shape[0] == 3:
214
+ # create dummy alpha channel
215
+ alpha = np.full(original_shape[1:], 255, dtype=np.uint8)
216
+ else:
217
+ # extract alpha channel
218
+ img, alpha = img[:3], img[3]
219
+
220
+ # resize image
221
+ img = resize(img, scaling_factor=scaling_factor)
222
+
223
+ # pad image
224
+ pad_t, pad_b, pad_l, pad_r = minimum_needed_padding(img.shape[1:], patch_size, stride)
225
+ padded_img = pad(img, pad=(pad_t, pad_b, pad_l, pad_r))
226
+ padded_shape = padded_img.shape
227
+
228
+ # extract patches indexes
229
+ patches_idx = extract_patches(padded_img, patch_size=patch_size, stride=stride)
230
+
231
+ ### segment
232
+ masks = []
233
+ masks_idx = []
234
+
235
+ batch = []
236
+ for i, p_idx in enumerate(tqdm(patches_idx, disable=not verbose, desc="Predicting...", total=len(patches_idx))):
237
+ t, b, l, r = p_idx
238
+
239
+ # extract patch
240
+ patch = padded_img[:, t:b, l:r]
241
+
242
+ # consider patch only if it is valid (i.e. not all black or all white)
243
+ if np.any(patch != 0) and np.any(patch != 255):
244
+
245
+ # convert patch to torch.tensor with float32 values in [0,1] (as required by torch)
246
+ patch = torch.tensor(patch).float() / 255.
247
+
248
+ # normalize patch with ImageNet mean and std
249
+ patch = (patch - torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)) / torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
250
+
251
+ # add patch to batch
252
+ batch.append(patch)
253
+ masks_idx.append(p_idx)
254
+
255
+ # (optional) for each patch extracted, consider also its rotated versions
256
+ if rotate:
257
+ for rot in range(1,4):
258
+ patch = torch.rot90(patch, rot, dims=[1,2])
259
+ batch.append(patch)
260
+ masks_idx.append(p_idx)
261
+
262
+ # if the batch is full, perform prediction
263
+ if len(batch) >= batch_size or i == len(patches_idx)-1:
264
+
265
+ # move batch to GPU
266
+ batch = torch.stack(batch).to(device)
267
+
268
+ # perform prediction
269
+ out = segment_batch(batch, model)
270
+
271
+ # append predictions to masks
272
+ masks.append(out.cpu().numpy())
273
+
274
+ # reset batch
275
+ batch = []
276
+
277
+ # concatenate predictions
278
+ masks = np.concatenate(masks) # (n_patches, 1, H, W)
279
+
280
+ # merge patches
281
+ mask = merge_patches(masks, masks_idx, rotate=rotate, canvas_shape=padded_shape[1:]) # (1, H, W)
282
+
283
+ # undo padding
284
+ mask = mask[:, pad_t:padded_shape[1]-pad_b, pad_l:padded_shape[2]-pad_r]
285
+
286
+ # resize mask to original shape
287
+ mask = resize(mask, shape=original_shape[1:])
288
+
289
+ # apply alpha channel, i.e. set to -1 the pixels where alpha is 0
290
+ mask = np.where(alpha == 0, -1, mask)
291
+
292
+ return mask.squeeze()
293
+
294
+
295
+
296
+
297
+
298
+
299
+
300
+
301
+
302
+
303
+
304
+
305
+
306
+
307
+ def sliding_window_avg_pooling(img, window, granularity, alpha=None, min_nonblank_pixels=0., normalize=False, return_min_max=False, verbose=False):
308
+ assert isinstance(img, np.ndarray), f'Input image must be a numpy array. Got {type(img)}'
309
+ assert img.shape[2] == 1, f'Input image must be formatted as HWC, with C = 1. Got a shape of {img.shape}'
310
+
311
+ # check if alpha channel was given, and cast it to np.float32 with values in [0,1]
312
+ if alpha is not None:
313
+ assert isinstance(alpha, np.ndarray), f'Alpha channel must be a numpy array. Got {type(alpha)}'
314
+ assert alpha.shape[2] == 1, f'Alpha channel must be formatted as HWC, with C = 1. Got a shape of {alpha.shape}'
315
+ assert img.shape == alpha.shape, f'The shape of input image {img.shape} and alpha channel {alpha.shape} do not match'
316
+ if alpha.dtype == np.uint8:
317
+ alpha = (alpha / 255).astype(np.float32)
318
+ elif alpha.dtype == bool:
319
+ alpha = alpha.astype(np.float32)
320
+ else:
321
+ alpha = np.ones_like(img)
322
+
323
+ # extract patches
324
+ patches, patches_idx = extract_patches(img, patch_size=window, stride=granularity, order='HWC', only_return_idx=False)
325
+ patches_alpha, _ = extract_patches(alpha, patch_size=window, stride=granularity, order='HWC', only_return_idx=False)
326
+
327
+ # keep only patches with more than min_nonblank_pixels
328
+ kept_patches = []
329
+ for i, p_a in tqdm(enumerate(patches_alpha), total=len(patches), disable=not verbose):
330
+ if p_a.sum() > min_nonblank_pixels * window**2:
331
+ kept_patches.append(i)
332
+ patches = [patches[i] for i in kept_patches]
333
+ patches_idx = [patches_idx[i] for i in kept_patches]
334
+ patches_alpha = [patches_alpha[i] for i in kept_patches]
335
+
336
+ # compute average patch value (i.e. density inside the patch)
337
+ patches_density = [np.full_like(p_a, (p * p_a).sum() / p_a.sum()) for p, p_a in zip(patches, patches_alpha)]
338
+
339
+ # merge patches
340
+ pooled_img = merge_patches(patches_density, patches_idx, canvas_shape=img.shape[:2], order='HWC')
341
+
342
+ # apply alpha
343
+ pooled_img = pooled_img * alpha
344
+
345
+ if normalize:
346
+ # [0,1]-normalize
347
+ pooled_img_min = pooled_img.min()
348
+ pooled_img_max = pooled_img.max()
349
+ pooled_img = (pooled_img - pooled_img_min) / (pooled_img_max - pooled_img_min)
350
+
351
+ if return_min_max:
352
+ return pooled_img, pooled_img_min, pooled_img_max
353
+
354
+ return pooled_img
355
+
356
+
357
+
358
+ def compute_vndvi(image, mask, dilate_rows=True, window_size=360):
359
+ assert image.dtype == np.uint8, f"Input image must be a numpy array with dtype np.uint8. Got {image.dtype}"
360
+ assert mask.dtype == np.uint8, f"Input mask must be a numpy array with dtype np.uint8. Got {mask.dtype}"
361
+
362
+ # CHW -> HWC
363
+ image = image.transpose(1,2,0)
364
+
365
+ # extract channels
366
+ _image = image.astype(np.float32) / 255 # convert to float32 in [0,1]
367
+ R, G, B = _image[:,:,0], _image[:,:,1], _image[:,:,2]
368
+
369
+ # to avoid division by 0 due to negative power, we replace 0 with 1 in R and B channels
370
+ R = np.where(R == 0, 1, R)
371
+ B = np.where(B == 0, 1, B)
372
+
373
+ # compute vndvi
374
+ vndvi = 0.5268 * (R**(-0.1294) * G**(0.3389) * B**(-0.3118))
375
+
376
+ # clip values to [0,1]
377
+ vndvi = np.clip(vndvi, 0, 1)
378
+
379
+ # compute vndvi rows heatmap
380
+ #vndvi_rows = np.where(mask == 255, vndvi, np.nan)
381
+
382
+ # compute vndvi interrows heatmap
383
+ #vndvi_interrows = np.where(mask == 0, vndvi, np.nan)
384
+
385
+ # compute 10th and 90th percentile on whole vineyard vndvi heatmap
386
+ vndvi_perc10, vndvi_perc90 = np.percentile(vndvi[mask != 1], [10,90]) # mask is 1 for nodata, 0 or 255 for valid pixels
387
+
388
+ # clip values between 10th and 90th percentile
389
+ vndvi_clipped = np.clip(vndvi, vndvi_perc10, vndvi_perc90)
390
+
391
+ # perform sliding window average pooling to smooth the heatmap
392
+ # NB: the window takes into account only the rows
393
+ vndvi_rows_clipped_pooled = sliding_window_avg_pooling(
394
+ np.where(mask == 255, vndvi_clipped, 0)[...,None],
395
+ window = int(window_size / 4),
396
+ granularity = 10,
397
+ alpha = (mask == 255)[...,None],
398
+ min_nonblank_pixels = 0.0,
399
+ )
400
+ # same, but for interrows
401
+ vndvi_interrows_clipped_pooled = sliding_window_avg_pooling(
402
+ np.where(mask == 0, vndvi_clipped, 0)[...,None],
403
+ window = int(window_size / 4),
404
+ granularity = 10,
405
+ alpha = (mask == 0)[...,None],
406
+ min_nonblank_pixels = 0.0,
407
+ )
408
+
409
+ # apply dilation to rows mask
410
+ dilate_rows = True
411
+ if dilate_rows:
412
+ dil_factor = int(window_size / 60)
413
+ mask_rows_dilated = grey_dilation(mask == 255, size=(dil_factor,dil_factor))
414
+ vndvi_rows_clipped_pooled_dilated = grey_dilation(vndvi_rows_clipped_pooled, size=(dil_factor,dil_factor,1))
415
+
416
+ # for visualization purposes, normalize with vndvi_perc10 and
417
+ # vndvi_perc90 (because we want vndvi_perc10 to be the first color of
418
+ # the colormap and vndvi_perc90 to be the last)
419
+ vndvi_rows_clipped_pooled_normalized = (vndvi_rows_clipped_pooled - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10)
420
+ vndvi_rows_clipped_pooled_dilated_normalized = (vndvi_rows_clipped_pooled_dilated - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10)
421
+ vndvi_interrows_clipped_pooled_normalized = (vndvi_interrows_clipped_pooled - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10)
422
+
423
+ # for visualization
424
+ vndvi_rows_img = alpha_composite(
425
+ image,
426
+ vndvi_rows_clipped_pooled_dilated_normalized if dilate_rows else vndvi_rows_clipped_pooled_normalized,
427
+ opacity = 1.0,
428
+ colormap = 'RdYlGn',
429
+ alpha_image = np.zeros_like(image[:,:,[0]]),
430
+ alpha_mask = mask_rows_dilated[...,None] if dilate_rows else (mask == 255)[...,None],
431
+ )
432
+
433
+ vndvi_interrows_img = alpha_composite(
434
+ image,
435
+ vndvi_interrows_clipped_pooled_normalized,
436
+ opacity = 1.0,
437
+ colormap = 'RdYlGn',
438
+ alpha_image = np.zeros_like(image[:,:,[0]]),
439
+ alpha_mask = (mask == 0)[...,None],
440
+ )
441
+
442
+ # add colorbar
443
+ fig_rows, ax = plt.subplots(1, 1, figsize=(10, 10))
444
+ divider = make_axes_locatable(ax)
445
+ cax = divider.append_axes('right', size='5%', pad=0.15)
446
+ ax.imshow(vndvi_rows_img)
447
+ fig_rows.colorbar(
448
+ mappable = mpl.cm.ScalarMappable(
449
+ norm = mpl.colors.Normalize(
450
+ vmin = vndvi_perc10,
451
+ vmax = vndvi_perc90),
452
+ cmap = 'RdYlGn'),
453
+ cax = cax,
454
+ orientation = 'vertical',
455
+ label = 'vNDVI',
456
+ shrink = 1)
457
+
458
+ fig_interrows, ax = plt.subplots(1, 1, figsize=(10, 10))
459
+ divider = make_axes_locatable(ax)
460
+ cax = divider.append_axes('right', size='5%', pad=0.15)
461
+ ax.imshow(vndvi_interrows_img)
462
+ fig_interrows.colorbar(
463
+ mappable = mpl.cm.ScalarMappable(
464
+ norm = mpl.colors.Normalize(
465
+ vmin = vndvi_perc10,
466
+ vmax = vndvi_perc90),
467
+ cmap = 'RdYlGn'),
468
+ cax = cax,
469
+ orientation = 'vertical',
470
+ label = 'vNDVI',
471
+ shrink = 1)
472
+
473
+ return fig_rows, fig_interrows
474
+
475
+
476
+
477
+ def compute_vdi(image, mask, window_size=360):
478
+ assert image.dtype == np.uint8, f"Input image must be a numpy array with dtype np.uint8. Got {image.dtype}"
479
+ assert mask.dtype == np.uint8, f"Input mask must be a numpy array with dtype np.uint8. Got {mask.dtype}"
480
+
481
+ # CHW -> HWC
482
+ image = image.transpose(1,2,0)
483
+
484
+ # compute vdi
485
+ vdi, vdi_min, vdi_max = sliding_window_avg_pooling(
486
+ (mask == 255)[...,None],
487
+ window = window_size,
488
+ granularity = 10,
489
+ alpha = (mask != 1)[...,None], # mask is 1 for nodata, 0 or 255 for valid pixels
490
+ min_nonblank_pixels = 0.9,
491
+ normalize=True,
492
+ return_min_max=True
493
+ )
494
+
495
+ # for visualization
496
+ vdi_img = alpha_composite(
497
+ image,
498
+ vdi,
499
+ opacity = 0.5,
500
+ colormap = 'jet_r',
501
+ alpha_image = (mask != 1)[...,None],
502
+ alpha_mask = (mask != 1)[...,None],
503
+ )
504
+
505
+ # add colorbar
506
+ fig, ax = plt.subplots(1, 1, figsize=(10, 10))
507
+ divider = make_axes_locatable(ax)
508
+ cax = divider.append_axes('right', size='5%', pad=0.15)
509
+ ax.imshow(vdi_img)
510
+ fig.colorbar(
511
+ mappable = mpl.cm.ScalarMappable(
512
+ norm = mpl.colors.Normalize(
513
+ vmin = vdi_min,
514
+ vmax = vdi_max),
515
+ cmap = 'jet_r'),
516
+ cax = cax,
517
+ orientation = 'vertical',
518
+ label = 'VDI',
519
+ shrink = 1)
520
+
521
+ return fig
lib/viz_utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import functools
3
+ import numpy as np
4
+ import cv2
5
+ import cmapy
6
+ from PIL import Image
7
+ import matplotlib
8
+
9
+
10
+
11
+ # BUGFIX in cmapy.py
12
+ def cmap(cmap_name, rgb_order=False):
13
+ """
14
+ Extract colormap color information as a LUT compatible with cv2.applyColormap().
15
+ Default channel order is BGR.
16
+
17
+ Args:
18
+ cmap_name: string, name of the colormap.
19
+ rgb_order: boolean, if false or not set, the returned array will be in
20
+ BGR order (standard OpenCV format). If true, the order
21
+ will be RGB.
22
+
23
+ Returns:
24
+ A numpy array of type uint8 containing the colormap.
25
+ """
26
+
27
+ c_map = matplotlib.colormaps.get_cmap(cmap_name)
28
+ rgba_data = matplotlib.cm.ScalarMappable(cmap=c_map).to_rgba(
29
+ np.arange(0, 1.0, 1.0 / 256.0), bytes=True
30
+ )
31
+ rgba_data = rgba_data[:, 0:-1].reshape((256, 1, 3))
32
+
33
+ # Convert to BGR (or RGB), uint8, for OpenCV.
34
+ cmap = np.zeros((256, 1, 3), np.uint8)
35
+
36
+ if not rgb_order:
37
+ cmap[:, :, :] = rgba_data[:, :, ::-1]
38
+ else:
39
+ cmap[:, :, :] = rgba_data[:, :, :]
40
+
41
+ return cmap
42
+
43
+ # If python 3, redefine cmap() to use lru_cache.
44
+ if sys.version_info > (3, 0):
45
+ cmap = functools.lru_cache(maxsize=200)(cmap)
46
+
47
+
48
+
49
+ def alpha_composite(img, msk, opacity=0.5, colormap=None, alpha_image=None, alpha_mask=None, red_mask=False):
50
+ """Alpha composite an RGBA image (img) and a grayscale mask (msk).
51
+ - If alpha_image is None, img's alpha channel is used (or, if not present,
52
+ initialized to all 255).
53
+ - If alpha_mask is None, msk is overlaid on img only where img's alpha
54
+ channel is not 0.
55
+ - If alpha_mask is not None, the above behavior is overridden and msk is
56
+ overlaid on img only where alpha_mask is not 0."""
57
+ # only HWC numpy arrays allowed
58
+ assert isinstance(img, np.ndarray), f'Input image must be a numpy array. Got {type(img)}'
59
+ assert isinstance(msk, np.ndarray), f'Input mask must be a numpy array. Got {type(msk)}'
60
+ if alpha_mask is not None:
61
+ assert isinstance(alpha_mask, np.ndarray), f'Alpha mask must be a numpy array. Got {type(alpha_mask)}'
62
+ assert alpha_mask.dtype in [np.float32, bool], f'Alpha mask must be of type np.float32 or bool. Got {alpha_mask.dtype}'
63
+ assert alpha_mask.shape[2] == 1, f'Alpha mask must be formatted as HWC, with C = 1. Got a shape of {msk.shape}'
64
+ assert img.shape[2] in [3,4], f'Input image must be formatted as HWC, with C = 3,4. Got a shape of {img.shape}'
65
+ assert msk.shape[2] == 1, f'Input mask must be formatted as HWC, with C = 1. Got a shape of {msk.shape}'
66
+ assert (opacity >= 0) and (opacity <= 1), f'Mask opacity must be between 0 and 1. Got {opacity}'
67
+
68
+ # to avoid modifying the original arrays
69
+ img = img.copy()
70
+ msk = msk.copy()
71
+
72
+ if img.shape[2] == 3:
73
+ # add alpha channel to img
74
+ img = np.concatenate([
75
+ img,
76
+ np.full((img.shape[0], img.shape[1], 1), 255, dtype=np.uint8)
77
+ ], axis=-1)
78
+
79
+ if alpha_image is None:
80
+ # initialize alpha_image to all Trues
81
+ alpha_image = img[:,:,[3]]
82
+ # convert alpha image to bool
83
+ alpha_image = alpha_image.astype(bool)
84
+
85
+ if alpha_mask is None:
86
+ # initialize alpha_mask to alpha_image
87
+ alpha_mask = alpha_image # so that alpha_mask is AT LEAST as restrictive as alpha_image
88
+ # convert alpha mask to bool
89
+ alpha_mask = alpha_mask.astype(bool)
90
+
91
+
92
+ if msk.dtype != np.uint8:
93
+ # convert mask to a uint8 grayscale image ([0,1] -> [0,255])
94
+ # NB: normalize the pixels of the mask we are interested in to [0,1]
95
+ # before passing it as input!!!
96
+ msk = (msk * 255).astype(np.uint8)
97
+
98
+ # convert mask from grayscale to RGBA
99
+ msk = cv2.cvtColor(msk, cv2.COLOR_GRAY2RGBA)
100
+
101
+ if colormap is not None:
102
+ # apply specified colormap to msk
103
+ # NB: values near 0 will be converted to the first colors of the chosen
104
+ # colormap, whereas values near 255 will be converted to the last colors
105
+ msk[:,:,:3] = cmapy.colorize(msk[:,:,:3], colormap, rgb_order=True)
106
+ elif red_mask:
107
+ # convert white to red
108
+ msk[:,:,[1,2]] = 0
109
+
110
+
111
+ # apply alpha_image to img's alpha channel
112
+ img[:,:,[3]] = (alpha_image * img[:,:,[3]]).astype(np.uint8)
113
+
114
+ # apply alpha_mask and opacity to msk's alpha channel
115
+ msk[:,:,[3]] = (alpha_mask * opacity * msk[:,:,[3]]).astype(np.uint8)
116
+
117
+ # alpha compositing
118
+ img_pil = Image.fromarray(img)
119
+ msk_pil = Image.fromarray(msk)
120
+ img_pil.alpha_composite(msk_pil)
121
+
122
+ return np.array(img_pil)
123
+
124
+
125
+
model/growseg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:293f45cfd45e402b2f2e4398aa32e32e2866431e7fa582293718c51bd1317182
3
+ size 338870239
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ scipy
3
+ rasterio
4
+ torch
5
+ transformers
6
+ tqdm
7
+ loguru
8
+ opencv-python-headless
9
+ pillow
10
+ matplotlib
11
+ cmapy