|
from pytorch_lightning.utilities.cli import LightningCLI |
|
|
|
import torch |
|
|
|
|
|
class TerraGPULightningCLI(LightningCLI): |
|
|
|
def add_arguments_to_parser(self, parser): |
|
|
|
|
|
parser.set_defaults({"trainer.accelerator": "auto"}) |
|
parser.set_defaults({"trainer.devices": "auto"}) |
|
parser.set_defaults({"trainer.auto_select_gpus": True}) |
|
parser.set_defaults({"trainer.precision": 32}) |
|
|
|
|
|
parser.set_defaults({"trainer.max_epochs": 500}) |
|
parser.set_defaults({"trainer.min_epochs": 1}) |
|
parser.set_defaults({"trainer.detect_anomaly": True}) |
|
parser.set_defaults({"trainer.logger": True}) |
|
parser.set_defaults({"trainer.default_root_dir": "output_model"}) |
|
|
|
|
|
_ = { |
|
"class_path": torch.optim.Adam, |
|
"init_args": { |
|
"lr": 0.01 |
|
} |
|
} |
|
|
|
|
|
default_callbacks = [ |
|
{"class_path": "pytorch_lightning.callbacks.DeviceStatsMonitor"}, |
|
{ |
|
"class_path": "pytorch_lightning.callbacks.EarlyStopping", |
|
"init_args": { |
|
"monitor": "val_loss", |
|
"patience": 5, |
|
"mode": "min" |
|
} |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
parser.set_defaults({"trainer.callbacks": default_callbacks}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|