import torch
import torch.nn as nn
import numpy as np
from lib.pymafx.core import constants

from lib.common.config import cfg
from lib.pymafx.utils.geometry import rot6d_to_rotmat, rotmat_to_rot6d, projection, rotation_matrix_to_angle_axis, compute_twist_rotation
from .maf_extractor import MAF_Extractor, Mesh_Sampler
from .smpl import SMPL, SMPL_MODEL_DIR, SMPL_MEAN_PARAMS, get_partial_smpl, SMPL_Family
from lib.smplx.lbs import batch_rodrigues
from .res_module import IUV_predict_layer
from .hr_module import get_hrnet_encoder
from .pose_resnet import get_resnet_encoder
from lib.pymafx.utils.imutils import j2d_processing
from lib.pymafx.utils.cam_params import homo_vector
from .attention import get_att_block

import logging

logger = logging.getLogger(__name__)

BN_MOMENTUM = 0.1


class Regressor(nn.Module):
    def __init__(
        self,
        feat_dim,
        smpl_mean_params,
        use_cam_feats=False,
        feat_dim_hand=0,
        feat_dim_face=0,
        bhf_names=['body'],
        smpl_models={}
    ):
        super().__init__()

        npose = 24 * 6
        shape_dim = 10
        cam_dim = 3
        hand_dim = 15 * 6
        face_dim = 3 * 6 + 10

        self.body_feat_dim = feat_dim

        self.smpl_mode = (cfg.MODEL.MESH_MODEL == 'smpl')
        self.smplx_mode = (cfg.MODEL.MESH_MODEL == 'smplx')
        self.use_cam_feats = use_cam_feats

        cam_feat_len = 4 if self.use_cam_feats else 0

        self.bhf_names = bhf_names
        self.hand_only_mode = (cfg.TRAIN.BHF_MODE == 'hand_only')
        self.face_only_mode = (cfg.TRAIN.BHF_MODE == 'face_only')
        self.body_hand_mode = (cfg.TRAIN.BHF_MODE == 'body_hand')
        self.full_body_mode = (cfg.TRAIN.BHF_MODE == 'full_body')

        # if self.use_cam_feats:
        #     assert cfg.MODEL.USE_IWP_CAM is False
        if 'body' in self.bhf_names:
            self.fc1 = nn.Linear(feat_dim + npose + cam_feat_len + shape_dim + cam_dim, 1024)
            self.drop1 = nn.Dropout()
            self.fc2 = nn.Linear(1024, 1024)
            self.drop2 = nn.Dropout()
            self.decpose = nn.Linear(1024, npose)
            self.decshape = nn.Linear(1024, 10)
            self.deccam = nn.Linear(1024, 3)
            nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
            nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
            nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)

        if not self.smpl_mode:
            if self.hand_only_mode:
                self.part_names = ['rhand']
            elif self.face_only_mode:
                self.part_names = ['face']
            elif self.body_hand_mode:
                self.part_names = ['lhand', 'rhand']
            elif self.full_body_mode:
                self.part_names = ['lhand', 'rhand', 'face']
            else:
                self.part_names = []

            if 'rhand' in self.part_names:
                # self.fc1_hand = nn.Linear(feat_dim_hand + hand_dim + rh_orient_dim + rh_shape_dim + rh_cam_dim, 1024)
                self.fc1_hand = nn.Linear(feat_dim_hand + hand_dim, 1024)
                self.drop1_hand = nn.Dropout()
                self.fc2_hand = nn.Linear(1024, 1024)
                self.drop2_hand = nn.Dropout()

                # self.declhand = nn.Linear(1024, 15*6)
                self.decrhand = nn.Linear(1024, 15 * 6)
                # nn.init.xavier_uniform_(self.declhand.weight, gain=0.01)
                nn.init.xavier_uniform_(self.decrhand.weight, gain=0.01)

                if cfg.MODEL.MESH_MODEL == 'mano' or cfg.MODEL.PyMAF.OPT_WRIST:
                    rh_cam_dim = 3
                    rh_orient_dim = 6
                    rh_shape_dim = 10
                    self.fc3_hand = nn.Linear(
                        1024 + rh_orient_dim + rh_shape_dim + rh_cam_dim, 1024
                    )
                    self.drop3_hand = nn.Dropout()

                    self.decshape_rhand = nn.Linear(1024, 10)
                    self.decorient_rhand = nn.Linear(1024, 6)
                    self.deccam_rhand = nn.Linear(1024, 3)
                    nn.init.xavier_uniform_(self.decshape_rhand.weight, gain=0.01)
                    nn.init.xavier_uniform_(self.decorient_rhand.weight, gain=0.01)
                    nn.init.xavier_uniform_(self.deccam_rhand.weight, gain=0.01)

            if 'face' in self.part_names:
                self.fc1_face = nn.Linear(feat_dim_face + face_dim, 1024)
                self.drop1_face = nn.Dropout()
                self.fc2_face = nn.Linear(1024, 1024)
                self.drop2_face = nn.Dropout()

                self.dechead = nn.Linear(1024, 3 * 6)
                self.decexp = nn.Linear(1024, 10)
                nn.init.xavier_uniform_(self.dechead.weight, gain=0.01)
                nn.init.xavier_uniform_(self.decexp.weight, gain=0.01)

                if cfg.MODEL.MESH_MODEL == 'flame':
                    rh_cam_dim = 3
                    rh_orient_dim = 6
                    rh_shape_dim = 10
                    self.fc3_face = nn.Linear(
                        1024 + rh_orient_dim + rh_shape_dim + rh_cam_dim, 1024
                    )
                    self.drop3_face = nn.Dropout()

                    self.decshape_face = nn.Linear(1024, 10)
                    self.decorient_face = nn.Linear(1024, 6)
                    self.deccam_face = nn.Linear(1024, 3)
                    nn.init.xavier_uniform_(self.decshape_face.weight, gain=0.01)
                    nn.init.xavier_uniform_(self.decorient_face.weight, gain=0.01)
                    nn.init.xavier_uniform_(self.deccam_face.weight, gain=0.01)

            if self.smplx_mode and cfg.MODEL.PyMAF.PRED_VIS_H:
                self.fc1_vis = nn.Linear(1024 + 1024 + 1024, 1024)
                self.drop1_vis = nn.Dropout()
                self.fc2_vis = nn.Linear(1024, 1024)
                self.drop2_vis = nn.Dropout()
                self.decvis = nn.Linear(1024, 2)
                nn.init.xavier_uniform_(self.decvis.weight, gain=0.01)

        if 'body' in smpl_models:
            self.smpl = smpl_models['body']
        if 'hand' in smpl_models:
            self.mano = smpl_models['hand']
        if 'face' in smpl_models:
            self.flame = smpl_models['face']

        if cfg.MODEL.PyMAF.OPT_WRIST:
            self.body_model = SMPL(model_path=SMPL_MODEL_DIR, batch_size=64, create_transl=False)

        mean_params = np.load(smpl_mean_params)
        init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
        init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0)
        init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)
        self.register_buffer('init_pose', init_pose)
        self.register_buffer('init_shape', init_shape)
        self.register_buffer('init_cam', init_cam)
        self.register_buffer('init_orient', init_pose[:, :6])

        self.flip_vector = torch.ones((1, 9), dtype=torch.float32)
        self.flip_vector[:, [1, 2, 3, 6]] *= -1
        self.flip_vector = self.flip_vector.reshape(1, 3, 3)

        if not self.smpl_mode:
            lhand_mean_rot6d = rotmat_to_rot6d(
                batch_rodrigues(self.smpl.model.model_neutral.left_hand_mean.view(-1, 3)).view(
                    [-1, 3, 3]
                )
            )
            rhand_mean_rot6d = rotmat_to_rot6d(
                batch_rodrigues(self.smpl.model.model_neutral.right_hand_mean.view(-1, 3)).view(
                    [-1, 3, 3]
                )
            )
            init_lhand = lhand_mean_rot6d.reshape(-1).unsqueeze(0)
            init_rhand = rhand_mean_rot6d.reshape(-1).unsqueeze(0)
            # init_hand = torch.cat([init_lhand, init_rhand]).unsqueeze(0)
            init_face = rotmat_to_rot6d(torch.stack([torch.eye(3)] * 3)).reshape(-1).unsqueeze(0)
            init_exp = torch.zeros(10).unsqueeze(0)

        if self.smplx_mode or 'hand' in bhf_names:
            # init_hand = torch.cat([init_lhand, init_rhand]).unsqueeze(0)
            self.register_buffer('init_lhand', init_lhand)
            self.register_buffer('init_rhand', init_rhand)
        if self.smplx_mode or 'face' in bhf_names:
            self.register_buffer('init_face', init_face)
            self.register_buffer('init_exp', init_exp)

    def forward(
        self,
        x=None,
        n_iter=1,
        J_regressor=None,
        rw_cam={},
        init_mode=False,
        global_iter=-1,
        **kwargs
    ):
        if x is not None:
            batch_size = x.shape[0]
        else:
            if 'xc_rhand' in kwargs:
                batch_size = kwargs['xc_rhand'].shape[0]
            elif 'xc_face' in kwargs:
                batch_size = kwargs['xc_face'].shape[0]

        if 'body' in self.bhf_names:
            if 'init_pose' not in kwargs:
                kwargs['init_pose'] = self.init_pose.expand(batch_size, -1)
            if 'init_shape' not in kwargs:
                kwargs['init_shape'] = self.init_shape.expand(batch_size, -1)
            if 'init_cam' not in kwargs:
                kwargs['init_cam'] = self.init_cam.expand(batch_size, -1)

            pred_cam = kwargs['init_cam']
            pred_pose = kwargs['init_pose']
            pred_shape = kwargs['init_shape']

        if self.full_body_mode or self.body_hand_mode:
            if cfg.MODEL.PyMAF.OPT_WRIST:
                pred_rotmat_body = rot6d_to_rotmat(
                    pred_pose.reshape(batch_size, -1, 6)
                )    # .view(batch_size, 24, 3, 3)
            if cfg.MODEL.PyMAF.PRED_VIS_H:
                pred_vis_hands = None

        # if self.full_body_mode or 'hand' in self.bhf_names:
        if self.smplx_mode or 'hand' in self.bhf_names:
            if 'init_lhand' not in kwargs:
                # kwargs['init_lhand'] = self.init_lhand.expand(batch_size, -1)
                # init with **right** hand pose
                kwargs['init_lhand'] = self.init_rhand.expand(batch_size, -1)
            if 'init_rhand' not in kwargs:
                kwargs['init_rhand'] = self.init_rhand.expand(batch_size, -1)

            pred_lhand, pred_rhand = kwargs['init_lhand'], kwargs['init_rhand']

            if cfg.MODEL.MESH_MODEL == 'mano' or cfg.MODEL.PyMAF.OPT_WRIST:
                if 'init_orient_rh' not in kwargs:
                    kwargs['init_orient_rh'] = self.init_orient.expand(batch_size, -1)
                if 'init_shape_rh' not in kwargs:
                    kwargs['init_shape_rh'] = self.init_shape.expand(batch_size, -1)
                if 'init_cam_rh' not in kwargs:
                    kwargs['init_cam_rh'] = self.init_cam.expand(batch_size, -1)
                pred_orient_rh = kwargs['init_orient_rh']
                pred_shape_rh = kwargs['init_shape_rh']
                pred_cam_rh = kwargs['init_cam_rh']
                if cfg.MODEL.PyMAF.OPT_WRIST:
                    if 'init_orient_lh' not in kwargs:
                        kwargs['init_orient_lh'] = self.init_orient.expand(batch_size, -1)
                    if 'init_shape_lh' not in kwargs:
                        kwargs['init_shape_lh'] = self.init_shape.expand(batch_size, -1)
                    if 'init_cam_lh' not in kwargs:
                        kwargs['init_cam_lh'] = self.init_cam.expand(batch_size, -1)
                    pred_orient_lh = kwargs['init_orient_lh']
                    pred_shape_lh = kwargs['init_shape_lh']
                    pred_cam_lh = kwargs['init_cam_lh']
                if cfg.MODEL.MESH_MODEL == 'mano':
                    pred_cam = torch.cat([pred_cam_rh[:, 0:1] * 10., pred_cam_rh[:, 1:]], dim=1)

        # if self.full_body_mode or 'face' in self.bhf_names:
        if self.smplx_mode or 'face' in self.bhf_names:
            if 'init_face' not in kwargs:
                kwargs['init_face'] = self.init_face.expand(batch_size, -1)
            if 'init_hand' not in kwargs:
                kwargs['init_exp'] = self.init_exp.expand(batch_size, -1)

            pred_face = kwargs['init_face']
            pred_exp = kwargs['init_exp']

            if cfg.MODEL.MESH_MODEL == 'flame' or cfg.MODEL.PyMAF.OPT_WRIST:
                if 'init_orient_fa' not in kwargs:
                    kwargs['init_orient_fa'] = self.init_orient.expand(batch_size, -1)
                pred_orient_fa = kwargs['init_orient_fa']
                if 'init_shape_fa' not in kwargs:
                    kwargs['init_shape_fa'] = self.init_shape.expand(batch_size, -1)
                if 'init_cam_fa' not in kwargs:
                    kwargs['init_cam_fa'] = self.init_cam.expand(batch_size, -1)
                pred_shape_fa = kwargs['init_shape_fa']
                pred_cam_fa = kwargs['init_cam_fa']
                if cfg.MODEL.MESH_MODEL == 'flame':
                    pred_cam = torch.cat([pred_cam_fa[:, 0:1] * 10., pred_cam_fa[:, 1:]], dim=1)

        if not init_mode:
            for i in range(n_iter):
                if 'body' in self.bhf_names:
                    xc = torch.cat([x, pred_pose, pred_shape, pred_cam], 1)
                    if self.use_cam_feats:
                        if cfg.MODEL.USE_IWP_CAM:
                            # for IWP camera, simply use pre-defined values
                            vfov = torch.ones((batch_size, 1)).to(xc) * 0.8
                            crop_ratio = torch.ones((batch_size, 1)).to(xc) * 0.3
                            crop_center = torch.ones((batch_size, 2)).to(xc) * 0.5
                        else:
                            vfov = rw_cam['vfov'][:, None]
                            crop_ratio = rw_cam['crop_ratio'][:, None]
                            crop_center = rw_cam['bbox_center'] / torch.cat(
                                [rw_cam['img_w'][:, None], rw_cam['img_h'][:, None]], 1
                            )
                        xc = torch.cat([xc, vfov, crop_ratio, crop_center], 1)

                    xc = self.fc1(xc)
                    xc = self.drop1(xc)
                    xc = self.fc2(xc)
                    xc = self.drop2(xc)

                    pred_cam = self.deccam(xc) + pred_cam
                    pred_pose = self.decpose(xc) + pred_pose
                    pred_shape = self.decshape(xc) + pred_shape

                if not self.smpl_mode:
                    if self.hand_only_mode:
                        xc_rhand = kwargs['xc_rhand']
                        xc_rhand = torch.cat([xc_rhand, pred_rhand], 1)
                    elif self.face_only_mode:
                        xc_face = kwargs['xc_face']
                        xc_face = torch.cat([xc_face, pred_face, pred_exp], 1)
                    elif self.body_hand_mode:
                        xc_lhand, xc_rhand = kwargs['xc_lhand'], kwargs['xc_rhand']
                        xc_lhand = torch.cat([xc_lhand, pred_lhand], 1)
                        xc_rhand = torch.cat([xc_rhand, pred_rhand], 1)
                    elif self.full_body_mode:
                        xc_lhand, xc_rhand, xc_face = kwargs['xc_lhand'], kwargs['xc_rhand'
                                                                                ], kwargs['xc_face']
                        xc_lhand = torch.cat([xc_lhand, pred_lhand], 1)
                        xc_rhand = torch.cat([xc_rhand, pred_rhand], 1)
                        xc_face = torch.cat([xc_face, pred_face, pred_exp], 1)

                    if 'lhand' in self.part_names:
                        xc_lhand = self.drop1_hand(self.fc1_hand(xc_lhand))
                        xc_lhand = self.drop2_hand(self.fc2_hand(xc_lhand))
                        pred_lhand = self.decrhand(xc_lhand) + pred_lhand

                        if cfg.MODEL.PyMAF.OPT_WRIST:
                            xc_lhand = torch.cat(
                                [xc_lhand, pred_shape_lh, pred_orient_lh, pred_cam_lh], 1
                            )
                            xc_lhand = self.drop3_hand(self.fc3_hand(xc_lhand))

                            pred_shape_lh = self.decshape_rhand(xc_lhand) + pred_shape_lh
                            pred_orient_lh = self.decorient_rhand(xc_lhand) + pred_orient_lh
                            pred_cam_lh = self.deccam_rhand(xc_lhand) + pred_cam_lh

                    if 'rhand' in self.part_names:
                        xc_rhand = self.drop1_hand(self.fc1_hand(xc_rhand))
                        xc_rhand = self.drop2_hand(self.fc2_hand(xc_rhand))
                        pred_rhand = self.decrhand(xc_rhand) + pred_rhand

                        if cfg.MODEL.MESH_MODEL == 'mano' or cfg.MODEL.PyMAF.OPT_WRIST:
                            xc_rhand = torch.cat(
                                [xc_rhand, pred_shape_rh, pred_orient_rh, pred_cam_rh], 1
                            )
                            xc_rhand = self.drop3_hand(self.fc3_hand(xc_rhand))

                            pred_shape_rh = self.decshape_rhand(xc_rhand) + pred_shape_rh
                            pred_orient_rh = self.decorient_rhand(xc_rhand) + pred_orient_rh
                            pred_cam_rh = self.deccam_rhand(xc_rhand) + pred_cam_rh

                            if cfg.MODEL.MESH_MODEL == 'mano':
                                pred_cam = torch.cat(
                                    [pred_cam_rh[:, 0:1] * 10., pred_cam_rh[:, 1:] / 10.], dim=1
                                )

                    if 'face' in self.part_names:
                        xc_face = self.drop1_face(self.fc1_face(xc_face))
                        xc_face = self.drop2_face(self.fc2_face(xc_face))
                        pred_face = self.dechead(xc_face) + pred_face
                        pred_exp = self.decexp(xc_face) + pred_exp

                        if cfg.MODEL.MESH_MODEL == 'flame':
                            xc_face = torch.cat(
                                [xc_face, pred_shape_fa, pred_orient_fa, pred_cam_fa], 1
                            )
                            xc_face = self.drop3_face(self.fc3_face(xc_face))

                            pred_shape_fa = self.decshape_face(xc_face) + pred_shape_fa
                            pred_orient_fa = self.decorient_face(xc_face) + pred_orient_fa
                            pred_cam_fa = self.deccam_face(xc_face) + pred_cam_fa

                            if cfg.MODEL.MESH_MODEL == 'flame':
                                pred_cam = torch.cat(
                                    [pred_cam_fa[:, 0:1] * 10., pred_cam_fa[:, 1:] / 10.], dim=1
                                )

                    if self.full_body_mode or self.body_hand_mode:
                        if cfg.MODEL.PyMAF.PRED_VIS_H:
                            xc_vis = torch.cat([xc, xc_lhand, xc_rhand], 1)

                            xc_vis = self.drop1_vis(self.fc1_vis(xc_vis))
                            xc_vis = self.drop2_vis(self.fc2_vis(xc_vis))
                            pred_vis_hands = self.decvis(xc_vis)

                            pred_vis_lhand = pred_vis_hands[:, 0] > cfg.MODEL.PyMAF.HAND_VIS_TH
                            pred_vis_rhand = pred_vis_hands[:, 1] > cfg.MODEL.PyMAF.HAND_VIS_TH

                        if cfg.MODEL.PyMAF.OPT_WRIST:

                            pred_rotmat_body = rot6d_to_rotmat(
                                pred_pose.reshape(batch_size, -1, 6)
                            )    # .view(batch_size, 24, 3, 3)
                            pred_lwrist = pred_rotmat_body[:, 20]
                            pred_rwrist = pred_rotmat_body[:, 21]

                            pred_gl_body, body_joints = self.body_model.get_global_rotation(
                                global_orient=pred_rotmat_body[:, 0:1],
                                body_pose=pred_rotmat_body[:, 1:]
                            )
                            pred_gl_lelbow = pred_gl_body[:, 18]
                            pred_gl_relbow = pred_gl_body[:, 19]

                            target_gl_lwrist = rot6d_to_rotmat(
                                pred_orient_lh.reshape(batch_size, -1, 6)
                            )
                            target_gl_lwrist *= self.flip_vector.to(target_gl_lwrist.device)
                            target_gl_rwrist = rot6d_to_rotmat(
                                pred_orient_rh.reshape(batch_size, -1, 6)
                            )

                            opt_lwrist = torch.bmm(pred_gl_lelbow.transpose(1, 2), target_gl_lwrist)
                            opt_rwrist = torch.bmm(pred_gl_relbow.transpose(1, 2), target_gl_rwrist)

                            if cfg.MODEL.PyMAF.ADAPT_INTEGR:
                                # if cfg.MODEL.PyMAF.ADAPT_INTEGR and global_iter == (cfg.MODEL.PyMAF.N_ITER - 1):
                                tpose_joints = self.smpl.get_tpose(betas=pred_shape)
                                lelbow_twist_axis = nn.functional.normalize(
                                    tpose_joints[:, 20] - tpose_joints[:, 18], dim=1
                                )
                                relbow_twist_axis = nn.functional.normalize(
                                    tpose_joints[:, 21] - tpose_joints[:, 19], dim=1
                                )

                                lelbow_twist, lelbow_twist_angle = compute_twist_rotation(
                                    opt_lwrist, lelbow_twist_axis
                                )
                                relbow_twist, relbow_twist_angle = compute_twist_rotation(
                                    opt_rwrist, relbow_twist_axis
                                )

                                min_angle = -0.4 * float(np.pi)
                                max_angle = 0.4 * float(np.pi)

                                lelbow_twist_angle[lelbow_twist_angle == torch.
                                                   clamp(lelbow_twist_angle, min_angle, max_angle)
                                                  ] = 0
                                relbow_twist_angle[relbow_twist_angle == torch.
                                                   clamp(relbow_twist_angle, min_angle, max_angle)
                                                  ] = 0
                                lelbow_twist_angle[lelbow_twist_angle > max_angle] -= max_angle
                                lelbow_twist_angle[lelbow_twist_angle < min_angle] -= min_angle
                                relbow_twist_angle[relbow_twist_angle > max_angle] -= max_angle
                                relbow_twist_angle[relbow_twist_angle < min_angle] -= min_angle

                                lelbow_twist = batch_rodrigues(
                                    lelbow_twist_axis * lelbow_twist_angle
                                )
                                relbow_twist = batch_rodrigues(
                                    relbow_twist_axis * relbow_twist_angle
                                )

                                opt_lwrist = torch.bmm(lelbow_twist.transpose(1, 2), opt_lwrist)
                                opt_rwrist = torch.bmm(relbow_twist.transpose(1, 2), opt_rwrist)

                                # left elbow: 18
                                opt_lelbow = torch.bmm(pred_rotmat_body[:, 18], lelbow_twist)
                                # right elbow: 19
                                opt_relbow = torch.bmm(pred_rotmat_body[:, 19], relbow_twist)

                                if cfg.MODEL.PyMAF.PRED_VIS_H and global_iter == (
                                    cfg.MODEL.PyMAF.N_ITER - 1
                                ):
                                    opt_lwrist_filtered = [
                                        opt_lwrist[_i]
                                        if pred_vis_lhand[_i] else pred_rotmat_body[_i, 20]
                                        for _i in range(batch_size)
                                    ]
                                    opt_rwrist_filtered = [
                                        opt_rwrist[_i]
                                        if pred_vis_rhand[_i] else pred_rotmat_body[_i, 21]
                                        for _i in range(batch_size)
                                    ]
                                    opt_lelbow_filtered = [
                                        opt_lelbow[_i]
                                        if pred_vis_lhand[_i] else pred_rotmat_body[_i, 18]
                                        for _i in range(batch_size)
                                    ]
                                    opt_relbow_filtered = [
                                        opt_relbow[_i]
                                        if pred_vis_rhand[_i] else pred_rotmat_body[_i, 19]
                                        for _i in range(batch_size)
                                    ]

                                    opt_lwrist = torch.stack(opt_lwrist_filtered)
                                    opt_rwrist = torch.stack(opt_rwrist_filtered)
                                    opt_lelbow = torch.stack(opt_lelbow_filtered)
                                    opt_relbow = torch.stack(opt_relbow_filtered)

                                pred_rotmat_body = torch.cat(
                                    [
                                        pred_rotmat_body[:, :18],
                                        opt_lelbow.unsqueeze(1),
                                        opt_relbow.unsqueeze(1),
                                        opt_lwrist.unsqueeze(1),
                                        opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:]
                                    ], 1
                                )
                            else:
                                if cfg.MODEL.PyMAF.PRED_VIS_H and global_iter == (
                                    cfg.MODEL.PyMAF.N_ITER - 1
                                ):
                                    opt_lwrist_filtered = [
                                        opt_lwrist[_i]
                                        if pred_vis_lhand[_i] else pred_rotmat_body[_i, 20]
                                        for _i in range(batch_size)
                                    ]
                                    opt_rwrist_filtered = [
                                        opt_rwrist[_i]
                                        if pred_vis_rhand[_i] else pred_rotmat_body[_i, 21]
                                        for _i in range(batch_size)
                                    ]

                                    opt_lwrist = torch.stack(opt_lwrist_filtered)
                                    opt_rwrist = torch.stack(opt_rwrist_filtered)

                                pred_rotmat_body = torch.cat(
                                    [
                                        pred_rotmat_body[:, :20],
                                        opt_lwrist.unsqueeze(1),
                                        opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:]
                                    ], 1
                                )

        if self.hand_only_mode:
            pred_rotmat_rh = rot6d_to_rotmat(
                torch.cat([pred_orient_rh, pred_rhand], dim=1).reshape(batch_size, -1, 6)
            )    # .view(batch_size, 16, 3, 3)
            assert pred_rotmat_rh.shape[1] == 1 + 15
        elif self.face_only_mode:
            pred_rotmat_fa = rot6d_to_rotmat(
                torch.cat([pred_orient_fa, pred_face], dim=1).reshape(batch_size, -1, 6)
            )    # .view(batch_size, 16, 3, 3)
            assert pred_rotmat_fa.shape[1] == 1 + 3
        elif self.full_body_mode or self.body_hand_mode:
            if cfg.MODEL.PyMAF.OPT_WRIST:
                pred_rotmat = pred_rotmat_body
            else:
                pred_rotmat = rot6d_to_rotmat(
                    pred_pose.reshape(batch_size, -1, 6)
                )    # .view(batch_size, 24, 3, 3)
            assert pred_rotmat.shape[1] == 24
        else:
            pred_rotmat = rot6d_to_rotmat(
                pred_pose.reshape(batch_size, -1, 6)
            )    # .view(batch_size, 24, 3, 3)
            assert pred_rotmat.shape[1] == 24

        # if self.full_body_mode:
        if self.smplx_mode:
            if cfg.MODEL.PyMAF.PRED_VIS_H and global_iter == (cfg.MODEL.PyMAF.N_ITER - 1):
                pred_lhand_filtered = [
                    pred_lhand[_i] if pred_vis_lhand[_i] else self.init_rhand[0]
                    for _i in range(batch_size)
                ]
                pred_rhand_filtered = [
                    pred_rhand[_i] if pred_vis_rhand[_i] else self.init_rhand[0]
                    for _i in range(batch_size)
                ]
                pred_lhand_filtered = torch.stack(pred_lhand_filtered)
                pred_rhand_filtered = torch.stack(pred_rhand_filtered)
                pred_hf6d = torch.cat([pred_lhand_filtered, pred_rhand_filtered, pred_face],
                                      dim=1).reshape(batch_size, -1, 6)
            else:
                pred_hf6d = torch.cat([pred_lhand, pred_rhand, pred_face],
                                      dim=1).reshape(batch_size, -1, 6)
            pred_hfrotmat = rot6d_to_rotmat(pred_hf6d)
            assert pred_hfrotmat.shape[1] == (15 * 2 + 3)

            # flip left hand pose
            pred_lhand_rotmat = pred_hfrotmat[:, :15] * self.flip_vector.to(pred_hfrotmat.device
                                                                           ).unsqueeze(0)
            pred_rhand_rotmat = pred_hfrotmat[:, 15:30]
            pred_face_rotmat = pred_hfrotmat[:, 30:]

        if self.hand_only_mode:
            pred_output = self.mano(
                betas=pred_shape_rh,
                right_hand_pose=pred_rotmat_rh[:, 1:],
                global_orient=pred_rotmat_rh[:, 0].unsqueeze(1),
                pose2rot=False,
            )
        elif self.face_only_mode:
            pred_output = self.flame(
                betas=pred_shape_fa,
                global_orient=pred_rotmat_fa[:, 0].unsqueeze(1),
                jaw_pose=pred_rotmat_fa[:, 1:2],
                leye_pose=pred_rotmat_fa[:, 2:3],
                reye_pose=pred_rotmat_fa[:, 3:4],
                expression=pred_exp,
                pose2rot=False,
            )
        else:
            smplx_kwargs = {}
            # if self.full_body_mode:
            if self.smplx_mode:
                smplx_kwargs['left_hand_pose'] = pred_lhand_rotmat
                smplx_kwargs['right_hand_pose'] = pred_rhand_rotmat
                smplx_kwargs['jaw_pose'] = pred_face_rotmat[:, 0:1]
                smplx_kwargs['leye_pose'] = pred_face_rotmat[:, 1:2]
                smplx_kwargs['reye_pose'] = pred_face_rotmat[:, 2:3]
                smplx_kwargs['expression'] = pred_exp

            pred_output = self.smpl(
                betas=pred_shape,
                body_pose=pred_rotmat[:, 1:],
                global_orient=pred_rotmat[:, 0].unsqueeze(1),
                pose2rot=False,
                **smplx_kwargs,
            )

            pred_vertices = pred_output.vertices
            pred_joints = pred_output.joints

        if self.hand_only_mode:
            pred_joints_full = pred_output.rhand_joints
        elif self.face_only_mode:
            pred_joints_full = pred_output.face_joints
        elif self.smplx_mode:
            pred_joints_full = torch.cat(
                [
                    pred_joints, pred_output.lhand_joints, pred_output.rhand_joints,
                    pred_output.face_joints, pred_output.lfoot_joints, pred_output.rfoot_joints
                ],
                dim=1
            )
        else:
            pred_joints_full = pred_joints
        pred_keypoints_2d = projection(
            pred_joints_full, {
                **rw_cam, 'cam_sxy': pred_cam
            }, iwp_mode=cfg.MODEL.USE_IWP_CAM
        )
        if cfg.MODEL.USE_IWP_CAM:
            # Normalize keypoints to [-1,1]
            pred_keypoints_2d = pred_keypoints_2d / (224. / 2.)
        else:
            pred_keypoints_2d = j2d_processing(pred_keypoints_2d, rw_cam['kps_transf'])

        len_b_kp = len(constants.JOINT_NAMES)
        output = {}
        if self.smpl_mode or self.smplx_mode:
            if J_regressor is not None:
                kp_3d = torch.matmul(J_regressor, pred_vertices)
                pred_pelvis = kp_3d[:, [0], :].clone()
                kp_3d = kp_3d[:, constants.H36M_TO_J14, :]
                kp_3d = kp_3d - pred_pelvis
            else:
                kp_3d = pred_joints
            pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3, 3)).reshape(-1, 72)
            output.update(
                {
                    'theta': torch.cat([pred_cam, pred_shape, pose], dim=1),
                    'verts': pred_vertices,
                    'kp_2d': pred_keypoints_2d[:, :len_b_kp],
                    'kp_3d': kp_3d,
                    'pred_joints': pred_joints,
                    'smpl_kp_3d': pred_output.smpl_joints,
                    'rotmat': pred_rotmat,
                    'pred_cam': pred_cam,
                    'pred_shape': pred_shape,
                    'pred_pose': pred_pose,
                }
            )
            # if self.full_body_mode:
            if self.smplx_mode:
                # assert pred_keypoints_2d.shape[1] == 144
                len_h_kp = len(constants.HAND_NAMES)
                len_f_kp = len(constants.FACIAL_LANDMARKS)
                len_feet_kp = 2 * len(constants.FOOT_NAMES)
                output.update(
                    {
                        'smplx_verts':
                            pred_output.smplx_vertices if cfg.MODEL.EVAL_MODE else None,
                        'pred_lhand':
                            pred_lhand,
                        'pred_rhand':
                            pred_rhand,
                        'pred_face':
                            pred_face,
                        'pred_exp':
                            pred_exp,
                        'verts_lh':
                            pred_output.lhand_vertices,
                        'verts_rh':
                            pred_output.rhand_vertices,
                # 'pred_arm_rotmat': pred_arm_rotmat,
                # 'pred_hfrotmat': pred_hfrotmat,
                        'pred_lhand_rotmat':
                            pred_lhand_rotmat,
                        'pred_rhand_rotmat':
                            pred_rhand_rotmat,
                        'pred_face_rotmat':
                            pred_face_rotmat,
                        'pred_lhand_kp3d':
                            pred_output.lhand_joints,
                        'pred_rhand_kp3d':
                            pred_output.rhand_joints,
                        'pred_face_kp3d':
                            pred_output.face_joints,
                        'pred_lhand_kp2d':
                            pred_keypoints_2d[:, len_b_kp:len_b_kp + len_h_kp],
                        'pred_rhand_kp2d':
                            pred_keypoints_2d[:, len_b_kp + len_h_kp:len_b_kp + len_h_kp * 2],
                        'pred_face_kp2d':
                            pred_keypoints_2d[:, len_b_kp + len_h_kp * 2:len_b_kp + len_h_kp * 2 +
                                              len_f_kp],
                        'pred_feet_kp2d':
                            pred_keypoints_2d[:, len_b_kp + len_h_kp * 2 + len_f_kp:len_b_kp +
                                              len_h_kp * 2 + len_f_kp + len_feet_kp],
                    }
                )
                if cfg.MODEL.PyMAF.OPT_WRIST:
                    output.update(
                        {
                            'pred_orient_lh': pred_orient_lh,
                            'pred_shape_lh': pred_shape_lh,
                            'pred_orient_rh': pred_orient_rh,
                            'pred_shape_rh': pred_shape_rh,
                            'pred_cam_fa': pred_cam_fa,
                            'pred_cam_lh': pred_cam_lh,
                            'pred_cam_rh': pred_cam_rh,
                        }
                    )
                if cfg.MODEL.PyMAF.PRED_VIS_H:
                    output.update({'pred_vis_hands': pred_vis_hands})
        elif self.hand_only_mode:
            # hand mesh out
            assert pred_keypoints_2d.shape[1] == 21
            output.update(
                {
                    'theta': pred_cam,
                    'pred_cam': pred_cam,
                    'pred_rhand': pred_rhand,
                    'pred_rhand_rotmat': pred_rotmat_rh[:, 1:],
                    'pred_orient_rh': pred_orient_rh,
                    'pred_orient_rh_rotmat': pred_rotmat_rh[:, 0],
                    'verts_rh': pred_output.rhand_vertices,
                    'pred_cam_rh': pred_cam_rh,
                    'pred_shape_rh': pred_shape_rh,
                    'pred_rhand_kp3d': pred_output.rhand_joints,
                    'pred_rhand_kp2d': pred_keypoints_2d,
                }
            )
        elif self.face_only_mode:
            # face mesh out
            assert pred_keypoints_2d.shape[1] == 68
            output.update(
                {
                    'theta': pred_cam,
                    'pred_cam': pred_cam,
                    'pred_face': pred_face,
                    'pred_exp': pred_exp,
                    'pred_face_rotmat': pred_rotmat_fa[:, 1:],
                    'pred_orient_fa': pred_orient_fa,
                    'pred_orient_fa_rotmat': pred_rotmat_fa[:, 0],
                    'verts_fa': pred_output.flame_vertices,
                    'pred_cam_fa': pred_cam_fa,
                    'pred_shape_fa': pred_shape_fa,
                    'pred_face_kp3d': pred_output.face_joints,
                    'pred_face_kp2d': pred_keypoints_2d,
                }
            )
        return output


