Spaces:
Build error
Build error
| import os | |
| import torch | |
| import sys | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src')) | |
| from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d # , batch_rot2aa, geodesic_loss_R | |
| def reset_loss_values(losses): | |
| # losses is a dict | |
| for key, val in losses.items(): | |
| val['value'] = 0.0 | |
| return losses | |
| def get_optimed_pose_with_glob(optimed_orient_6d, optimed_pose_6d): | |
| # optimed_orient_6d: (1, 1, 6) | |
| # optimed_pose_6d: (1, 34, 6) | |
| bs = optimed_pose_6d.shape[0] | |
| assert bs == 1 | |
| optimed_pose_with_glob_6d = torch.cat((optimed_orient_6d, optimed_pose_6d), dim=1) | |
| optimed_pose_with_glob = rot6d_to_rotmat(optimed_pose_with_glob_6d.reshape((-1, 6))).reshape((bs, -1, 3, 3)) | |
| return optimed_pose_with_glob | |