|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import pdb |
|
import sys |
|
import traceback |
|
from tempfile import TemporaryDirectory |
|
|
|
import safetensors |
|
import torch.nn as nn |
|
from accelerate import Accelerator |
|
from megfile import ( |
|
smart_copy, |
|
smart_exists, |
|
smart_listdir, |
|
smart_makedirs, |
|
smart_path_join, |
|
) |
|
from omegaconf import OmegaConf |
|
|
|
sys.path.append(".") |
|
|
|
from LHM.models import model_dict |
|
from LHM.utils.hf_hub import wrap_model_hub |
|
from LHM.utils.proxy import no_proxy |
|
|
|
|
|
@no_proxy |
|
def auto_load_model(cfg, model: nn.Module) -> int: |
|
|
|
ckpt_root = smart_path_join( |
|
cfg.saver.checkpoint_root, |
|
cfg.experiment.parent, |
|
cfg.experiment.child, |
|
) |
|
if not smart_exists(ckpt_root): |
|
raise FileNotFoundError(f"Checkpoint root not found: {ckpt_root}") |
|
ckpt_dirs = smart_listdir(ckpt_root) |
|
if len(ckpt_dirs) == 0: |
|
raise FileNotFoundError(f"No checkpoint found in {ckpt_root}") |
|
ckpt_dirs.sort() |
|
|
|
load_step = ( |
|
f"{cfg.convert.global_step}" |
|
if cfg.convert.global_step is not None |
|
else ckpt_dirs[-1] |
|
) |
|
load_model_path = smart_path_join(ckpt_root, load_step, "model.safetensors") |
|
|
|
if load_model_path.startswith("s3"): |
|
tmpdir = TemporaryDirectory() |
|
tmp_model_path = smart_path_join(tmpdir.name, f"tmp.safetensors") |
|
smart_copy(load_model_path, tmp_model_path) |
|
load_model_path = tmp_model_path |
|
|
|
print(f"Loading from {load_model_path}") |
|
try: |
|
safetensors.torch.load_model(model, load_model_path, strict=True) |
|
except: |
|
traceback.print_exc() |
|
safetensors.torch.load_model(model, load_model_path, strict=False) |
|
|
|
return int(load_step) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", type=str, default="./assets/config.yaml") |
|
args, unknown = parser.parse_known_args() |
|
cfg = OmegaConf.load(args.config) |
|
cli_cfg = OmegaConf.from_cli(unknown) |
|
cfg = OmegaConf.merge(cfg, cli_cfg) |
|
|
|
""" |
|
[cfg.convert] |
|
global_step: int |
|
save_dir: str |
|
""" |
|
|
|
accelerator = Accelerator() |
|
|
|
|
|
hf_model_cls = wrap_model_hub(model_dict["human_lrm_sapdino_bh_sd3_5"]) |
|
|
|
hf_model = hf_model_cls(OmegaConf.to_container(cfg.model)) |
|
loaded_step = auto_load_model(cfg, hf_model) |
|
dump_path = smart_path_join( |
|
f"./exps/releases", |
|
cfg.experiment.parent, |
|
cfg.experiment.child, |
|
f"step_{loaded_step:06d}", |
|
) |
|
print(f"Saving locally to {dump_path}") |
|
smart_makedirs(dump_path, exist_ok=True) |
|
hf_model.save_pretrained( |
|
save_directory=dump_path, |
|
config=hf_model.config, |
|
) |
|
|