Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import einops | |
import torch.nn as nn | |
# import pytorch_lightning as pl | |
from yacs.config import CfgNode | |
from .vit import vit | |
from .smpl_head import SMPLTransformerDecoderHead | |
# class HMR2(pl.LightningModule): | |
class HMR2(nn.Module): | |
def __init__(self): | |
""" | |
Setup HMR2 model | |
Args: | |
cfg (CfgNode): Config file as a yacs CfgNode | |
""" | |
super().__init__() | |
# Create backbone feature extractor | |
self.backbone = vit() | |
# Create SMPL head | |
self.smpl_head = SMPLTransformerDecoderHead() | |
def decode(self, x): | |
batch_size = x.shape[0] | |
pred_smpl_params, pred_cam, _ = self.smpl_head(x) | |
# Compute model vertices, joints and the projected joints | |
pred_smpl_params['global_orient'] = pred_smpl_params['global_orient'].reshape(batch_size, -1, 3, 3) | |
pred_smpl_params['body_pose'] = pred_smpl_params['body_pose'].reshape(batch_size, -1, 3, 3) | |
pred_smpl_params['betas'] = pred_smpl_params['betas'].reshape(batch_size, -1) | |
return pred_smpl_params['global_orient'], pred_smpl_params['body_pose'], pred_smpl_params['betas'], pred_cam | |
def forward(self, x, encode=False, **kwargs): | |
""" | |
Run a forward step of the network | |
Args: | |
batch (Dict): Dictionary containing batch data | |
train (bool): Flag indicating whether it is training or validation mode | |
Returns: | |
Dict: Dictionary containing the regression output | |
""" | |
# Use RGB image as input | |
batch_size = x.shape[0] | |
# Compute conditioning features using the backbone | |
# if using ViT backbone, we need to use a different aspect ratio | |
conditioning_feats = self.backbone(x[:,:,:,32:-32]) | |
if encode: | |
conditioning_feats = einops.rearrange(conditioning_feats, 'b c h w -> b (h w) c') | |
token = torch.zeros(batch_size, 1, 1).to(x.device) | |
token_out = self.smpl_head.transformer(token, context=conditioning_feats) | |
return token_out.squeeze(1) | |
pred_smpl_params, pred_cam, _ = self.smpl_head(conditioning_feats) | |
# Compute model vertices, joints and the projected joints | |
pred_smpl_params['global_orient'] = pred_smpl_params['global_orient'].reshape(batch_size, -1, 3, 3) | |
pred_smpl_params['body_pose'] = pred_smpl_params['body_pose'].reshape(batch_size, -1, 3, 3) | |
pred_smpl_params['betas'] = pred_smpl_params['betas'].reshape(batch_size, -1) | |
return pred_smpl_params['global_orient'], pred_smpl_params['body_pose'], pred_smpl_params['betas'], pred_cam | |
def hmr2(checkpoint_pth): | |
model = HMR2() | |
if os.path.exists(checkpoint_pth): | |
model.load_state_dict(torch.load(checkpoint_pth, map_location='cpu')['state_dict'], strict=False) | |
print(f'Load backbone weight: {checkpoint_pth}') | |
return model |