#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from pathlib import Path from termcolor import colored from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from lerobot.common.constants import ( CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK, PRETRAINED_MODEL_DIR, TRAINING_STATE_DIR, TRAINING_STEP, ) from lerobot.common.datasets.utils import load_json, write_json from lerobot.common.optim.optimizers import load_optimizer_state, save_optimizer_state from lerobot.common.optim.schedulers import load_scheduler_state, save_scheduler_state from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.utils.random_utils import load_rng_state, save_rng_state from lerobot.configs.train import TrainPipelineConfig def log_output_dir(out_dir): logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") def get_step_identifier(step: int, total_steps: int) -> str: num_digits = max(6, len(str(total_steps))) return f"{step:0{num_digits}d}" def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Path: """Returns the checkpoint sub-directory corresponding to the step number.""" step_identifier = get_step_identifier(step, total_steps) return output_dir / CHECKPOINTS_DIR / step_identifier def save_training_step(step: int, save_dir: Path) -> None: write_json({"step": step}, save_dir / TRAINING_STEP) def load_training_step(save_dir: Path) -> int: training_step = load_json(save_dir / TRAINING_STEP) return training_step["step"] def update_last_checkpoint(checkpoint_dir: Path) -> Path: last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK if last_checkpoint_dir.is_symlink(): last_checkpoint_dir.unlink() relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent) last_checkpoint_dir.symlink_to(relative_target) def save_checkpoint( checkpoint_dir: Path, step: int, cfg: TrainPipelineConfig, policy: PreTrainedPolicy, optimizer: Optimizer, scheduler: LRScheduler | None = None, ) -> None: """This function creates the following directory structure: 005000/ # training step at checkpoint ├── pretrained_model/ │ ├── config.json # policy config │ ├── model.safetensors # policy weights │ └── train_config.json # train config └── training_state/ ├── optimizer_param_groups.json # optimizer param groups ├── optimizer_state.safetensors # optimizer state ├── rng_state.safetensors # rng states ├── scheduler_state.json # scheduler state └── training_step.json # training step Args: cfg (TrainPipelineConfig): The training config used for this run. step (int): The training step at that checkpoint. policy (PreTrainedPolicy): The policy to save. optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None. scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. """ pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) cfg.save_pretrained(pretrained_dir) save_training_state(checkpoint_dir, step, optimizer, scheduler) def save_training_state( checkpoint_dir: Path, train_step: int, optimizer: Optimizer | None = None, scheduler: LRScheduler | None = None, ) -> None: """ Saves the training step, optimizer state, scheduler state, and rng state. Args: save_dir (Path): The directory to save artifacts to. train_step (int): Current training step. optimizer (Optimizer | None, optional): The optimizer from which to save the state_dict. Defaults to None. scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict. Defaults to None. """ save_dir = checkpoint_dir / TRAINING_STATE_DIR save_dir.mkdir(parents=True, exist_ok=True) save_training_step(train_step, save_dir) save_rng_state(save_dir) if optimizer is not None: save_optimizer_state(optimizer, save_dir) if scheduler is not None: save_scheduler_state(scheduler, save_dir) def load_training_state( checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None ) -> tuple[int, Optimizer, LRScheduler | None]: """ Loads the training step, optimizer state, scheduler state, and rng state. This is used to resume a training run. Args: checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir. optimizer (Optimizer): The optimizer to load the state_dict to. scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None). Raises: NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir Returns: tuple[int, Optimizer, LRScheduler | None]: training step, optimizer and scheduler with their state_dict loaded. """ training_state_dir = checkpoint_dir / TRAINING_STATE_DIR if not training_state_dir.is_dir(): raise NotADirectoryError(training_state_dir) load_rng_state(training_state_dir) step = load_training_step(training_state_dir) optimizer = load_optimizer_state(optimizer, training_state_dir) if scheduler is not None: scheduler = load_scheduler_state(scheduler, training_state_dir) return step, optimizer, scheduler