Techt3o's picture
ca5705cc9c8581d916aca37e6759c44f0b1e70429e49ce83e658a0517cd3d6fe
c87d1bc verified
raw
history blame
2.97 kB
import torch
def gmof(x, sigma):
"""
Geman-McClure error function
"""
x_squared = x ** 2
sigma_squared = sigma ** 2
return (sigma_squared * x_squared) / (sigma_squared + x_squared)
def compute_jitter(x):
"""
Compute jitter for the input tensor
"""
return torch.linalg.norm(x[:, 2:] + x[:, :-2] - 2 * x[:, 1:-1], dim=-1)
class SMPLifyLoss(torch.nn.Module):
def __init__(self,
res,
cam_intrinsics,
init_pose,
device,
**kwargs
):
super().__init__()
self.res = res
self.cam_intrinsics = cam_intrinsics
self.init_pose = torch.from_numpy(init_pose).float().to(device)
def forward(self, output, params, input_keypoints, bbox,
reprojection_weight=100., regularize_weight=60.0,
consistency_weight=10.0, sprior_weight=0.04,
smooth_weight=20.0, sigma=100):
pose, shape, cam = params
scale = bbox[..., 2:].unsqueeze(-1) * 200.
# Loss 1. Data term
pred_keypoints = output.full_joints2d[..., :17, :]
joints_conf = input_keypoints[..., -1:]
reprojection_error = gmof(pred_keypoints - input_keypoints[..., :-1], sigma)
reprojection_error = ((reprojection_error * joints_conf) / scale).mean()
# Loss 2. Regularization term
regularize_error = torch.linalg.norm(pose - self.init_pose, dim=-1).mean()
# Loss 3. Shape prior and consistency error
consistency_error = shape.std(dim=1).mean()
sprior_error = torch.linalg.norm(shape, dim=-1).mean()
shape_error = sprior_weight * sprior_error + consistency_weight * consistency_error
# Loss 4. Smooth loss
pose_diff = compute_jitter(pose).mean()
cam_diff = compute_jitter(cam).mean()
smooth_error = pose_diff + cam_diff
# Sum up losses
loss = {
'reprojection': reprojection_weight * reprojection_error,
'regularize': regularize_weight * regularize_error,
'shape': shape_error,
'smooth': smooth_weight * smooth_error
}
return loss
def create_closure(self,
optimizer,
smpl,
params,
bbox,
input_keypoints):
def closure():
optimizer.zero_grad()
output = smpl(*params, cam_intrinsics=self.cam_intrinsics, bbox=bbox, res=self.res)
loss_dict = self.forward(output, params, input_keypoints, bbox)
loss = sum(loss_dict.values())
loss.backward()
return loss
return closure