Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import lightning as pl | |
from torchinfo import summary | |
from lightning.pytorch import loggers as pl_loggers | |
from functorch.compile import compiled_function,draw_graph | |
from lightning.pytorch.profilers import PyTorchProfiler | |
from lightning.pytorch.callbacks import ( | |
DeviceStatsMonitor, | |
EarlyStopping, | |
LearningRateMonitor, | |
ModelCheckpoint, | |
ModelPruning | |
) | |
from lightning.pytorch.callbacks.progress import TQDMProgressBar | |
from data import LitMNISTDataModule | |
from config import CONFIG | |
from model import LitMNISTModel | |
from utils import TRAIN_TRANSFORMS, TEST_TRANSFORMS | |
# Auxilary utils | |
torch.set_float32_matmul_precision('high') | |
torch.cuda.amp.autocast(enabled=True,dtype=torch.float32) | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
torch.set_default_device( device= device ) | |
torch.cuda.empty_cache() | |
# pl.seed_everything(123, workers=True) | |
## Loggers | |
logger:pl_loggers.TensorBoardLogger = pl_loggers.TensorBoardLogger(save_dir='logs/',name= "lightning_logs",log_graph=True) | |
## CallBacks | |
call_backs = [ | |
TQDMProgressBar(refresh_rate=10), | |
ModelCheckpoint( | |
monitor="val/loss", dirpath=os.path.join('logs','chkpoints'), filename="{epoch:02d}",save_top_k=1, | |
), | |
DeviceStatsMonitor(cpu_stats=True), | |
# EarlyStopping(monitor="val/loss",mode='min'), | |
LearningRateMonitor(logging_interval='step') | |
] | |
## Profilers | |
perf_dir = os.path.join(os.path.dirname(__file__),'logs','profiler') | |
perf_profiler =PyTorchProfiler( | |
dirpath=perf_dir, | |
filename="perf_logs_pytorch", | |
group_by_input_shapes=True, | |
emit_nvtx=torch.cuda.is_available(), | |
activities=( | |
[ | |
torch.profiler.ProfilerActivity.CPU, | |
torch.profiler.ProfilerActivity.CUDA, | |
] | |
if torch.cuda.is_available() | |
else [ | |
torch.profiler.ProfilerActivity.CPU, | |
] | |
), | |
schedule=torch.profiler.schedule( | |
wait=1, warmup=1, active=5, repeat=3, skip_first=True | |
), | |
profile_memory=True, | |
with_stack=True, | |
with_flops=True, | |
with_modules=True, | |
on_trace_ready=torch.profiler.tensorboard_trace_handler(str( os.path.join(perf_dir,'trace')) ), | |
) | |
## MNISTDataModule | |
dm = LitMNISTDataModule( | |
data_dir=CONFIG['data'].get('dir_path','.'), | |
batch_size= CONFIG.get('batch_size'), | |
num_workers=CONFIG.get('num_workers'), | |
test_transform=TEST_TRANSFORMS, | |
train_transform=TRAIN_TRANSFORMS | |
) | |
dm.prepare_data() | |
dm.setup() | |
## MNISTModel | |
model = LitMNISTModel() | |
# model = LitMNISTModel.load_from_checkpoint(r'C:\Users\muthu\GitHub\Spaces π\UnSolvedMNIST\logs\chkpoints\epoch=04.ckpt') | |
# Single BATCH | |
batch = next(iter(dm.train_dataloader())) | |
# Computational graph | |
model.example_input_array = batch[0] | |
# CPU Stats | |
with torch.autograd.profiler.profile() as prof: | |
output = model.to(device)(batch[0].to(device)) | |
os.makedirs(name=os.path.join(os.path.dirname(__file__),'logs','profiler'),exist_ok=True) | |
with open(os.path.join(os.path.dirname(__file__),'logs','profiler',"cpu_throttle.txt"), "w") as text_file: | |
text_file.write(f"{prof.key_averages().table(sort_by='self_cpu_time_total',top_level_events_only=False)}") | |
# Model Summary | |
summary( | |
model, | |
input_size=batch[0].shape, | |
depth=5, | |
verbose=2, | |
col_width=16, | |
col_names=[ | |
"input_size", | |
"output_size", | |
"num_params", | |
"kernel_size", | |
"mult_adds", | |
], | |
row_settings=["var_names"], | |
) | |
## Trainer | |
trainer = pl.Trainer( | |
max_epochs=CONFIG['training'].get('num_epochs',15), | |
logger=logger, | |
profiler=perf_profiler, | |
callbacks=call_backs, | |
precision=32, | |
enable_model_summary=False, | |
enable_progress_bar=True, | |
) | |
## Training | |
trainer.fit(model=model,datamodule=dm) | |
## Validation | |
trainer.validate(model,datamodule=dm) |