Spaces:
Sleeping
Sleeping
import torch | |
def gmof(x, sigma): | |
""" | |
Geman-McClure error function | |
""" | |
x_squared = x ** 2 | |
sigma_squared = sigma ** 2 | |
return (sigma_squared * x_squared) / (sigma_squared + x_squared) | |
def compute_jitter(x): | |
""" | |
Compute jitter for the input tensor | |
""" | |
return torch.linalg.norm(x[:, 2:] + x[:, :-2] - 2 * x[:, 1:-1], dim=-1) | |
class SMPLifyLoss(torch.nn.Module): | |
def __init__(self, | |
res, | |
cam_intrinsics, | |
init_pose, | |
device, | |
**kwargs | |
): | |
super().__init__() | |
self.res = res | |
self.cam_intrinsics = cam_intrinsics | |
self.init_pose = torch.from_numpy(init_pose).float().to(device) | |
def forward(self, output, params, input_keypoints, bbox, | |
reprojection_weight=100., regularize_weight=60.0, | |
consistency_weight=10.0, sprior_weight=0.04, | |
smooth_weight=20.0, sigma=100): | |
pose, shape, cam = params | |
scale = bbox[..., 2:].unsqueeze(-1) * 200. | |
# Loss 1. Data term | |
pred_keypoints = output.full_joints2d[..., :17, :] | |
joints_conf = input_keypoints[..., -1:] | |
reprojection_error = gmof(pred_keypoints - input_keypoints[..., :-1], sigma) | |
reprojection_error = ((reprojection_error * joints_conf) / scale).mean() | |
# Loss 2. Regularization term | |
regularize_error = torch.linalg.norm(pose - self.init_pose, dim=-1).mean() | |
# Loss 3. Shape prior and consistency error | |
consistency_error = shape.std(dim=1).mean() | |
sprior_error = torch.linalg.norm(shape, dim=-1).mean() | |
shape_error = sprior_weight * sprior_error + consistency_weight * consistency_error | |
# Loss 4. Smooth loss | |
pose_diff = compute_jitter(pose).mean() | |
cam_diff = compute_jitter(cam).mean() | |
smooth_error = pose_diff + cam_diff | |
# Sum up losses | |
loss = { | |
'reprojection': reprojection_weight * reprojection_error, | |
'regularize': regularize_weight * regularize_error, | |
'shape': shape_error, | |
'smooth': smooth_weight * smooth_error | |
} | |
return loss | |
def create_closure(self, | |
optimizer, | |
smpl, | |
params, | |
bbox, | |
input_keypoints): | |
def closure(): | |
optimizer.zero_grad() | |
output = smpl(*params, cam_intrinsics=self.cam_intrinsics, bbox=bbox, res=self.res) | |
loss_dict = self.forward(output, params, input_keypoints, bbox) | |
loss = sum(loss_dict.values()) | |
loss.backward() | |
return loss | |
return closure |