my-cool-model / run /Train.py
crapthings's picture
Upload folder using huggingface_hub
f7f604d
raw
history blame
7.09 kB
import os
import cv2
import sys
import tqdm
import torch
import datetime
import torch.nn as nn
import torch.distributed as dist
import torch.cuda as cuda
from torch.utils.data.dataloader import DataLoader
from torch.optim import Adam, SGD
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp.grad_scaler import GradScaler
from torch.cuda.amp.autocast_mode import autocast
filepath = os.path.split(os.path.abspath(__file__))[0]
repopath = os.path.split(filepath)[0]
sys.path.append(repopath)
from lib import *
from lib.optim import *
from data.dataloader import *
from utils.misc import *
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
def train(opt, args):
train_dataset = eval(opt.Train.Dataset.type)(
root=opt.Train.Dataset.root,
sets=opt.Train.Dataset.sets,
tfs=opt.Train.Dataset.transforms)
if args.device_num > 1:
cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', rank=args.local_rank, world_size=args.device_num, timeout=datetime.timedelta(seconds=3600))
train_sampler = DistributedSampler(train_dataset, shuffle=True)
else:
train_sampler = None
train_loader = DataLoader(dataset=train_dataset,
batch_size=opt.Train.Dataloader.batch_size,
shuffle=train_sampler is None,
sampler=train_sampler,
num_workers=opt.Train.Dataloader.num_workers,
pin_memory=opt.Train.Dataloader.pin_memory,
drop_last=True)
model_ckpt = None
state_ckpt = None
if args.resume is True:
if os.path.isfile(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth')):
model_ckpt = torch.load(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth'), map_location='cpu')
if args.local_rank <= 0:
print('Resume from checkpoint')
if os.path.isfile(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'state.pth')):
state_ckpt = torch.load(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'state.pth'), map_location='cpu')
if args.local_rank <= 0:
print('Resume from state')
model = eval(opt.Model.name)(**opt.Model)
if model_ckpt is not None:
model.load_state_dict(model_ckpt)
if args.device_num > 1:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = model.cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
else:
model = model.cuda()
backbone_params = nn.ParameterList()
decoder_params = nn.ParameterList()
for name, param in model.named_parameters():
if 'backbone' in name:
backbone_params.append(param)
else:
decoder_params.append(param)
params_list = [{'params': backbone_params}, {
'params': decoder_params, 'lr': opt.Train.Optimizer.lr * 10}]
optimizer = eval(opt.Train.Optimizer.type)(
params_list, opt.Train.Optimizer.lr, weight_decay=opt.Train.Optimizer.weight_decay)
if state_ckpt is not None:
optimizer.load_state_dict(state_ckpt['optimizer'])
if opt.Train.Optimizer.mixed_precision is True:
scaler = GradScaler()
else:
scaler = None
scheduler = eval(opt.Train.Scheduler.type)(optimizer, gamma=opt.Train.Scheduler.gamma,
minimum_lr=opt.Train.Scheduler.minimum_lr,
max_iteration=len(train_loader) * opt.Train.Scheduler.epoch,
warmup_iteration=opt.Train.Scheduler.warmup_iteration)
if state_ckpt is not None:
scheduler.load_state_dict(state_ckpt['scheduler'])
model.train()
start = 1
if state_ckpt is not None:
start = state_ckpt['epoch']
epoch_iter = range(start, opt.Train.Scheduler.epoch + 1)
if args.local_rank <= 0 and args.verbose is True:
epoch_iter = tqdm.tqdm(epoch_iter, desc='Epoch', total=opt.Train.Scheduler.epoch, initial=start - 1,
position=0, bar_format='{desc:<5.5}{percentage:3.0f}%|{bar:40}{r_bar}')
for epoch in epoch_iter:
if args.local_rank <= 0 and args.verbose is True:
step_iter = tqdm.tqdm(enumerate(train_loader, start=1), desc='Iter', total=len(
train_loader), position=1, leave=False, bar_format='{desc:<5.5}{percentage:3.0f}%|{bar:40}{r_bar}')
if args.device_num > 1 and train_sampler is not None:
train_sampler.set_epoch(epoch)
else:
step_iter = enumerate(train_loader, start=1)
for i, sample in step_iter:
optimizer.zero_grad()
if opt.Train.Optimizer.mixed_precision is True and scaler is not None:
with autocast():
sample = to_cuda(sample)
out = model(sample)
scaler.scale(out['loss']).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
else:
sample = to_cuda(sample)
out = model(sample)
out['loss'].backward()
optimizer.step()
scheduler.step()
if args.local_rank <= 0 and args.verbose is True:
step_iter.set_postfix({'loss': out['loss'].item()})
if args.local_rank <= 0:
os.makedirs(opt.Train.Checkpoint.checkpoint_dir, exist_ok=True)
os.makedirs(os.path.join(
opt.Train.Checkpoint.checkpoint_dir, 'debug'), exist_ok=True)
if epoch % opt.Train.Checkpoint.checkpoint_epoch == 0:
if args.device_num > 1:
model_ckpt = model.module.state_dict()
else:
model_ckpt = model.state_dict()
state_ckpt = {'epoch': epoch + 1,
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict()}
torch.save(model_ckpt, os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth'))
torch.save(state_ckpt, os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'state.pth'))
if args.debug is True:
debout = debug_tile(sum([out[k] for k in opt.Train.Debug.keys], []), activation=torch.sigmoid)
cv2.imwrite(os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'debug', str(epoch) + '.png'), debout)
if args.local_rank <= 0:
torch.save(model.module.state_dict() if args.device_num > 1 else model.state_dict(),
os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth'))
if __name__ == '__main__':
args = parse_args()
opt = load_config(args.config)
train(opt, args)