File size: 2,120 Bytes
ab687e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from pytorch_lightning.utilities.cli import LightningCLI
import torch
class TerraGPULightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
# Trainer - performance
parser.set_defaults({"trainer.accelerator": "auto"})
parser.set_defaults({"trainer.devices": "auto"})
parser.set_defaults({"trainer.auto_select_gpus": True})
parser.set_defaults({"trainer.precision": 32})
# Trainer - training
parser.set_defaults({"trainer.max_epochs": 500})
parser.set_defaults({"trainer.min_epochs": 1})
parser.set_defaults({"trainer.detect_anomaly": True})
parser.set_defaults({"trainer.logger": True})
parser.set_defaults({"trainer.default_root_dir": "output_model"})
# Trainer - optimizer - TODO
_ = {
"class_path": torch.optim.Adam,
"init_args": {
"lr": 0.01
}
}
# Trainer - callbacks
default_callbacks = [
{"class_path": "pytorch_lightning.callbacks.DeviceStatsMonitor"},
{
"class_path": "pytorch_lightning.callbacks.EarlyStopping",
"init_args": {
"monitor": "val_loss",
"patience": 5,
"mode": "min"
}
},
# {
# "class_path": "pytorch_lightning.callbacks.ModelCheckpoint",
# "init_args": {
# "dirpath": "output_model",
# "monitor": "val_loss",
# "auto_insert_metric_name": True
# }
# },
]
parser.set_defaults({"trainer.callbacks": default_callbacks})
# {
# "class_path": "pytorch_lightning.callbacks.ModelCheckpoint",
# "init_args": {
# "dirpath": "output_model",
# "monitor": "val_loss",
# "auto_insert_metric_name": True
# }
# },
# ]
# parser.set_defaults({"trainer.callbacks": default_callbacks})
|