Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Soutrik
		
	commited on
		
		
					Commit 
							
							·
						
						8d4131e
	
1
								Parent(s):
							
							de7d21e
								
optuna added as base
Browse files- configs/experiment/catdog_experiment.yaml +1 -1
- configs/trainer/default.yaml +2 -3
- src/train_new.py +84 -99
- src/train_old.py +260 -0
- src/{train.py → train_optuna_callbacks.py} +120 -98
    	
        configs/experiment/catdog_experiment.yaml
    CHANGED
    
    | @@ -39,7 +39,7 @@ model: | |
| 39 |  | 
| 40 | 
             
            trainer:
         | 
| 41 | 
             
              min_epochs: 1
         | 
| 42 | 
            -
              max_epochs:  | 
| 43 |  | 
| 44 | 
             
            callbacks:
         | 
| 45 | 
             
              model_checkpoint:
         | 
|  | |
| 39 |  | 
| 40 | 
             
            trainer:
         | 
| 41 | 
             
              min_epochs: 1
         | 
| 42 | 
            +
              max_epochs: 5
         | 
| 43 |  | 
| 44 | 
             
            callbacks:
         | 
| 45 | 
             
              model_checkpoint:
         | 
    	
        configs/trainer/default.yaml
    CHANGED
    
    | @@ -1,4 +1,3 @@ | |
| 1 | 
            -
            _target_: lightning.Trainer
         | 
| 2 |  | 
| 3 | 
             
            default_root_dir: ${paths.output_dir}
         | 
| 4 | 
             
            min_epochs: 1
         | 
| @@ -10,8 +9,7 @@ devices: auto | |
| 10 | 
             
            # mixed precision for extra speed-up
         | 
| 11 | 
             
            # precision: 16
         | 
| 12 |  | 
| 13 | 
            -
            # set True to to ensure deterministic results
         | 
| 14 | 
            -
            # makes training slower but gives more reproducibility than just setting seeds
         | 
| 15 | 
             
            deterministic: True
         | 
| 16 |  | 
| 17 | 
             
            # Log every N steps in training and validation
         | 
| @@ -19,3 +17,4 @@ log_every_n_steps: 10 | |
| 19 | 
             
            fast_dev_run: False
         | 
| 20 |  | 
| 21 | 
             
            gradient_clip_val: 1.0
         | 
|  | 
|  | |
|  | |
| 1 |  | 
| 2 | 
             
            default_root_dir: ${paths.output_dir}
         | 
| 3 | 
             
            min_epochs: 1
         | 
|  | |
| 9 | 
             
            # mixed precision for extra speed-up
         | 
| 10 | 
             
            # precision: 16
         | 
| 11 |  | 
| 12 | 
            +
            # set True to to ensure deterministic results makes training slower but gives more reproducibility than just setting seeds
         | 
|  | |
| 13 | 
             
            deterministic: True
         | 
| 14 |  | 
| 15 | 
             
            # Log every N steps in training and validation
         | 
|  | |
| 17 | 
             
            fast_dev_run: False
         | 
| 18 |  | 
| 19 | 
             
            gradient_clip_val: 1.0
         | 
| 20 | 
            +
            gradient_clip_algorithm: 'norm'
         | 
    	
        src/train_new.py
    CHANGED
    
    | @@ -1,7 +1,5 @@ | |
| 1 | 
             
            """
         | 
| 2 | 
            -
            Train and evaluate a model using PyTorch Lightning.
         | 
| 3 | 
            -
            Initializes the DataModule, Model, Trainer, and runs training and testing.
         | 
| 4 | 
            -
            Initializes loggers and callbacks from the configuration using Hydra and target paths from the configuration.
         | 
| 5 | 
             
            """
         | 
| 6 |  | 
| 7 | 
             
            import os
         | 
| @@ -17,51 +15,34 @@ from src.utils.logging_utils import setup_logger, task_wrapper | |
| 17 | 
             
            from loguru import logger
         | 
| 18 | 
             
            import rootutils
         | 
| 19 | 
             
            from lightning.pytorch.loggers import Logger
         | 
| 20 | 
            -
             | 
|  | |
| 21 |  | 
| 22 | 
             
            # Load environment variables
         | 
| 23 | 
             
            load_dotenv(find_dotenv(".env"))
         | 
| 24 |  | 
| 25 | 
             
            # Setup root directory
         | 
| 26 | 
            -
             | 
| 27 | 
             
            root = rootutils.setup_root(__file__, indicator=".project-root")
         | 
| 28 |  | 
| 29 |  | 
| 30 | 
            -
            def instantiate_callbacks(callback_cfg: DictConfig) -> List[Callback]:
         | 
| 31 | 
            -
                """Instantiate and return a list of callbacks from the configuration."""
         | 
| 32 | 
            -
                callbacks_ls: List[L.Callback] = []
         | 
| 33 | 
            -
             | 
| 34 | 
            -
                if not callback_cfg:
         | 
| 35 | 
            -
                    logger.warning("No callback configs found! Skipping..")
         | 
| 36 | 
            -
                    return None
         | 
| 37 | 
            -
             | 
| 38 | 
            -
                if not isinstance(callback_cfg, DictConfig):
         | 
| 39 | 
            -
                    raise TypeError("Callbacks config must be a DictConfig!")
         | 
| 40 | 
            -
             | 
| 41 | 
            -
                for _, cb_conf in callback_cfg.items():
         | 
| 42 | 
            -
                    if "_target_" in cb_conf:
         | 
| 43 | 
            -
                        logger.info(f"Instantiating callback <{cb_conf._target_}>")
         | 
| 44 | 
            -
                        callbacks_ls.append(hydra.utils.instantiate(cb_conf))
         | 
| 45 | 
            -
             | 
| 46 | 
            -
                return callbacks_ls
         | 
| 47 | 
            -
             | 
| 48 | 
            -
             | 
| 49 | 
             
            def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
         | 
| 50 | 
             
                """Instantiate and return a list of loggers from the configuration."""
         | 
| 51 | 
             
                loggers_ls: List[Logger] = []
         | 
| 52 |  | 
| 53 | 
            -
                if not logger_cfg:
         | 
| 54 | 
            -
                    logger.warning("No logger configs found! Skipping..")
         | 
| 55 | 
             
                    return loggers_ls
         | 
| 56 |  | 
| 57 | 
             
                if not isinstance(logger_cfg, DictConfig):
         | 
| 58 | 
             
                    raise TypeError("Logger config must be a DictConfig!")
         | 
| 59 |  | 
| 60 | 
             
                for _, lg_conf in logger_cfg.items():
         | 
| 61 | 
            -
                    if "_target_" in lg_conf:
         | 
