File size: 4,078 Bytes
fe8c91b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93b3d62
fe8c91b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
  accelerator: gpu # we can also use auto or cpu
  strategy: auto
  devices: auto
  num_nodes: 1
  logger: True # will use tensorboardlogger

  callbacks:
    - class_path: RichProgressBar
    - class_path: LearningRateMonitor
      init_args:
        logging_interval: epoch
    - class_path: EarlyStopping 
      init_args:
        monitor: val/loss
        patience: 30

  max_epochs: 5
  check_val_every_n_epoch: 1
  log_every_n_steps: 1
  enable_checkpointing: true
  default_root_dir: ./../data/fine_tuning/granite_geospatial_uki_flood_detection_v2
data:
  class_path: GenericNonGeoSegmentationDataModule
  init_args:
    batch_size: 4
    num_workers: 1
    constant_scale: 0.0001
    dataset_bands: # what bands are in your data 
      - VV
      - VH
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
      - CLOUD
    output_bands: # which bands do you want to fine-tune 
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
      - VV
      - VH
      - CLOUD
    rgb_indices:
      - 4
      - 3
      - 2
    train_data_root: ./../data/regions/combined_uki_spain/images/
    train_label_data_root: ./../data/regions/combined_uki_spain/labels/
    val_data_root: ./../data/regions/combined_uki_spain/images/
    val_label_data_root: ./../data/regions/combined_uki_spain/labels/
    test_data_root: ./../data/regions/combined_uki_spain/images/
    test_label_data_root: ./../data/regions/combined_uki_spain/labels/
    train_split: ./../data/regions/combined_uki_spain/splits/flood_train_data.txt
    test_split: ./../data/regions/combined_uki_spain/splits/flood_test_data.txt
    val_split: ./../data/regions/combined_uki_spain/splits/flood_val_data.txt
    img_grep: "*_image.tif"
    label_grep: "*_label.tif"
    no_label_replace: -1
    no_data_replace: 0
    means:
      - 0.1290484133335582      # BLUE
      - 0.13423481405157794    # GREEN
      - 0.1328938801112928      # RED
      - 0.20036851044035797     # NIR_NARROW
      - 0.13804629743141042      # SWIR_1
      - 0.10409700513471637      # SWIR_2
      - -0.0018052691820029847   # VV
      - -0.0023712696527645486  # VH
      - 0.000024014472961425782 #CLOUD

    stds:
      - 0.25406999374272976
      - 0.22949378991348005
      - 0.21689414406289836
      - 0.22552362238920548
      - 0.1600542128720416
      - 0.12602917719190815
      - 0.0011294842635096356
      - 0.0008879269711519241
      - 0.00004271712050839232

    num_classes: 2

model:
  class_path: terratorch.tasks.SemanticSegmentationTask 
  init_args:
    model_args:
      decoder: FCNDecoder
      backbone_pretrained: true
      backbone: granite_geospatial_uki
      backbone_pretrained_cfg_overlay:
        file: ./../data/checkpoints/granite_geospatial_uki.pt
      backbone_pretrain_img_size: 512
      decoder_channels: 256
      # in_channels: 9
      backbone_bands:
        - BLUE
        - GREEN
        - RED
        - NIR_NARROW
        - SWIR_1
        - SWIR_2
        - VV
        - VH
        - CLOUD 
      # num_frames: 1
      num_classes: 2
      head_dropout: 0.1
      decoder_num_convs: 4
      head_channel_list:
        - 256
      necks:
        - name: SelectIndices
          indices:
            - -1
        - name: ReshapeTokensToImage
    loss: ce
    aux_heads:
      - name: aux_head
        decoder: FCNDecoder
        decoder_args:
          decoder_channels: 256
          decoder_in_index: -1
          decoder_num_convs: 2
          head_dropout: 0.1
    aux_loss:
      aux_head: 1.0
    ignore_index: -1
    class_weights:
      - 0.3
      - 0.7
    freeze_backbone: false
    freeze_decoder: false
    model_factory: EncoderDecoderFactory
    tiled_inference_parameters:
      h_crop: 512
      h_stride: 496
      w_crop: 512
      w_stride: 496
      average_patches: true
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 6.e-5
    weight_decay: 0.05
lr_scheduler:
  class_path: ReduceLROnPlateau
  init_args:
    monitor: val/loss