File size: 4,143 Bytes
af3a445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8de6687
af3a445
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os 
import torch
import lightning as pl
from torchinfo import summary

from lightning.pytorch import loggers as pl_loggers
from functorch.compile import compiled_function,draw_graph
from lightning.pytorch.profilers import PyTorchProfiler
from lightning.pytorch.callbacks import (
    DeviceStatsMonitor,
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
    ModelPruning
)
from lightning.pytorch.callbacks.progress import TQDMProgressBar


from data import  LitMNISTDataModule 
from config import CONFIG
from model import LitMNISTModel
from utils import TRAIN_TRANSFORMS, TEST_TRANSFORMS


# Auxilary utils
torch.set_float32_matmul_precision('high')
torch.cuda.amp.autocast(enabled=True,dtype=torch.float32)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.set_default_device( device= device  )
torch.cuda.empty_cache()
# pl.seed_everything(123, workers=True)


## Loggers
logger:pl_loggers.TensorBoardLogger = pl_loggers.TensorBoardLogger(save_dir='logs/',name= "lightning_logs",log_graph=True) 

## CallBacks
call_backs = [
    TQDMProgressBar(refresh_rate=10),
    ModelCheckpoint(
        monitor="val/loss", dirpath=os.path.join('logs','chkpoints'), filename="{epoch:02d}",save_top_k=1,
    ),
    DeviceStatsMonitor(cpu_stats=True),
    # EarlyStopping(monitor="val/loss",mode='min'),
    LearningRateMonitor(logging_interval='step')
]

## Profilers
perf_dir = os.path.join(os.path.dirname(__file__),'logs','profiler')
perf_profiler =PyTorchProfiler(
        dirpath=perf_dir,
        filename="perf_logs_pytorch",
        group_by_input_shapes=True,
        emit_nvtx=torch.cuda.is_available(),
        activities=(
            [
                torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA,
            ]
            if torch.cuda.is_available()
            else [
                torch.profiler.ProfilerActivity.CPU,
            ]
        ),
        schedule=torch.profiler.schedule(
            wait=1, warmup=1, active=5, repeat=3, skip_first=True
        ),
        profile_memory=True,
        with_stack=True,
        with_flops=True,
        with_modules=True,
        on_trace_ready=torch.profiler.tensorboard_trace_handler(str( os.path.join(perf_dir,'trace')) ),
    )



## MNISTDataModule
dm = LitMNISTDataModule(
                data_dir=CONFIG['data'].get('dir_path','.'),
                batch_size= CONFIG.get('batch_size'),
                num_workers=CONFIG.get('num_workers'),
                test_transform=TEST_TRANSFORMS,
                train_transform=TRAIN_TRANSFORMS
    )
dm.prepare_data()
dm.setup()

## MNISTModel
model =  LitMNISTModel()
# model = LitMNISTModel.load_from_checkpoint(r'C:\Users\muthu\GitHub\Spaces πŸš€\UnSolvedMNIST\logs\chkpoints\epoch=04.ckpt')

# Single BATCH 
batch = next(iter(dm.train_dataloader()))
# Computational graph
model.example_input_array = batch[0]
# CPU Stats
with torch.autograd.profiler.profile() as prof:
    output = model.to(device)(batch[0].to(device))

os.makedirs(name=os.path.join(os.path.dirname(__file__),'logs','profiler'),exist_ok=True)
with open(os.path.join(os.path.dirname(__file__),'logs','profiler',"cpu_throttle.txt"), "w") as text_file:
    text_file.write(f"{prof.key_averages().table(sort_by='self_cpu_time_total',top_level_events_only=False)}")
# Model Summary
summary(
        model,
        input_size=batch[0].shape,
        depth=5,
        verbose=2,
        col_width=16,
        col_names=[
            "input_size",
            "output_size",
            "num_params",
            "kernel_size",
            "mult_adds",
        ],
        row_settings=["var_names"],
    )


## Trainer
trainer  = pl.Trainer(
                max_epochs=CONFIG['training'].get('num_epochs',15),
                logger=logger,
                profiler=perf_profiler,#'advanced',
                callbacks=call_backs,
                precision=32,
                enable_model_summary=False,
                enable_progress_bar=True,
             )

## Training
trainer.fit(model=model,datamodule=dm)
## Validation
trainer.validate(model,datamodule=dm)