| 62 | 
             
                        logger.info(f"Instantiating logger <{lg_conf._target_}>")
         | 
| 63 | 
            -
                         | 
| 64 | 
            -
             | 
|  | |
|  | |
| 65 | 
             
                return loggers_ls
         | 
| 66 |  | 
| 67 |  | 
| @@ -93,16 +74,19 @@ def clear_checkpoint_directory(ckpt_dir: str): | |
| 93 | 
             
            def train_module(
         | 
| 94 | 
             
                data_module: L.LightningDataModule, model: L.LightningModule, trainer: L.Trainer
         | 
| 95 | 
             
            ):
         | 
| 96 | 
            -
                """Train the model  | 
| 97 | 
            -
                logger.info("Starting training")
         | 
|  | |
| 98 | 
             
                trainer.fit(model, data_module)
         | 
| 99 | 
            -
                 | 
| 100 | 
            -
             | 
| 101 | 
            -
                 | 
| 102 | 
            -
             | 
| 103 | 
            -
                     | 
| 104 | 
            -
             | 
| 105 | 
            -
             | 
|  | |
|  | |
| 106 |  | 
| 107 |  | 
| 108 | 
             
            @task_wrapper
         | 
| @@ -122,77 +106,78 @@ def run_test_module( | |
| 122 | 
             
                return test_metrics[0] if test_metrics else {}
         | 
| 123 |  | 
| 124 |  | 
| 125 | 
            -
             | 
| 126 | 
            -
             | 
| 127 | 
            -
                """Set up and run the Trainer for training and testing."""
         | 
| 128 | 
            -
                # Display configuration
         | 
| 129 | 
            -
                logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
         | 
| 130 |  | 
| 131 | 
            -
                #  | 
| 132 | 
            -
                 | 
| 133 | 
            -
             | 
| 134 | 
            -
                )
         | 
| 135 | 
            -
                 | 
| 136 | 
            -
             | 
| 137 | 
            -
                # Display key paths
         | 
| 138 | 
            -
                for path_name in [
         | 
| 139 | 
            -
                    "root_dir",
         | 
| 140 | 
            -
                    "data_dir",
         | 
| 141 | 
            -
                    "log_dir",
         | 
| 142 | 
            -
                    "ckpt_dir",
         | 
| 143 | 
            -
                    "artifact_dir",
         | 
| 144 | 
            -
                    "output_dir",
         | 
| 145 | 
            -
                ]:
         | 
| 146 | 
            -
                    logger.info(
         | 
| 147 | 
            -
                        f"{path_name.replace('_', ' ').capitalize()}: {cfg.paths[path_name]}"
         | 
| 148 | 
            -
                    )
         | 
| 149 |  | 
| 150 | 
            -
                # Initialize  | 
| 151 | 
            -
                 | 
| 152 | 
            -
                datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
         | 
| 153 | 
            -
                logger.info(f"Instantiating model <{cfg.model._target_}>")
         | 
| 154 | 
             
                model: L.LightningModule = hydra.utils.instantiate(cfg.model)
         | 
| 155 |  | 
| 156 | 
            -
                #  | 
| 157 | 
            -
                logger.info("GPU available" if torch.cuda.is_available() else "No GPU available")
         | 
| 158 | 
            -
                L.seed_everything(cfg.seed, workers=True)
         | 
| 159 | 
            -
             | 
| 160 | 
            -
                # Set up callbacks, loggers, and Trainer
         | 
| 161 | 
            -
                callbacks = instantiate_callbacks(cfg.callbacks)
         | 
| 162 | 
            -
                logger.info(f"Callbacks: {callbacks}")
         | 
| 163 | 
             
                loggers = instantiate_loggers(cfg.logger)
         | 
| 164 | 
            -
             | 
| 165 | 
            -
                 | 
| 166 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 167 | 
             
                )
         | 
| 168 |  | 
| 169 | 
            -
                 | 
| 170 | 
            -
             | 
| 171 | 
            -
             | 
| 172 | 
            -
             | 
| 173 | 
            -
                    train_metrics = train_module(datamodule, model, trainer)
         | 
| 174 | 
            -
                    (Path(cfg.paths.ckpt_dir) / "train_done.flag").write_text(
         | 
| 175 | 
            -
                        "Training completed.\n"
         | 
| 176 | 
             
                    )
         | 
| 177 | 
            -
             | 
| 178 | 
            -
             | 
| 179 | 
            -
                test_metrics = {}
         | 
| 180 | 
            -
                if cfg.get("test"):
         | 
| 181 | 
            -
                    test_metrics = run_test_module(cfg, datamodule, model, trainer)
         | 
| 182 | 
            -
             | 
| 183 | 
            -
                # Combine metrics and extract optimization metric
         | 
| 184 | 
            -
                all_metrics = {**train_metrics, **test_metrics}
         | 
| 185 | 
            -
                optimization_metric = all_metrics.get(cfg.get("optimization_metric"), 0.0)
         | 
| 186 | 
            -
                (
         | 
| 187 | 
            -
                    logger.warning(
         | 
| 188 | 
            -
                        f"Optimization metric '{cfg.get('optimization_metric')}' not found. Defaulting to 0."
         | 
| 189 | 
             
                    )
         | 
| 190 | 
            -
                    if optimization_metric == 0.0
         | 
| 191 | 
            -
                    else logger.info(f"Optimization metric: {optimization_metric}")
         | 
| 192 | 
            -
                )
         | 
| 193 |  | 
| 194 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 195 |  | 
| 196 |  | 
| 197 | 
             
            if __name__ == "__main__":
         | 
| 198 | 
            -
                 | 
|  | |
| 1 | 
             
            """
         | 
| 2 | 
            +
            Train and evaluate a model using PyTorch Lightning with Optuna for hyperparameter optimization.
         | 
