blumenstiel commited on
Commit
7749565
·
1 Parent(s): c1e129b

Initial files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.tif filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,61 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ # Prithvi-EO-2.0
6
+
7
+ Prithvi-EO-2.0 is the second generation EO foundation model jointly developed by IBM, NASA, and Jülich Supercomputing Centre.
8
+
9
+ ## Architecture Overview
10
+
11
+ Prithvi-EO-2.0 is based on the ViT architecture, pre-trained using a masked autoencoder (MAE) approach, with two major modifications as shown in the figure below. First, we introduce a random dropout mechanism that completely removes different bands before the patch embeddings, with the aim of improving the ability of the model to deal with missingness of data. Second, we make modifications to support inputs with temporal and multi-spectral characteristics.
12
+
13
+ ![model_architecture](assets/modal_architecture.jpg)
14
+
15
+ Our main modifications to the ViT architecture are the 3D positional embedding and the 3D patch embedding, which are required to deal with spatiotemporal data. We have also included metadata and process metadata about the actual geolocation (e.g. latitude and longitude) and date (i.e. year and day-of-year ranging 1-365). This is done by adding biases that are calculated via 2D sine-cosine positional encoding and added to the 3D positional embeddings and 3D patch embeddings via a learned weighted sum (i.e. the weight given is a parameter learned during pretraining). Since this metadata is often not available, we pretrained Prithvi-EO-2.0 allowing for this to be absent via a dropout.
16
+
17
+ ## Pre-trained Models
18
+
19
+ | Model | Details | Weights |
20
+ | ------------- | ------------- |----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
21
+ |Prithvi-EO-2.0-300M | Pretrained 300M parameter model | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M) |
22
+ |Prithvi-EO-2.0-300M-TL | Pretrained 300M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL) |
23
+ |Prithvi-EO-2.0-600M | Pretrained 600M parameter model | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M) | |
24
+ |Prithvi-EO-2.0-600M-TL | Pretrained 600M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL) |
25
+
26
+ The models were pre-trained at the Julich Supercomputing Center with NASA's HLS V2 product (30m granularity) using 4.2M samples with six bands in the following order: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
27
+
28
+ ## Benchmarking
29
+ The model was benchmarked on GEO-Bench across 12 different earth observation classification and segmentation tasks at different resolutions against some of the most popular geospatial foundation models. Below the average score across all GEO-Bench tasks is shown.
30
+
31
+ ![geobench_overall_600M_TL.png](assets/geobench_overall_600M_TL.png)
32
+
33
+ ## Demo and inference
34
+ We provide a **demo** running Prithvi-EO-2.0-300M-TL [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-EO-2.0-Demo).
35
+
36
+ There is also an inference script (`inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different timestamps (see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units.
37
+
38
+ ```
39
+ python inference.py --data_files t1.tif t2.tif t3.tif t4.tif --input_indices <optional, space separated 0-based indices of the six Prithvi channels in your input>
40
+ ```
41
+
42
+ ## Finetuning
43
+
44
+ You can finetune the model using [TerraTorch](https://github.com/IBM/terratorch).
45
+
46
+ ### Feedback
47
+
48
+ Your feedback is invaluable to us. If you have any feedback about the model, please feel free to share it with us. You can do this by starting a discussion in this HF repository or submitting an issue to [TerraTorch](https://github.com/IBM/terratorch) on GitHub.
49
+
50
+ ### Citation
51
+
52
+ If this model helped your research, please cite `Prithvi-EO-2.0` in your publications. Here are two BibTeX entries as examples:
53
+
54
+ ```
55
+ @article{Prithvi-EO-2-preprint,
56
+ author = {},
57
+ title = {{Title}},
58
+ journal = {arxiv},
59
+ year = {2024}
60
+ }
61
+ ```
assets/geobench_overall_600M_TL.png ADDED
assets/logos.png ADDED
assets/modal_architecture.jpg ADDED
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": "prithvi_eo_v2_600_tl",
3
+ "num_features": 1024,
4
+ "pretrained_cfg": {
5
+ "img_size": 224,
6
+ "num_frames": 4,
7
+ "patch_size": [1, 14, 14],
8
+ "in_chans": 6,
9
+ "embed_dim": 1280,
10
+ "depth": 32,
11
+ "num_heads": 16,
12
+ "decoder_embed_dim": 512,
13
+ "decoder_depth": 8,
14
+ "decoder_num_heads": 16,
15
+ "mlp_ratio": 4,
16
+ "coords_encoding": ["time", "location"],
17
+ "coords_scale_learn": true,
18
+ "mask_ratio": 0.75,
19
+ "norm_pix_loss": false,
20
+ "bands": ["B02", "B03", "B04", "B05", "B06", "B07"],
21
+ "mean": [1087.0, 1342.0, 1433.0, 2734.0, 1958.0, 1363.0],
22
+ "std": [2248.0, 2179.0, 2178.0, 1850.0, 1242.0, 1049.0],
23
+ "origin_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL",
24
+ "paper_ids": "arXiv:X.X"
25
+ }
26
+ }
examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif ADDED

Git LFS Details

  • SHA256: a2e1f9d91fedf9b286aaeef5197f4715f3caf2851187356d598d9fe78beb7c6b
  • Pointer size: 132 Bytes
  • Size of remote file: 3.24 MB
examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif ADDED

Git LFS Details

  • SHA256: 92b5e2072f9b72fee207b8aec2f91f5c42f42f60950c8ca10d9022192d2cfb1a
  • Pointer size: 132 Bytes
  • Size of remote file: 3.24 MB
examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif ADDED

Git LFS Details

  • SHA256: 24feb1904fc62268494c9c0d8628124a41621cb4ee705d82cbce7586121c91c5
  • Pointer size: 132 Bytes
  • Size of remote file: 3.24 MB
inference.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import os
4
+ from typing import List, Union
5
+ import re
6
+ import datetime
7
+ import numpy as np
8
+ import pandas as pd
9
+ import rasterio
10
+ import torch
11
+ import yaml
12
+ from einops import rearrange
13
+
14
+ from functools import partial
15
+ from prithvi_mae import PrithviMAE
16
+
17
+ NO_DATA = -9999
18
+ NO_DATA_FLOAT = 0.0001
19
+ OFFSET = 0
20
+ PERCENTILE = 99.9
21
+
22
+
23
+ def process_channel_group(orig_img, new_img, channels, mean, std):
24
+ """Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
25
+ original range using *data_mean* and *data_std* and then lowest and highest percentiles are
26
+ removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
27
+
28
+ Args:
29
+ orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
30
+ new_img: torch.Tensor representing image with shape = (bands, H, W).
31
+ channels: list of indices representing RGB channels.
32
+ mean: list of mean values for each band.
33
+ std: list of std values for each band.
34
+
35
+ Returns:
36
+ torch.Tensor with shape (num_channels, height, width) for original image
37
+ torch.Tensor with shape (num_channels, height, width) for the other image
38
+ """
39
+
40
+ mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
41
+ std = torch.tensor(np.asarray(std)[:, None, None])
42
+ orig_img = orig_img[channels, ...]
43
+ valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
44
+ valid_mask[orig_img == NO_DATA_FLOAT] = False
45
+
46
+ # Back to original data range
47
+ orig_img = (orig_img * std[channels]) + mean[channels]
48
+ new_img = (new_img[channels, ...] * std[channels]) + mean[channels]
49
+
50
+ # Rescale (enhancing contrast)
51
+ max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
52
+ min_value = OFFSET
53
+
54
+ orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
55
+ new_img = torch.clamp((new_img - min_value) / (max_value - min_value), 0, 1)
56
+
57
+ # No data as zeros
58
+ orig_img[~valid_mask] = 0
59
+ new_img[~valid_mask] = 0
60
+
61
+ return orig_img, new_img
62
+
63
+
64
+ def read_geotiff(file_path: str):
65
+ """Read all bands from *file_path* and return image + meta info.
66
+
67
+ Args:
68
+ file_path: path to image file.
69
+
70
+ Returns:
71
+ np.ndarray with shape (bands, height, width)
72
+ meta info dict
73
+ """
74
+
75
+ with rasterio.open(file_path) as src:
76
+ img = src.read()
77
+ meta = src.meta
78
+ try:
79
+ coords = src.lnglat()
80
+ except:
81
+ # Cannot read coords
82
+ coords = None
83
+
84
+ return img, meta, coords
85
+
86
+
87
+ def save_geotiff(image, output_path: str, meta: dict):
88
+ """Save multi-band image in Geotiff file.
89
+
90
+ Args:
91
+ image: np.ndarray with shape (bands, height, width)
92
+ output_path: path where to save the image
93
+ meta: dict with meta info.
94
+ """
95
+
96
+ with rasterio.open(output_path, "w", **meta) as dest:
97
+ for i in range(image.shape[0]):
98
+ dest.write(image[i, :, :], i + 1)
99
+
100
+ return
101
+
102
+
103
+ def _convert_np_uint8(float_image: torch.Tensor):
104
+ image = float_image.numpy() * 255.0
105
+ image = image.astype(dtype=np.uint8)
106
+
107
+ return image
108
+
109
+
110
+ def load_example(
111
+ file_paths: List[str],
112
+ mean: List[float],
113
+ std: List[float],
114
+ indices: Union[list[int], None] = None,
115
+ ):
116
+ """Build an input example by loading images in *file_paths*.
117
+
118
+ Args:
119
+ file_paths: list of file paths .
120
+ mean: list containing mean values for each band in the images in *file_paths*.
121
+ std: list containing std values for each band in the images in *file_paths*.
122
+
123
+ Returns:
124
+ np.array containing created example
125
+ list of meta info for each image in *file_paths*
126
+ """
127
+
128
+ imgs = []
129
+ metas = []
130
+ temporal_coords = []
131
+ location_coords = []
132
+
133
+ for file in file_paths:
134
+ img, meta, coords = read_geotiff(file)
135
+
136
+ # Rescaling (don't normalize on nodata)
137
+ img = np.moveaxis(img, 0, -1) # channels last for rescaling
138
+ if indices is not None:
139
+ img = img[..., indices]
140
+ img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
141
+
142
+ imgs.append(img)
143
+ metas.append(meta)
144
+ if coords is not None:
145
+ location_coords.append(coords)
146
+
147
+ try:
148
+ match = re.search(r'(\d{7,8}T\d{6})', file)
149
+ if match:
150
+ year = int(match.group(1)[:4])
151
+ julian_day = match.group(1).split('T')[0][4:]
152
+ if len(julian_day) == 3:
153
+ julian_day = int(julian_day)
154
+ else:
155
+ julian_day = datetime.datetime.strptime(julian_day, '%m%d').timetuple().tm_yday
156
+ temporal_coords.append([year, julian_day])
157
+ except Exception as e:
158
+ print(f'Could not extract timestamp for {file} ({e})')
159
+
160
+ imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
161
+ imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
162
+ imgs = np.expand_dims(imgs, axis=0) # add batch di
163
+
164
+ return imgs, temporal_coords, location_coords, metas
165
+
166
+
167
+ def run_model(
168
+ model: torch.nn.Module,
169
+ input_data: torch.Tensor,
170
+ temporal_coords: None | torch.Tensor,
171
+ location_coords: None | torch.Tensor,
172
+ mask_ratio: float,
173
+ device: torch.device,
174
+ ):
175
+ """Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
176
+
177
+ Args:
178
+ model: MAE model to run.
179
+ input_data: torch.Tensor with shape (B, C, T, H, W).
180
+ mask_ratio: mask ratio to use.
181
+ device: device where model should run.
182
+
183
+ Returns:
184
+ 3 torch.Tensor with shape (B, C, T, H, W).
185
+ """
186
+
187
+ with torch.no_grad():
188
+ x = input_data.to(device)
189
+
190
+ _, pred, mask = model(x, temporal_coords, location_coords, mask_ratio)
191
+
192
+ # Create mask and prediction images (un-patchify)
193
+ mask_img = (
194
+ model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
195
+ )
196
+ pred_img = model.unpatchify(pred).detach().cpu()
197
+
198
+ # Mix visible and predicted patches
199
+ rec_img = input_data.clone()
200
+ rec_img[mask_img == 1] = pred_img[
201
+ mask_img == 1
202
+ ] # binary mask: 0 is keep, 1 is remove
203
+
204
+ # Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
205
+ mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
206
+
207
+ return rec_img, mask_img
208
+
209
+
210
+ def save_rgb_imgs(
211
+ input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data
212
+ ):
213
+ """Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
214
+
215
+ Args:
216
+ input_img: input torch.Tensor with shape (C, T, H, W).
217
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
218
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
219
+ channels: list of indices representing RGB channels.
220
+ mean: list of mean values for each band.
221
+ std: list of std values for each band.
222
+ output_dir: directory where to save outputs.
223
+ meta_data: list of dicts with geotiff meta info.
224
+ """
225
+
226
+ for t in range(input_img.shape[1]):
227
+ rgb_orig, rgb_pred = process_channel_group(
228
+ orig_img=input_img[:, t, :, :],
229
+ new_img=rec_img[:, t, :, :],
230
+ channels=channels,
231
+ mean=mean,
232
+ std=std,
233
+ )
234
+
235
+ rgb_mask = mask_img[channels, t, :, :] * rgb_orig
236
+
237
+ # Saving images
238
+
239
+ save_geotiff(
240
+ image=_convert_np_uint8(rgb_orig),
241
+ output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
242
+ meta=meta_data[t],
243
+ )
244
+
245
+ save_geotiff(
246
+ image=_convert_np_uint8(rgb_pred),
247
+ output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
248
+ meta=meta_data[t],
249
+ )
250
+
251
+ save_geotiff(
252
+ image=_convert_np_uint8(rgb_mask),
253
+ output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
254
+ meta=meta_data[t],
255
+ )
256
+
257
+
258
+ def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
259
+ """Wrapper function to save Geotiff images (reconstructed, mask) per timestamp.
260
+
261
+ Args:
262
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
263
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
264
+ mean: list of mean values for each band.
265
+ std: list of std values for each band.
266
+ output_dir: directory where to save outputs.
267
+ meta_data: list of dicts with geotiff meta info.
268
+ """
269
+
270
+ mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
271
+ std = torch.tensor(np.asarray(std)[:, None, None])
272
+
273
+ for t in range(rec_img.shape[1]):
274
+ # Back to original data range
275
+ rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
276
+
277
+ mask_img_t = mask_img[:, t, :, :].to(torch.int16)
278
+
279
+ # Saving images
280
+
281
+ save_geotiff(
282
+ image=rec_img_t,
283
+ output_path=os.path.join(output_dir, f"predicted_t{t}.tiff"),
284
+ meta=meta_data[t],
285
+ )
286
+
287
+ save_geotiff(
288
+ image=mask_img_t,
289
+ output_path=os.path.join(output_dir, f"mask_t{t}.tiff"),
290
+ meta=meta_data[t],
291
+ )
292
+
293
+
294
+ def main(
295
+ data_files: List[str],
296
+ config_path: str,
297
+ checkpoint: str,
298
+ output_dir: str,
299
+ rgb_outputs: bool,
300
+ mask_ratio: float = None,
301
+ input_indices: list[int] = None,
302
+ ):
303
+ os.makedirs(output_dir, exist_ok=True)
304
+
305
+ # Get parameters --------
306
+
307
+ import json
308
+ with open(config_path, "r") as f:
309
+ config = yaml.safe_load(f)['pretrained_cfg']
310
+
311
+ batch_size = 1
312
+ bands = config['bands']
313
+ num_frames = len(data_files)
314
+ mean = config['mean']
315
+ std = config['std']
316
+ coords_encoding = config['coords_encoding']
317
+ img_size = config['img_size']
318
+ mask_ratio = mask_ratio or config['mask_ratio']
319
+
320
+ print(
321
+ f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
322
+ )
323
+ if len(data_files) != 3:
324
+ print(
325
+ "The original model was trained for 3 time steps (expecting 3 files). \nResults with different numbers of timesteps may vary"
326
+ )
327
+
328
+ if torch.cuda.is_available():
329
+ device = torch.device("cuda")
330
+ else:
331
+ device = torch.device("cpu")
332
+
333
+ print(f"Using {device} device.\n")
334
+
335
+ # Loading data ---------------------------------------------------------------------------------
336
+
337
+ input_data, temporal_coords, location_coords, meta_data = load_example(
338
+ file_paths=data_files, indices=input_indices, mean=mean, std=std
339
+ )
340
+
341
+ if len(temporal_coords) != num_frames and 'time' in coords_encoding:
342
+ coords_encoding.pop('time')
343
+ if not len(location_coords) and 'location' in coords_encoding:
344
+ coords_encoding.pop('location')
345
+
346
+ # Create model and load checkpoint -------------------------------------------------------------
347
+
348
+ config.update(
349
+ coords_encoding=coords_encoding,
350
+ num_frames=num_frames,
351
+ in_chans=len(bands),
352
+ )
353
+
354
+ model = PrithviMAE(**config)
355
+
356
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
357
+ print(f"\n--> Model has {total_params:,} parameters.\n")
358
+
359
+ model.to(device)
360
+
361
+ state_dict = torch.load(checkpoint, map_location=device)
362
+ # discard fixed pos_embedding weight
363
+ for k in list(state_dict.keys()):
364
+ if 'pos_embed' in k:
365
+ del state_dict[k]
366
+ model.load_state_dict(state_dict, strict=False)
367
+ print(f"Loaded checkpoint from {checkpoint}")
368
+
369
+ # Running model --------------------------------------------------------------------------------
370
+
371
+ model.eval()
372
+ channels = [bands.index(b) for b in ["B04", "B03", "B02"]] # BGR -> RGB
373
+
374
+ # Reflect pad if not divisible by img_size
375
+ original_h, original_w = input_data.shape[-2:]
376
+ pad_h = img_size - (original_h % img_size)
377
+ pad_w = img_size - (original_w % img_size)
378
+ input_data = np.pad(
379
+ input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
380
+ )
381
+
382
+ # Build sliding window
383
+ batch = torch.tensor(input_data, device="cpu")
384
+ windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
385
+ h1, w1 = windows.shape[3:5]
386
+ windows = rearrange(
387
+ windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
388
+ )
389
+
390
+ # Split into batches if number of windows > batch_size
391
+ num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
392
+ windows = torch.tensor_split(windows, num_batches, dim=0)
393
+
394
+ temporal_coords = torch.Tensor(temporal_coords, device=device).unsqueeze(0)
395
+ location_coords = torch.Tensor(location_coords[0], device=device).unsqueeze(0)
396
+
397
+ # Run model
398
+ rec_imgs = []
399
+ mask_imgs = []
400
+ for x in windows:
401
+ rec_img, mask_img = run_model(model, x, temporal_coords, location_coords, mask_ratio, device)
402
+ rec_imgs.append(rec_img)
403
+ mask_imgs.append(mask_img)
404
+
405
+ rec_imgs = torch.concat(rec_imgs, dim=0)
406
+ mask_imgs = torch.concat(mask_imgs, dim=0)
407
+
408
+ # Build images from patches
409
+ rec_imgs = rearrange(
410
+ rec_imgs,
411
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
412
+ h=img_size,
413
+ w=img_size,
414
+ b=1,
415
+ c=len(bands),
416
+ t=num_frames,
417
+ h1=h1,
418
+ w1=w1,
419
+ )
420
+ mask_imgs = rearrange(
421
+ mask_imgs,
422
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
423
+ h=img_size,
424
+ w=img_size,
425
+ b=1,
426
+ c=len(bands),
427
+ t=num_frames,
428
+ h1=h1,
429
+ w1=w1,
430
+ )
431
+
432
+ # Cut padded images back to original size
433
+ rec_imgs_full = rec_imgs[..., :original_h, :original_w]
434
+ mask_imgs_full = mask_imgs[..., :original_h, :original_w]
435
+ batch_full = batch[..., :original_h, :original_w]
436
+
437
+ # Build output images
438
+ if rgb_outputs:
439
+ for d in meta_data:
440
+ d.update(count=3, dtype="uint8", compress="lzw", nodata=0)
441
+
442
+ save_rgb_imgs(
443
+ batch_full[0, ...],
444
+ rec_imgs_full[0, ...],
445
+ mask_imgs_full[0, ...],
446
+ channels,
447
+ mean,
448
+ std,
449
+ output_dir,
450
+ meta_data,
451
+ )
452
+ else:
453
+ for d in meta_data:
454
+ d.update(compress="lzw", nodata=0)
455
+
456
+ save_imgs(
457
+ rec_imgs_full[0, ...],
458
+ mask_imgs_full[0, ...],
459
+ mean,
460
+ std,
461
+ output_dir,
462
+ meta_data,
463
+ )
464
+
465
+ print("Done!")
466
+
467
+
468
+ if __name__ == "__main__":
469
+ parser = argparse.ArgumentParser("MAE run inference", add_help=False)
470
+
471
+ parser.add_argument(
472
+ "--data_files",
473
+ type=str,
474
+ nargs="+",
475
+ default=["examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif",
476
+ "examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif",
477
+ "examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"
478
+ ],
479
+ help="Path to the data files. Assumes multi-band files.",
480
+ )
481
+ parser.add_argument(
482
+ "--config",
483
+ "-c",
484
+ type=str,
485
+ default="config.json",
486
+ help="Path to json file containing model training parameters.",
487
+ )
488
+ parser.add_argument(
489
+ "--checkpoint",
490
+ type=str,
491
+ default="Prithvi_EO_V2_300M_TL.pt",
492
+ help="Path to a checkpoint file to load from.",
493
+ )
494
+ parser.add_argument(
495
+ "--output_dir",
496
+ type=str,
497
+ default="output",
498
+ help="Path to the directory where to save outputs.",
499
+ )
500
+ parser.add_argument(
501
+ "--mask_ratio",
502
+ default=0.75,
503
+ type=float,
504
+ help="Masking ratio (percentage of removed patches). "
505
+ "If None (default) use same value used for pretraining.",
506
+ )
507
+ parser.add_argument(
508
+ "--input_indices",
509
+ default=None,
510
+ type=int,
511
+ nargs="+",
512
+ help="0-based indices of channels to be selected from the input. By default takes all.",
513
+ )
514
+ parser.add_argument(
515
+ "--rgb_outputs",
516
+ action="store_true",
517
+ help="If present, output files will only contain RGB channels. "
518
+ "Otherwise, all bands will be saved.",
519
+ )
520
+ args = parser.parse_args()
521
+
522
+ main(**vars(args))
prithvi_mae.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) IBM Corp. 2024. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------
15
+ # References:
16
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
17
+ # transformers: https://github.com/huggingface/transformers
18
+ # --------------------------------------------------------
19
+
20
+ from functools import partial
21
+ from typing import List, Tuple
22
+
23
+ import logging
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn as nn
27
+ from einops import rearrange
28
+ from timm.layers import to_2tuple
29
+ from timm.models.vision_transformer import Block
30
+
31
+
32
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
33
+ """
34
+ Create 3D sin/cos positional embeddings.
35
+
36
+ Args:
37
+ embed_dim (int):
38
+ Embedding dimension.
39
+ grid_size (tuple[int, int, int] | list[int]):
40
+ The grid depth, height and width.
41
+ add_cls_token (bool, *optional*, defaults to False):
42
+ Whether or not to add a classification (CLS) token.
43
+
44
+ Returns:
45
+ (`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or
46
+ (1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token)
47
+ """
48
+
49
+ assert embed_dim % 16 == 0
50
+
51
+ t_size, h_size, w_size = grid_size
52
+
53
+ w_embed_dim = embed_dim // 16 * 6
54
+ h_embed_dim = embed_dim // 16 * 6
55
+ t_embed_dim = embed_dim // 16 * 4
56
+
57
+ w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
58
+ h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
59
+ t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
60
+
61
+ w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
62
+ h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
63
+ t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
64
+
65
+ pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
66
+
67
+ if add_cls_token:
68
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
69
+ return pos_embed
70
+
71
+
72
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
73
+ """
74
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
75
+ """
76
+ if embed_dim % 2 != 0:
77
+ raise ValueError("embed_dim must be even")
78
+
79
+ omega = np.arange(embed_dim // 2, dtype=float)
80
+ omega /= embed_dim / 2.0
81
+ omega = 1.0 / 10000**omega # (D/2,)
82
+
83
+ pos = pos.reshape(-1) # (M,)
84
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
85
+
86
+ emb_sin = np.sin(out) # (M, D/2)
87
+ emb_cos = np.cos(out) # (M, D/2)
88
+
89
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
90
+ return emb
91
+
92
+
93
+ def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
94
+ """ This is the torch version of *get_1d_sincos_pos_embed_from_grid()*. However,
95
+ it was modified to cast omega values to pos.dtype which must be float (and not int as in
96
+ regular positional embeddings). This was required in order to allow for native FSDP mixed
97
+ precision support: modify omega to appropriate dtype (pos carries the correct float dtype),
98
+ instead of manually forcing float32.
99
+
100
+ embed_dim: output dimension for each position
101
+ pos: a list of positions to be encoded: size (M,) - must be float dtype!
102
+ out: (M, D)
103
+ """
104
+ assert embed_dim % 2 == 0
105
+ assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
106
+
107
+ omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
108
+ omega /= embed_dim / 2.0
109
+ omega = 1.0 / 10000**omega # (D/2,)
110
+
111
+ pos = pos.reshape(-1) # (M,)
112
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
113
+
114
+ emb_sin = torch.sin(out) # (M, D/2)
115
+ emb_cos = torch.cos(out) # (M, D/2)
116
+
117
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
118
+
119
+ return emb
120
+
121
+
122
+ def _init_weights(module):
123
+ """Initialize the weights"""
124
+ if isinstance(module, nn.Linear):
125
+ nn.init.xavier_uniform_(module.weight)
126
+ if module.bias is not None:
127
+ module.bias.data.zero_()
128
+ elif isinstance(module, nn.LayerNorm):
129
+ module.bias.data.zero_()
130
+ module.weight.data.fill_(1.0)
131
+
132
+
133
+ class PatchEmbed(nn.Module):
134
+ """3D version of timm.models.vision_transformer.PatchEmbed"""
135
+ def __init__(
136
+ self,
137
+ input_size: Tuple[int, int, int] = (1, 224, 224),
138
+ patch_size: Tuple[int, int, int] = (1, 16, 16),
139
+ in_chans: int = 3,
140
+ embed_dim: int = 768,
141
+ norm_layer: nn.Module | None = None,
142
+ flatten: bool = True,
143
+ bias: bool = True,
144
+ ):
145
+ super().__init__()
146
+ self.input_size = input_size
147
+ self.patch_size = patch_size
148
+ self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
149
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
150
+ self.flatten = flatten
151
+
152
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
153
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
154
+
155
+ def forward(self, x):
156
+ B, C, T, H, W = x.shape
157
+
158
+ if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1:
159
+ logging.warning(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
160
+ f"The border will be ignored, add backbone_padding for pixel-wise tasks.")
161
+
162
+ x = self.proj(x)
163
+ if self.flatten:
164
+ x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
165
+ x = self.norm(x)
166
+ return x
167
+
168
+
169
+ class TemporalEncoder(nn.Module):
170
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
171
+ super().__init__()
172
+ self.embed_dim = embed_dim
173
+ self.year_embed_dim = embed_dim // 2
174
+ self.julian_day_embed_dim = embed_dim - self.year_embed_dim
175
+
176
+ # If trainable, initialize scale with small number
177
+ if trainable_scale:
178
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
179
+ else:
180
+ self.register_buffer('scale', torch.ones(1))
181
+
182
+ def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None):
183
+ """
184
+ temporal_coords: year and day-of-year info with shape (B, T, 2).
185
+ tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be
186
+ repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim).
187
+ """
188
+ shape = temporal_coords.shape[:2] + (-1,) # B, T, -1
189
+
190
+ year = _get_1d_sincos_embed_from_grid_torch(
191
+ self.year_embed_dim, temporal_coords[:, :, 0].flatten()).reshape(shape)
192
+ julian_day = _get_1d_sincos_embed_from_grid_torch(
193
+ self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()).reshape(shape)
194
+
195
+ embedding = self.scale * torch.cat([year, julian_day], dim=-1)
196
+
197
+ if tokens_per_frame is not None:
198
+ embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
199
+
200
+ return embedding # B, T*tokens_per_frame, embed_dim
201
+
202
+
203
+ class LocationEncoder(nn.Module):
204
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
205
+ super().__init__()
206
+ self.embed_dim = embed_dim
207
+ self.lat_embed_dim = embed_dim // 2
208
+ self.lon_embed_dim = embed_dim - self.lat_embed_dim
209
+
210
+ # If trainable, initialize scale with small number
211
+ if trainable_scale:
212
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
213
+ else:
214
+ self.register_buffer('scale', torch.ones(1))
215
+
216
+ def forward(self, location_coords: torch.Tensor):
217
+ """
218
+ location_coords: lat and lon info with shape (B, 2).
219
+ """
220
+ shape = location_coords.shape[:1] + (1, -1) # B, 1, -1
221
+
222
+ lat = _get_1d_sincos_embed_from_grid_torch(
223
+ self.lat_embed_dim, location_coords[:, 0].flatten()).reshape(shape)
224
+ lon = _get_1d_sincos_embed_from_grid_torch(
225
+ self.lon_embed_dim, location_coords[:, 1].flatten()).reshape(shape)
226
+
227
+ embedding = self.scale * torch.cat([lat, lon], dim=-1)
228
+
229
+ return embedding # B, 1, embed_dim
230
+
231
+
232
+ class PrithviViT(nn.Module):
233
+ """ Prithvi ViT Encoder"""
234
+ def __init__(self,
235
+ img_size: int | Tuple[int, int] = 224,
236
+ patch_size: int | Tuple[int, int, int] = (1, 16, 16),
237
+ num_frames: int = 1,
238
+ in_chans: int = 3,
239
+ embed_dim: int = 1024,
240
+ depth: int = 24,
241
+ num_heads: int = 16,
242
+ mlp_ratio: float = 4.,
243
+ norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
244
+ coords_encoding: List[str] | None = None,
245
+ coords_scale_learn: bool = False,
246
+ encoder_only: bool = True, # needed for timm
247
+ ** kwargs,
248
+ ):
249
+ super().__init__()
250
+
251
+ self.feature_info = []
252
+ self.encoder_only = encoder_only
253
+ self.in_chans = in_chans
254
+ self.num_frames = num_frames
255
+ self.embed_dim = embed_dim
256
+ self.img_size = to_2tuple(img_size)
257
+ if isinstance(patch_size, int):
258
+ patch_size = (1, patch_size, patch_size)
259
+
260
+ # 3D patch embedding
261
+ self.patch_embed = PatchEmbed(
262
+ input_size=(num_frames,) + self.img_size,
263
+ patch_size=patch_size,
264
+ in_chans=in_chans,
265
+ embed_dim=embed_dim,
266
+ )
267
+
268
+ # Optional temporal and location embedding
269
+ coords_encoding = coords_encoding or []
270
+ self.temporal_encoding = 'time' in coords_encoding
271
+ self.location_encoding = 'location' in coords_encoding
272
+ if self.temporal_encoding:
273
+ assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}"
274
+ self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn)
275
+ if self.location_encoding:
276
+ self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn)
277
+
278
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
279
+ self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
280
+
281
+ # Transformer layers
282
+ self.blocks = []
283
+ for i in range(depth):
284
+ self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer))
285
+ self.feature_info.append(
286
+ {"num_chs": embed_dim * self.patch_embed.patch_size[0], "reduction": 1, "module": f"blocks.{i}"}
287
+ )
288
+ self.blocks = nn.ModuleList(self.blocks)
289
+
290
+ self.norm = norm_layer(embed_dim)
291
+
292
+ self.initialize_weights()
293
+
294
+ def initialize_weights(self):
295
+ # initialize (and freeze) position embeddings by sin-cos embedding
296
+ pos_embed = get_3d_sincos_pos_embed(
297
+ self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
298
+ )
299
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
300
+
301
+ # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
302
+ w = self.patch_embed.proj.weight.data
303
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
304
+
305
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
306
+ torch.nn.init.normal_(self.cls_token, std=0.02)
307
+ self.apply(_init_weights)
308
+
309
+ def random_masking(self, sequence, mask_ratio, noise=None):
310
+ """
311
+ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
312
+ noise.
313
+
314
+ Args:
315
+ sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
316
+ mask_ratio (float): mask ratio to use.
317
+ noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
318
+ mainly used for testing purposes to control randomness and maintain the reproducibility
319
+ """
320
+ batch_size, seq_length, dim = sequence.shape
321
+ len_keep = int(seq_length * (1 - mask_ratio))
322
+
323
+ if noise is None:
324
+ noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
325
+
326
+ # sort noise for each sample
327
+ ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
328
+ ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
329
+
330
+ # keep the first subset
331
+ ids_keep = ids_shuffle[:, :len_keep]
332
+ sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
333
+
334
+ # generate the binary mask: 0 is keep, 1 is remove
335
+ mask = torch.ones([batch_size, seq_length], device=sequence.device)
336
+ mask[:, :len_keep] = 0
337
+ # unshuffle to get the binary mask
338
+ mask = torch.gather(mask, dim=1, index=ids_restore)
339
+
340
+ return sequence_unmasked, mask, ids_restore
341
+
342
+ def _get_pos_embed(self, x):
343
+ t, h, w = x.shape[-3:]
344
+
345
+ pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(
346
+ self.embed_dim,
347
+ (
348
+ t // self.patch_embed.patch_size[0],
349
+ h // self.patch_embed.patch_size[1],
350
+ w // self.patch_embed.patch_size[2],
351
+ ),
352
+ add_cls_token=True,
353
+ )).float().unsqueeze(0).to(x)
354
+
355
+ return pos_embed
356
+
357
+
358
+ def forward(
359
+ self, x: torch.Tensor,
360
+ temporal_coords: None | torch.Tensor = None,
361
+ location_coords: None | torch.Tensor = None,
362
+ mask_ratio=0.75
363
+ ):
364
+ if x.shape[-3:] != self.patch_embed.input_size:
365
+ # changed input size
366
+ pos_embed = self._get_pos_embed(x)
367
+ else:
368
+ pos_embed = self.pos_embed
369
+
370
+ # embed patches
371
+ x = self.patch_embed(x)
372
+
373
+ # add pos embed w/o cls token
374
+ x = x + pos_embed[:, 1:, :]
375
+
376
+ if self.temporal_encoding:
377
+ num_tokens_per_frame = x.shape[1] // self.num_frames
378
+ temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
379
+ x = x + temporal_encoding
380
+ if self.location_encoding:
381
+ location_encoding = self.location_embed_enc(location_coords)
382
+ x = x + location_encoding
383
+
384
+ # masking: length -> length * mask_ratio
385
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
386
+
387
+ # append cls token
388
+ cls_token = self.cls_token + pos_embed[:, :1, :]
389
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
390
+ x = torch.cat((cls_tokens, x), dim=1)
391
+
392
+ # apply Transformer blocks
393
+ for block in self.blocks:
394
+ x = block(x)
395
+ x = self.norm(x)
396
+
397
+ return x, mask, ids_restore
398
+
399
+ def forward_features(
400
+ self,
401
+ x: torch.Tensor,
402
+ temporal_coords: None | torch.Tensor = None,
403
+ location_coords: None | torch.Tensor = None,
404
+ ) -> list[torch.Tensor]:
405
+ if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
406
+ # add time dim
407
+ x = x.unsqueeze(2)
408
+
409
+ if x.shape[-3:] != self.patch_embed.input_size:
410
+ pos_embed = self._get_pos_embed(x)
411
+ else:
412
+ pos_embed = self.pos_embed
413
+
414
+ # embed patches
415
+ x = self.patch_embed(x)
416
+
417
+ # add pos embed w/o cls token
418
+ x = x + pos_embed[:, 1:, :]
419
+
420
+ if self.temporal_encoding:
421
+ num_tokens_per_frame = x.shape[1] // self.patch_embed.num_frames
422
+ temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
423
+ x = x + temporal_encoding
424
+ if self.location_encoding:
425
+ location_encoding = self.location_embed_enc(location_coords)
426
+ x = x + location_encoding
427
+
428
+ # append cls token
429
+ cls_token = self.cls_token + pos_embed[:, :1, :]
430
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
431
+ x = torch.cat((cls_tokens, x), dim=1)
432
+
433
+ # apply Transformer blocks
434
+ out = []
435
+ for block in self.blocks:
436
+ x = block(x)
437
+ out.append(x.clone())
438
+
439
+ x = self.norm(x)
440
+ out[-1] = x
441
+ return out
442
+
443
+ def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
444
+ out = []
445
+ effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0]
446
+ for x in features:
447
+ x_no_token = x[:, 1:, :]
448
+ number_of_tokens = x_no_token.shape[1]
449
+ tokens_per_timestep = number_of_tokens // effective_time_dim
450
+ h = int(np.sqrt(tokens_per_timestep))
451
+ encoded = rearrange(
452
+ x_no_token,
453
+ "batch (t h w) e -> batch (t e) h w",
454
+ e=self.embed_dim,
455
+ t=effective_time_dim,
456
+ h=h,
457
+ )
458
+ out.append(encoded)
459
+ return out
460
+
461
+
462
+ class MAEDecoder(nn.Module):
463
+ """ Transformer Decoder used in the Prithvi MAE"""
464
+ def __init__(self,
465
+ patch_size: int | Tuple[int, int, int] = (1, 16, 16),
466
+ grid_size: List[int] | Tuple[int, int, int] = (3, 14, 14),
467
+ in_chans: int = 3,
468
+ encoder_embed_dim: int = 1024,
469
+ decoder_embed_dim: int = 512,
470
+ depth: int = 8,
471
+ num_heads: int = 16,
472
+ mlp_ratio: float = 4.,
473
+ norm_layer: nn.Module = nn.LayerNorm,
474
+ coords_encoding: List[str] | None = None,
475
+ coords_scale_learn: bool = False,
476
+ ):
477
+ super().__init__()
478
+
479
+ self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
480
+ self.decoder_embed_dim = decoder_embed_dim
481
+ self.grid_size = grid_size
482
+ if isinstance(patch_size, int):
483
+ patch_size = (1, patch_size, patch_size)
484
+ self.patch_size = patch_size
485
+ self.num_frames = self.grid_size[0] * patch_size[0]
486
+ num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
487
+
488
+ # Optional temporal and location embedding
489
+ coords_encoding = coords_encoding or []
490
+ self.temporal_encoding = 'time' in coords_encoding
491
+ self.location_encoding = 'location' in coords_encoding
492
+ if self.temporal_encoding:
493
+ self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn)
494
+ if self.location_encoding:
495
+ self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn)
496
+
497
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
498
+
499
+ self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim))
500
+
501
+ self.decoder_blocks = nn.ModuleList(
502
+ [Block(decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)]
503
+ )
504
+
505
+ self.decoder_norm = norm_layer(decoder_embed_dim)
506
+ self.decoder_pred = nn.Linear(decoder_embed_dim,
507
+ patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
508
+ bias=True)
509
+
510
+ self.initialize_weights()
511
+
512
+ def initialize_weights(self):
513
+ # initialize (and freeze) position embeddings by sin-cos embedding
514
+ decoder_pos_embed = get_3d_sincos_pos_embed(
515
+ self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
516
+ )
517
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
518
+
519
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
520
+ torch.nn.init.normal_(self.mask_token, std=0.02)
521
+ self.apply(_init_weights)
522
+
523
+ def forward(
524
+ self,
525
+ hidden_states: torch.Tensor,
526
+ ids_restore: torch.Tensor,
527
+ temporal_coords: None | torch.Tensor = None,
528
+ location_coords: None | torch.Tensor = None,
529
+ input_size: list[int] = None,
530
+ ):
531
+ # embed tokens
532
+ x = self.decoder_embed(hidden_states)
533
+
534
+ t, h, w = input_size[-3:]
535
+ decoder_pos_embed = torch.from_numpy(
536
+ get_3d_sincos_pos_embed(
537
+ self.decoder_embed_dim,
538
+ (
539
+ t // self.patch_size[0],
540
+ h // self.patch_size[1],
541
+ w // self.patch_size[2],
542
+ ),
543
+ add_cls_token=True,
544
+ )
545
+ ).to(x)
546
+
547
+ # append mask tokens to sequence
548
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
549
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
550
+ # unshuffle
551
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
552
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
553
+ # add pos embed
554
+ x = x + decoder_pos_embed
555
+
556
+ # remove cls token
557
+ x_ = x[:, 1:, :]
558
+
559
+ if self.temporal_encoding:
560
+ num_tokens_per_frame = x_.shape[1] // self.num_frames
561
+ temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame)
562
+ # Add temporal encoding w/o cls token
563
+ x_ = x_ + temporal_encoding
564
+ if self.location_encoding:
565
+ location_encoding = self.location_embed_dec(location_coords)
566
+ # Add location encoding w/o cls token
567
+ x_ = x_ + location_encoding
568
+
569
+ # append cls token
570
+ x = torch.cat([x[:, :1, :], x_], dim=1)
571
+
572
+ # apply Transformer layers (blocks)
573
+ for block in self.decoder_blocks:
574
+ x = block(x)
575
+ x = self.decoder_norm(x)
576
+
577
+ # predictor projection
578
+ pred = self.decoder_pred(x)
579
+
580
+ # remove cls token
581
+ pred = pred[:, 1:, :]
582
+
583
+ return pred
584
+
585
+
586
+ class PrithviMAE(nn.Module):
587
+ """ Prithvi Masked Autoencoder"""
588
+
589
+ def __init__(self,
590
+ img_size: int | Tuple[int, int] = 224,
591
+ patch_size: int | Tuple[int, int, int] = (1, 16, 16),
592
+ num_frames: int = 3,
593
+ in_chans: int = 3,
594
+ embed_dim: int = 1024,
595
+ depth: int = 24,
596
+ num_heads: int = 16,
597
+ decoder_embed_dim: int = 512,
598
+ decoder_depth: int = 8,
599
+ decoder_num_heads: int = 16,
600
+ mlp_ratio: float = 4.,
601
+ norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
602
+ norm_pix_loss: bool = False,
603
+ coords_encoding: List[str] | None = None,
604
+ coords_scale_learn: bool = False,
605
+ encoder_only: bool = False,
606
+ **kwargs,
607
+ ):
608
+ super().__init__()
609
+
610
+ self.encoder = PrithviViT(
611
+ img_size=img_size,
612
+ num_frames=num_frames,
613
+ patch_size=patch_size,
614
+ in_chans=in_chans,
615
+ embed_dim=embed_dim,
616
+ depth=depth,
617
+ num_heads=num_heads,
618
+ mlp_ratio=mlp_ratio,
619
+ norm_layer=norm_layer,
620
+ coords_encoding=coords_encoding,
621
+ coords_scale_learn=coords_scale_learn,
622
+ )
623
+
624
+ self.encoder_only = encoder_only
625
+
626
+ if not encoder_only:
627
+ self.decoder = MAEDecoder(
628
+ patch_size=patch_size,
629
+ grid_size=self.encoder.patch_embed.grid_size,
630
+ in_chans=in_chans,
631
+ encoder_embed_dim=embed_dim,
632
+ decoder_embed_dim=decoder_embed_dim,
633
+ depth=decoder_depth,
634
+ num_heads=decoder_num_heads,
635
+ mlp_ratio=mlp_ratio,
636
+ norm_layer=norm_layer,
637
+ coords_encoding=coords_encoding,
638
+ coords_scale_learn=coords_scale_learn,
639
+ )
640
+ else:
641
+ self.decoder = nn.Identity()
642
+
643
+ self.norm_pix_loss = norm_pix_loss
644
+
645
+ def patchify(self, pixel_values):
646
+ """
647
+ Args:
648
+ pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`):
649
+ Pixel values.
650
+
651
+ Returns:
652
+ torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
653
+ Patchified pixel values.
654
+ """
655
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
656
+ num_channels = self.encoder.in_chans
657
+
658
+ # patchify
659
+ patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)',
660
+ c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w)
661
+
662
+
663
+ return patchified_pixel_values
664
+
665
+ def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None):
666
+ """
667
+ Args:
668
+ patchified_pixel_values (`torch.FloatTensor` of shape
669
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
670
+ Patchified pixel values.
671
+ image_size (`Tuple[int, int]`, *optional*):
672
+ Original image size.
673
+
674
+ Returns:
675
+ `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
676
+ Pixel values.
677
+ """
678
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
679
+ image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size
680
+ original_height, original_width = image_size
681
+ num_patches_h = original_height // patch_size_h
682
+ num_patches_w = original_width // patch_size_w
683
+ num_channels = self.encoder.in_chans
684
+
685
+ pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)',
686
+ c=num_channels, h=num_patches_h, w=num_patches_w,
687
+ s=patch_size_t, p=patch_size_h, q=patch_size_w)
688
+ return pixel_values
689
+
690
+ def forward_loss(self, pixel_values, pred, mask):
691
+ """
692
+ Args:
693
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
694
+ Pixel values.
695
+ pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
696
+ Predicted pixel values.
697
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
698
+ Tensor indicating which patches are masked (1) and which are not (0).
699
+
700
+ Returns:
701
+ `torch.FloatTensor`: Pixel reconstruction loss.
702
+ """
703
+ target = self.patchify(pixel_values)
704
+ if self.norm_pix_loss:
705
+ mean = target.mean(dim=-1, keepdim=True)
706
+ var = target.var(dim=-1, keepdim=True)
707
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
708
+
709
+ loss = (pred - target) ** 2
710
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
711
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
712
+ return loss
713
+
714
+ def forward(
715
+ self,
716
+ pixel_values: torch.Tensor,
717
+ temporal_coords: None | torch.Tensor = None,
718
+ location_coords: None | torch.Tensor = None,
719
+ mask_ratio: float = 0.75
720
+ ):
721
+ if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
722
+ # add time dim
723
+ pixel_values = pixel_values.unsqueeze(2)
724
+
725
+ latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
726
+ pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape)
727
+ loss = self.forward_loss(pixel_values, pred, mask)
728
+ return loss, pred, mask
729
+
730
+ def forward_features(
731
+ self,
732
+ x: torch.Tensor,
733
+ temporal_coords: None | torch.Tensor = None,
734
+ location_coords: None | torch.Tensor = None,
735
+ ) -> List[torch.Tensor]:
736
+ return self.encoder.forward_features(x, temporal_coords, location_coords)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ einops
5
+ rasterio