UnsolvedMNIST / train.py
Muthukamalan's picture
src file added
af3a445
raw
history blame
4.13 kB
import os
import torch
import lightning as pl
from torchinfo import summary
from lightning.pytorch import loggers as pl_loggers
from functorch.compile import compiled_function,draw_graph
from lightning.pytorch.profilers import PyTorchProfiler
from lightning.pytorch.callbacks import (
DeviceStatsMonitor,
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
ModelPruning
)
from lightning.pytorch.callbacks.progress import TQDMProgressBar
from data import LitMNISTDataModule
from config import CONFIG
from model import LitMNISTModel
from utils import TRAIN_TRANSFORMS, TEST_TRANSFORMS
# Auxilary utils
torch.set_float32_matmul_precision('high')
torch.cuda.amp.autocast(enabled=True,dtype=torch.float32)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.set_default_device( device= device )
torch.cuda.empty_cache()
# pl.seed_everything(123, workers=True)
## Loggers
logger:pl_loggers.TensorBoardLogger = pl_loggers.TensorBoardLogger(save_dir='logs/',name= "lightning_logs",log_graph=True)
## CallBacks
call_backs = [
TQDMProgressBar(refresh_rate=10),
ModelCheckpoint(
monitor="val/loss", dirpath=os.path.join('logs','chkpoints'), filename="{epoch:02d}",save_top_k=1,
),
DeviceStatsMonitor(cpu_stats=True),
# EarlyStopping(monitor="val/loss",mode='min'),
LearningRateMonitor(logging_interval='step')
]
## Profilers
perf_dir = os.path.join(os.path.dirname(__file__),'logs','profiler')
perf_profiler =PyTorchProfiler(
dirpath=perf_dir,
filename="perf_logs_pytorch",
group_by_input_shapes=True,
emit_nvtx=torch.cuda.is_available(),
activities=(
[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
if torch.cuda.is_available()
else [
torch.profiler.ProfilerActivity.CPU,
]
),
schedule=torch.profiler.schedule(
wait=1, warmup=1, active=5, repeat=3, skip_first=True
),
profile_memory=True,
with_stack=True,
with_flops=True,
with_modules=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(str( os.path.join(perf_dir,'trace')) ),
)
## MNISTDataModule
dm = LitMNISTDataModule(
data_dir=CONFIG['data'].get('dir_path','.'),
batch_size= CONFIG.get('batch_size'),
num_workers=CONFIG.get('num_workers'),
test_transform=TEST_TRANSFORMS,
train_transform=TRAIN_TRANSFORMS
)
dm.prepare_data()
dm.setup()
## MNISTModel
model = LitMNISTModel()
# model = LitMNISTModel.load_from_checkpoint(r'C:\Users\muthu\GitHub\Spaces πŸš€\UnSolvedMNIST\logs\chkpoints\epoch=04.ckpt')
# Single BATCH
batch = next(iter(dm.train_dataloader()))
# Computational graph
model.example_input_array = batch[0]
# CPU Stats
with torch.autograd.profiler.profile() as prof:
output = model.to(device)(batch[0].to(device))
os.makedirs(name=os.path.join(os.path.dirname(__file__),'logs','profiler'),exist_ok=True)
with open(os.path.join(os.path.dirname(__file__),'logs','profiler',"cpu_throttle.txt"), "w") as text_file:
text_file.write(f"{prof.key_averages().table(sort_by='self_cpu_time_total',top_level_events_only=False)}")
# Model Summary
summary(
model,
input_size=batch[0].shape,
depth=5,
verbose=2,
col_width=16,
col_names=[
"input_size",
"output_size",
"num_params",
"kernel_size",
"mult_adds",
],
row_settings=["var_names"],
)
## Trainer
trainer = pl.Trainer(
max_epochs=CONFIG['training'].get('num_epochs',15),
logger=logger,
profiler=perf_profiler,
callbacks=call_backs,
precision=32,
enable_model_summary=False,
enable_progress_bar=True,
)
## Training
trainer.fit(model=model,datamodule=dm)
## Validation
trainer.validate(model,datamodule=dm)