import os import torch import lightning as pl from torchinfo import summary from lightning.pytorch import loggers as pl_loggers from functorch.compile import compiled_function,draw_graph from lightning.pytorch.profilers import PyTorchProfiler from lightning.pytorch.callbacks import ( DeviceStatsMonitor, EarlyStopping, LearningRateMonitor, ModelCheckpoint, ModelPruning ) from lightning.pytorch.callbacks.progress import TQDMProgressBar from data import LitMNISTDataModule from config import CONFIG from model import LitMNISTModel from utils import TRAIN_TRANSFORMS, TEST_TRANSFORMS # Auxilary utils torch.backends.cuda.matmul.allow_tf32=True torch.set_float32_matmul_precision('high') torch.cuda.amp.autocast(enabled=True,dtype=torch.float32) device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') torch.set_default_device( device= device ) torch.cuda.empty_cache() # pl.seed_everything(123, workers=True) ## Loggers logger:pl_loggers.TensorBoardLogger = pl_loggers.TensorBoardLogger(save_dir='logs/',name= "lightning_logs",log_graph=True) ## CallBacks call_backs = [ TQDMProgressBar(refresh_rate=10), ModelCheckpoint( monitor="val/loss", dirpath=os.path.join('logs','chkpoints'), filename="{epoch:02d}",save_top_k=1, ), DeviceStatsMonitor(cpu_stats=True), # EarlyStopping(monitor="val/loss",mode='min'), LearningRateMonitor(logging_interval='step') ] ## Profilers perf_dir = os.path.join(os.path.dirname(__file__),'logs','profiler') perf_profiler =PyTorchProfiler( dirpath=perf_dir, filename="perf_logs_pytorch", group_by_input_shapes=True, emit_nvtx=torch.cuda.is_available(), activities=( [ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ] if torch.cuda.is_available() else [ torch.profiler.ProfilerActivity.CPU, ] ), schedule=torch.profiler.schedule( wait=1, warmup=1, active=5, repeat=3, skip_first=True ), profile_memory=True, with_stack=True, with_flops=True, with_modules=True, on_trace_ready=torch.profiler.tensorboard_trace_handler(str( os.path.join(perf_dir,'trace')) ), ) ## MNISTDataModule dm = LitMNISTDataModule( data_dir=CONFIG['data'].get('dir_path','.'), batch_size= CONFIG.get('batch_size'), num_workers=CONFIG.get('num_workers'), test_transform=TEST_TRANSFORMS, train_transform=TRAIN_TRANSFORMS ) dm.prepare_data() dm.setup() ## MNISTModel model = LitMNISTModel() # model = LitMNISTModel.load_from_checkpoint(r'C:\Users\muthu\GitHub\Spaces 🚀\UnSolvedMNIST\logs\chkpoints\epoch=04.ckpt') # Single BATCH batch = next(iter(dm.train_dataloader())) # Computational graph model.example_input_array = batch[0] # CPU Stats with torch.autograd.profiler.profile() as prof: output = model.to(device)(batch[0].to(device)) os.makedirs(name=os.path.join(os.path.dirname(__file__),'logs','profiler'),exist_ok=True) with open(os.path.join(os.path.dirname(__file__),'logs','profiler',"cpu_throttle.txt"), "w") as text_file: text_file.write(f"{prof.key_averages().table(sort_by='self_cpu_time_total',top_level_events_only=False)}") # Model Summary summary( model, input_size=batch[0].shape, depth=5, verbose=2, col_width=16, col_names=[ "input_size", "output_size", "num_params", "kernel_size", "mult_adds", ], row_settings=["var_names"], ) ## Trainer trainer = pl.Trainer( max_epochs=CONFIG['training'].get('num_epochs',15), logger=logger, profiler='pytorch',#perf_profiler,#'advanced', callbacks=call_backs, precision=32, enable_model_summary=False, enable_progress_bar=True, ) ## Training trainer.fit(model=model,datamodule=dm) ## Validation trainer.validate(model,datamodule=dm)