Spaces:
Sleeping
Sleeping
File size: 4,143 Bytes
af3a445 8de6687 af3a445 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
import torch
import lightning as pl
from torchinfo import summary
from lightning.pytorch import loggers as pl_loggers
from functorch.compile import compiled_function,draw_graph
from lightning.pytorch.profilers import PyTorchProfiler
from lightning.pytorch.callbacks import (
DeviceStatsMonitor,
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
ModelPruning
)
from lightning.pytorch.callbacks.progress import TQDMProgressBar
from data import LitMNISTDataModule
from config import CONFIG
from model import LitMNISTModel
from utils import TRAIN_TRANSFORMS, TEST_TRANSFORMS
# Auxilary utils
torch.set_float32_matmul_precision('high')
torch.cuda.amp.autocast(enabled=True,dtype=torch.float32)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.set_default_device( device= device )
torch.cuda.empty_cache()
# pl.seed_everything(123, workers=True)
## Loggers
logger:pl_loggers.TensorBoardLogger = pl_loggers.TensorBoardLogger(save_dir='logs/',name= "lightning_logs",log_graph=True)
## CallBacks
call_backs = [
TQDMProgressBar(refresh_rate=10),
ModelCheckpoint(
monitor="val/loss", dirpath=os.path.join('logs','chkpoints'), filename="{epoch:02d}",save_top_k=1,
),
DeviceStatsMonitor(cpu_stats=True),
# EarlyStopping(monitor="val/loss",mode='min'),
LearningRateMonitor(logging_interval='step')
]
## Profilers
perf_dir = os.path.join(os.path.dirname(__file__),'logs','profiler')
perf_profiler =PyTorchProfiler(
dirpath=perf_dir,
filename="perf_logs_pytorch",
group_by_input_shapes=True,
emit_nvtx=torch.cuda.is_available(),
activities=(
[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
if torch.cuda.is_available()
else [
torch.profiler.ProfilerActivity.CPU,
]
),
schedule=torch.profiler.schedule(
wait=1, warmup=1, active=5, repeat=3, skip_first=True
),
profile_memory=True,
with_stack=True,
with_flops=True,
with_modules=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(str( os.path.join(perf_dir,'trace')) ),
)
## MNISTDataModule
dm = LitMNISTDataModule(
data_dir=CONFIG['data'].get('dir_path','.'),
batch_size= CONFIG.get('batch_size'),
num_workers=CONFIG.get('num_workers'),
test_transform=TEST_TRANSFORMS,
train_transform=TRAIN_TRANSFORMS
)
dm.prepare_data()
dm.setup()
## MNISTModel
model = LitMNISTModel()
# model = LitMNISTModel.load_from_checkpoint(r'C:\Users\muthu\GitHub\Spaces π\UnSolvedMNIST\logs\chkpoints\epoch=04.ckpt')
# Single BATCH
batch = next(iter(dm.train_dataloader()))
# Computational graph
model.example_input_array = batch[0]
# CPU Stats
with torch.autograd.profiler.profile() as prof:
output = model.to(device)(batch[0].to(device))
os.makedirs(name=os.path.join(os.path.dirname(__file__),'logs','profiler'),exist_ok=True)
with open(os.path.join(os.path.dirname(__file__),'logs','profiler',"cpu_throttle.txt"), "w") as text_file:
text_file.write(f"{prof.key_averages().table(sort_by='self_cpu_time_total',top_level_events_only=False)}")
# Model Summary
summary(
model,
input_size=batch[0].shape,
depth=5,
verbose=2,
col_width=16,
col_names=[
"input_size",
"output_size",
"num_params",
"kernel_size",
"mult_adds",
],
row_settings=["var_names"],
)
## Trainer
trainer = pl.Trainer(
max_epochs=CONFIG['training'].get('num_epochs',15),
logger=logger,
profiler=perf_profiler,#'advanced',
callbacks=call_backs,
precision=32,
enable_model_summary=False,
enable_progress_bar=True,
)
## Training
trainer.fit(model=model,datamodule=dm)
## Validation
trainer.validate(model,datamodule=dm) |