File size: 1,612 Bytes
1ba539f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from lib.config import cfg
from .nerf_net_utils import *
from .. import embedder
from . import if_clight_renderer_mmsk


class Renderer(if_clight_renderer_mmsk.Renderer):
    def __init__(self, net):
        super(Renderer, self).__init__(net)

    def prepare_inside_pts(self, wpts, batch):
        if 'R0_snap' not in batch:
            __import__('ipdb').set_trace()
            return raw

        # transform points from the world space to the smpl space
        Th = batch['Th']
        can_pts = wpts - Th[:, None, None]
        R = batch['R']
        can_pts = torch.matmul(can_pts, R)

        R0 = batch['R0_snap']
        Th0 = batch['Th0_snap']

        # transform pts from smpl coordinate to the world coordinate
        sh = can_pts.shape
        can_pts = can_pts.view(sh[0], -1, sh[3])
        pts = torch.matmul(can_pts, R0.transpose(2, 1)) + Th0[:, None]

        # project pts to image space
        R = batch['RT'][..., :3]
        T = batch['RT'][..., 3]
        pts = torch.matmul(pts, R.transpose(2, 1)) + T[:, None]
        pts = torch.matmul(pts, batch['K'].transpose(2, 1))
        pts2d = pts[..., :2] / pts[..., 2:]

        # ensure that pts2d is inside the image
        pts2d = pts2d.round().long()
        H, W = int(cfg.H * cfg.ratio), int(cfg.W * cfg.ratio)
        pts2d[..., 0] = torch.clamp(pts2d[..., 0], 0, W - 1)
        pts2d[..., 1] = torch.clamp(pts2d[..., 1], 0, H - 1)

        # remove the points outside the mask
        pts2d = pts2d[0]
        msk = batch['msk'][0]
        inside = msk[pts2d[:, 1], pts2d[:, 0]][None].bool()

        return inside