|  | |
|  | |
| 3 | 
             
            """
         | 
| 4 |  | 
| 5 | 
             
            import os
         | 
|  | |
| 15 | 
             
            from loguru import logger
         | 
| 16 | 
             
            import rootutils
         | 
| 17 | 
             
            from lightning.pytorch.loggers import Logger
         | 
| 18 | 
            +
            import optuna
         | 
| 19 | 
            +
            from lightning.pytorch import Trainer
         | 
| 20 |  | 
| 21 | 
             
            # Load environment variables
         | 
| 22 | 
             
            load_dotenv(find_dotenv(".env"))
         | 
| 23 |  | 
| 24 | 
             
            # Setup root directory
         | 
|  | |
| 25 | 
             
            root = rootutils.setup_root(__file__, indicator=".project-root")
         | 
| 26 |  | 
| 27 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 28 | 
             
            def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
         | 
| 29 | 
             
                """Instantiate and return a list of loggers from the configuration."""
         | 
| 30 | 
             
                loggers_ls: List[Logger] = []
         | 
| 31 |  | 
| 32 | 
            +
                if not logger_cfg or isinstance(logger_cfg, bool):
         | 
| 33 | 
            +
                    logger.warning("No valid logger configs found! Skipping..")
         | 
| 34 | 
             
                    return loggers_ls
         | 
| 35 |  | 
| 36 | 
             
                if not isinstance(logger_cfg, DictConfig):
         | 
| 37 | 
             
                    raise TypeError("Logger config must be a DictConfig!")
         | 
| 38 |  | 
| 39 | 
             
                for _, lg_conf in logger_cfg.items():
         | 
| 40 | 
            +
                    if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
         | 
| 41 | 
             
                        logger.info(f"Instantiating logger <{lg_conf._target_}>")
         | 
| 42 | 
            +
                        try:
         | 
| 43 | 
            +
                            loggers_ls.append(hydra.utils.instantiate(lg_conf))
         | 
| 44 | 
            +
                        except Exception as e:
         | 
| 45 | 
            +
                            logger.error(f"Failed to instantiate logger {lg_conf}: {e}")
         | 
| 46 | 
             
                return loggers_ls
         | 
| 47 |  | 
| 48 |  | 
|  | |
| 74 | 
             
            def train_module(
         | 
| 75 | 
             
                data_module: L.LightningDataModule, model: L.LightningModule, trainer: L.Trainer
         | 
| 76 | 
             
            ):
         | 
| 77 | 
            +
                """Train the model, return validation accuracy for each epoch."""
         | 
| 78 | 
            +
                logger.info("Starting training with custom pruning")
         | 
| 79 | 
            +
             | 
| 80 | 
             
                trainer.fit(model, data_module)
         | 
| 81 | 
            +
                val_accuracies = []
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                for epoch in range(trainer.current_epoch):
         | 
| 84 | 
            +
                    val_acc = trainer.callback_metrics.get("val_acc")
         | 
| 85 | 
            +
                    if val_acc is not None:
         | 
| 86 | 
            +
                        val_accuracies.append(val_acc.item())
         | 
| 87 | 
            +
                        logger.info(f"Epoch {epoch}: val_acc={val_acc}")
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                return val_accuracies
         | 
| 90 |  | 
| 91 |  | 
| 92 | 
             
            @task_wrapper
         | 
|  | |
| 106 | 
             
                return test_metrics[0] if test_metrics else {}
         | 
| 107 |  | 
| 108 |  | 
| 109 | 
            +
            def objective(trial: optuna.trial.Trial, cfg: DictConfig):
         | 
| 110 | 
            +
                """Objective function for Optuna hyperparameter tuning."""
         | 
|  | |
|  | |
|  | |
| 111 |  | 
| 112 | 
            +
                # Sample hyperparameters for the model
         | 
| 113 | 
            +
                cfg.model.embed_dim = trial.suggest_categorical("embed_dim", [64, 128, 256])
         | 
| 114 | 
            +
                cfg.model.depth = trial.suggest_int("depth", 2, 6)
         | 
| 115 | 
            +
                cfg.model.lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
         | 
| 116 | 
            +
                cfg.model.mlp_ratio = trial.suggest_float("mlp_ratio", 1.0, 4.0)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 117 |  | 
| 118 | 
            +
                # Initialize data module and model
         | 
| 119 | 
            +
                data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
         | 
|  | |
|  | |
| 120 | 
             
                model: L.LightningModule = hydra.utils.instantiate(cfg.model)
         | 
| 121 |  | 
| 122 | 
            +
                # Set up logger
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 123 | 
             
                loggers = instantiate_loggers(cfg.logger)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                # Trainer configuration without pruning callback
         | 
| 126 | 
            +
                trainer = Trainer(**cfg.trainer, logger=loggers)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                # Clear checkpoint directory
         | 
| 129 | 
            +
                clear_checkpoint_directory(cfg.paths.ckpt_dir)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                # Train and get val_acc for each epoch
         | 
| 132 | 
            +
                val_accuracies = train_module(data_module, model, trainer)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                # Report validation accuracy and prune if necessary
         | 
| 135 | 
            +
                for epoch, val_acc in enumerate(val_accuracies):
         | 
| 136 | 
            +
                    trial.report(val_acc, step=epoch)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    # Check if the trial should be pruned at this epoch
         | 
| 139 | 
            +
                    if trial.should_prune():
         | 
| 140 | 
            +
                        logger.info(f"Pruning trial at epoch {epoch}")
         | 
| 141 | 
            +
                        raise optuna.TrialPruned()
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                # Return the final validation accuracy as the objective metric
         | 
| 144 | 
            +
                return val_accuracies[-1] if val_accuracies else 0.0
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            @hydra.main(config_path="../configs", config_name="train", version_base="1.3")
         | 
| 148 | 
            +
            def setup_trainer(cfg: DictConfig):
         | 
| 149 | 
            +
                logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                setup_logger(
         | 
| 152 | 
            +
                    Path(cfg.paths.log_dir)
         | 
| 153 | 
            +
                    / ("train.log" if cfg.task_name == "train" else "eval.log")
         | 
| 154 | 
             
                )
         | 
| 155 |  | 
| 156 | 
            +
                if cfg.get("train", False):
         | 
| 157 | 
            +
                    pruner = optuna.pruners.MedianPruner()
         | 
| 158 | 
            +
                    study = optuna.create_study(
         | 
| 159 | 
            +
                        direction="maximize", pruner=pruner, study_name="pytorch_lightning_optuna"
         | 
|  | |
|  | |
|  | |
| 160 | 
             
                    )
         | 
| 161 | 
            +
                    study.optimize(
         | 
| 162 | 
            +
                        lambda trial: objective(trial, cfg), n_trials=3, show_progress_bar=True
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 163 | 
             
                    )
         | 
|  | |
|  | |
|  | |
| 164 |  | 
| 165 | 
            +
                    # Log best trial results
         | 
| 166 | 
            +
                    best_trial = study.best_trial
         | 
| 167 | 
            +
                    logger.info(f"Best trial number: {best_trial.number}")
         | 
| 168 | 
            +
                    logger.info(f"Best trial value (val_acc): {best_trial.value}")
         | 
| 169 | 
            +
                    for key, value in best_trial.params.items():
         | 
| 170 | 
            +
                        logger.info(f"  {key}: {value}")
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                if cfg.get("test", False):
         | 
| 173 | 
            +
                    data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
         | 
| 174 | 
            +
                    model: L.LightningModule = hydra.utils.instantiate(cfg.model)
         | 
| 175 | 
            +
                    trainer = Trainer(**cfg.trainer, logger=instantiate_loggers(cfg.logger))
         | 
| 176 | 
            +
                    test_metrics = run_test_module(cfg, data_module, model, trainer)
         | 
| 177 | 
            +
                    logger.info(f"Test metrics: {test_metrics}")
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                return cfg.model if not cfg.get("test", False) else test_metrics
         | 
| 180 |  | 
| 181 |  | 
| 182 | 
             
            if __name__ == "__main__":
         | 
| 183 | 
            +
                setup_trainer()
         | 
    	
        src/train_old.py
    ADDED
    
    | @@ -0,0 +1,260 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import shutil
         | 
| 3 | 
            +
            from pathlib import Path
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import lightning as L
         | 
| 6 | 
            +
            from lightning.pytorch.loggers import Logger
         | 
| 7 | 
            +
            from typing import List
         | 
| 8 | 
            +
            from src.datamodules.dogbreed_datamodule import main_dataloader
         | 
| 9 | 
            +
            from src.utils.logging_utils import setup_logger, task_wrapper
         | 
| 10 | 
            +
            from loguru import logger
         | 
| 11 | 
            +
            from dotenv import load_dotenv, find_dotenv
         | 
| 12 | 
            +
            import rootutils
         | 
| 13 | 
            +
            import hydra
         | 
| 14 | 
            +
            from omegaconf import DictConfig, OmegaConf
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # Load environment variables
         | 
| 17 | 
            +
            load_dotenv(find_dotenv(".env"))
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # Setup root directory
         | 
| 20 | 
            +
            root = rootutils.setup_root(__file__, indicator=".project-root")
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def instantiate_callbacks(callback_cfg: DictConfig) -> List[L.Callback]:
         | 
| 24 | 
            +
                """Instantiate and return a list of callbacks from the configuration."""
         | 
| 25 | 
            +
                callbacks: List[L.Callback] = []
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                if not callback_cfg:
         | 
| 28 | 
            +
                    logger.warning("No callback configs found! Skipping..")
         | 
| 29 | 
            +
                    return callbacks
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                if not isinstance(callback_cfg, DictConfig):
         | 
| 32 | 
            +
                    raise TypeError("Callbacks config must be a DictConfig!")
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                for _, cb_conf in callback_cfg.items():
         | 
| 35 | 
            +
                    if "_target_" in cb_conf:
         | 
| 36 | 
            +
                        logger.info(f"Instantiating callback <{cb_conf._target_}>")
         | 
| 37 | 
            +
                        callbacks.append(hydra.utils.instantiate(cb_conf))
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                return callbacks
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
         | 
| 43 | 
            +
                """Instantiate and return a list of loggers from the configuration."""
         | 
| 44 | 
            +
                loggers_ls: List[Logger] = []
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                if not logger_cfg:
         | 
| 47 | 
            +
                    logger.warning("No logger configs found! Skipping..")
         | 
| 48 | 
            +
                    return loggers_ls
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                if not isinstance(logger_cfg, DictConfig):
         | 
| 51 | 
            +
                    raise TypeError("Logger config must be a DictConfig!")
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                for _, lg_conf in logger_cfg.items():
         | 
| 54 | 
            +
                    if "_target_" in lg_conf:
         | 
| 55 | 
            +
                        logger.info(f"Instantiating logger <{lg_conf._target_}>")
         | 
| 56 | 
            +
                        loggers_ls.append(hydra.utils.instantiate(lg_conf))
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                return loggers_ls
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def load_checkpoint_if_available(ckpt_path: str) -> str:
         | 
| 62 | 
            +
                """Check if the specified checkpoint exists and return the valid checkpoint path."""
         | 
| 63 | 
            +
                if ckpt_path and Path(ckpt_path).exists():
         | 
| 64 | 
            +
                    logger.info(f"Checkpoint found: {ckpt_path}")
         | 
| 65 | 
            +
                    return ckpt_path
         | 
| 66 | 
            +
                else:
         | 
| 67 | 
            +
                    logger.warning(
         | 
| 68 | 
            +
                        f"No checkpoint found at {ckpt_path}. Using current model weights."
         | 
| 69 | 
            +
                    )
         | 
| 70 | 
            +
                    return None
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            def clear_checkpoint_directory(ckpt_dir: str):
         | 
| 74 | 
            +
                """Clear all contents of the checkpoint directory without deleting the directory itself."""
         | 
| 75 | 
            +
                ckpt_dir_path = Path(ckpt_dir)
         | 
| 76 | 
            +
                if ckpt_dir_path.exists() and ckpt_dir_path.is_dir():
         | 
| 77 | 
            +
                    logger.info(f"Clearing checkpoint directory: {ckpt_dir}")
         | 
| 78 | 
            +
                    # Iterate over all files and directories in the checkpoint directory and remove them
         | 
| 79 | 
            +
                    for item in ckpt_dir_path.iterdir():
         | 
| 80 | 
            +
                        try:
         | 
| 81 | 
            +
                            if item.is_file() or item.is_symlink():
         | 
| 82 | 
            +
                                item.unlink()  # Remove file or symlink
         | 
| 83 | 
            +
                            elif item.is_dir():
         | 
| 84 | 
            +
                                shutil.rmtree(item)  # Remove directory
         | 
| 85 | 
            +
                        except Exception as e:
         | 
| 86 | 
            +
                            logger.error(f"Failed to delete {item}: {e}")
         | 
| 87 | 
            +
                    logger.info(f"Checkpoint directory cleared: {ckpt_dir}")
         | 
| 88 | 
            +
                else:
         | 
| 89 | 
            +
                    logger.info(
         | 
| 90 | 
            +
                        f"Checkpoint directory does not exist. Creating directory: {ckpt_dir}"
         | 
| 91 | 
            +
                    )
         | 
| 92 | 
            +
                    os.makedirs(ckpt_dir_path, exist_ok=True)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            @task_wrapper
         | 
| 96 | 
            +
            def train_module(
         | 
| 97 | 
            +
                cfg: DictConfig,
         | 
| 98 | 
            +
                data_module: L.LightningDataModule,
         | 
| 99 | 
            +
                model: L.LightningModule,
         | 
| 100 | 
            +
                trainer: L.Trainer,
         | 
| 101 | 
            +
            ):
         | 
| 102 | 
            +
                """Train the model using the provided Trainer and DataModule."""
         | 
| 103 | 
            +
                logger.info("Training the model")
         | 
| 104 | 
            +
                trainer.fit(model, data_module)
         | 
| 105 | 
            +
                train_metrics = trainer.callback_metrics
         | 
| 106 | 
            +
                try:
         | 
| 107 | 
            +
                    logger.info(
         | 
| 108 | 
            +
                        f"Training completed with the following metrics- train_acc: {train_metrics['train_acc'].item()} and val_acc: {train_metrics['val_acc'].item()}"
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
                except KeyError:
         | 
| 111 | 
            +
                    logger.info(f"Training completed with the following metrics:{train_metrics}")
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                return train_metrics
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            @task_wrapper
         | 
| 117 | 
            +
            def run_test_module(
         | 
| 118 | 
            +
                cfg: DictConfig,
         | 
| 119 | 
            +
                datamodule: L.LightningDataModule,
         | 
| 120 | 
            +
                model: L.LightningModule,
         | 
| 121 | 
            +
                trainer: L.Trainer,
         | 
| 122 | 
            +
            ):
         | 
| 123 | 
            +
                """Test the model using the best checkpoint or the current model weights."""
         | 
| 124 | 
            +
                logger.info("Testing the model")
         | 
| 125 | 
            +
                datamodule.setup(stage="test")
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                ckpt_path = load_checkpoint_if_available(cfg.ckpt_path)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                # If no checkpoint is available, Lightning will use current model weights
         | 
| 130 | 
            +
                test_metrics = trainer.test(model, datamodule, ckpt_path=ckpt_path)
         | 
| 131 | 
            +
                logger.info(f"Test metrics:\n{test_metrics}")
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                return test_metrics[0] if test_metrics else {}
         | 
| 134 | 
            +
             | 
| 135 | 
            +
             | 
| 136 | 
            +
            @hydra.main(config_path="../configs", config_name="train", version_base="1.1")
         | 
| 137 | 
            +
            def setup_run_trainer(cfg: DictConfig):
         | 
| 138 | 
            +
                """Set up and run the Trainer for training and testing the model."""
         | 
| 139 | 
            +
                # show me the entire config
         | 
| 140 | 
            +
                logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
         | 
| 141 | 
            +
                # Initialize logger
         | 
| 142 | 
            +
                if cfg.task_name == "train":
         | 
| 143 | 
            +
                    log_path = Path(cfg.paths.log_dir) / "train.log"
         | 
| 144 | 
            +
                else:
         | 
| 145 | 
            +
                    log_path = Path(cfg.paths.log_dir) / "eval.log"
         | 
| 146 | 
            +
                setup_logger(log_path)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                # the path to the checkpoint directory
         | 
| 149 | 
            +
                root_dir = cfg.paths.root_dir
         | 
| 150 | 
            +
                logger.info(f"Root directory: {root_dir}")
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                logger.info(f"Current working directory: {os.listdir(root_dir)}")
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                ckpt_dir = cfg.paths.ckpt_dir
         | 
| 155 | 
            +
                logger.info(f"Checkpoint directory: {ckpt_dir}")
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                # the path to the data directory
         | 
| 158 | 
            +
                data_dir = cfg.paths.data_dir
         | 
| 159 | 
            +
                logger.info(f"Data directory: {data_dir}")
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                # the path to the log directory
         | 
| 162 | 
            +
                log_dir = cfg.paths.log_dir
         | 
| 163 | 
            +
                logger.info(f"Log directory: {log_dir}")
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                # the path to the artifact directory
         | 
| 166 | 
            +
                artifact_dir = cfg.paths.artifact_dir
         | 
| 167 | 
            +
                logger.info(f"Artifact directory: {artifact_dir}")
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                # output directory
         | 
| 170 | 
            +
                output_dir = cfg.paths.output_dir
         | 
| 171 | 
            +
                logger.info(f"Output directory: {output_dir}")
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                # name of the experiment
         | 
| 174 | 
            +
                experiment_name = cfg.name
         | 
| 175 | 
            +
                logger.info(f"Experiment name: {experiment_name}")
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                # Initialize DataModule
         | 
| 178 | 
            +
                if experiment_name == "dogbreed_experiment":
         | 
| 179 | 
            +
                    logger.info("Setting up the DataModule")
         | 
| 180 | 
            +
                    dataset_df, datamodule = main_dataloader(cfg)
         | 
| 181 | 
            +
                    labels = dataset_df.label.nunique()
         | 
| 182 | 
            +
                    logger.info(f"Number of classes: {labels}")
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    os.makedirs(cfg.paths.artifact_dir, exist_ok=True)
         | 
| 185 | 
            +
                    dataset_df.to_csv(
         | 
| 186 | 
            +
                        Path(cfg.paths.artifact_dir) / "dogbreed_dataset.csv", index=False
         | 
| 187 | 
            +
                    )
         | 
| 188 | 
            +
                elif (
         | 
| 189 | 
            +
                    experiment_name == "catdog_experiment"
         | 
| 190 | 
            +
                    or experiment_name == "catdog_experiment_convnext"
         | 
| 191 | 
            +
                ):
         | 
| 192 | 
            +
                    # Initialize DataModule
         | 
| 193 | 
            +
                    logger.info(f"Instantiating datamodule <{cfg.data._target_}>")
         | 
| 194 | 
            +
                    datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                # Check for GPU availability
         | 
| 197 | 
            +
                logger.info("GPU available" if torch.cuda.is_available() else "No GPU available")
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                # Set seed for reproducibility
         | 
| 200 | 
            +
                L.seed_everything(cfg.seed, workers=True)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                # Initialize model
         | 
| 203 | 
            +
                logger.info(f"Instantiating model <{cfg.model._target_}>")
         | 
| 204 | 
            +
                model: L.LightningModule = hydra.utils.instantiate(cfg.model)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                logger.info(f"Model summary:\n{model}")
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                # Set up callbacks and loggers
         | 
| 209 | 
            +
                logger.info("Setting up callbacks and loggers")
         | 
| 210 | 
            +
                callbacks: List[L.Callback] = instantiate_callbacks(cfg.get("callbacks"))
         | 
| 211 | 
            +
                logger.info(f"Callbacks: {callbacks}")
         | 
| 212 | 
            +
                loggers: List[Logger] = instantiate_loggers(cfg.get("logger"))
         | 
| 213 | 
            +
                logger.info(f"Loggers: {loggers}")
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                # Initialize Trainer
         | 
| 216 | 
            +
                logger.info(f"Instantiating trainer <{cfg.trainer._target_}>")
         | 
| 217 | 
            +
                trainer: L.Trainer = hydra.utils.instantiate(
         | 
| 218 | 
            +
                    cfg.trainer, callbacks=callbacks, logger=loggers
         | 
| 219 | 
            +
                )
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                # Train and test the model based on config settings
         | 
| 222 | 
            +
                train_metrics = {}
         | 
| 223 | 
            +
                if cfg.get("train"):
         | 
| 224 | 
            +
                    # clear the checkpoint directory
         | 
| 225 | 
            +
                    clear_checkpoint_directory(cfg.paths.ckpt_dir)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    logger.info("Training the model")
         | 
| 228 | 
            +
                    train_metrics = train_module(cfg, datamodule, model, trainer)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    # Write training done flag using Hydra paths config
         | 
| 231 | 
            +
                    done_flag_path = Path(cfg.paths.ckpt_dir) / "train_done.flag"
         | 
| 232 | 
            +
                    with done_flag_path.open("w") as f:
         | 
| 233 | 
            +
                        f.write("Training completed.\n")
         | 
| 234 | 
            +
                    logger.info(f"Training completion flag written to: {done_flag_path}")
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    logger.info(
         | 
| 237 | 
            +
                        f"Training completed. Checkpoint directory: {os.listdir(cfg.paths.ckpt_dir)}"
         | 
| 238 | 
            +
                    )
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                test_metrics = {}
         | 
| 241 | 
            +
                if cfg.get("test"):
         | 
| 242 | 
            +
                    logger.info(f"Checkpoint directory: {os.listdir(cfg.paths.ckpt_dir)}")
         | 
| 243 | 
            +
                    test_metrics = run_test_module(cfg, datamodule, model, trainer)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                # Combine metrics
         | 
| 246 | 
            +
                all_metrics = {**train_metrics, **test_metrics}
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                # Extract and return the optimization metric
         | 
| 249 | 
            +
                optimization_metric = all_metrics.get(cfg.get("optimization_metric"))
         | 
| 250 | 
            +
                if optimization_metric is None:
         | 
| 251 | 
            +
                    logger.warning(
         | 
| 252 | 
            +
                        f"Optimization metric '{cfg.get('optimization_metric')}' not found in metrics. Returning 0."
         | 
| 253 | 
            +
                    )
         | 
| 254 | 
            +
                    return 0.0
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                return optimization_metric
         | 
| 257 | 
            +
             | 
| 258 | 
            +
             | 
| 259 | 
            +
            if __name__ == "__main__":
         | 
| 260 | 
            +
                setup_run_trainer()
         | 
    	
        src/{train.py → train_optuna_callbacks.py}
    RENAMED
    
    | @@ -1,7 +1,5 @@ | |
| 1 | 
             
            """
         | 
