File size: 803 Bytes
753fd9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

import os
import torch

import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src'))
from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d  # , batch_rot2aa, geodesic_loss_R


def reset_loss_values(losses):
    # losses is a dict
    for key, val in losses.items(): 
        val['value'] = 0.0
    return losses

def get_optimed_pose_with_glob(optimed_orient_6d, optimed_pose_6d):
    # optimed_orient_6d: (1, 1, 6)
    # optimed_pose_6d:  (1, 34, 6)
    bs = optimed_pose_6d.shape[0]
    assert bs == 1
    optimed_pose_with_glob_6d = torch.cat((optimed_orient_6d, optimed_pose_6d), dim=1)
    optimed_pose_with_glob = rot6d_to_rotmat(optimed_pose_with_glob_6d.reshape((-1, 6))).reshape((bs, -1, 3, 3))
    return optimed_pose_with_glob