# lightning.pytorch==2.1.1 | |
seed_everything: 0 | |
trainer: | |
accelerator: cpu | |
strategy: auto | |
devices: auto | |
num_nodes: 1 | |
logger: True # will use tensorboardlogger | |
callbacks: | |
- class_path: LearningRateMonitor | |
init_args: | |
logging_interval: epoch | |
- class_path: EarlyStopping | |
init_args: | |
monitor: val/loss | |
patience: 30 | |
max_epochs: 200 | |
check_val_every_n_epoch: 1 | |
log_every_n_steps: 1 | |
enable_checkpointing: true | |
default_root_dir: ./../data/fine_tuning/granite_geospatial_uki_flood_detection_v1 | |
data: | |
class_path: GenericNonGeoSegmentationDataModule | |
init_args: | |
batch_size: 16 | |
num_workers: 1 | |
constant_scale: 0.0001 | |
dataset_bands: # what bands are in your data | |
- VV | |
- VH | |
- BLUE | |
- GREEN | |
- RED | |
- NIR_NARROW | |
- SWIR_1 | |
- SWIR_2 | |
- CLOUD | |
output_bands: # which bands do you want to fine-tune | |
- BLUE | |
- GREEN | |
- RED | |
- NIR_NARROW | |
- SWIR_1 | |
- SWIR_2 | |
- VV | |
- VH | |
- CLOUD | |
rgb_indices: | |
- 4 | |
- 3 | |
- 2 | |
train_data_root: ./../data/regions/uki/images/ | |
train_label_data_root: ./../data/regions/uki/labels_without_cloud/ | |
val_data_root: ./../data/regions/uki/images/ | |
val_label_data_root: ./../data/regions/uki/labels_without_cloud/ | |
test_data_root: ./../data/regions/uki/images/ | |
test_label_data_root: ./../data/regions/uki/labels_without_cloud/ | |
train_split: ./../data/regions/uki/splits/flood_train_data.txt | |
test_split: ./../data/regions/uki/splits/flood_test_data.txt | |
val_split: ./../data/regions/uki/splits/flood_val_data.txt | |
img_grep: "*_image.tif" | |
label_grep: "*_label.tif" | |
no_label_replace: -1 | |
no_data_replace: 0 | |
means: | |
- 0.08867253281911215 # BLUE | |
- 0.09101736325581869 # GREEN | |
- 0.08757093732833862 # RED | |
- 0.1670982579167684 # NIR_NARROW | |
- 0.09420119639078776 # SWIR_1 | |
- 0.07141083437601725 # SWIR_2 | |
- -0.0017641318140774339 # VV | |
- -0.002356150351719506 # VH | |
- 0.00002777560551961263 # CLOUD | |
stds: | |
- 0.13656951175974685 | |
- 0.13202436625655786 | |
- 0.1307223895526036 | |
- 0.18946390520629108 | |
- 0.11561659013865118 | |
- 0.09351007561544347 | |
- 0.001035692652952644 | |
- 0.000864295592912648 | |
- 0.00004478924301636066 | |
num_classes: 2 | |
model: | |
class_path: terratorch.tasks.SemanticSegmentationTask | |
init_args: | |
model_args: | |
decoder: FCNDecoder | |
backbone_pretrained: false | |
backbone: granite_geospatial_uki | |
backbone_pretrain_img_size: 512 | |
decoder_channels: 256 | |
backbone_bands: | |
- BLUE | |
- GREEN | |
- RED | |
- NIR_NARROW | |
- SWIR_1 | |
- SWIR_2 | |
- VV | |
- VH | |
- CLOUD | |
num_classes: 2 | |
head_dropout: 0.1 | |
decoder_num_convs: 4 | |
head_channel_list: | |
- 256 | |
necks: | |
- name: SelectIndices | |
indices: | |
- -1 | |
- name: ReshapeTokensToImage | |
loss: ce | |
aux_heads: | |
- name: aux_head | |
decoder: FCNDecoder | |
decoder_args: | |
decoder_channels: 256 | |
decoder_in_index: -1 | |
decoder_num_convs: 2 | |
head_dropout: 0.1 | |
aux_loss: | |
aux_head: 1.0 | |
ignore_index: -1 | |
class_weights: | |
- 0.3 | |
- 0.7 | |
freeze_backbone: false | |
freeze_decoder: false | |
model_factory: EncoderDecoderFactory | |
optimizer: | |
class_path: torch.optim.AdamW | |
init_args: | |
lr: 6.e-5 | |
weight_decay: 0.05 | |
lr_scheduler: | |
class_path: ReduceLROnPlateau | |
init_args: | |
monitor: val/loss |