| 2 | 
            -
            Train and evaluate a model using PyTorch Lightning.
         | 
| 3 | 
            -
            Initializes the DataModule, Model, Trainer, and runs training and testing.
         | 
| 4 | 
            -
            Initializes loggers and callbacks from the configuration using Hydra configuration but with a more modular approach without direct instantiation.
         | 
| 5 | 
             
            """
         | 
| 6 |  | 
| 7 | 
             
            import os
         | 
| @@ -10,47 +8,61 @@ from pathlib import Path | |
| 10 | 
             
            from typing import List
         | 
| 11 | 
             
            import torch
         | 
| 12 | 
             
            import lightning as L
         | 
| 13 | 
            -
            from lightning.pytorch.loggers import Logger, TensorBoardLogger, CSVLogger
         | 
| 14 | 
            -
            from lightning.pytorch.callbacks import (
         | 
| 15 | 
            -
                ModelCheckpoint,
         | 
| 16 | 
            -
                EarlyStopping,
         | 
| 17 | 
            -
                RichModelSummary,
         | 
| 18 | 
            -
                RichProgressBar,
         | 
| 19 | 
            -
            )
         | 
| 20 | 
             
            from dotenv import load_dotenv, find_dotenv
         | 
| 21 | 
             
            import hydra
         | 
| 22 | 
             
            from omegaconf import DictConfig, OmegaConf
         | 
| 23 | 
            -
            from src.datamodules.catdog_datamodule import CatDogImageDataModule
         | 
| 24 | 
             
            from src.utils.logging_utils import setup_logger, task_wrapper
         | 
| 25 | 
             
            from loguru import logger
         | 
| 26 | 
             
            import rootutils
         | 
|  | |
|  | |
|  | |
| 27 |  | 
| 28 | 
             
            # Load environment variables
         | 
| 29 | 
             
            load_dotenv(find_dotenv(".env"))
         | 
| 30 |  | 
| 31 | 
             
            # Setup root directory
         | 
| 32 | 
            -
             | 
| 33 | 
             
            root = rootutils.setup_root(__file__, indicator=".project-root")
         | 
| 34 |  | 
| 35 |  | 
| 36 | 
            -
            def  | 
| 37 | 
            -
                """ | 
