File size: 2,966 Bytes
c87d1bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch

def gmof(x, sigma):
    """

    Geman-McClure error function

    """
    x_squared = x ** 2
    sigma_squared = sigma ** 2
    return (sigma_squared * x_squared) / (sigma_squared + x_squared)


def compute_jitter(x):
    """

    Compute jitter for the input tensor

    """
    return torch.linalg.norm(x[:, 2:] + x[:, :-2] - 2 * x[:, 1:-1], dim=-1)


class SMPLifyLoss(torch.nn.Module):
    def __init__(self, 

                 res,

                 cam_intrinsics,

                 init_pose, 

                 device,

                 **kwargs

                 ):
        
        super().__init__()
        
        self.res = res
        self.cam_intrinsics = cam_intrinsics
        self.init_pose = torch.from_numpy(init_pose).float().to(device)
        
    def forward(self, output, params, input_keypoints, bbox, 

                reprojection_weight=100., regularize_weight=60.0, 

                consistency_weight=10.0, sprior_weight=0.04, 

                smooth_weight=20.0, sigma=100):
        
        pose, shape, cam = params
        scale = bbox[..., 2:].unsqueeze(-1) * 200.
        
        # Loss 1. Data term
        pred_keypoints = output.full_joints2d[..., :17, :]
        joints_conf = input_keypoints[..., -1:]
        reprojection_error = gmof(pred_keypoints - input_keypoints[..., :-1], sigma)
        reprojection_error = ((reprojection_error * joints_conf) / scale).mean()
        
        # Loss 2. Regularization term
        regularize_error = torch.linalg.norm(pose - self.init_pose, dim=-1).mean()
        
        # Loss 3. Shape prior and consistency error
        consistency_error = shape.std(dim=1).mean()
        sprior_error = torch.linalg.norm(shape, dim=-1).mean()
        shape_error = sprior_weight * sprior_error + consistency_weight * consistency_error
        
        # Loss 4. Smooth loss
        pose_diff = compute_jitter(pose).mean()
        cam_diff = compute_jitter(cam).mean()
        smooth_error = pose_diff + cam_diff
        
        # Sum up losses
        loss = {
            'reprojection': reprojection_weight * reprojection_error,
            'regularize': regularize_weight * regularize_error,
            'shape': shape_error,
            'smooth': smooth_weight * smooth_error
        }
        
        return loss
        
    def create_closure(self,

                       optimizer,

                       smpl, 

                       params,

                       bbox,

                       input_keypoints):
        
        def closure():
            optimizer.zero_grad()
            output = smpl(*params, cam_intrinsics=self.cam_intrinsics, bbox=bbox, res=self.res)
            
            loss_dict = self.forward(output, params, input_keypoints, bbox)
            loss = sum(loss_dict.values())
            loss.backward()
            return loss
        
        return closure