import os import datetime import argparse import torch import torch.nn as nn import torch.optim as optim from torch.autograd import Variable from config import Config from loss import PixLoss, ClsLoss from dataset import MyData from models.birefnet import BiRefNet, BiRefNetC2F from utils import Logger, AverageMeter, set_seed, check_state_dict from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group parser = argparse.ArgumentParser(description='') parser.add_argument('--resume', default=None, type=str, help='path to latest checkpoint') parser.add_argument('--epochs', default=120, type=int) parser.add_argument('--ckpt_dir', default='ckpt/tmp', help='Temporary folder') parser.add_argument('--testsets', default='DIS-VD+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4', type=str) parser.add_argument('--dist', default=False, type=lambda x: x == 'True') parser.add_argument('--use_accelerate', action='store_true', help='`accelerate launch --multi_gpu train.py --use_accelerate`. Use accelerate for training, good for FP16/BF16/...') args = parser.parse_args() if args.use_accelerate: from accelerate import Accelerator accelerator = Accelerator( mixed_precision=['no', 'fp16', 'bf16', 'fp8'][1], gradient_accumulation_steps=1, ) args.dist = False config = Config() if config.rand_seed: set_seed(config.rand_seed) # DDP to_be_distributed = args.dist if to_be_distributed: init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=3600*10)) device = int(os.environ["LOCAL_RANK"]) else: device = config.device epoch_st = 1 # make dir for ckpt os.makedirs(args.ckpt_dir, exist_ok=True) # Init log file logger = Logger(os.path.join(args.ckpt_dir, "log.txt")) logger_loss_idx = 1 # log model and optimizer params # logger.info("Model details:"); logger.info(model) if args.use_accelerate and accelerator.mixed_precision != 'no': config.compile = False logger.info("datasets: load_all={}, compile={}.".format(config.load_all, config.compile)) logger.info("Other hyperparameters:"); logger.info(args) print('batch size:', config.batch_size) if os.path.exists(os.path.join(config.data_root_dir, config.task, args.testsets.strip('+').split('+')[0])): args.testsets = args.testsets.strip('+').split('+') else: args.testsets = [] def prepare_dataloader(dataset: torch.utils.data.Dataset, batch_size: int, to_be_distributed=False, is_train=True): # Prepare dataloaders if to_be_distributed: return torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size), pin_memory=True, shuffle=False, sampler=DistributedSampler(dataset), drop_last=True ) else: return torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size, 0), pin_memory=True, shuffle=is_train, drop_last=True ) def init_data_loaders(to_be_distributed): # Prepare datasets train_loader = prepare_dataloader( MyData(datasets=config.training_set, image_size=config.size, is_train=True), config.batch_size, to_be_distributed=to_be_distributed, is_train=True ) print(len(train_loader), "batches of train dataloader {} have been created.".format(config.training_set)) test_loaders = {} for testset in args.testsets: _data_loader_test = prepare_dataloader( MyData(datasets=testset, image_size=config.size, is_train=False), config.batch_size_valid, is_train=False ) print(len(_data_loader_test), "batches of valid dataloader {} have been created.".format(testset)) test_loaders[testset] = _data_loader_test return train_loader, test_loaders def init_models_optimizers(epochs, to_be_distributed): # Init models if config.model == 'BiRefNet': model = BiRefNet(bb_pretrained=True and not os.path.isfile(str(args.resume))) elif config.model == 'BiRefNetC2F': model = BiRefNetC2F(bb_pretrained=True and not os.path.isfile(str(args.resume))) if args.resume: if os.path.isfile(args.resume): logger.info("=> loading checkpoint '{}'".format(args.resume)) state_dict = torch.load(args.resume, map_location='cpu', weights_only=True) state_dict = check_state_dict(state_dict) model.load_state_dict(state_dict) global epoch_st epoch_st = int(args.resume.rstrip('.pth').split('epoch_')[-1]) + 1 else: logger.info("=> no checkpoint found at '{}'".format(args.resume)) if not args.use_accelerate: if to_be_distributed: model = model.to(device) model = DDP(model, device_ids=[device]) else: model = model.to(device) if config.compile: model = torch.compile(model, mode=['default', 'reduce-overhead', 'max-autotune'][0]) if config.precisionHigh: torch.set_float32_matmul_precision('high') # Setting optimizer if config.optimizer == 'AdamW': optimizer = optim.AdamW(params=model.parameters(), lr=config.lr, weight_decay=1e-2) elif config.optimizer == 'Adam': optimizer = optim.Adam(params=model.parameters(), lr=config.lr, weight_decay=0) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[lde if lde > 0 else epochs + lde + 1 for lde in config.lr_decay_epochs], gamma=config.lr_decay_rate ) logger.info("Optimizer details:"); logger.info(optimizer) logger.info("Scheduler details:"); logger.info(lr_scheduler) return model, optimizer, lr_scheduler class Trainer: def __init__( self, data_loaders, model_opt_lrsch, ): self.model, self.optimizer, self.lr_scheduler = model_opt_lrsch self.train_loader, self.test_loaders = data_loaders if args.use_accelerate: self.train_loader, self.model, self.optimizer = accelerator.prepare(self.train_loader, self.model, self.optimizer) for testset in self.test_loaders.keys(): self.test_loaders[testset] = accelerator.prepare(self.test_loaders[testset]) if config.out_ref: self.criterion_gdt = nn.BCELoss() # Setting Losses self.pix_loss = PixLoss() self.cls_loss = ClsLoss() # Others self.loss_log = AverageMeter() def _train_batch(self, batch): if args.use_accelerate: inputs = batch[0]#.to(device) gts = batch[1]#.to(device) class_labels = batch[2]#.to(device) else: inputs = batch[0].to(device) gts = batch[1].to(device) class_labels = batch[2].to(device) scaled_preds, class_preds_lst = self.model(inputs) if config.out_ref: (outs_gdt_pred, outs_gdt_label), scaled_preds = scaled_preds for _idx, (_gdt_pred, _gdt_label) in enumerate(zip(outs_gdt_pred, outs_gdt_label)): _gdt_pred = nn.functional.interpolate(_gdt_pred, size=_gdt_label.shape[2:], mode='bilinear', align_corners=True).sigmoid() _gdt_label = _gdt_label.sigmoid() loss_gdt = self.criterion_gdt(_gdt_pred, _gdt_label) if _idx == 0 else self.criterion_gdt(_gdt_pred, _gdt_label) + loss_gdt # self.loss_dict['loss_gdt'] = loss_gdt.item() if None in class_preds_lst: loss_cls = 0. else: loss_cls = self.cls_loss(class_preds_lst, class_labels) * 1.0 self.loss_dict['loss_cls'] = loss_cls.item() # Loss loss_pix = self.pix_loss(scaled_preds, torch.clamp(gts, 0, 1)) * 1.0 self.loss_dict['loss_pix'] = loss_pix.item() # since there may be several losses for sal, the lambdas for them (lambdas_pix) are inside the loss.py loss = loss_pix + loss_cls if config.out_ref: loss = loss + loss_gdt * 1.0 self.loss_log.update(loss.item(), inputs.size(0)) self.optimizer.zero_grad() if args.use_accelerate: accelerator.backward(loss) else: loss.backward() self.optimizer.step() def train_epoch(self, epoch): global logger_loss_idx self.model.train() self.loss_dict = {} if epoch > args.epochs + config.finetune_last_epochs: if config.task == 'Matting': self.pix_loss.lambdas_pix_last['mae'] *= 1 self.pix_loss.lambdas_pix_last['mse'] *= 0.9 self.pix_loss.lambdas_pix_last['ssim'] *= 0.9 else: self.pix_loss.lambdas_pix_last['bce'] *= 0 self.pix_loss.lambdas_pix_last['ssim'] *= 1 self.pix_loss.lambdas_pix_last['iou'] *= 0.5 self.pix_loss.lambdas_pix_last['mae'] *= 0.9 for batch_idx, batch in enumerate(self.train_loader): self._train_batch(batch) # Logger if batch_idx % 20 == 0: info_progress = 'Epoch[{0}/{1}] Iter[{2}/{3}].'.format(epoch, args.epochs, batch_idx, len(self.train_loader)) info_loss = 'Training Losses' for loss_name, loss_value in self.loss_dict.items(): info_loss += ', {}: {:.3f}'.format(loss_name, loss_value) logger.info(' '.join((info_progress, info_loss))) info_loss = '@==Final== Epoch[{0}/{1}] Training Loss: {loss.avg:.3f} '.format(epoch, args.epochs, loss=self.loss_log) logger.info(info_loss) self.lr_scheduler.step() return self.loss_log.avg def main(): trainer = Trainer( data_loaders=init_data_loaders(to_be_distributed), model_opt_lrsch=init_models_optimizers(args.epochs, to_be_distributed) ) for epoch in range(epoch_st, args.epochs+1): train_loss = trainer.train_epoch(epoch) # Save checkpoint # DDP if epoch >= args.epochs - config.save_last and epoch % config.save_step == 0: torch.save( trainer.model.module.state_dict() if to_be_distributed or args.use_accelerate else trainer.model.state_dict(), os.path.join(args.ckpt_dir, 'epoch_{}.pth'.format(epoch)) ) if to_be_distributed: destroy_process_group() if __name__ == '__main__': main()