| 38 | 
            -
                 | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
                    " | 
| 42 | 
            -
                     | 
| 43 | 
            -
             | 
| 44 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 45 |  | 
| 46 |  | 
| 47 | 
            -
            def  | 
| 48 | 
            -
                """ | 
| 49 | 
            -
                 | 
| 50 | 
            -
             | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 54 |  | 
| 55 |  | 
| 56 | 
             
            def load_checkpoint_if_available(ckpt_path: str) -> str:
         | 
| @@ -81,16 +93,19 @@ def clear_checkpoint_directory(ckpt_dir: str): | |
| 81 | 
             
            def train_module(
         | 
| 82 | 
             
                data_module: L.LightningDataModule, model: L.LightningModule, trainer: L.Trainer
         | 
| 83 | 
             
            ):
         | 
| 84 | 
            -
                """Train the model  | 
| 85 | 
            -
                logger.info("Starting training")
         | 
|  | |
| 86 | 
             
                trainer.fit(model, data_module)
         | 
| 87 | 
            -
                 | 
| 88 | 
            -
             | 
| 89 | 
            -
                 | 
| 90 | 
            -
             | 
| 91 | 
            -
                     | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
|  | |
|  | |
| 94 |  | 
| 95 |  | 
| 96 | 
             
            @task_wrapper
         | 
