Spaces:
Running
Running
#!/usr/bin/env python | |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
from pathlib import Path | |
from termcolor import colored | |
from torch.optim import Optimizer | |
from torch.optim.lr_scheduler import LRScheduler | |
from lerobot.common.constants import ( | |
CHECKPOINTS_DIR, | |
LAST_CHECKPOINT_LINK, | |
PRETRAINED_MODEL_DIR, | |
TRAINING_STATE_DIR, | |
TRAINING_STEP, | |
) | |
from lerobot.common.datasets.utils import load_json, write_json | |
from lerobot.common.optim.optimizers import load_optimizer_state, save_optimizer_state | |
from lerobot.common.optim.schedulers import load_scheduler_state, save_scheduler_state | |
from lerobot.common.policies.pretrained import PreTrainedPolicy | |
from lerobot.common.utils.random_utils import load_rng_state, save_rng_state | |
from lerobot.configs.train import TrainPipelineConfig | |
def log_output_dir(out_dir): | |
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") | |
def get_step_identifier(step: int, total_steps: int) -> str: | |
num_digits = max(6, len(str(total_steps))) | |
return f"{step:0{num_digits}d}" | |
def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Path: | |
"""Returns the checkpoint sub-directory corresponding to the step number.""" | |
step_identifier = get_step_identifier(step, total_steps) | |
return output_dir / CHECKPOINTS_DIR / step_identifier | |
def save_training_step(step: int, save_dir: Path) -> None: | |
write_json({"step": step}, save_dir / TRAINING_STEP) | |
def load_training_step(save_dir: Path) -> int: | |
training_step = load_json(save_dir / TRAINING_STEP) | |
return training_step["step"] | |
def update_last_checkpoint(checkpoint_dir: Path) -> Path: | |
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK | |
if last_checkpoint_dir.is_symlink(): | |
last_checkpoint_dir.unlink() | |
relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent) | |
last_checkpoint_dir.symlink_to(relative_target) | |
def save_checkpoint( | |
checkpoint_dir: Path, | |
step: int, | |
cfg: TrainPipelineConfig, | |
policy: PreTrainedPolicy, | |
optimizer: Optimizer, | |
scheduler: LRScheduler | None = None, | |
) -> None: | |
"""This function creates the following directory structure: | |
005000/ # training step at checkpoint | |
βββ pretrained_model/ | |
β βββ config.json # policy config | |
β βββ model.safetensors # policy weights | |
β βββ train_config.json # train config | |
βββ training_state/ | |
βββ optimizer_param_groups.json # optimizer param groups | |
βββ optimizer_state.safetensors # optimizer state | |
βββ rng_state.safetensors # rng states | |
βββ scheduler_state.json # scheduler state | |
βββ training_step.json # training step | |
Args: | |
cfg (TrainPipelineConfig): The training config used for this run. | |
step (int): The training step at that checkpoint. | |
policy (PreTrainedPolicy): The policy to save. | |
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None. | |
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. | |
""" | |
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR | |
policy.save_pretrained(pretrained_dir) | |
cfg.save_pretrained(pretrained_dir) | |
save_training_state(checkpoint_dir, step, optimizer, scheduler) | |
def save_training_state( | |
checkpoint_dir: Path, | |
train_step: int, | |
optimizer: Optimizer | None = None, | |
scheduler: LRScheduler | None = None, | |
) -> None: | |
""" | |
Saves the training step, optimizer state, scheduler state, and rng state. | |
Args: | |
save_dir (Path): The directory to save artifacts to. | |
train_step (int): Current training step. | |
optimizer (Optimizer | None, optional): The optimizer from which to save the state_dict. | |
Defaults to None. | |
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict. | |
Defaults to None. | |
""" | |
save_dir = checkpoint_dir / TRAINING_STATE_DIR | |
save_dir.mkdir(parents=True, exist_ok=True) | |
save_training_step(train_step, save_dir) | |
save_rng_state(save_dir) | |
if optimizer is not None: | |
save_optimizer_state(optimizer, save_dir) | |
if scheduler is not None: | |
save_scheduler_state(scheduler, save_dir) | |
def load_training_state( | |
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None | |
) -> tuple[int, Optimizer, LRScheduler | None]: | |
""" | |
Loads the training step, optimizer state, scheduler state, and rng state. | |
This is used to resume a training run. | |
Args: | |
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir. | |
optimizer (Optimizer): The optimizer to load the state_dict to. | |
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None). | |
Raises: | |
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir | |
Returns: | |
tuple[int, Optimizer, LRScheduler | None]: training step, optimizer and scheduler with their | |
state_dict loaded. | |
""" | |
training_state_dir = checkpoint_dir / TRAINING_STATE_DIR | |
if not training_state_dir.is_dir(): | |
raise NotADirectoryError(training_state_dir) | |
load_rng_state(training_state_dir) | |
step = load_training_step(training_state_dir) | |
optimizer = load_optimizer_state(optimizer, training_state_dir) | |
if scheduler is not None: | |
scheduler = load_scheduler_state(scheduler, training_state_dir) | |
return step, optimizer, scheduler | |