Spaces:
Runtime error
Runtime error
Soutrik
commited on
Commit
·
3fa4d71
1
Parent(s):
8d4131e
local training check
Browse files- artifacts/image_prediction.png +0 -0
- docker-compose.yaml +2 -2
- src/train_optuna_callbacks.py +27 -6
artifacts/image_prediction.png
CHANGED
|
|
docker-compose.yaml
CHANGED
|
@@ -5,7 +5,7 @@ services:
|
|
| 5 |
build:
|
| 6 |
context: .
|
| 7 |
command: |
|
| 8 |
-
python -m src.
|
| 9 |
touch /app/checkpoints/train_done.flag
|
| 10 |
volumes:
|
| 11 |
- ./data:/app/data
|
|
@@ -25,7 +25,7 @@ services:
|
|
| 25 |
build:
|
| 26 |
context: .
|
| 27 |
command: |
|
| 28 |
-
sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.
|
| 29 |
volumes:
|
| 30 |
- ./data:/app/data
|
| 31 |
- ./checkpoints:/app/checkpoints
|
|
|
|
| 5 |
build:
|
| 6 |
context: .
|
| 7 |
command: |
|
| 8 |
+
python -m src.train_optuna_callbacks experiment=catdog_experiment ++task_name=train ++train=True ++test=False && \
|
| 9 |
touch /app/checkpoints/train_done.flag
|
| 10 |
volumes:
|
| 11 |
- ./data:/app/data
|
|
|
|
| 25 |
build:
|
| 26 |
context: .
|
| 27 |
command: |
|
| 28 |
+
sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.train_optuna_callbacks experiment=catdog_experiment ++task_name=eval ++train=False ++test=True'
|
| 29 |
volumes:
|
| 30 |
- ./data:/app/data
|
| 31 |
- ./checkpoints:/app/checkpoints
|
src/train_optuna_callbacks.py
CHANGED
|
@@ -15,8 +15,10 @@ from src.utils.logging_utils import setup_logger, task_wrapper
|
|
| 15 |
from loguru import logger
|
| 16 |
import rootutils
|
| 17 |
from lightning.pytorch.loggers import Logger
|
|
|
|
| 18 |
import optuna
|
| 19 |
from lightning.pytorch import Trainer
|
|
|
|
| 20 |
|
| 21 |
# Load environment variables
|
| 22 |
load_dotenv(find_dotenv(".env"))
|
|
@@ -25,7 +27,7 @@ load_dotenv(find_dotenv(".env"))
|
|
| 25 |
root = rootutils.setup_root(__file__, indicator=".project-root")
|
| 26 |
|
| 27 |
|
| 28 |
-
def instantiate_callbacks(callback_cfg: DictConfig) -> List[
|
| 29 |
"""Instantiate and return a list of callbacks from the configuration."""
|
| 30 |
callbacks: List[L.Callback] = []
|
| 31 |
|
|
@@ -125,7 +127,7 @@ def run_test_module(
|
|
| 125 |
return test_metrics[0] if test_metrics else {}
|
| 126 |
|
| 127 |
|
| 128 |
-
def objective(trial: optuna.trial.Trial, cfg: DictConfig, callbacks: List[
|
| 129 |
"""Objective function for Optuna hyperparameter tuning."""
|
| 130 |
|
| 131 |
# Sample hyperparameters for the model
|
|
@@ -144,9 +146,6 @@ def objective(trial: optuna.trial.Trial, cfg: DictConfig, callbacks: List[L.Call
|
|
| 144 |
# Trainer configuration with passed callbacks
|
| 145 |
trainer = Trainer(**cfg.trainer, logger=loggers, callbacks=callbacks)
|
| 146 |
|
| 147 |
-
# Clear checkpoint directory
|
| 148 |
-
clear_checkpoint_directory(cfg.paths.ckpt_dir)
|
| 149 |
-
|
| 150 |
# Train and get val_acc for each epoch
|
| 151 |
val_accuracies = train_module(data_module, model, trainer)
|
| 152 |
|
|
@@ -177,13 +176,16 @@ def setup_trainer(cfg: DictConfig):
|
|
| 177 |
logger.info(f"Callbacks: {callbacks}")
|
| 178 |
|
| 179 |
if cfg.get("train", False):
|
|
|
|
|
|
|
|
|
|
| 180 |
pruner = optuna.pruners.MedianPruner()
|
| 181 |
study = optuna.create_study(
|
| 182 |
direction="maximize", pruner=pruner, study_name="pytorch_lightning_optuna"
|
| 183 |
)
|
| 184 |
study.optimize(
|
| 185 |
lambda trial: objective(trial, cfg, callbacks),
|
| 186 |
-
n_trials=
|
| 187 |
show_progress_bar=True,
|
| 188 |
)
|
| 189 |
|
|
@@ -194,7 +196,26 @@ def setup_trainer(cfg: DictConfig):
|
|
| 194 |
for key, value in best_trial.params.items():
|
| 195 |
logger.info(f" {key}: {value}")
|
| 196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
if cfg.get("test", False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
|
| 199 |
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
|
| 200 |
trainer = Trainer(**cfg.trainer, logger=instantiate_loggers(cfg.logger))
|
|
|
|
| 15 |
from loguru import logger
|
| 16 |
import rootutils
|
| 17 |
from lightning.pytorch.loggers import Logger
|
| 18 |
+
from lightning.pytorch.callbacks import Callback
|
| 19 |
import optuna
|
| 20 |
from lightning.pytorch import Trainer
|
| 21 |
+
import json
|
| 22 |
|
| 23 |
# Load environment variables
|
| 24 |
load_dotenv(find_dotenv(".env"))
|
|
|
|
| 27 |
root = rootutils.setup_root(__file__, indicator=".project-root")
|
| 28 |
|
| 29 |
|
| 30 |
+
def instantiate_callbacks(callback_cfg: DictConfig) -> List[Callback]:
|
| 31 |
"""Instantiate and return a list of callbacks from the configuration."""
|
| 32 |
callbacks: List[L.Callback] = []
|
| 33 |
|
|
|
|
| 127 |
return test_metrics[0] if test_metrics else {}
|
| 128 |
|
| 129 |
|
| 130 |
+
def objective(trial: optuna.trial.Trial, cfg: DictConfig, callbacks: List[Callback]):
|
| 131 |
"""Objective function for Optuna hyperparameter tuning."""
|
| 132 |
|
| 133 |
# Sample hyperparameters for the model
|
|
|
|
| 146 |
# Trainer configuration with passed callbacks
|
| 147 |
trainer = Trainer(**cfg.trainer, logger=loggers, callbacks=callbacks)
|
| 148 |
|
|
|
|
|
|
|
|
|
|
| 149 |
# Train and get val_acc for each epoch
|
| 150 |
val_accuracies = train_module(data_module, model, trainer)
|
| 151 |
|
|
|
|
| 176 |
logger.info(f"Callbacks: {callbacks}")
|
| 177 |
|
| 178 |
if cfg.get("train", False):
|
| 179 |
+
# Clear checkpoint directory
|
| 180 |
+
clear_checkpoint_directory(cfg.paths.ckpt_dir)
|
| 181 |
+
# find the best hyperparameters using Optuna and train the model
|
| 182 |
pruner = optuna.pruners.MedianPruner()
|
| 183 |
study = optuna.create_study(
|
| 184 |
direction="maximize", pruner=pruner, study_name="pytorch_lightning_optuna"
|
| 185 |
)
|
| 186 |
study.optimize(
|
| 187 |
lambda trial: objective(trial, cfg, callbacks),
|
| 188 |
+
n_trials=3,
|
| 189 |
show_progress_bar=True,
|
| 190 |
)
|
| 191 |
|
|
|
|
| 196 |
for key, value in best_trial.params.items():
|
| 197 |
logger.info(f" {key}: {value}")
|
| 198 |
|
| 199 |
+
# write the best hyperparameters to the config
|
| 200 |
+
best_hyperparams = {key: value for key, value in best_trial.params.items()}
|
| 201 |
+
best_hyperparams_path = Path(cfg.paths.ckpt_dir) / "best_hyperparams.json"
|
| 202 |
+
with open(best_hyperparams_path, "w") as f:
|
| 203 |
+
json.dump(best_hyperparams, f)
|
| 204 |
+
logger.info(f"Best hyperparameters saved to {best_hyperparams_path}")
|
| 205 |
+
|
| 206 |
if cfg.get("test", False):
|
| 207 |
+
best_hyperparams_path = Path(cfg.paths.ckpt_dir) / "best_hyperparams.json"
|
| 208 |
+
if best_hyperparams_path.exists():
|
| 209 |
+
with open(best_hyperparams_path, "r") as f:
|
| 210 |
+
best_hyperparams = json.load(f)
|
| 211 |
+
cfg.model.update(best_hyperparams)
|
| 212 |
+
logger.info(f"Loaded best hyperparameters for testing: {best_hyperparams}")
|
| 213 |
+
else:
|
| 214 |
+
logger.error(
|
| 215 |
+
"Best hyperparameters not found! Using default hyperparameters."
|
| 216 |
+
)
|
| 217 |
+
raise FileNotFoundError("Best hyperparameters not found!")
|
| 218 |
+
|
| 219 |
data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
|
| 220 |
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
|
| 221 |
trainer = Trainer(**cfg.trainer, logger=instantiate_loggers(cfg.logger))
|