| @@ -110,77 +125,84 @@ def run_test_module( | |
| 110 | 
             
                return test_metrics[0] if test_metrics else {}
         | 
| 111 |  | 
| 112 |  | 
| 113 | 
            -
             | 
| 114 | 
            -
             | 
| 115 | 
            -
                """Set up and run the Trainer for training and testing."""
         | 
| 116 | 
            -
                # Display configuration
         | 
| 117 | 
            -
                logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
         | 
| 118 |  | 
| 119 | 
            -
                #  | 
| 120 | 
            -
                 | 
| 121 | 
            -
             | 
| 122 | 
            -
                )
         | 
| 123 | 
            -
                 | 
| 124 | 
            -
             | 
| 125 | 
            -
                # Display key paths
         | 
| 126 | 
            -
                for path_name in [
         | 
| 127 | 
            -
                    "root_dir",
         | 
| 128 | 
            -
                    "data_dir",
         | 
| 129 | 
            -
                    "log_dir",
         | 
| 130 | 
            -
                    "ckpt_dir",
         | 
| 131 | 
            -
                    "artifact_dir",
         | 
| 132 | 
            -
                    "output_dir",
         | 
| 133 | 
            -
                ]:
         | 
| 134 | 
            -
                    logger.info(
         | 
| 135 | 
            -
                        f"{path_name.replace('_', ' ').capitalize()}: {cfg.paths[path_name]}"
         | 
| 136 | 
            -
                    )
         | 
