Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import copy | |
| import os | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| from pytorch_lightning import loggers as pl_loggers | |
| from pytorch_lightning.callbacks import * | |
| from torch.cuda import amp | |
| from torch.optim.optimizer import Optimizer | |
| from torch.utils.data.dataset import TensorDataset | |
| from model.seq2seq import DiffusionPredictor | |
| from config import * | |
| from dist_utils import * | |
| from renderer import * | |
| # This part is modified from: https://github.com/phizaz/diffae/blob/master/experiment.py | |
| class LitModel(pl.LightningModule): | |
| def __init__(self, conf: TrainConfig): | |
| super().__init__() | |
| assert conf.train_mode != TrainMode.manipulate | |
| if conf.seed is not None: | |
| pl.seed_everything(conf.seed) | |
| self.save_hyperparameters(conf.as_dict_jsonable()) | |
| self.conf = conf | |
| self.model = DiffusionPredictor(conf) | |
| self.ema_model = copy.deepcopy(self.model) | |
| self.ema_model.requires_grad_(False) | |
| self.ema_model.eval() | |
| self.sampler = conf.make_diffusion_conf().make_sampler() | |
| self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler() | |
| # this is shared for both model and latent | |
| self.T_sampler = conf.make_T_sampler() | |
| if conf.train_mode.use_latent_net(): | |
| self.latent_sampler = conf.make_latent_diffusion_conf( | |
| ).make_sampler() | |
| self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf( | |
| ).make_sampler() | |
| else: | |
| self.latent_sampler = None | |
| self.eval_latent_sampler = None | |
| # initial variables for consistent sampling | |
| self.register_buffer( | |
| 'x_T', | |
| torch.randn(conf.sample_size, 3, conf.img_size, conf.img_size)) | |
| def render(self, start, motion_direction_start, audio_driven, face_location, face_scale, ypr_info, noisyT, step_T, control_flag): | |
| if step_T is None: | |
| sampler = self.eval_sampler | |
| else: | |
| sampler = self.conf._make_diffusion_conf(step_T).make_sampler() | |
| pred_img = render_condition(self.conf, | |
| self.ema_model, | |
| sampler, start, motion_direction_start, audio_driven, face_location, face_scale, ypr_info, noisyT, control_flag) | |
| return pred_img | |
| def forward(self, noise=None, x_start=None, ema_model: bool = False): | |
| with amp.autocast(False): | |
| if not self.disable_ema: | |
| model = self.ema_model | |
| else: | |
| model = self.model | |
| gen = self.eval_sampler.sample(model=model, | |
| noise=noise, | |
| x_start=x_start) | |
| return gen | |
| def setup(self, stage=None) -> None: | |
| """ | |
| make datasets & seeding each worker separately | |
| """ | |
| ############################################## | |
| # NEED TO SET THE SEED SEPARATELY HERE | |
| if self.conf.seed is not None: | |
| seed = self.conf.seed * get_world_size() + self.global_rank | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| print('local seed:', seed) | |
| ############################################## | |
| self.train_data = self.conf.make_dataset() | |
| print('train data:', len(self.train_data)) | |
| self.val_data = self.train_data | |
| print('val data:', len(self.val_data)) | |
| def _train_dataloader(self, drop_last=True): | |
| """ | |
| really make the dataloader | |
| """ | |
| # make sure to use the fraction of batch size | |
| # the batch size is global! | |
| conf = self.conf.clone() | |
| conf.batch_size = self.batch_size | |
| dataloader = conf.make_loader(self.train_data, | |
| shuffle=True, | |
| drop_last=drop_last) | |
| return dataloader | |
| def train_dataloader(self): | |
| """ | |
| return the dataloader, if diffusion mode => return image dataset | |
| if latent mode => return the inferred latent dataset | |
| """ | |
| print('on train dataloader start ...') | |
| if self.conf.train_mode.require_dataset_infer(): | |
| if self.conds is None: | |
| # usually we load self.conds from a file | |
| # so we do not need to do this again! | |
| self.conds = self.infer_whole_dataset() | |
| # need to use float32! unless the mean & std will be off! | |
| # (1, c) | |
| self.conds_mean.data = self.conds.float().mean(dim=0, | |
| keepdim=True) | |
| self.conds_std.data = self.conds.float().std(dim=0, | |
| keepdim=True) | |
| print('mean:', self.conds_mean.mean(), 'std:', | |
| self.conds_std.mean()) | |
| # return the dataset with pre-calculated conds | |
| conf = self.conf.clone() | |
| conf.batch_size = self.batch_size | |
| data = TensorDataset(self.conds) | |
| return conf.make_loader(data, shuffle=True) | |
| else: | |
| return self._train_dataloader() | |
| def batch_size(self): | |
| """ | |
| local batch size for each worker | |
| """ | |
| ws = get_world_size() | |
| assert self.conf.batch_size % ws == 0 | |
| return self.conf.batch_size // ws | |
| def num_samples(self): | |
| """ | |
| (global) batch size * iterations | |
| """ | |
| # batch size here is global! | |
| # global_step already takes into account the accum batches | |
| return self.global_step * self.conf.batch_size_effective | |
| def is_last_accum(self, batch_idx): | |
| """ | |
| is it the last gradient accumulation loop? | |
| used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not | |
| """ | |
| return (batch_idx + 1) % self.conf.accum_batches == 0 | |
| def training_step(self, batch, batch_idx): | |
| """ | |
| given an input, calculate the loss function | |
| no optimization at this stage. | |
| """ | |
| with amp.autocast(False): | |
| motion_start = batch['motion_start'] # torch.Size([B, 512]) | |
| motion_direction = batch['motion_direction'] # torch.Size([B, 125, 20]) | |
| audio_feats = batch['audio_feats'].float() # torch.Size([B, 25, 250, 1024]) | |
| face_location = batch['face_location'].float() # torch.Size([B, 125]) | |
| face_scale = batch['face_scale'].float() # torch.Size([B, 125, 1]) | |
| yaw_pitch_roll = batch['yaw_pitch_roll'].float() # torch.Size([B, 125, 3]) | |
| motion_direction_start = batch['motion_direction_start'].float() # torch.Size([B, 20]) | |
| # import pdb; pdb.set_trace() | |
| if self.conf.train_mode == TrainMode.diffusion: | |
| """ | |
| main training mode!!! | |
| """ | |
| # with numpy seed we have the problem that the sample t's are related! | |
| t, weight = self.T_sampler.sample(len(motion_start), motion_start.device) | |
| losses = self.sampler.training_losses(model=self.model, | |
| motion_direction_start=motion_direction_start, | |
| motion_target=motion_direction, | |
| motion_start=motion_start, | |
| audio_feats=audio_feats, | |
| face_location=face_location, | |
| face_scale=face_scale, | |
| yaw_pitch_roll=yaw_pitch_roll, | |
| t=t) | |
| else: | |
| raise NotImplementedError() | |
| loss = losses['loss'].mean() | |
| # divide by accum batches to make the accumulated gradient exact! | |
| for key in losses.keys(): | |
| losses[key] = self.all_gather(losses[key]).mean() | |
| if self.global_rank == 0: | |
| self.logger.experiment.add_scalar('loss', losses['loss'], | |
| self.num_samples) | |
| for key in losses: | |
| self.logger.experiment.add_scalar( | |
| f'loss/{key}', losses[key], self.num_samples) | |
| return {'loss': loss} | |
| def on_train_batch_end(self, outputs, batch, batch_idx: int, | |
| dataloader_idx: int) -> None: | |
| """ | |
| after each training step ... | |
| """ | |
| if self.is_last_accum(batch_idx): | |
| if self.conf.train_mode == TrainMode.latent_diffusion: | |
| # it trains only the latent hence change only the latent | |
| ema(self.model.latent_net, self.ema_model.latent_net, | |
| self.conf.ema_decay) | |
| else: | |
| ema(self.model, self.ema_model, self.conf.ema_decay) | |
| def on_before_optimizer_step(self, optimizer: Optimizer, | |
| optimizer_idx: int) -> None: | |
| # fix the fp16 + clip grad norm problem with pytorch lightinng | |
| # this is the currently correct way to do it | |
| if self.conf.grad_clip > 0: | |
| # from trainer.params_grads import grads_norm, iter_opt_params | |
| params = [ | |
| p for group in optimizer.param_groups for p in group['params'] | |
| ] | |
| torch.nn.utils.clip_grad_norm_(params, | |
| max_norm=self.conf.grad_clip) | |
| def configure_optimizers(self): | |
| out = {} | |
| if self.conf.optimizer == OptimizerType.adam: | |
| optim = torch.optim.Adam(self.model.parameters(), | |
| lr=self.conf.lr, | |
| weight_decay=self.conf.weight_decay) | |
| elif self.conf.optimizer == OptimizerType.adamw: | |
| optim = torch.optim.AdamW(self.model.parameters(), | |
| lr=self.conf.lr, | |
| weight_decay=self.conf.weight_decay) | |
| else: | |
| raise NotImplementedError() | |
| out['optimizer'] = optim | |
| if self.conf.warmup > 0: | |
| sched = torch.optim.lr_scheduler.LambdaLR(optim, | |
| lr_lambda=WarmupLR( | |
| self.conf.warmup)) | |
| out['lr_scheduler'] = { | |
| 'scheduler': sched, | |
| 'interval': 'step', | |
| } | |
| return out | |
| def split_tensor(self, x): | |
| """ | |
| extract the tensor for a corresponding "worker" in the batch dimension | |
| Args: | |
| x: (n, c) | |
| Returns: x: (n_local, c) | |
| """ | |
| n = len(x) | |
| rank = self.global_rank | |
| world_size = get_world_size() | |
| # print(f'rank: {rank}/{world_size}') | |
| per_rank = n // world_size | |
| return x[rank * per_rank:(rank + 1) * per_rank] | |
| def ema(source, target, decay): | |
| source_dict = source.state_dict() | |
| target_dict = target.state_dict() | |
| for key in source_dict.keys(): | |
| target_dict[key].data.copy_(target_dict[key].data * decay + | |
| source_dict[key].data * (1 - decay)) | |
| class WarmupLR: | |
| def __init__(self, warmup) -> None: | |
| self.warmup = warmup | |
| def __call__(self, step): | |
| return min(step, self.warmup) / self.warmup | |
| def is_time(num_samples, every, step_size): | |
| closest = (num_samples // every) * every | |
| return num_samples - closest < step_size | |
| def train(conf: TrainConfig, gpus, nodes=1, mode: str = 'train'): | |
| print('conf:', conf.name) | |
| # assert not (conf.fp16 and conf.grad_clip > 0 | |
| # ), 'pytorch lightning has bug with amp + gradient clipping' | |
| model = LitModel(conf) | |
| if not os.path.exists(conf.logdir): | |
| os.makedirs(conf.logdir) | |
| checkpoint = ModelCheckpoint(dirpath=f'{conf.logdir}', | |
| save_last=True, | |
| save_top_k=-1, | |
| every_n_epochs=10) | |
| checkpoint_path = f'{conf.logdir}/last.ckpt' | |
| print('ckpt path:', checkpoint_path) | |
| if os.path.exists(checkpoint_path): | |
| resume = checkpoint_path | |
| print('resume!') | |
| else: | |
| if conf.continue_from is not None: | |
| # continue from a checkpoint | |
| resume = conf.continue_from.pathcd | |
| else: | |
| resume = None | |
| tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir, | |
| name=None, | |
| version='') | |
| # from pytorch_lightning. | |
| plugins = [] | |
| if len(gpus) == 1 and nodes == 1: | |
| accelerator = None | |
| else: | |
| accelerator = 'ddp' | |
| from pytorch_lightning.plugins import DDPPlugin | |
| # important for working with gradient checkpoint | |
| plugins.append(DDPPlugin(find_unused_parameters=True)) | |
| trainer = pl.Trainer( | |
| max_steps=conf.total_samples // conf.batch_size_effective, | |
| resume_from_checkpoint=resume, | |
| gpus=gpus, | |
| num_nodes=nodes, | |
| accelerator=accelerator, | |
| precision=16 if conf.fp16 else 32, | |
| callbacks=[ | |
| checkpoint, | |
| LearningRateMonitor(), | |
| ], | |
| # clip in the model instead | |
| # gradient_clip_val=conf.grad_clip, | |
| replace_sampler_ddp=True, | |
| logger=tb_logger, | |
| accumulate_grad_batches=conf.accum_batches, | |
| plugins=plugins, | |
| ) | |
| trainer.fit(model) | |
