| # lightning.pytorch==2.4.0 | |
| seed_everything: 42 | |
| ### Trainer configuration | |
| trainer: | |
| accelerator: auto | |
| strategy: auto | |
| devices: auto | |
| num_nodes: 1 | |
| # precision: 16-mixed | |
| logger: | |
| class_path: TensorBoardLogger | |
| init_args: | |
| save_dir: ../experiments | |
| name: finetune_region | |
| callbacks: | |
| - class_path: RichProgressBar | |
| - class_path: LearningRateMonitor | |
| init_args: | |
| logging_interval: epoch | |
| - class_path: EarlyStopping | |
| init_args: | |
| monitor: val/loss | |
| patience: 100 | |
| max_epochs: 300 | |
| check_val_every_n_epoch: 1 | |
| log_every_n_steps: 20 | |
| enable_checkpointing: true | |
| default_root_dir: ./experiments | |
| ### Data configuration | |
| data: | |
| class_path: GenericNonGeoPixelwiseRegressionDataModule | |
| init_args: | |
| batch_size: 64 | |
| num_workers: 8 | |
| train_transform: | |
| - class_path: albumentations.HorizontalFlip | |
| init_args: | |
| p: 0.5 | |
| - class_path: albumentations.RandomRotate90 | |
| init_args: | |
| p: 0.5 | |
| - class_path: albumentations.VerticalFlip | |
| init_args: | |
| p: 0.5 | |
| - class_path: ToTensorV2 | |
| # Specify all bands which are in the input data. | |
| # -1 are placeholders for bands that are in the data but that we will discard | |
| dataset_bands: | |
| - -1 | |
| - BLUE | |
| - GREEN | |
| - RED | |
| - NIR_NARROW | |
| - SWIR_1 | |
| - SWIR_2 | |
| - -1 | |
| - -1 | |
| - -1 | |
| - -1 | |
| output_bands: #Specify the bands which are used from the input data. | |
| - BLUE | |
| - GREEN | |
| - RED | |
| - NIR_NARROW | |
| - SWIR_1 | |
| - SWIR_2 | |
| rgb_indices: | |
| - 2 | |
| - 1 | |
| - 0 | |
| # Directory roots to training, validation and test datasplits: | |
| train_data_root: train_images | |
| train_label_data_root: train_labels | |
| val_data_root: val_images | |
| val_label_data_root: val_labels | |
| test_data_root: test_images | |
| test_label_data_root: test_labels | |
| means: # Mean value of the training dataset per band | |
| - 556.025024 | |
| - 910.020020 | |
| - 1039.141968 | |
| - 2665.447266 | |
| - 2361.062256 | |
| - 1633.309326 | |
| stds: # Standard deviation of the training dataset per band | |
| - 413.787903 | |
| - 562.086670 | |
| - 819.830444 | |
| - 816.528381 | |
| - 1120.049438 | |
| - 1072.057861 | |
| # Nodata value in label data | |
| no_label_replace: -1 | |
| # Nodata value in the input data | |
| no_data_replace: 0 | |
| ### Model configuration | |
| model: | |
| class_path: terratorch.tasks.PixelwiseRegressionTask | |
| init_args: | |
| model_args: | |
| decoder: UperNetDecoder | |
| pretrained: false | |
| backbone: prithvi_swin_B | |
| backbone_drop_path_rate: 0.3 | |
| decoder_channels: 32 | |
| in_channels: 6 | |
| bands: | |
| - BLUE | |
| - GREEN | |
| - RED | |
| - NIR_NARROW | |
| - SWIR_1 | |
| - SWIR_2 | |
| num_frames: 1 | |
| head_dropout: 0.16 | |
| head_final_act: torch.nn.ReLU | |
| head_learned_upscale_layers: 2 | |
| loss: rmse | |
| ignore_index: -1 | |
| freeze_backbone: false | |
| freeze_decoder: false | |
| model_factory: PrithviModelFactory | |
| # uncomment this block for tiled inference | |
| # tiled_inference_parameters: | |
| # h_crop: 224 | |
| # h_stride: 192 | |
| # w_crop: 224 | |
| # w_stride: 192 | |
| # average_patches: true | |
| optimizer: | |
| class_path: torch.optim.AdamW | |
| init_args: | |
| lr: 5.0e-05 | |
| weight_decay: 0.3 | |
| lr_scheduler: | |
| class_path: ReduceLROnPlateau | |
| init_args: | |
| monitor: val/loss | |
| out_dtype: float32 |