| 137 |  | 
| 138 | 
            -
                # Initialize  | 
| 139 | 
            -
                 | 
| 140 | 
            -
                datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
         | 
| 141 | 
            -
                logger.info(f"Instantiating model <{cfg.model._target_}>")
         | 
| 142 | 
             
                model: L.LightningModule = hydra.utils.instantiate(cfg.model)
         | 
| 143 |  | 
| 144 | 
            -
                #  | 
| 145 | 
            -
                 | 
| 146 | 
            -
                L.seed_everything(cfg.seed, workers=True)
         | 
| 147 |  | 
| 148 | 
            -
                #  | 
| 149 | 
            -
                 | 
| 150 | 
            -
             | 
| 151 | 
            -
                 | 
| 152 | 
            -
                 | 
| 153 | 
            -
             | 
| 154 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 155 | 
             
                )
         | 
| 156 |  | 
| 157 | 
            -
                #  | 
| 158 | 
            -
                 | 
| 159 | 
            -
                 | 
| 160 | 
            -
                    clear_checkpoint_directory(cfg.paths.ckpt_dir)
         | 
| 161 | 
            -
                    train_metrics = train_module(datamodule, model, trainer)
         | 
| 162 | 
            -
                    (Path(cfg.paths.ckpt_dir) / "train_done.flag").write_text(
         | 
| 163 | 
            -
                        "Training completed.\n"
         | 
| 164 | 
            -
                    )
         | 
| 165 |  | 
| 166 | 
            -
                 | 
| 167 | 
            -
             | 
| 168 | 
            -
             | 
| 169 | 
            -
             | 
| 170 | 
            -
             | 
| 171 | 
            -
                # Combine metrics and extract optimization metric
         | 
| 172 | 
            -
                all_metrics = {**train_metrics, **test_metrics}
         | 
| 173 | 
            -
                optimization_metric = all_metrics.get(cfg.get("optimization_metric"), 0.0)
         | 
