Spaces:
Running
Running
| import torch | |
| import abc | |
| import os | |
| import copy | |
| import pytorch_lightning as pl | |
| from utils.lr_scheduler import * | |
| from torch import distributed as dist | |
| class AbstractModel(pl.LightningModule): | |
| def __init__(self, | |
| lr_scheduler_kwargs: dict = None, | |
| optimizer_kwargs: dict = None, | |
| save_path: str = None, | |
| from_checkpoint: str = None, | |
| load_prev_scheduler: bool = False, | |
| save_weights_only: bool = True,): | |
| """ | |
| Args: | |
| lr_scheduler: Kwargs for lr_scheduler | |
| optimizer_kwargs: Kwargs for optimizer_kwargs | |
| save_path: Save trained model | |
| from_checkpoint: Load model from checkpoint | |
| load_prev_scheduler: Whether load previous scheduler from checkpoint | |
| load_strict: Whether load model strictly | |
| save_weights_only: Whether save only weights or also optimizer and lr_scheduler | |
| """ | |
| super().__init__() | |
| self.initialize_model() | |
| self.metrics = {} | |
| for stage in ["train", "valid", "test"]: | |
| stage_metrics = self.initialize_metrics(stage) | |
| # Rigister metrics as attributes | |
| for metric_name, metric in stage_metrics.items(): | |
| setattr(self, metric_name, metric) | |
| self.metrics[stage] = stage_metrics | |
| if lr_scheduler_kwargs is None: | |
| # Default lr_scheduler | |
| self.lr_scheduler_kwargs = { | |
| "class": "ConstantLRScheduler", | |
| "init_lr": 0, | |
| } | |
| print("No lr_scheduler_kwargs provided. The default learning rate is 0.") | |
| else: | |
| self.lr_scheduler_kwargs = lr_scheduler_kwargs | |
| if optimizer_kwargs is None: | |
| # Default optimizer | |
| self.optimizer_kwargs = { | |
| "class": "AdamW", | |
| "betas": (0.9, 0.98), | |
| "weight_decay": 0.01, | |
| } | |
| print("No optimizer_kwargs provided. The default optimizer is AdamW.") | |
| else: | |
| self.optimizer_kwargs = optimizer_kwargs | |
| self.init_optimizers() | |
| self.save_path = save_path | |
| self.save_weights_only = save_weights_only | |
| # temp_step is used for accumulating gradients | |
| self.temp_step = 0 | |
| self.step = 0 | |
| self.epoch = 0 | |
| self.load_prev_scheduler = load_prev_scheduler | |
| self.from_checkpoint = from_checkpoint | |
| if from_checkpoint: | |
| self.load_checkpoint(from_checkpoint) | |
| def initialize_model(self) -> None: | |
| """ | |
| All model initialization should be done here | |
| Note that the whole model must be named as "self.model" for model saving and loading | |
| """ | |
| raise NotImplementedError | |
| def forward(self, *args, **kwargs): | |
| """ | |
| Forward propagation | |
| """ | |
| raise NotImplementedError | |
| def initialize_metrics(self, stage: str) -> dict: | |
| """ | |
| Initialize metrics for each stage | |
| Args: | |
| stage: "train", "valid" or "test" | |
| Returns: | |
| A dictionary of metrics for the stage. Keys are metric names and values are metric objects | |
| """ | |
| raise NotImplementedError | |
| def loss_func(self, stage: str, outputs, labels) -> torch.Tensor: | |
| """ | |
| Args: | |
| stage: "train", "valid" or "test" | |
| outputs: model outputs for calculating loss | |
| labels: labels for calculating loss | |
| Returns: | |
| loss | |
| """ | |
| raise NotImplementedError | |
| def load_weights(model, weights): | |
| model_dict = model.state_dict() | |
| unused_params = [] | |
| missed_params = list(model_dict.keys()) | |
| for k, v in weights.items(): | |
| if k in model_dict.keys(): | |
| model_dict[k] = v | |
| missed_params.remove(k) | |
| else: | |
| unused_params.append(k) | |
| if len(missed_params) > 0: | |
| print(f"\033[31mSome weights of {type(model).__name__} were not " | |
| f"initialized from the model checkpoint: {missed_params}\033[0m") | |
| if len(unused_params) > 0: | |
| print(f"\033[31mSome weights of the model checkpoint were not used: {unused_params}\033[0m") | |
| model.load_state_dict(model_dict) | |
| def optimizer_step( | |
| self, | |
| epoch: int, | |
| batch_idx: int, | |
| optimizer, | |
| optimizer_closure=None, | |
| ) -> None: | |
| super().optimizer_step(epoch, batch_idx, optimizer, optimizer_closure) | |
| self.temp_step += 1 | |
| if self.temp_step == self.trainer.accumulate_grad_batches: | |
| self.step += 1 | |
| self.temp_step = 0 | |
| # For pytorch-lightning 1.9.5 | |
| # def optimizer_step( | |
| # self, | |
| # epoch: int, | |
| # batch_idx: int, | |
| # optimizer, | |
| # optimizer_idx: int = 0, | |
| # optimizer_closure=None, | |
| # on_tpu: bool = False, | |
| # using_native_amp: bool = False, | |
| # using_lbfgs: bool = False, | |
| # ) -> None: | |
| # super().optimizer_step( | |
| # epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs | |
| # ) | |
| # self.temp_step += 1 | |
| # if self.temp_step == self.trainer.accumulate_grad_batches: | |
| # self.step += 1 | |
| # self.temp_step = 0 | |
| def on_train_epoch_end(self): | |
| self.epoch += 1 | |
| def training_step(self, batch, batch_idx): | |
| inputs, labels = batch | |
| # optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-4, weight_decay=0.01, betas=(0.9, 0.98)) | |
| # for _ in range(1000): | |
| # outputs = self(**inputs) | |
| # loss = self.loss_func('train', outputs, labels) | |
| # loss.backward() | |
| # optimizer.step() | |
| # optimizer.zero_grad() | |
| # | |
| # raise | |
| outputs = self(**inputs) | |
| loss = self.loss_func('train', outputs, labels) | |
| self.log("loss", loss, prog_bar=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| inputs, labels = batch | |
| outputs = self(**inputs) | |
| loss = self.loss_func('valid', outputs, labels) | |
| self.valid_outputs.append(loss) | |
| return loss | |
| def test_step(self, batch, batch_idx): | |
| inputs, labels = batch | |
| outputs = self(**inputs) | |
| loss = self.loss_func('test', outputs, labels) | |
| self.test_outputs.append(loss) | |
| return loss | |
| def on_train_start(self) -> None: | |
| # Load previous scheduler | |
| if getattr(self, "prev_schechuler", None) is not None: | |
| try: | |
| self.step = self.prev_schechuler["global_step"] | |
| self.epoch = self.prev_schechuler["epoch"] | |
| self.best_value = self.prev_schechuler["best_value"] | |
| self.lr_scheduler.load_state_dict(self.prev_schechuler["lr_scheduler"]) | |
| print(f"Previous training global step: {self.step}") | |
| print(f"Previous training epoch: {self.epoch}") | |
| print(f"Previous best value: {self.best_value}") | |
| print(f"Previous lr_scheduler: {self.prev_schechuler['lr_scheduler']}") | |
| # Load optimizer state | |
| if hasattr(self.trainer.strategy, "deepspeed_engine"): | |
| # For DeepSpeed strategy | |
| try: | |
| self.trainer.strategy.deepspeed_engine.load_checkpoint(self.from_checkpoint) | |
| except Exception as e: | |
| print(e) | |
| else: | |
| # For DDP strategy | |
| self.optimizer.load_state_dict(self.prev_schechuler["optimizer"]) | |
| except Exception as e: | |
| print(e) | |
| raise Exception("Error in loading previous scheduler. Please set load_prev_scheduler=False") | |
| def on_validation_epoch_start(self) -> None: | |
| setattr(self, "valid_outputs", []) | |
| def on_test_epoch_start(self) -> None: | |
| setattr(self, "test_outputs", []) | |
| def load_checkpoint(self, from_checkpoint: str) -> None: | |
| """ | |
| Args: | |
| from_checkpoint: Path to checkpoint. | |
| """ | |
| # If ``from_checkpoint`` is a directory, load the checkpoint in it | |
| if os.path.isdir(from_checkpoint): | |
| basename = os.path.basename(from_checkpoint) | |
| from_checkpoint = os.path.join(from_checkpoint, f"{basename}.pt") | |
| state_dict = torch.load(from_checkpoint, map_location=self.device) | |
| self.load_weights(self.model, state_dict["model"]) | |
| if self.load_prev_scheduler: | |
| state_dict.pop("model") | |
| self.prev_schechuler = state_dict | |
| def save_checkpoint(self, save_path: str, save_info: dict = None, save_weights_only: bool = True) -> None: | |
| """ | |
| Save model to save_path | |
| Args: | |
| save_path: Path to save model | |
| save_info: Other info to save | |
| save_weights_only: Whether only save model weights | |
| """ | |
| dir = os.path.dirname(save_path) | |
| os.makedirs(dir, exist_ok=True) | |
| state_dict = {} if save_info is None else save_info | |
| state_dict["model"] = self.model.state_dict() | |
| # Convert model weights to fp32 | |
| for k, v in state_dict["model"].items(): | |
| state_dict["model"][k] = v.float() | |
| if not save_weights_only: | |
| state_dict["global_step"] = self.step | |
| state_dict["epoch"] = self.epoch | |
| state_dict["best_value"] = getattr(self, f"best_value", None) | |
| state_dict["lr_scheduler"] = self.lr_schedulers().state_dict() | |
| # If not using DeepSpeed, save optimizer state | |
| if not hasattr(self.trainer.strategy, "deepspeed_engine"): | |
| state_dict["optimizer"] = self.optimizers().optimizer.state_dict() | |
| torch.save(state_dict, save_path) | |
| def check_save_condition(self, now_value: float, mode: str, save_info: dict = None) -> None: | |
| """ | |
| Check whether to save model. If save_path is not None and now_value is the best, save model. | |
| Args: | |
| now_value: Current metric value | |
| mode: "min" or "max", meaning whether the lower the better or the higher the better | |
| save_info: Other info to save | |
| """ | |
| assert mode in ["min", "max"], "mode should be 'min' or 'max'" | |
| if self.save_path is not None: | |
| # In case there are variables to be included in the save path | |
| save_path = eval(f"f'{self.save_path}'") | |
| dir = os.path.dirname(save_path) | |
| os.makedirs(dir, exist_ok=True) | |
| # Check whether to save model | |
| best_value = getattr(self, f"best_value", None) | |
| if best_value is not None: | |
| if mode == "min" and now_value >= best_value or mode == "max" and now_value <= best_value: | |
| return | |
| setattr(self, "best_value", now_value) | |
| # For DeepSpeed strategy | |
| if hasattr(self.trainer.strategy, "deepspeed_engine"): | |
| if not self.save_weights_only: | |
| self.trainer.strategy.deepspeed_engine.save_checkpoint(save_path, tag="deepspeed_ckpt") | |
| # Save a complete checkpoint | |
| if dist.get_rank() == 0: | |
| basename = os.path.basename(save_path) | |
| ckpt_path = os.path.join(save_path, f"{basename}.pt") | |
| self.save_checkpoint(ckpt_path, save_info, self.save_weights_only) | |
| # For normal situation | |
| else: | |
| if dist.get_rank() == 0: | |
| self.save_checkpoint(save_path, save_info, self.save_weights_only) | |
| def reset_metrics(self, stage) -> None: | |
| """ | |
| Reset metrics for given stage | |
| Args: | |
| stage: "train", "valid" or "test" | |
| """ | |
| for metric in self.metrics[stage].values(): | |
| metric.reset() | |
| def get_log_dict(self, stage: str) -> dict: | |
| """ | |
| Get log dict for the stage | |
| Args: | |
| stage: "train", "valid" or "test" | |
| Returns: | |
| A dictionary of metrics for the stage. Keys are metric names and values are metric values | |
| """ | |
| return {name: metric.compute() for name, metric in self.metrics[stage].items()} | |
| def log_info(self, info: dict) -> None: | |
| """ | |
| Record metrics during training and testing | |
| Args: | |
| info: dict of metrics | |
| """ | |
| if getattr(self, "logger", None) is not None and dist.get_rank() == 0: | |
| info["learning_rate"] = self.lr_scheduler.get_last_lr()[0] | |
| info["epoch"] = self.epoch | |
| self.logger.log_metrics(info, step=self.step) | |
| def init_optimizers(self): | |
| copy_optimizer_kwargs = copy.deepcopy(self.optimizer_kwargs) | |
| # No decay for layer norm and bias | |
| no_decay = ['LayerNorm.weight', 'bias'] | |
| weight_decay = copy_optimizer_kwargs.pop("weight_decay") | |
| optimizer_grouped_parameters = [ | |
| {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], | |
| 'weight_decay': weight_decay}, | |
| {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], | |
| 'weight_decay': 0.0} | |
| ] | |
| optimizer_cls = eval(f"torch.optim.{copy_optimizer_kwargs.pop('class')}") | |
| self.optimizer = optimizer_cls(optimizer_grouped_parameters, | |
| lr=self.lr_scheduler_kwargs['init_lr'], | |
| **copy_optimizer_kwargs) | |
| tmp_kwargs = copy.deepcopy(self.lr_scheduler_kwargs) | |
| lr_scheduler = tmp_kwargs.pop("class") | |
| self.lr_scheduler = eval(lr_scheduler)(self.optimizer, **tmp_kwargs) | |
| def configure_optimizers(self): | |
| return {"optimizer": self.optimizer, | |
| "lr_scheduler": {"scheduler": self.lr_scheduler, | |
| "interval": "step", | |
| "frequency": 1} | |
| } | |