Techt3o's picture
e01e49325338173592071b44501d91416d6a9072d1040c9d9f5aecf816533bec
f561f8b verified
raw
history blame
1.38 kB
import os, sys
import yaml
import torch
from loguru import logger
from configs import constants as _C
from .smpl import SMPL
def build_body_model(device, batch_size=1, gender='neutral', **kwargs):
sys.stdout = open(os.devnull, 'w')
body_model = SMPL(
model_path=_C.BMODEL.FLDR,
gender=gender,
batch_size=batch_size,
create_transl=False).to(device)
sys.stdout = sys.__stdout__
return body_model
def build_network(cfg, smpl):
from .wham import Network
with open(cfg.MODEL_CONFIG, 'r') as f:
model_config = yaml.safe_load(f)
model_config.update({'d_feat': _C.IMG_FEAT_DIM[cfg.MODEL.BACKBONE]})
network = Network(smpl, **model_config).to(cfg.DEVICE)
# Load Checkpoint
if os.path.isfile(cfg.TRAIN.CHECKPOINT):
checkpoint = torch.load(cfg.TRAIN.CHECKPOINT)
ignore_keys = ['smpl.body_pose', 'smpl.betas', 'smpl.global_orient', 'smpl.J_regressor_extra', 'smpl.J_regressor_eval']
model_state_dict = {k: v for k, v in checkpoint['model'].items() if k not in ignore_keys}
network.load_state_dict(model_state_dict, strict=False)
logger.info(f"=> loaded checkpoint '{cfg.TRAIN.CHECKPOINT}' ")
else:
logger.info(f"=> Warning! no checkpoint found at '{cfg.TRAIN.CHECKPOINT}'.")
return network