# -*- coding: utf-8 -*- | |
# RF pipeline: preprocess, train, and predict. | |
import sys | |
import logging | |
# from terragpu import unet_model | |
# from terragpu.decorators import DuplicateFilter | |
# from terragpu.ai.deep_learning.datamodules.segmentation_datamodule \ | |
# import SegmentationDataModule | |
from pytorch_lightning import seed_everything # , trainer | |
# from pytorch_lightning import LightningModule, LightningDataModule | |
from terragpu.ai.deep_learning.console.cli import TerraGPULightningCLI | |
# ----------------------------------------------------------------------------- | |
# main | |
# | |
# python rf_pipeline.py options here | |
# ----------------------------------------------------------------------------- | |
def main(): | |
# ------------------------------------------------------------------------- | |
# Set logging | |
# ------------------------------------------------------------------------- | |
logger = logging.getLogger() | |
logger.setLevel(logging.INFO) | |
ch = logging.StreamHandler(sys.stdout) | |
ch.setLevel(logging.INFO) | |
# Set formatter and handlers | |
formatter = logging.Formatter( | |
"%(asctime)s; %(levelname)s; %(message)s", "%Y-%m-%d %H:%M:%S") | |
ch.setFormatter(formatter) | |
logger.addHandler(ch) | |
# ------------------------------------------------------------------------- | |
# Execute pipeline step | |
# ------------------------------------------------------------------------- | |
# Seed every library | |
seed_everything(1234, workers=True) | |
_ = TerraGPULightningCLI(save_config_callback=None) | |
# unet_model.UNetSegmentation, SegmentationDataModule) | |
# train | |
# trainer = pl.Trainer() | |
# trainer.fit(model, datamodule=dm) | |
# validate | |
# trainer.validate(datamodule=dm) | |
# test | |
# trainer.test(datamodule=dm) | |
# predict | |
# predictions = trainer.predict(datamodule=dm) | |
return | |
# ----------------------------------------------------------------------------- | |
# Invoke the main | |
# ----------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
sys.exit(main()) | |