def get_attention_modules(
    module_keys, img_feature_dim_list, hidden_feat_dim, n_iter, num_attention_heads=1
):

    align_attention = nn.ModuleDict()
    for k in module_keys:
        align_attention[k] = nn.ModuleList()
        for i in range(n_iter):
            align_attention[k].append(
                get_att_block(
                    img_feature_dim=img_feature_dim_list[k][i],
                    hidden_feat_dim=hidden_feat_dim,
                    num_attention_heads=num_attention_heads
                )
            )

    return align_attention


def get_fusion_modules(module_keys, ma_feat_dim, grid_feat_dim, n_iter, out_feat_len):

    feat_fusion = nn.ModuleDict()
    for k in module_keys:
        feat_fusion[k] = nn.ModuleList()
        for i in range(n_iter):
            feat_fusion[k].append(nn.Linear(grid_feat_dim + ma_feat_dim[k], out_feat_len[k]))

    return feat_fusion


class PyMAF(nn.Module):
    """ PyMAF based Regression Network for Human Mesh Recovery / Full-body Mesh Recovery
    PyMAF: 3D Human Pose and Shape Regression with Pyramidal Mesh Alignment Feedback Loop, in ICCV, 2021
    PyMAF-X: Towards Well-aligned Full-body Model Regression from Monocular Images, arXiv:2207.06400, 2022
    """
    def __init__(
        self, smpl_mean_params=SMPL_MEAN_PARAMS, pretrained=True, device=torch.device('cuda')
    ):
        super().__init__()

        self.device = device

        self.smpl_mode = (cfg.MODEL.MESH_MODEL == 'smpl')
        self.smplx_mode = (cfg.MODEL.MESH_MODEL == 'smplx')

        assert cfg.TRAIN.BHF_MODE in [
            'body_only', 'hand_only', 'face_only', 'body_hand', 'full_body'
        ]
        self.hand_only_mode = (cfg.TRAIN.BHF_MODE == 'hand_only')
        self.face_only_mode = (cfg.TRAIN.BHF_MODE == 'face_only')
        self.body_hand_mode = (cfg.TRAIN.BHF_MODE == 'body_hand')
        self.full_body_mode = (cfg.TRAIN.BHF_MODE == 'full_body')

        bhf_names = []
        if cfg.TRAIN.BHF_MODE in ['body_only', 'body_hand', 'full_body']:
            bhf_names.append('body')
        if cfg.TRAIN.BHF_MODE in ['hand_only', 'body_hand', 'full_body']:
            bhf_names.append('hand')
        if cfg.TRAIN.BHF_MODE in ['face_only', 'full_body']:
            bhf_names.append('face')
        self.bhf_names = bhf_names

        self.part_module_names = {'body': {}, 'hand': {}, 'face': {}, 'link': {}}

        # the limb parts need to be handled
        if self.hand_only_mode:
            self.part_names = ['rhand']
        elif self.face_only_mode:
            self.part_names = ['face']
        elif self.body_hand_mode:
            self.part_names = ['lhand', 'rhand']
        elif self.full_body_mode:
            self.part_names = ['lhand', 'rhand', 'face']
        else:
            self.part_names = []

        # joint index info
        if not self.smpl_mode:
            h_root_idx = constants.HAND_NAMES.index('wrist')
            h_idx = constants.HAND_NAMES.index('middle1')
            f_idx = constants.FACIAL_LANDMARKS.index('nose_middle')
            self.hf_center_idx = {'lhand': h_idx, 'rhand': h_idx, 'face': f_idx}
            self.hf_root_idx = {'lhand': h_root_idx, 'rhand': h_root_idx, 'face': f_idx}

            lh_idx_coco = constants.COCO_KEYPOINTS.index('left_wrist')
            rh_idx_coco = constants.COCO_KEYPOINTS.index('right_wrist')
            f_idx_coco = constants.COCO_KEYPOINTS.index('nose')
            self.hf_root_idx_coco = {'lhand': lh_idx_coco, 'rhand': rh_idx_coco, 'face': f_idx_coco}

        # create parametric mesh models
        self.smpl_family = {}
        if self.hand_only_mode and cfg.MODEL.MESH_MODEL == 'mano':
            self.smpl_family['hand'] = SMPL_Family(model_type='mano')
            self.smpl_family['body'] = SMPL_Family(model_type='smplx')
        elif self.face_only_mode and cfg.MODEL.MESH_MODEL == 'flame':
            self.smpl_family['face'] = SMPL_Family(model_type='flame')
            self.smpl_family['body'] = SMPL_Family(model_type='smplx')
        else:
            self.smpl_family['body'] = SMPL_Family(
                model_type=cfg.MODEL.MESH_MODEL, all_gender=cfg.MODEL.ALL_GENDER
            )

        self.init_mesh_output = None
        self.batch_size = 1

        self.encoders = nn.ModuleDict()
        self.global_mode = not cfg.MODEL.PyMAF.MAF_ON

        # build encoders
        global_feat_dim = 2048
        bhf_ma_feat_dim = {}
        # encoder for the body part
        if 'body' in bhf_names:
            # if self.smplx_mode or 'hr' in cfg.MODEL.PyMAF.BACKBONE:
            if cfg.MODEL.PyMAF.BACKBONE == 'res50':
                body_encoder = get_resnet_encoder(
                    cfg, init_weight=(not cfg.MODEL.EVAL_MODE), global_mode=self.global_mode
                )
                body_sfeat_dim = list(cfg.POSE_RES_MODEL.EXTRA.NUM_DECONV_FILTERS)
            elif cfg.MODEL.PyMAF.BACKBONE == 'hr48':
                body_encoder = get_hrnet_encoder(
                    cfg, init_weight=(not cfg.MODEL.EVAL_MODE), global_mode=self.global_mode
                )
                body_sfeat_dim = list(cfg.HR_MODEL.EXTRA.STAGE4.NUM_CHANNELS)
                body_sfeat_dim.reverse()
                body_sfeat_dim = body_sfeat_dim[1:]
            else:
                raise NotImplementedError
            self.encoders['body'] = body_encoder
            self.part_module_names['body'].update({'encoders.body': self.encoders['body']})

            self.mesh_sampler = Mesh_Sampler(type='smpl')
            self.part_module_names['body'].update({'mesh_sampler': self.mesh_sampler})

            if not cfg.MODEL.PyMAF.GRID_FEAT:
                ma_feat_dim = self.mesh_sampler.Dmap.shape[0] * cfg.MODEL.PyMAF.MLP_DIM[-1]
            else:
                ma_feat_dim = 0
            bhf_ma_feat_dim['body'] = ma_feat_dim

            dp_feat_dim = body_sfeat_dim[-1]
            self.with_uv = cfg.LOSS.POINT_REGRESSION_WEIGHTS > 0
            if cfg.MODEL.PyMAF.AUX_SUPV_ON:
                assert cfg.MODEL.PyMAF.MAF_ON
                self.dp_head = IUV_predict_layer(feat_dim=dp_feat_dim)
                self.part_module_names['body'].update({'dp_head': self.dp_head})

        # encoders for the hand / face parts
        if 'hand' in self.bhf_names or 'face' in self.bhf_names:
            for hf in ['hand', 'face']:
                if hf in bhf_names:
                    if cfg.MODEL.PyMAF.HF_BACKBONE == 'res50':
                        self.encoders[hf] = get_resnet_encoder(
                            cfg,
                            init_weight=(not cfg.MODEL.EVAL_MODE),
                            global_mode=self.global_mode
                        )
                        self.part_module_names[hf].update({f'encoders.{hf}': self.encoders[hf]})
                        hf_sfeat_dim = list(cfg.POSE_RES_MODEL.EXTRA.NUM_DECONV_FILTERS)
                    else:
                        raise NotImplementedError

            if cfg.MODEL.PyMAF.HF_AUX_SUPV_ON:
                assert cfg.MODEL.PyMAF.MAF_ON
                self.dp_head_hf = nn.ModuleDict()
                if 'hand' in bhf_names:
                    self.dp_head_hf['hand'] = IUV_predict_layer(
                        feat_dim=hf_sfeat_dim[-1], mode='pncc'
                    )
                    self.part_module_names['hand'].update(
                        {'dp_head_hf.hand': self.dp_head_hf['hand']}
                    )
                if 'face' in bhf_names:
                    self.dp_head_hf['face'] = IUV_predict_layer(
                        feat_dim=hf_sfeat_dim[-1], mode='pncc'
                    )
                    self.part_module_names['face'].update(
                        {'dp_head_hf.face': self.dp_head_hf['face']}
                    )

            smpl2limb_vert_faces = get_partial_smpl()

            self.smpl2lhand = torch.from_numpy(smpl2limb_vert_faces['lhand']['vids']).long()
            self.smpl2rhand = torch.from_numpy(smpl2limb_vert_faces['rhand']['vids']).long()

        # grid points for grid feature extraction
        grid_size = 21
        xv, yv = torch.meshgrid(
            [torch.linspace(-1, 1, grid_size),
             torch.linspace(-1, 1, grid_size)]
        )
        grid_points = torch.stack([xv.reshape(-1), yv.reshape(-1)]).unsqueeze(0)
        self.register_buffer('grid_points', grid_points)
        grid_feat_dim = grid_size * grid_size * cfg.MODEL.PyMAF.MLP_DIM[-1]

        # the fusion of grid and mesh-aligned features
        self.fuse_grid_align = cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT or cfg.MODEL.PyMAF.GRID_ALIGN.USE_FC
        assert not (cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT and cfg.MODEL.PyMAF.GRID_ALIGN.USE_FC)

        if self.fuse_grid_align:
            self.att_starts = cfg.MODEL.PyMAF.GRID_ALIGN.ATT_STARTS
            n_iter_att = cfg.MODEL.PyMAF.N_ITER - self.att_starts
            att_feat_dim_idx = -cfg.MODEL.PyMAF.GRID_ALIGN.ATT_FEAT_IDX
            num_att_heads = cfg.MODEL.PyMAF.GRID_ALIGN.ATT_HEAD
            hidden_feat_dim = cfg.MODEL.PyMAF.MLP_DIM[att_feat_dim_idx]
            bhf_att_feat_dim = {'body': 2048}

        if 'hand' in self.bhf_names:
            self.mano_sampler = Mesh_Sampler(type='mano', level=1)
            self.mano_ds_len = self.mano_sampler.Dmap.shape[0]
            self.part_module_names['hand'].update({'mano_sampler': self.mano_sampler})

            bhf_ma_feat_dim.update({'hand': self.mano_ds_len * cfg.MODEL.PyMAF.HF_MLP_DIM[-1]})

            if self.fuse_grid_align:
                bhf_att_feat_dim.update({'hand': 1024})

        if 'face' in self.bhf_names:
            bhf_ma_feat_dim.update(
                {'face': len(constants.FACIAL_LANDMARKS) * cfg.MODEL.PyMAF.HF_MLP_DIM[-1]}
            )
            if self.fuse_grid_align:
                bhf_att_feat_dim.update({'face': 1024})

        # spatial alignment attention
        if cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT:
            hfimg_feat_dim_list = {}
            if 'body' in bhf_names:
                hfimg_feat_dim_list['body'] = body_sfeat_dim[-n_iter_att:]

            if 'hand' in self.bhf_names or 'face' in self.bhf_names:
                if 'hand' in bhf_names:
                    hfimg_feat_dim_list['hand'] = hf_sfeat_dim[-n_iter_att:]
                if 'face' in bhf_names:
                    hfimg_feat_dim_list['face'] = hf_sfeat_dim[-n_iter_att:]

            self.align_attention = get_attention_modules(
                bhf_names,
                hfimg_feat_dim_list,
                hidden_feat_dim,
                n_iter=n_iter_att,
                num_attention_heads=num_att_heads
            )

            for part in bhf_names:
                self.part_module_names[part].update(
                    {f'align_attention.{part}': self.align_attention[part]}
                )

        if self.fuse_grid_align:
            self.att_feat_reduce = get_fusion_modules(
                bhf_names,
                bhf_ma_feat_dim,
                grid_feat_dim,
                n_iter=n_iter_att,
                out_feat_len=bhf_att_feat_dim
            )
            for part in bhf_names:
                self.part_module_names[part].update(
                    {f'att_feat_reduce.{part}': self.att_feat_reduce[part]}
                )

        # build regressor for parameter prediction
        self.regressor = nn.ModuleList()
        for i in range(cfg.MODEL.PyMAF.N_ITER):
            ref_infeat_dim = 0
            if 'body' in self.bhf_names:
                if cfg.MODEL.PyMAF.MAF_ON:
                    if self.fuse_grid_align:
                        if i >= self.att_starts:
                            ref_infeat_dim = bhf_att_feat_dim['body']
                        elif i == 0 or cfg.MODEL.PyMAF.GRID_FEAT:
                            ref_infeat_dim = grid_feat_dim
                        else:
                            ref_infeat_dim = ma_feat_dim
                    else:
                        if i == 0 or cfg.MODEL.PyMAF.GRID_FEAT:
                            ref_infeat_dim = grid_feat_dim
                        else:
                            ref_infeat_dim = ma_feat_dim
                else:
                    ref_infeat_dim = global_feat_dim

            if self.smpl_mode:
                self.regressor.append(
                    Regressor(
                        feat_dim=ref_infeat_dim,
                        smpl_mean_params=smpl_mean_params,
                        use_cam_feats=cfg.MODEL.PyMAF.USE_CAM_FEAT,
                        smpl_models=self.smpl_family
                    )
                )
            else:
                if cfg.MODEL.PyMAF.MAF_ON:
                    if 'hand' in self.bhf_names or 'face' in self.bhf_names:
                        if i == 0:
                            feat_dim_hand = grid_feat_dim if 'hand' in self.bhf_names else None
                            feat_dim_face = grid_feat_dim if 'face' in self.bhf_names else None
                        else:
                            if self.fuse_grid_align:
                                feat_dim_hand = bhf_att_feat_dim[
                                    'hand'] if 'hand' in self.bhf_names else None
                                feat_dim_face = bhf_att_feat_dim[
                                    'face'] if 'face' in self.bhf_names else None
                            else:
                                feat_dim_hand = bhf_ma_feat_dim[
                                    'hand'] if 'hand' in self.bhf_names else None
                                feat_dim_face = bhf_ma_feat_dim[
                                    'face'] if 'face' in self.bhf_names else None
                    else:
                        feat_dim_hand = ref_infeat_dim
                        feat_dim_face = ref_infeat_dim
                else:
                    ref_infeat_dim = global_feat_dim
                    feat_dim_hand = global_feat_dim
                    feat_dim_face = global_feat_dim

                self.regressor.append(
                    Regressor(
                        feat_dim=ref_infeat_dim,
                        smpl_mean_params=smpl_mean_params,
                        use_cam_feats=cfg.MODEL.PyMAF.USE_CAM_FEAT,
                        feat_dim_hand=feat_dim_hand,
                        feat_dim_face=feat_dim_face,
                        bhf_names=bhf_names,
                        smpl_models=self.smpl_family
                    )
                )

            # assign sub-regressor to each part
            for dec_name, dec_module in self.regressor[-1].named_children():
                if 'hand' in dec_name:
                    self.part_module_names['hand'].update(
                        {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}
                    )
                elif 'face' in dec_name or 'head' in dec_name or 'exp' in dec_name:
                    self.part_module_names['face'].update(
                        {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}
                    )
                elif 'res' in dec_name or 'vis' in dec_name:
                    self.part_module_names['link'].update(
                        {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}
                    )
                elif 'body' in self.part_module_names:
                    self.part_module_names['body'].update(
                        {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}
                    )

        # mesh-aligned feature extractor
        self.maf_extractor = nn.ModuleDict()
        for part in bhf_names:
            self.maf_extractor[part] = nn.ModuleList()
            filter_channels_default = cfg.MODEL.PyMAF.MLP_DIM if part == 'body' else cfg.MODEL.PyMAF.HF_MLP_DIM
            sfeat_dim = body_sfeat_dim if part == 'body' else hf_sfeat_dim
            for i in range(cfg.MODEL.PyMAF.N_ITER):
                for f_i, f_dim in enumerate(filter_channels_default):
                    if sfeat_dim[i] > f_dim:
                        filter_start = f_i
                        break
                filter_channels = [sfeat_dim[i]] + filter_channels_default[filter_start:]

                if cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT and i >= self.att_starts:
                    self.maf_extractor[part].append(
                        MAF_Extractor(
                            filter_channels=filter_channels_default[att_feat_dim_idx:],
                            iwp_cam_mode=cfg.MODEL.USE_IWP_CAM
                        )
                    )
                else:
                    self.maf_extractor[part].append(
                        MAF_Extractor(
                            filter_channels=filter_channels, iwp_cam_mode=cfg.MODEL.USE_IWP_CAM
                        )
                    )
            self.part_module_names[part].update({f'maf_extractor.{part}': self.maf_extractor[part]})

        # check all modules have been added to part_module_names
        model_dict_all = dict.fromkeys(self.state_dict().keys())
        for key in self.part_module_names.keys():
            for name in list(model_dict_all.keys()):
                for k in self.part_module_names[key].keys():
                    if name.startswith(k):
                        del model_dict_all[name]
                # if name.startswith('regressor.') and '.smpl.' in name:
                #     del model_dict_all[name]
                # if name.startswith('regressor.') and '.mano.' in name:
                #     del model_dict_all[name]
                if name.startswith('regressor.') and '.init_' in name:
                    del model_dict_all[name]
                if name == 'grid_points':
                    del model_dict_all[name]
        assert (len(model_dict_all.keys()) == 0)

    def init_mesh(self, batch_size, J_regressor=None, rw_cam={}):
        """ initialize the mesh model with default poses and shapes
        """
        if self.init_mesh_output is None or self.batch_size != batch_size:
            self.init_mesh_output = self.regressor[0](
                torch.zeros(batch_size), J_regressor=J_regressor, rw_cam=rw_cam, init_mode=True
            )
            self.batch_size = batch_size
        return self.init_mesh_output

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.inplanes,
                    planes * block.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False
                ),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
        """
        Deconv_layer used in Simple Baselines:
        Xiao et al. Simple Baselines for Human Pose Estimation and Tracking
        https://github.com/microsoft/human-pose-estimation.pytorch
        """
        assert num_layers == len(num_filters), \
            'ERROR: num_deconv_layers is different len(num_deconv_filters)'
        assert num_layers == len(num_kernels), \
            'ERROR: num_deconv_layers is different len(num_deconv_filters)'

        def _get_deconv_cfg(deconv_kernel, index):
            if deconv_kernel == 4:
                padding = 1
                output_padding = 0
            elif deconv_kernel == 3:
                padding = 1
                output_padding = 1
            elif deconv_kernel == 2:
                padding = 0
                output_padding = 0

            return deconv_kernel, padding, output_padding

        layers = []
        for i in range(num_layers):
            kernel, padding, output_padding = _get_deconv_cfg(num_kernels[i], i)

            planes = num_filters[i]
            layers.append(
                nn.ConvTranspose2d(
                    in_channels=self.inplanes,
                    out_channels=planes,
                    kernel_size=kernel,
                    stride=2,
                    padding=padding,
                    output_padding=output_padding,
                    bias=self.deconv_with_bias
                )
            )
            layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
            layers.append(nn.ReLU(inplace=True))
            self.inplanes = planes

        return nn.Sequential(*layers)

    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        for m in ['body', 'hand', 'face']:
            if m in self.smpl_family:
                self.smpl_family[m].model.to(*args, **kwargs)
        return self

    def cuda(self, *args, **kwargs):
        super().cuda(*args, **kwargs)
        for m in ['body', 'hand', 'face']:
            if m in self.smpl_family:
                self.smpl_family[m].model.cuda(*args, **kwargs)
        return self

    def forward(self, batch={}, J_regressor=None, rw_cam={}):
        '''
        Args:
            batch: input dictionary, including 
                   images: 'img_{part}', for part in body, hand, and face if applicable
                   inversed affine transformation for the cropping of hand/face images: '{part}_theta_inv' for part in lhand, rhand, and face if applicable
            J_regressor: joint regression matrix
            rw_cam: real-world camera information, applied when cfg.MODEL.USE_IWP_CAM is False
        Returns:
            out_dict: the list containing the predicted parameters
            vis_feat_list: the list containing features for visualization
        '''

        # batch keys: ['img_body', 'orig_height', 'orig_width', 'person_id', 'img_lhand',
        # 'lhand_theta_inv', 'img_rhand', 'rhand_theta_inv', 'img_face', 'face_theta_inv']

        # extract spatial features or global features
        # run encoder for body
        if 'body' in self.bhf_names:
            img_body = batch['img_body']
            batch_size = img_body.shape[0]
            s_feat_body, g_feat = self.encoders['body'](batch['img_body'])
            if cfg.MODEL.PyMAF.MAF_ON:
                assert len(s_feat_body) == cfg.MODEL.PyMAF.N_ITER

        # run encoders for hand / face
        if 'hand' in self.bhf_names or 'face' in self.bhf_names:
            limb_feat_dict = {}
            limb_gfeat_dict = {}
            if 'face' in self.bhf_names:
                img_face = batch['img_face']
                batch_size = img_face.shape[0]
                limb_feat_dict['face'], limb_gfeat_dict['face'] = self.encoders['face'](img_face)

            if 'hand' in self.bhf_names:
                if 'lhand' in self.part_names:
                    img_rhand = batch['img_rhand']
                    batch_size = img_rhand.shape[0]
                    # flip left hand images
                    img_lhand = torch.flip(batch['img_lhand'], [3])
                    img_hands = torch.cat([img_rhand, img_lhand])
                    s_feat_hands, g_feat_hands = self.encoders['hand'](img_hands)
                    limb_feat_dict['rhand'] = [feat[:batch_size] for feat in s_feat_hands]
                    limb_feat_dict['lhand'] = [feat[batch_size:] for feat in s_feat_hands]
                    if g_feat_hands is not None:
                        limb_gfeat_dict['rhand'] = g_feat_hands[:batch_size]
                        limb_gfeat_dict['lhand'] = g_feat_hands[batch_size:]
                else:
                    img_rhand = batch['img_rhand']
                    batch_size = img_rhand.shape[0]
                    limb_feat_dict['rhand'], limb_gfeat_dict['rhand'] = self.encoders['hand'](
                        img_rhand
                    )

            if cfg.MODEL.PyMAF.MAF_ON:
                for k in limb_feat_dict.keys():
                    assert len(limb_feat_dict[k]) == cfg.MODEL.PyMAF.N_ITER

        out_dict = {}

        # grid-pattern points
        grid_points = torch.transpose(self.grid_points.expand(batch_size, -1, -1), 1, 2)

        # initial parameters
        mesh_output = self.init_mesh(batch_size, J_regressor, rw_cam)

        out_dict['mesh_out'] = [mesh_output]
        out_dict['dp_out'] = []

        # for visulization
        vis_feat_list = []

        # dense prediction during training
        if not cfg.MODEL.EVAL_MODE:
            if 'body' in self.bhf_names:
                if cfg.MODEL.PyMAF.AUX_SUPV_ON:
                    iuv_out_dict = self.dp_head(s_feat_body[-1])
                    out_dict['dp_out'].append(iuv_out_dict)
            elif self.hand_only_mode:
                if cfg.MODEL.PyMAF.HF_AUX_SUPV_ON:
                    out_dict['rhand_dpout'] = []
                    dphand_out_dict = self.dp_head_hf['hand'](limb_feat_dict['rhand'][-1])
                    out_dict['rhand_dpout'].append(dphand_out_dict)
            elif self.face_only_mode:
                if cfg.MODEL.PyMAF.HF_AUX_SUPV_ON:
                    out_dict['face_dpout'] = []
                    dpface_out_dict = self.dp_head_hf['face'](limb_feat_dict['face'][-1])
                    out_dict['face_dpout'].append(dpface_out_dict)

        # parameter predictions
        for rf_i in range(cfg.MODEL.PyMAF.N_ITER):
            current_states = {}
            if 'body' in self.bhf_names:
                pred_cam = mesh_output['pred_cam'].detach()
                pred_shape = mesh_output['pred_shape'].detach()
                pred_pose = mesh_output['pred_pose'].detach()

                current_states['init_cam'] = pred_cam
                current_states['init_shape'] = pred_shape
                current_states['init_pose'] = pred_pose

                pred_smpl_verts = mesh_output['verts'].detach()

                if cfg.MODEL.PyMAF.MAF_ON:
                    s_feat_i = s_feat_body[rf_i]

            # re-project mesh on the image plane
            if self.hand_only_mode:
                pred_cam = mesh_output['pred_cam'].detach()
                pred_rhand_v = self.mano_sampler(mesh_output['verts_rh'])
                pred_rhand_proj = projection(
                    pred_rhand_v, {
                        **rw_cam, 'cam_sxy': pred_cam
                    }, iwp_mode=cfg.MODEL.USE_IWP_CAM
                )
                if cfg.MODEL.USE_IWP_CAM:
                    pred_rhand_proj = pred_rhand_proj / (224. / 2.)
                else:
                    pred_rhand_proj = j2d_processing(pred_rhand_proj, rw_cam['kps_transf'])
                proj_hf_center = {
                    'rhand':
                        mesh_output['pred_rhand_kp2d'][:, self.hf_root_idx['rhand']].unsqueeze(1)
                }
                proj_hf_pts = {
                    'rhand': torch.cat([proj_hf_center['rhand'], pred_rhand_proj], dim=1)
                }
            elif self.face_only_mode:
                pred_cam = mesh_output['pred_cam'].detach()
                pred_face_v = mesh_output['pred_face_kp3d']
                pred_face_proj = projection(
                    pred_face_v, {
                        **rw_cam, 'cam_sxy': pred_cam
                    }, iwp_mode=cfg.MODEL.USE_IWP_CAM
                )
                if cfg.MODEL.USE_IWP_CAM:
                    pred_face_proj = pred_face_proj / (224. / 2.)
                else:
                    pred_face_proj = j2d_processing(pred_face_proj, rw_cam['kps_transf'])
                proj_hf_center = {
                    'face': mesh_output['pred_face_kp2d'][:, self.hf_root_idx['face']].unsqueeze(1)
                }
                proj_hf_pts = {'face': torch.cat([proj_hf_center['face'], pred_face_proj], dim=1)}
            elif self.body_hand_mode:
                pred_lhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2lhand])
                pred_rhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2rhand])
                pred_hand_v = torch.cat([pred_lhand_v, pred_rhand_v], dim=1)
                pred_hand_proj = projection(
                    pred_hand_v, {
                        **rw_cam, 'cam_sxy': pred_cam
                    }, iwp_mode=cfg.MODEL.USE_IWP_CAM
                )
                if cfg.MODEL.USE_IWP_CAM:
                    pred_hand_proj = pred_hand_proj / (224. / 2.)
                else:
                    pred_hand_proj = j2d_processing(pred_hand_proj, rw_cam['kps_transf'])

                proj_hf_center = {
                    'lhand':
                        mesh_output['pred_lhand_kp2d'][:, self.hf_root_idx['lhand']].unsqueeze(1),
                    'rhand':
                        mesh_output['pred_rhand_kp2d'][:, self.hf_root_idx['rhand']].unsqueeze(1),
                }
                proj_hf_pts = {
                    'lhand':
                        torch.cat(
                            [proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]], dim=1
                        ),
                    'rhand':
                        torch.cat(
                            [proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]], dim=1
                        ),
                }
            elif self.full_body_mode:
                pred_lhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2lhand])
                pred_rhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2rhand])
                pred_hand_v = torch.cat([pred_lhand_v, pred_rhand_v], dim=1)
                pred_hand_proj = projection(
                    pred_hand_v, {
                        **rw_cam, 'cam_sxy': pred_cam
                    }, iwp_mode=cfg.MODEL.USE_IWP_CAM
                )
                if cfg.MODEL.USE_IWP_CAM:
                    pred_hand_proj = pred_hand_proj / (224. / 2.)
                else:
                    pred_hand_proj = j2d_processing(pred_hand_proj, rw_cam['kps_transf'])

                proj_hf_center = {
                    'lhand':
                        mesh_output['pred_lhand_kp2d'][:, self.hf_root_idx['lhand']].unsqueeze(1),
                    'rhand':
                        mesh_output['pred_rhand_kp2d'][:, self.hf_root_idx['rhand']].unsqueeze(1),
                    'face':
                        mesh_output['pred_face_kp2d'][:, self.hf_root_idx['face']].unsqueeze(1)
                }
                proj_hf_pts = {
                    'lhand':
                        torch.cat(
                            [proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]], dim=1
                        ),
                    'rhand':
                        torch.cat(
                            [proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]], dim=1
                        ),
                    'face':
                        torch.cat([proj_hf_center['face'], mesh_output['pred_face_kp2d']], dim=1)
                }

            # extract mesh-aligned features for the hand / face part
            if 'hand' in self.bhf_names or 'face' in self.bhf_names:
                limb_rf_i = rf_i
                hand_face_feat = {}

                for hf_i, part_name in enumerate(self.part_names):
                    if 'hand' in part_name:
                        hf_key = 'hand'
                    elif 'face' in part_name:
                        hf_key = 'face'

                    if cfg.MODEL.PyMAF.MAF_ON:
                        if cfg.MODEL.PyMAF.HF_BACKBONE == 'res50':
                            limb_feat_i = limb_feat_dict[part_name][limb_rf_i]
                        else:
                            raise NotImplementedError

                        limb_reduce_dim = (not self.fuse_grid_align) or (rf_i < self.att_starts)

                        if limb_rf_i == 0 or cfg.MODEL.PyMAF.GRID_FEAT:
                            limb_ref_feat_ctd = self.maf_extractor[hf_key][limb_rf_i].sampling(
                                grid_points, im_feat=limb_feat_i, reduce_dim=limb_reduce_dim
                            )
                        else:
                            if self.hand_only_mode or self.face_only_mode:
                                proj_hf_pts_crop = proj_hf_pts[part_name][:, :, :2]

                                proj_hf_v_center = proj_hf_pts_crop[:, 0].unsqueeze(1)

                                if cfg.MODEL.PyMAF.HF_BOX_CENTER:
                                    part_box_ul = torch.min(proj_hf_pts_crop, dim=1)[0].unsqueeze(1)
                                    part_box_br = torch.max(proj_hf_pts_crop, dim=1)[0].unsqueeze(1)
                                    part_box_center = (part_box_ul + part_box_br) / 2.
                                    proj_hf_pts_crop_ctd = proj_hf_pts_crop[:, 1:] - part_box_center
                                else:
                                    proj_hf_pts_crop_ctd = proj_hf_pts_crop[:, 1:]

                            elif self.full_body_mode or self.body_hand_mode:
                                # convert projection points to the space of cropped hand/face images
                                theta_i_inv = batch[f'{part_name}_theta_inv']
                                proj_hf_pts_crop = torch.bmm(
                                    theta_i_inv,
                                    homo_vector(proj_hf_pts[part_name][:, :, :2]).permute(0, 2, 1)
                                ).permute(0, 2, 1)

                                if part_name == 'lhand':
                                    flip_x = torch.tensor([-1, 1])[None,
                                                                   None, :].to(proj_hf_pts_crop)
                                    proj_hf_pts_crop *= flip_x

                                if cfg.MODEL.PyMAF.HF_BOX_CENTER:
                                    # align projection points with the cropped image center
                                    part_box_ul = torch.min(proj_hf_pts_crop, dim=1)[0].unsqueeze(1)
                                    part_box_br = torch.max(proj_hf_pts_crop, dim=1)[0].unsqueeze(1)
                                    part_box_center = (part_box_ul + part_box_br) / 2.
                                    proj_hf_pts_crop_ctd = proj_hf_pts_crop[:, 1:] - part_box_center
                                else:
                                    proj_hf_pts_crop_ctd = proj_hf_pts_crop[:, 1:]

                                # 0 is the root point
                                proj_hf_v_center = proj_hf_pts_crop[:, 0].unsqueeze(1)

                            limb_ref_feat_ctd = self.maf_extractor[hf_key][limb_rf_i].sampling(
                                proj_hf_pts_crop_ctd.detach(),
                                im_feat=limb_feat_i,
                                reduce_dim=limb_reduce_dim
                            )

                        if self.fuse_grid_align and limb_rf_i >= self.att_starts:

                            limb_grid_feature_ctd = self.maf_extractor[hf_key][limb_rf_i].sampling(
                                grid_points, im_feat=limb_feat_i, reduce_dim=limb_reduce_dim
                            )
                            limb_grid_ref_feat_ctd = torch.cat(
                                [limb_grid_feature_ctd, limb_ref_feat_ctd], dim=-1
                            ).permute(0, 2, 1)

                            if cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT:
                                att_ref_feat_ctd = self.align_attention[hf_key][
                                    limb_rf_i - self.att_starts](limb_grid_ref_feat_ctd)[0]
                            elif cfg.MODEL.PyMAF.GRID_ALIGN.USE_FC:
                                att_ref_feat_ctd = limb_grid_ref_feat_ctd

                            att_ref_feat_ctd = self.maf_extractor[hf_key][limb_rf_i].reduce_dim(
                                att_ref_feat_ctd.permute(0, 2, 1)
                            ).view(batch_size, -1)
                            limb_ref_feat_ctd = self.att_feat_reduce[hf_key][
                                limb_rf_i - self.att_starts](att_ref_feat_ctd)

                        else:
                            # limb_ref_feat = limb_ref_feat.view(batch_size, -1)
                            limb_ref_feat_ctd = limb_ref_feat_ctd.view(batch_size, -1)
                        hand_face_feat[part_name] = limb_ref_feat_ctd
                    else:
                        hand_face_feat[part_name] = limb_gfeat_dict[part_name]

            # extract mesh-aligned features for the body part
            if 'body' in self.bhf_names:
                if cfg.MODEL.PyMAF.MAF_ON:
                    reduce_dim = (not self.fuse_grid_align) or (rf_i < self.att_starts)
                    if rf_i == 0 or cfg.MODEL.PyMAF.GRID_FEAT:
                        ref_feature = self.maf_extractor['body'][rf_i].sampling(
                            grid_points, im_feat=s_feat_i, reduce_dim=reduce_dim
                        )
                    else:
                        # TODO: use a more sparse SMPL implementation (with 431 vertices) for acceleration
                        pred_smpl_verts_ds = self.mesh_sampler.downsample(
                            pred_smpl_verts
                        )    # [B, 431, 3]
                        ref_feature = self.maf_extractor['body'][rf_i](
                            pred_smpl_verts_ds,
                            im_feat=s_feat_i,
                            cam={
                                **rw_cam, 'cam_sxy': pred_cam
                            },
                            add_att=True,
                            reduce_dim=reduce_dim
                        )    # [B, 431 * n_feat]

                    if self.fuse_grid_align and rf_i >= self.att_starts:
                        if rf_i > 0 and not cfg.MODEL.PyMAF.GRID_FEAT:
                            grid_feature = self.maf_extractor['body'][rf_i].sampling(
                                grid_points, im_feat=s_feat_i, reduce_dim=reduce_dim
                            )
                            grid_ref_feat = torch.cat([grid_feature, ref_feature], dim=-1)
                        else:
                            grid_ref_feat = ref_feature
                        grid_ref_feat = grid_ref_feat.permute(0, 2, 1)

                        if cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT:
                            att_ref_feat = self.align_attention['body'][
                                rf_i - self.att_starts](grid_ref_feat)[0]
                        elif cfg.MODEL.PyMAF.GRID_ALIGN.USE_FC:
                            att_ref_feat = grid_ref_feat

                        att_ref_feat = self.maf_extractor['body'][rf_i].reduce_dim(
                            att_ref_feat.permute(0, 2, 1)
                        )
                        att_ref_feat = att_ref_feat.view(batch_size, -1)

                        ref_feature = self.att_feat_reduce['body'][rf_i -
                                                                   self.att_starts](att_ref_feat)
                    else:
                        ref_feature = ref_feature.view(batch_size, -1)
                else:
                    ref_feature = g_feat
            else:
                ref_feature = None

            if not self.smpl_mode:
                if self.hand_only_mode:
                    current_states['xc_rhand'] = hand_face_feat['rhand']
                elif self.face_only_mode:
                    current_states['xc_face'] = hand_face_feat['face']
                elif self.body_hand_mode:
                    current_states['xc_lhand'] = hand_face_feat['lhand']
                    current_states['xc_rhand'] = hand_face_feat['rhand']
                elif self.full_body_mode:
                    current_states['xc_lhand'] = hand_face_feat['lhand']
                    current_states['xc_rhand'] = hand_face_feat['rhand']
                    current_states['xc_face'] = hand_face_feat['face']

                if rf_i > 0:
                    for part in self.part_names:
                        current_states[f'init_{part}'] = mesh_output[f'pred_{part}'].detach()
                        if part == 'face':
                            current_states['init_exp'] = mesh_output['pred_exp'].detach()
                    if self.hand_only_mode:
                        current_states['init_shape_rh'] = mesh_output['pred_shape_rh'].detach()
                        current_states['init_orient_rh'] = mesh_output['pred_orient_rh'].detach()
                        current_states['init_cam_rh'] = mesh_output['pred_cam_rh'].detach()
                    elif self.face_only_mode:
                        current_states['init_shape_fa'] = mesh_output['pred_shape_fa'].detach()
                        current_states['init_orient_fa'] = mesh_output['pred_orient_fa'].detach()
                        current_states['init_cam_fa'] = mesh_output['pred_cam_fa'].detach()
                    elif self.full_body_mode or self.body_hand_mode:
                        if cfg.MODEL.PyMAF.OPT_WRIST:
                            current_states['init_shape_lh'] = mesh_output['pred_shape_lh'].detach()
                            current_states['init_orient_lh'] = mesh_output['pred_orient_lh'].detach(
                            )
                            current_states['init_cam_lh'] = mesh_output['pred_cam_lh'].detach()

                            current_states['init_shape_rh'] = mesh_output['pred_shape_rh'].detach()
                            current_states['init_orient_rh'] = mesh_output['pred_orient_rh'].detach(
                            )
                            current_states['init_cam_rh'] = mesh_output['pred_cam_rh'].detach()

            # update mesh parameters
            mesh_output = self.regressor[rf_i](
                ref_feature,
                n_iter=1,
                J_regressor=J_regressor,
                rw_cam=rw_cam,
                global_iter=rf_i,
                **current_states
            )

            out_dict['mesh_out'].append(mesh_output)

        return out_dict, vis_feat_list


def pymaf_net(smpl_mean_params, pretrained=True, device=torch.device('cuda')):
    """ Constructs an PyMAF model with ResNet50 backbone.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = PyMAF(smpl_mean_params, pretrained, device)
    return model