File size: 7,087 Bytes
f7f604d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import os
import cv2
import sys
import tqdm
import torch
import datetime
import torch.nn as nn
import torch.distributed as dist
import torch.cuda as cuda
from torch.utils.data.dataloader import DataLoader
from torch.optim import Adam, SGD
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp.grad_scaler import GradScaler
from torch.cuda.amp.autocast_mode import autocast
filepath = os.path.split(os.path.abspath(__file__))[0]
repopath = os.path.split(filepath)[0]
sys.path.append(repopath)
from lib import *
from lib.optim import *
from data.dataloader import *
from utils.misc import *
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
def train(opt, args):
train_dataset = eval(opt.Train.Dataset.type)(
root=opt.Train.Dataset.root,
sets=opt.Train.Dataset.sets,
tfs=opt.Train.Dataset.transforms)
if args.device_num > 1:
cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', rank=args.local_rank, world_size=args.device_num, timeout=datetime.timedelta(seconds=3600))
train_sampler = DistributedSampler(train_dataset, shuffle=True)
else:
train_sampler = None
train_loader = DataLoader(dataset=train_dataset,
batch_size=opt.Train.Dataloader.batch_size,
shuffle=train_sampler is None,
sampler=train_sampler,
num_workers=opt.Train.Dataloader.num_workers,
pin_memory=opt.Train.Dataloader.pin_memory,
drop_last=True)
model_ckpt = None
state_ckpt = None
if args.resume is True:
if os.path.isfile(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth')):
model_ckpt = torch.load(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth'), map_location='cpu')
if args.local_rank <= 0:
print('Resume from checkpoint')
if os.path.isfile(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'state.pth')):
state_ckpt = torch.load(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'state.pth'), map_location='cpu')
if args.local_rank <= 0:
print('Resume from state')
model = eval(opt.Model.name)(**opt.Model)
if model_ckpt is not None:
model.load_state_dict(model_ckpt)
if args.device_num > 1:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = model.cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
else:
model = model.cuda()
backbone_params = nn.ParameterList()
decoder_params = nn.ParameterList()
for name, param in model.named_parameters():
if 'backbone' in name:
backbone_params.append(param)
else:
decoder_params.append(param)
params_list = [{'params': backbone_params}, {
'params': decoder_params, 'lr': opt.Train.Optimizer.lr * 10}]
optimizer = eval(opt.Train.Optimizer.type)(
params_list, opt.Train.Optimizer.lr, weight_decay=opt.Train.Optimizer.weight_decay)
if state_ckpt is not None:
optimizer.load_state_dict(state_ckpt['optimizer'])
if opt.Train.Optimizer.mixed_precision is True:
scaler = GradScaler()
else:
scaler = None
scheduler = eval(opt.Train.Scheduler.type)(optimizer, gamma=opt.Train.Scheduler.gamma,
minimum_lr=opt.Train.Scheduler.minimum_lr,
max_iteration=len(train_loader) * opt.Train.Scheduler.epoch,
warmup_iteration=opt.Train.Scheduler.warmup_iteration)
if state_ckpt is not None:
scheduler.load_state_dict(state_ckpt['scheduler'])
model.train()
start = 1
if state_ckpt is not None:
start = state_ckpt['epoch']
epoch_iter = range(start, opt.Train.Scheduler.epoch + 1)
if args.local_rank <= 0 and args.verbose is True:
epoch_iter = tqdm.tqdm(epoch_iter, desc='Epoch', total=opt.Train.Scheduler.epoch, initial=start - 1,
position=0, bar_format='{desc:<5.5}{percentage:3.0f}%|{bar:40}{r_bar}')
for epoch in epoch_iter:
if args.local_rank <= 0 and args.verbose is True:
step_iter = tqdm.tqdm(enumerate(train_loader, start=1), desc='Iter', total=len(
train_loader), position=1, leave=False, bar_format='{desc:<5.5}{percentage:3.0f}%|{bar:40}{r_bar}')
if args.device_num > 1 and train_sampler is not None:
train_sampler.set_epoch(epoch)
else:
step_iter = enumerate(train_loader, start=1)
for i, sample in step_iter:
optimizer.zero_grad()
if opt.Train.Optimizer.mixed_precision is True and scaler is not None:
with autocast():
sample = to_cuda(sample)
out = model(sample)
scaler.scale(out['loss']).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
else:
sample = to_cuda(sample)
out = model(sample)
out['loss'].backward()
optimizer.step()
scheduler.step()
if args.local_rank <= 0 and args.verbose is True:
step_iter.set_postfix({'loss': out['loss'].item()})
if args.local_rank <= 0:
os.makedirs(opt.Train.Checkpoint.checkpoint_dir, exist_ok=True)
os.makedirs(os.path.join(
opt.Train.Checkpoint.checkpoint_dir, 'debug'), exist_ok=True)
if epoch % opt.Train.Checkpoint.checkpoint_epoch == 0:
if args.device_num > 1:
model_ckpt = model.module.state_dict()
else:
model_ckpt = model.state_dict()
state_ckpt = {'epoch': epoch + 1,
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict()}
torch.save(model_ckpt, os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth'))
torch.save(state_ckpt, os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'state.pth'))
if args.debug is True:
debout = debug_tile(sum([out[k] for k in opt.Train.Debug.keys], []), activation=torch.sigmoid)
cv2.imwrite(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'debug', str(epoch) + '.png'), debout)
if args.local_rank <= 0:
torch.save(model.module.state_dict() if args.device_num > 1 else model.state_dict(),
os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth'))
if __name__ == '__main__':
args = parse_args()
opt = load_config(args.config)
train(opt, args)
|