| 174 | 
            -
                (
         | 
| 175 | 
            -
                    logger.warning(
         | 
| 176 | 
            -
                        f"Optimization metric '{cfg.get('optimization_metric')}' not found. Defaulting to 0."
         | 
| 177 | 
             
                    )
         | 
| 178 | 
            -
                     | 
| 179 | 
            -
             | 
| 180 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 181 |  | 
| 182 | 
            -
                return  | 
| 183 |  | 
| 184 |  | 
| 185 | 
             
            if __name__ == "__main__":
         | 
| 186 | 
            -
                 | 
|  | |
| 1 | 
             
            """
         | 
| 2 | 
            +
            Train and evaluate a model using PyTorch Lightning with Optuna for hyperparameter optimization.
         | 
|  | |
|  | |
| 3 | 
             
            """
         | 
| 4 |  | 
| 5 | 
             
            import os
         | 
|  | |
| 8 | 
             
            from typing import List
         | 
| 9 | 
             
            import torch
         | 
| 10 | 
             
            import lightning as L
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 11 | 
             
            from dotenv import load_dotenv, find_dotenv
         | 
| 12 | 
             
            import hydra
         | 
| 13 | 
             
            from omegaconf import DictConfig, OmegaConf
         | 
|  | |
| 14 | 
             
            from src.utils.logging_utils import setup_logger, task_wrapper
         | 
| 15 | 
             
            from loguru import logger
         | 
| 16 | 
             
            import rootutils
         | 
| 17 | 
            +
            from lightning.pytorch.loggers import Logger
         | 
| 18 | 
            +
            import optuna
         | 
| 19 | 
            +
            from lightning.pytorch import Trainer
         | 
| 20 |  | 
| 21 | 
             
            # Load environment variables
         | 
| 22 | 
             
            load_dotenv(find_dotenv(".env"))
         | 
| 23 |  | 
| 24 | 
             
            # Setup root directory
         | 
|  | |
| 25 | 
             
            root = rootutils.setup_root(__file__, indicator=".project-root")
         | 
| 26 |  | 
| 27 |  | 
| 28 | 
            +
            def instantiate_callbacks(callback_cfg: DictConfig) -> List[L.Callback]:
         | 
| 29 | 
            +
                """Instantiate and return a list of callbacks from the configuration."""
         | 
| 30 | 
            +
                callbacks: List[L.Callback] = []
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                if not callback_cfg:
         | 
| 33 | 
            +
                    logger.warning("No callback configs found! Skipping..")
         | 
| 34 | 
            +
                    return callbacks
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                if not isinstance(callback_cfg, DictConfig):
         | 
| 37 | 
            +
                    raise TypeError("Callbacks config must be a DictConfig!")
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                for _, cb_conf in callback_cfg.items():
         | 
| 40 | 
            +
                    if "_target_" in cb_conf:
         | 
| 41 | 
            +
                        logger.info(f"Instantiating callback <{cb_conf._target_}>")
         | 
| 42 | 
            +
                        callbacks.append(hydra.utils.instantiate(cb_conf))
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                return callbacks
         | 
| 45 |  | 
| 46 |  | 
| 47 | 
            +
            def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
         | 
| 48 | 
            +
                """Instantiate and return a list of loggers from the configuration."""
         | 
| 49 | 
            +
                loggers_ls: List[Logger] = []
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                if not logger_cfg or isinstance(logger_cfg, bool):
         | 
| 52 | 
            +
                    logger.warning("No valid logger configs found! Skipping..")
         | 
| 53 | 
            +
                    return loggers_ls
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                if not isinstance(logger_cfg, DictConfig):
         | 
| 56 | 
            +
                    raise TypeError("Logger config must be a DictConfig!")
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                for _, lg_conf in logger_cfg.items():
         | 
| 59 | 
            +
                    if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
         | 
| 60 | 
            +
                        logger.info(f"Instantiating logger <{lg_conf._target_}>")
         | 
| 61 | 
            +
                        try:
         | 
| 62 | 
            +
                            loggers_ls.append(hydra.utils.instantiate(lg_conf))
         | 
| 63 | 
            +
                        except Exception as e:
         | 
| 64 | 
            +
                            logger.error(f"Failed to instantiate logger {lg_conf}: {e}")
         | 
| 65 | 
            +
                return loggers_ls
         | 
| 66 |  | 
| 67 |  | 
| 68 | 
             
            def load_checkpoint_if_available(ckpt_path: str) -> str:
         | 
|  | |
| 93 | 
             
            def train_module(
         | 
| 94 | 
             
                data_module: L.LightningDataModule, model: L.LightningModule, trainer: L.Trainer
         | 
| 95 | 
             
            ):
         | 
| 96 | 
            +
                """Train the model, return validation accuracy for each epoch."""
         | 
| 97 | 
            +
                logger.info("Starting training with custom pruning")
         | 
| 98 | 
            +
             | 
| 99 | 
             
                trainer.fit(model, data_module)
         | 
| 100 | 
            +
                val_accuracies = []
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                for epoch in range(trainer.current_epoch):
         | 
| 103 | 
            +
                    val_acc = trainer.callback_metrics.get("val_acc")
         | 
| 104 | 
            +
                    if val_acc is not None:
         | 
| 105 | 
            +
                        val_accuracies.append(val_acc.item())
         | 
| 106 | 
            +
                        logger.info(f"Epoch {epoch}: val_acc={val_acc}")
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                return val_accuracies
         | 
| 109 |  | 
| 110 |  | 
| 111 | 
             
            @task_wrapper
         | 
|  | |
| 125 | 
             
                return test_metrics[0] if test_metrics else {}
         | 
| 126 |  | 
| 127 |  | 
| 128 | 
            +
            def objective(trial: optuna.trial.Trial, cfg: DictConfig, callbacks: List[L.Callback]):
         | 
| 129 | 
            +
                """Objective function for Optuna hyperparameter tuning."""
         | 
|  | |
|  | |
|  | |
| 130 |  | 
| 131 | 
            +
                # Sample hyperparameters for the model
         | 
| 132 | 
            +
                cfg.model.embed_dim = trial.suggest_categorical("embed_dim", [64, 128, 256])
         | 
| 133 | 
            +
                cfg.model.depth = trial.suggest_int("depth", 2, 6)
         | 
| 134 | 
            +
                cfg.model.lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
         | 
| 135 | 
            +
                cfg.model.mlp_ratio = trial.suggest_float("mlp_ratio", 1.0, 4.0)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 136 |  | 
| 137 | 
            +
                # Initialize data module and model
         | 
| 138 | 
            +
                data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
         | 
|  | |
|  | |
| 139 | 
             
                model: L.LightningModule = hydra.utils.instantiate(cfg.model)
         | 
| 140 |  | 
| 141 | 
            +
                # Set up logger
         | 
| 142 | 
            +
                loggers = instantiate_loggers(cfg.logger)
         | 
|  | |
| 143 |  | 
| 144 | 
            +
                # Trainer configuration with passed callbacks
         | 
| 145 | 
            +
                trainer = Trainer(**cfg.trainer, logger=loggers, callbacks=callbacks)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                # Clear checkpoint directory
         | 
| 148 | 
            +
                clear_checkpoint_directory(cfg.paths.ckpt_dir)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                # Train and get val_acc for each epoch
         | 
| 151 | 
            +
                val_accuracies = train_module(data_module, model, trainer)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                # Report validation accuracy and prune if necessary
         | 
| 154 | 
            +
                for epoch, val_acc in enumerate(val_accuracies):
         | 
| 155 | 
            +
                    trial.report(val_acc, step=epoch)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    # Check if the trial should be pruned at this epoch
         | 
| 158 | 
            +
                    if trial.should_prune():
         | 
| 159 | 
            +
                        logger.info(f"Pruning trial at epoch {epoch}")
         | 
| 160 | 
            +
                        raise optuna.TrialPruned()
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                # Return the final validation accuracy as the objective metric
         | 
| 163 | 
            +
                return val_accuracies[-1] if val_accuracies else 0.0
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            @hydra.main(config_path="../configs", config_name="train", version_base="1.3")
         | 
| 167 | 
            +
            def setup_trainer(cfg: DictConfig):
         | 
| 168 | 
            +
                logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                setup_logger(
         | 
| 171 | 
            +
                    Path(cfg.paths.log_dir)
         | 
| 172 | 
            +
                    / ("train.log" if cfg.task_name == "train" else "eval.log")
         | 
| 173 | 
             
                )
         | 
| 174 |  | 
| 175 | 
            +
                # Instantiate callbacks
         | 
| 176 | 
            +
                callbacks = instantiate_callbacks(cfg.callbacks)
         | 
| 177 | 
            +
                logger.info(f"Callbacks: {callbacks}")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 178 |  | 
| 179 | 
            +
                if cfg.get("train", False):
         | 
| 180 | 
            +
                    pruner = optuna.pruners.MedianPruner()
         | 
| 181 | 
            +
                    study = optuna.create_study(
         | 
| 182 | 
            +
                        direction="maximize", pruner=pruner, study_name="pytorch_lightning_optuna"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 183 | 
             
                    )
         | 
| 184 | 
            +
                    study.optimize(
         | 
| 185 | 
            +
                        lambda trial: objective(trial, cfg, callbacks),
         | 
| 186 | 
            +
                        n_trials=5,
         | 
| 187 | 
            +
                        show_progress_bar=True,
         | 
| 188 | 
            +
                    )
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    # Log best trial results
         | 
| 191 | 
            +
                    best_trial = study.best_trial
         | 
| 192 | 
            +
                    logger.info(f"Best trial number: {best_trial.number}")
         | 
| 193 | 
            +
                    logger.info(f"Best trial value (val_acc): {best_trial.value}")
         | 
| 194 | 
            +
                    for key, value in best_trial.params.items():
         | 
| 195 | 
            +
                        logger.info(f"  {key}: {value}")
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                if cfg.get("test", False):
         | 
| 198 | 
            +
                    data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
         | 
| 199 | 
            +
                    model: L.LightningModule = hydra.utils.instantiate(cfg.model)
         | 
| 200 | 
            +
                    trainer = Trainer(**cfg.trainer, logger=instantiate_loggers(cfg.logger))
         | 
| 201 | 
            +
                    test_metrics = run_test_module(cfg, data_module, model, trainer)
         | 
| 202 | 
            +
                    logger.info(f"Test metrics: {test_metrics}")
         | 
| 203 |  | 
| 204 | 
            +
                return cfg.model if not cfg.get("test", False) else test_metrics
         | 
| 205 |  | 
| 206 |  | 
| 207 | 
             
            if __name__ == "__main__":
         | 
| 208 | 
            +
                setup_trainer()
         | 
