File size: 2,076 Bytes
ab687e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
# -*- coding: utf-8 -*-
# RF pipeline: preprocess, train, and predict.
import sys
import logging
# from terragpu import unet_model
# from terragpu.decorators import DuplicateFilter
# from terragpu.ai.deep_learning.datamodules.segmentation_datamodule \
# import SegmentationDataModule
from pytorch_lightning import seed_everything # , trainer
# from pytorch_lightning import LightningModule, LightningDataModule
from terragpu.ai.deep_learning.console.cli import TerraGPULightningCLI
# -----------------------------------------------------------------------------
# main
#
# python rf_pipeline.py options here
# -----------------------------------------------------------------------------
def main():
# -------------------------------------------------------------------------
# Set logging
# -------------------------------------------------------------------------
logger = logging.getLogger()
logger.setLevel(logging.INFO)
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.INFO)
# Set formatter and handlers
formatter = logging.Formatter(
"%(asctime)s; %(levelname)s; %(message)s", "%Y-%m-%d %H:%M:%S")
ch.setFormatter(formatter)
logger.addHandler(ch)
# -------------------------------------------------------------------------
# Execute pipeline step
# -------------------------------------------------------------------------
# Seed every library
seed_everything(1234, workers=True)
_ = TerraGPULightningCLI(save_config_callback=None)
# unet_model.UNetSegmentation, SegmentationDataModule)
# train
# trainer = pl.Trainer()
# trainer.fit(model, datamodule=dm)
# validate
# trainer.validate(datamodule=dm)
# test
# trainer.test(datamodule=dm)
# predict
# predictions = trainer.predict(datamodule=dm)
return
# -----------------------------------------------------------------------------
# Invoke the main
# -----------------------------------------------------------------------------
if __name__ == "__main__":
sys.exit(main())
|