#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc from dataclasses import asdict, dataclass from pathlib import Path import draccus import torch from safetensors.torch import load_file, save_file from lerobot.common.constants import ( OPTIMIZER_PARAM_GROUPS, OPTIMIZER_STATE, ) from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json from lerobot.common.utils.io_utils import deserialize_json_into_object @dataclass class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): lr: float weight_decay: float grad_clip_norm: float @property def type(self) -> str: return self.get_choice_name(self.__class__) @classmethod def default_choice_name(cls) -> str | None: return "adam" @abc.abstractmethod def build(self) -> torch.optim.Optimizer: raise NotImplementedError @OptimizerConfig.register_subclass("adam") @dataclass class AdamConfig(OptimizerConfig): lr: float = 1e-3 betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 weight_decay: float = 0.0 grad_clip_norm: float = 10.0 def build(self, params: dict) -> torch.optim.Optimizer: kwargs = asdict(self) kwargs.pop("grad_clip_norm") return torch.optim.Adam(params, **kwargs) @OptimizerConfig.register_subclass("adamw") @dataclass class AdamWConfig(OptimizerConfig): lr: float = 1e-3 betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 weight_decay: float = 1e-2 grad_clip_norm: float = 10.0 def build(self, params: dict) -> torch.optim.Optimizer: kwargs = asdict(self) kwargs.pop("grad_clip_norm") return torch.optim.AdamW(params, **kwargs) @OptimizerConfig.register_subclass("sgd") @dataclass class SGDConfig(OptimizerConfig): lr: float = 1e-3 momentum: float = 0.0 dampening: float = 0.0 nesterov: bool = False weight_decay: float = 0.0 grad_clip_norm: float = 10.0 def build(self, params: dict) -> torch.optim.Optimizer: kwargs = asdict(self) kwargs.pop("grad_clip_norm") return torch.optim.SGD(params, **kwargs) def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None: state = optimizer.state_dict() param_groups = state.pop("param_groups") flat_state = flatten_dict(state) save_file(flat_state, save_dir / OPTIMIZER_STATE) write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS) def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer: current_state_dict = optimizer.state_dict() flat_state = load_file(save_dir / OPTIMIZER_STATE) state = unflatten_dict(flat_state) loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}} if "param_groups" in current_state_dict: param_groups = deserialize_json_into_object( save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"] ) loaded_state_dict["param_groups"] = param_groups optimizer.load_state_dict(loaded_state_dict) return optimizer