Spaces:
Running
Running
#!/usr/bin/env python | |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import abc | |
from dataclasses import asdict, dataclass | |
from pathlib import Path | |
import draccus | |
import torch | |
from safetensors.torch import load_file, save_file | |
from lerobot.common.constants import ( | |
OPTIMIZER_PARAM_GROUPS, | |
OPTIMIZER_STATE, | |
) | |
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json | |
from lerobot.common.utils.io_utils import deserialize_json_into_object | |
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): | |
lr: float | |
weight_decay: float | |
grad_clip_norm: float | |
def type(self) -> str: | |
return self.get_choice_name(self.__class__) | |
def default_choice_name(cls) -> str | None: | |
return "adam" | |
def build(self) -> torch.optim.Optimizer: | |
raise NotImplementedError | |
class AdamConfig(OptimizerConfig): | |
lr: float = 1e-3 | |
betas: tuple[float, float] = (0.9, 0.999) | |
eps: float = 1e-8 | |
weight_decay: float = 0.0 | |
grad_clip_norm: float = 10.0 | |
def build(self, params: dict) -> torch.optim.Optimizer: | |
kwargs = asdict(self) | |
kwargs.pop("grad_clip_norm") | |
return torch.optim.Adam(params, **kwargs) | |
class AdamWConfig(OptimizerConfig): | |
lr: float = 1e-3 | |
betas: tuple[float, float] = (0.9, 0.999) | |
eps: float = 1e-8 | |
weight_decay: float = 1e-2 | |
grad_clip_norm: float = 10.0 | |
def build(self, params: dict) -> torch.optim.Optimizer: | |
kwargs = asdict(self) | |
kwargs.pop("grad_clip_norm") | |
return torch.optim.AdamW(params, **kwargs) | |
class SGDConfig(OptimizerConfig): | |
lr: float = 1e-3 | |
momentum: float = 0.0 | |
dampening: float = 0.0 | |
nesterov: bool = False | |
weight_decay: float = 0.0 | |
grad_clip_norm: float = 10.0 | |
def build(self, params: dict) -> torch.optim.Optimizer: | |
kwargs = asdict(self) | |
kwargs.pop("grad_clip_norm") | |
return torch.optim.SGD(params, **kwargs) | |
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None: | |
state = optimizer.state_dict() | |
param_groups = state.pop("param_groups") | |
flat_state = flatten_dict(state) | |
save_file(flat_state, save_dir / OPTIMIZER_STATE) | |
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS) | |
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer: | |
current_state_dict = optimizer.state_dict() | |
flat_state = load_file(save_dir / OPTIMIZER_STATE) | |
state = unflatten_dict(flat_state) | |
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}} | |
if "param_groups" in current_state_dict: | |
param_groups = deserialize_json_into_object( | |
save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"] | |
) | |
loaded_state_dict["param_groups"] = param_groups | |
optimizer.load_state_dict(loaded_state_dict) | |
return optimizer | |