Francesco Capuano
Initial commit
529ed6b
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.common.datasets.utils import load_json, write_json
from lerobot.common.optim.optimizers import load_optimizer_state, save_optimizer_state
from lerobot.common.optim.schedulers import load_scheduler_state, save_scheduler_state
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.utils.random_utils import load_rng_state, save_rng_state
from lerobot.configs.train import TrainPipelineConfig
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
def get_step_identifier(step: int, total_steps: int) -> str:
num_digits = max(6, len(str(total_steps)))
return f"{step:0{num_digits}d}"
def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Path:
"""Returns the checkpoint sub-directory corresponding to the step number."""
step_identifier = get_step_identifier(step, total_steps)
return output_dir / CHECKPOINTS_DIR / step_identifier
def save_training_step(step: int, save_dir: Path) -> None:
write_json({"step": step}, save_dir / TRAINING_STEP)
def load_training_step(save_dir: Path) -> int:
training_step = load_json(save_dir / TRAINING_STEP)
return training_step["step"]
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
if last_checkpoint_dir.is_symlink():
last_checkpoint_dir.unlink()
relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent)
last_checkpoint_dir.symlink_to(relative_target)
def save_checkpoint(
checkpoint_dir: Path,
step: int,
cfg: TrainPipelineConfig,
policy: PreTrainedPolicy,
optimizer: Optimizer,
scheduler: LRScheduler | None = None,
) -> None:
"""This function creates the following directory structure:
005000/ # training step at checkpoint
β”œβ”€β”€ pretrained_model/
β”‚ β”œβ”€β”€ config.json # policy config
β”‚ β”œβ”€β”€ model.safetensors # policy weights
β”‚ └── train_config.json # train config
└── training_state/
β”œβ”€β”€ optimizer_param_groups.json # optimizer param groups
β”œβ”€β”€ optimizer_state.safetensors # optimizer state
β”œβ”€β”€ rng_state.safetensors # rng states
β”œβ”€β”€ scheduler_state.json # scheduler state
└── training_step.json # training step
Args:
cfg (TrainPipelineConfig): The training config used for this run.
step (int): The training step at that checkpoint.
policy (PreTrainedPolicy): The policy to save.
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
cfg.save_pretrained(pretrained_dir)
save_training_state(checkpoint_dir, step, optimizer, scheduler)
def save_training_state(
checkpoint_dir: Path,
train_step: int,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
) -> None:
"""
Saves the training step, optimizer state, scheduler state, and rng state.
Args:
save_dir (Path): The directory to save artifacts to.
train_step (int): Current training step.
optimizer (Optimizer | None, optional): The optimizer from which to save the state_dict.
Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
Defaults to None.
"""
save_dir = checkpoint_dir / TRAINING_STATE_DIR
save_dir.mkdir(parents=True, exist_ok=True)
save_training_step(train_step, save_dir)
save_rng_state(save_dir)
if optimizer is not None:
save_optimizer_state(optimizer, save_dir)
if scheduler is not None:
save_scheduler_state(scheduler, save_dir)
def load_training_state(
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None
) -> tuple[int, Optimizer, LRScheduler | None]:
"""
Loads the training step, optimizer state, scheduler state, and rng state.
This is used to resume a training run.
Args:
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
optimizer (Optimizer): The optimizer to load the state_dict to.
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
Raises:
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
Returns:
tuple[int, Optimizer, LRScheduler | None]: training step, optimizer and scheduler with their
state_dict loaded.
"""
training_state_dir = checkpoint_dir / TRAINING_STATE_DIR
if not training_state_dir.is_dir():
raise NotADirectoryError(training_state_dir)
load_rng_state(training_state_dir)
step = load_training_step(training_state_dir)
optimizer = load_optimizer_state(optimizer, training_state_dir)
if scheduler is not None:
scheduler = load_scheduler_state(scheduler, training_state_dir)
return step, optimizer, scheduler