Nadine Rueegg
initial commit with code and data
753fd9a
raw
history blame contribute delete
803 Bytes
import os
import torch
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src'))
from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d # , batch_rot2aa, geodesic_loss_R
def reset_loss_values(losses):
# losses is a dict
for key, val in losses.items():
val['value'] = 0.0
return losses
def get_optimed_pose_with_glob(optimed_orient_6d, optimed_pose_6d):
# optimed_orient_6d: (1, 1, 6)
# optimed_pose_6d: (1, 34, 6)
bs = optimed_pose_6d.shape[0]
assert bs == 1
optimed_pose_with_glob_6d = torch.cat((optimed_orient_6d, optimed_pose_6d), dim=1)
optimed_pose_with_glob = rot6d_to_rotmat(optimed_pose_with_glob_6d.reshape((-1, 6))).reshape((bs, -1, 3, 3))
return optimed_pose_with_glob