LAM / vhap /model /tracker.py
yuandong513
feat: init
17cd746
#
# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
# property and proprietary rights in and to this software and related documentation.
# Any commercial use, reproduction, disclosure or distribution of this software and
# related documentation without an express license agreement from Toyota Motor Europe NV/SA
# is strictly prohibited.
#
from vhap.config.base import import_module, PhotometricStageConfig, BaseTrackingConfig
from vhap.model.flame import FlameHead, FlameTexPCA, FlameTexPainted, FlameUvMask
from vhap.model.lbs import batch_rodrigues
from vhap.util.mesh import (
get_mtl_content,
get_obj_content,
normalize_image_points,
)
from vhap.util.log import get_logger
from vhap.util.visualization import plot_landmarks_2d
from torch.utils.tensorboard import SummaryWriter
import torch
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from matplotlib import cm
from typing import Literal
from functools import partial
import tyro
import yaml
from datetime import datetime
import threading
from typing import Optional
from collections import defaultdict
from copy import deepcopy
import time
import os
class FlameTracker:
def __init__(self, cfg: BaseTrackingConfig):
self.cfg = cfg
self.device = cfg.device
self.tb_writer = None
# model
self.flame = FlameHead(
cfg.model.n_shape,
cfg.model.n_expr,
add_teeth=cfg.model.add_teeth,
remove_lip_inside=cfg.model.remove_lip_inside,
face_clusters=cfg.model.tex_clusters,
).to(self.device)
if cfg.model.tex_painted:
self.flame_tex_painted = FlameTexPainted(tex_size=cfg.model.tex_resolution).to(self.device)
else:
self.flame_tex_pca = FlameTexPCA(cfg.model.n_tex, tex_size=cfg.model.tex_resolution).to(self.device)
self.flame_uvmask = FlameUvMask().to(self.device)
# renderer for visualization, dense photometric energy
if self.cfg.render.backend == 'nvdiffrast':
from vhap.util.render_nvdiffrast import NVDiffRenderer
self.render = NVDiffRenderer(
use_opengl=self.cfg.render.use_opengl,
lighting_type=self.cfg.render.lighting_type,
lighting_space=self.cfg.render.lighting_space,
disturb_rate_fg=self.cfg.render.disturb_rate_fg,
disturb_rate_bg=self.cfg.render.disturb_rate_bg,
fid2cid=self.flame.mask.fid2cid,
)
elif self.cfg.render.backend == 'pytorch3d':
from vhap.util.render_pytorch3d import PyTorch3DRenderer
self.render = PyTorch3DRenderer()
else:
raise NotImplementedError(f"Unknown renderer backend: {self.cfg.render.backend}")
def load_from_tracked_flame_params(self, fp):
"""
loads checkpoint from tracked_flame_params file. Counterpart to save_result()
:param fp:
:return:
"""
report = np.load(fp)
# LOADING PARAMETERS
def load_param(param, ckpt_array):
param.data[:] = torch.from_numpy(ckpt_array).to(param.device)
def load_param_list(param_list, ckpt_array):
for i in range(min(len(param_list), len(ckpt_array))):
load_param(param_list[i], ckpt_array[i])
load_param_list(self.rotation, report["rotation"])
load_param_list(self.translation, report["translation"])
load_param_list(self.neck_pose, report["neck_pose"])
load_param_list(self.jaw_pose, report["jaw_pose"])
load_param_list(self.eyes_pose, report["eyes_pose"])
load_param(self.shape, report["shape"])
load_param_list(self.expr, report["expr"])
load_param(self.lights, report["lights"])
# self.frame_idx = report["n_processed_frames"]
if not self.calibrated:
load_param(self.focal_length, report["focal_length"])
if not self.cfg.model.tex_painted:
if "tex" in report:
load_param(self.tex_pca, report["tex"])
else:
self.logger.warn("No tex_extra found in flame_params!")
if self.cfg.model.tex_extra:
if "tex_extra" in report:
load_param(self.tex_extra, report["tex_extra"])
else:
self.logger.warn("No tex_extra found in flame_params!")
if self.cfg.model.use_static_offset:
if "static_offset" in report:
load_param(self.static_offset, report["static_offset"])
else:
self.logger.warn("No static_offset found in flame_params!")
if self.cfg.model.use_dynamic_offset:
if "dynamic_offset" in report:
load_param_list(self.dynamic_offset, report["dynamic_offset"])
else:
self.logger.warn("No dynamic_offset found in flame_params!")
def trimmed_decays(self, is_init):
decays = {}
for k, v in self.decays.items():
if is_init and "init" in k or not is_init and "init" not in k:
decays[k.replace("_init", "")] = v
return decays
def clear_cache(self):
self.render.clear_cache()
def get_current_frame(self, frame_idx, include_keyframes=False):
"""
Creates a single item batch from the frame data at index frame_idx in the dataset.
If include_keyframes option is set, keyframe data will be appended to the batch. However,
it is guaranteed that the frame data belonging to frame_idx is at position 0
:param frame_idx:
:return:
"""
indices = [frame_idx]
if include_keyframes:
indices += self.cfg.exp.keyframes
samples = []
for idx in indices:
sample = self.dataset.getitem_by_timestep(idx)
# sample["timestep_index"] = idx
# for k, v in sample.items():
# if isinstance(v, torch.Tensor):
# sample[k] = v[None, ...].to(self.device)
samples.append(sample)
# if also keyframes have been loaded, stack all data
sample = {}
for k, v in samples[0].items():
values = [s[k] for s in samples]
if isinstance(v, torch.Tensor):
values = torch.cat(values, dim=0)
sample[k] = values
if "lmk2d_iris" in sample:
sample["lmk2d"] = torch.cat([sample["lmk2d"], sample["lmk2d_iris"]], dim=1)
return sample
def fill_cam_params_into_sample(self, sample):
"""
Adds intrinsics and extrinics to sample, if data is not calibrated
"""
if self.calibrated:
assert "intrinsic" in sample
assert "extrinsic" in sample
else:
b, _, h, w = sample["rgb"].shape
# K = torch.eye(3, 3).to(self.device)
# denormalize cam params
f = self.focal_length * max(h, w)
cx, cy = torch.tensor([[0.5*w], [0.5*h]]).to(f)
sample["intrinsic"] = torch.stack([f, f, cx, cy], dim=1)
sample["extrinsic"] = self.RT[None, ...].expand(b, -1, -1)
def configure_optimizer(self, params, lr_scale=1.0):
"""
Creates optimizer for the given set of parameters
:param params:
:return:
"""
# copy dict because we will call 'pop'
params = params.copy()
param_groups = []
default_lr = self.cfg.lr.base
# dict map group name to param dict keys
group_def = {
"translation": ["translation"],
"expr": ["expr"],
"light": ["lights"],
}
if not self.calibrated:
group_def ["cam"] = ["cam"]
if self.cfg.model.use_static_offset:
group_def ["static_offset"] = ["static_offset"]
if self.cfg.model.use_dynamic_offset:
group_def ["dynamic_offset"] = ["dynamic_offset"]
# dict map group name to lr
group_lr = {
"translation": self.cfg.lr.translation,
"expr": self.cfg.lr.expr,
"light": self.cfg.lr.light,
}
if not self.calibrated:
group_lr["cam"] = self.cfg.lr.camera
if self.cfg.model.use_static_offset:
group_lr["static_offset"] = self.cfg.lr.static_offset
if self.cfg.model.use_dynamic_offset:
group_lr["dynamic_offset"] = self.cfg.lr.dynamic_offset
for group_name, param_keys in group_def.items():
selected = []
for p in param_keys:
if p in params:
selected += params.pop(p)
if len(selected) > 0:
param_groups.append({"params": selected, "lr": group_lr[group_name] * lr_scale})
# create default group with remaining params
selected = []
for _, v in params.items():
selected += v
param_groups.append({"params": selected})
optim = torch.optim.Adam(param_groups, lr=default_lr * lr_scale)
return optim
def initialize_frame(self, frame_idx):
"""
Initializes parameters of frame frame_idx
:param frame_idx:
:return:
"""
if frame_idx > 0:
self.initialize_from_previous(frame_idx)
def initialize_from_previous(self, frame_idx):
"""
Initializes the flame parameters with the optimized ones from the previous frame
:param frame_idx:
:return:
"""
if frame_idx == 0:
return
param_list = [
self.expr,
self.neck_pose,
self.jaw_pose,
self.translation,
self.rotation,
self.eyes_pose,
]
for param in param_list:
param[frame_idx].data = param[frame_idx - 1].detach().clone().data
def select_frame_indices(self, frame_idx, include_keyframes):
indices = [frame_idx]
if include_keyframes:
indices += self.cfg.exp.keyframes
return indices
def forward_flame(self, frame_idx, include_keyframes):
"""
Evaluates the flame model using the given parameters
:param flame_params:
:return:
"""
indices = self.select_frame_indices(frame_idx, include_keyframes)
dynamic_offset = self.to_batch(self.dynamic_offset, indices) if self.cfg.model.use_dynamic_offset else None
ret = self.flame(
self.shape[None, ...].expand(len(indices), -1),
self.to_batch(self.expr, indices),
self.to_batch(self.rotation, indices),
self.to_batch(self.neck_pose, indices),
self.to_batch(self.jaw_pose, indices),
self.to_batch(self.eyes_pose, indices),
self.to_batch(self.translation, indices),
return_verts_cano=True,
static_offset=self.static_offset,
dynamic_offset=dynamic_offset,
)
verts, verts_cano, lmks = ret[0], ret[1], ret[2]
albedos = self.get_albedo().expand(len(indices), -1, -1, -1)
return verts, verts_cano, lmks, albedos
def get_base_texture(self):
if self.cfg.model.tex_extra and not self.cfg.model.residual_tex:
albedos_base = self.tex_extra[None, ...]
else:
if self.cfg.model.tex_painted:
albedos_base = self.flame_tex_painted()
else:
albedos_base = self.flame_tex_pca(self.tex_pca[None, :])
return albedos_base
def get_albedo(self):
albedos_base = self.get_base_texture()
if self.cfg.model.tex_extra and self.cfg.model.residual_tex:
albedos_res = self.tex_extra[None, :]
if albedos_base.shape[-1] != albedos_res.shape[-1] or albedos_base.shape[-2] != albedos_res.shape[-2]:
albedos_base = F.interpolate(albedos_base, albedos_res.shape[-2:], mode='bilinear')
albedos = albedos_base + albedos_res
else:
albedos = albedos_base
return albedos
def rasterize_flame(
self, sample, verts, faces, camera_index=None, train_mode=False
):
"""
Rasterizes the flame head mesh
:param verts:
:param albedos:
:param K:
:param RT:
:param resolution:
:param use_cache:
:return:
"""
# cameras parameters
K = sample["intrinsic"].clone().to(self.device)
RT = sample["extrinsic"].to(self.device)
if camera_index is not None:
K = K[[camera_index]]
RT = RT[[camera_index]]
H, W = self.image_size
image_size = H, W
# rasterize fragments
rast_dict = self.render.rasterize(verts, faces, RT, K, image_size, False, train_mode)
return rast_dict
@torch.no_grad()
def get_background_color(self, gt_rgb, gt_alpha, stage):
if stage is None: # when stage is None, it means we are in the evaluation mode
background = self.cfg.render.background_eval
else:
background = self.cfg.render.background_train
if background == 'target':
"""use gt_rgb as background"""
color = gt_rgb.permute(0, 2, 3, 1)
elif background == 'white':
color = [1, 1, 1]
elif background == 'black':
color = [0, 0, 0]
else:
raise NotImplementedError(f"Unknown background mode: {background}")
return color
def render_rgba(
self, rast_dict, verts, faces, albedos, lights, background_color=[1, 1, 1],
align_texture_except_fid=None, align_boundary_except_vid=None, enable_disturbance=False,
):
"""
Renders the rgba image from the rasterization result and
the optimized texture + lights
"""
faces_uv = self.flame.textures_idx
if self.cfg.render.backend == 'nvdiffrast':
verts_uv = self.flame.verts_uvs.clone()
verts_uv[:, 1] = 1 - verts_uv[:, 1]
tex = albedos
render_out = self.render.render_rgba(
rast_dict, verts, faces, verts_uv, faces_uv, tex, lights, background_color,
align_texture_except_fid, align_boundary_except_vid, enable_disturbance
)
render_out = {k: v.permute(0, 3, 1, 2) for k, v in render_out.items()}
elif self.cfg.render.backend == 'pytorch3d':
B = verts.shape[0] # TODO: double check
verts_uv = self.flame.face_uvcoords.repeat(B, 1, 1)
tex = albedos.expand(B, -1, -1, -1)
rgba = self.render.render_rgba(
rast_dict, verts, faces, verts_uv, faces_uv, tex, lights, background_color
)
render_out = {'rgba': rgba.permute(0, 3, 1, 2)}
else:
raise NotImplementedError(f"Unknown renderer backend: {self.cfg.render.backend}")
return render_out
def render_normal(self, rast_dict, verts, faces):
"""
Renders the rgba image from the rasterization result and
the optimized texture + lights
"""
uv_coords = self.flame.face_uvcoords
uv_coords = uv_coords.repeat(verts.shape[0], 1, 1)
return self.render.render_normal(rast_dict, verts, faces, uv_coords)
def compute_lmk_energy(self, sample, pred_lmks, disable_jawline_landmarks=False):
"""
Computes the landmark energy loss term between groundtruth landmarks and flame landmarks
:param sample:
:param pred_lmks:
:return: the lmk loss for all 68 facial landmarks, a separate 2 pupil landmark loss and
a relative eye close term
"""
img_size = sample["rgb"].shape[-2:]
# ground-truth landmark
lmk2d = sample["lmk2d"].clone().to(pred_lmks)
lmk2d, confidence = lmk2d[:, :, :2], lmk2d[:, :, 2]
lmk2d[:, :, 0], lmk2d[:, :, 1] = normalize_image_points(
lmk2d[:, :, 0], lmk2d[:, :, 1], img_size
)
# predicted landmark
K = sample["intrinsic"].to(self.device)
RT = sample["extrinsic"].to(self.device)
pred_lmk_ndc = self.render.world_to_ndc(pred_lmks, RT, K, img_size, flip_y=True)
pred_lmk2d = pred_lmk_ndc[:, :, :2]
if (lmk2d.shape[1] == 70):
diff = lmk2d - pred_lmk2d
confidence = confidence[:, :70]
# eyes weighting
confidence[:, 68:] = confidence[:, 68:] * 2
else:
diff = lmk2d[:, :68] - pred_lmk2d[:, :68]
confidence = confidence[:, :68]
# compute general landmark term
lmk_loss = torch.norm(diff, dim=2, p=1) * confidence
result_dict = {
"gt_lmk2d": lmk2d,
"pred_lmk2d": pred_lmk2d,
}
return lmk_loss.mean(), result_dict
def compute_photometric_energy(
self,
sample,
verts,
faces,
albedos,
rast_dict,
step_i=None,
stage=None,
include_keyframes=False,
):
"""
Computes the dense photometric energy
:param sample:
:param vertices:
:param albedos:
:return:
"""
gt_rgb = sample["rgb"].to(verts)
if "alpha" in sample:
gt_alpha = sample["alpha_map"].to(verts)
else:
gt_alpha = None
lights = self.lights[None] if self.lights is not None else None
bg_color = self.get_background_color(gt_rgb, gt_alpha, stage)
align_texture_except_fid = self.flame.mask.get_fid_by_region(
self.cfg.pipeline[stage].align_texture_except
) if stage is not None else None
align_boundary_except_vid = self.flame.mask.get_vid_by_region(
self.cfg.pipeline[stage].align_boundary_except
) if stage is not None else None
render_out = self.render_rgba(
rast_dict, verts, faces, albedos, lights, bg_color,
align_texture_except_fid, align_boundary_except_vid,
enable_disturbance=stage!=None,
)
pred_rgb = render_out['rgba'][:, :3]
pred_alpha = render_out['rgba'][:, 3:]
pred_mask = render_out['rgba'][:, [3]].detach() > 0
pred_mask = pred_mask.expand(-1, 3, -1, -1)
results_dict = render_out
# ---- rgb loss ----
error_rgb = gt_rgb - pred_rgb
color_loss = error_rgb.abs().sum() / pred_mask.detach().sum()
results_dict.update(
{
"gt_rgb": gt_rgb,
"pred_rgb": pred_rgb,
"error_rgb": error_rgb,
"pred_alpha": pred_alpha,
}
)
# ---- silhouette loss ----
# error_alpha = gt_alpha - pred_alpha
# mask_loss = error_alpha.abs().sum()
# results_dict.update(
# {
# "gt_alpha": gt_alpha,
# "error_alpha": error_alpha,
# }
# )
# ---- background loss ----
# bg_mask = gt_alpha < 0.5
# error_alpha = gt_alpha - pred_alpha
# error_alpha = torch.where(bg_mask, error_alpha, torch.zeros_like(error_alpha))
# mask_loss = error_alpha.abs().sum() / bg_mask.sum()
# results_dict.update(
# {
# "gt_alpha": gt_alpha,
# "error_alpha": error_alpha,
# }
# )
# --------
# photo_loss = color_loss + mask_loss
photo_loss = color_loss
# photo_loss = mask_loss
return photo_loss, results_dict
def compute_regularization_energy(self, result_dict, verts, verts_cano, lmks, albedos, frame_idx, include_keyframes, stage):
"""
Computes the energy term that penalizes strong deviations from the flame base model
"""
log_dict = {}
std_tex = 1
std_expr = 1
std_shape = 1
indices = self.select_frame_indices(frame_idx, include_keyframes)
# pose smoothness term
if self.opt_dict['pose'] and 'tracking' in stage:
E_pose_smooth = self.compute_pose_smooth_energy(frame_idx, stage=='global_tracking')
log_dict["pose_smooth"] = E_pose_smooth
# joint regularization term
if self.opt_dict['joints']:
if 'tracking' in stage:
joint_smooth = self.compute_joint_smooth_energy(frame_idx, stage=='global_tracking')
log_dict["joint_smooth"] = joint_smooth
joint_prior = self.compute_joint_prior_energy(frame_idx)
log_dict["joint_prior"] = joint_prior
# expression regularization
if self.opt_dict['expr']:
expr = self.to_batch(self.expr, indices)
reg_expr = (expr / std_expr) ** 2
log_dict["reg_expr"] = self.cfg.w.reg_expr * reg_expr.mean()
# shape regularization
if self.opt_dict['shape']:
reg_shape = (self.shape / std_shape) ** 2
log_dict["reg_shape"] = self.cfg.w.reg_shape * reg_shape.mean()
# texture regularization
if self.opt_dict['texture']:
# texture space
if not self.cfg.model.tex_painted:
reg_tex_pca = (self.tex_pca / std_tex) ** 2
log_dict["reg_tex_pca"] = self.cfg.w.reg_tex_pca * reg_tex_pca.mean()
# texture map
if self.cfg.model.tex_extra:
if self.cfg.model.residual_tex:
if self.cfg.w.reg_tex_res is not None:
reg_tex_res = self.tex_extra ** 2
# reg_tex_res = self.tex_extra.abs() # L1 loss can create noise textures
# if len(self.cfg.model.occluded) > 0:
# mask = (~self.flame_uvmask.get_uvmask_by_region(self.cfg.model.occluded)).float()[None, ...]
# reg_tex_res *= mask
log_dict["reg_tex_res"] = self.cfg.w.reg_tex_res * reg_tex_res.mean()
if self.cfg.w.reg_tex_tv is not None:
tex = self.get_albedo()[0] # (3, H, W)
tv_y = (tex[..., :-1, :] - tex[..., 1:, :]) ** 2
tv_x = (tex[..., :, :-1] - tex[..., :, 1:]) ** 2
tv = tv_y.reshape(tv_y.shape[0], -1) + tv_x.reshape(tv_x.shape[0], -1)
w_reg_tex_tv = self.cfg.w.reg_tex_tv * self.cfg.data.scale_factor ** 2
if self.cfg.data.n_downsample_rgb is not None:
w_reg_tex_tv /= (self.cfg.data.n_downsample_rgb ** 2)
log_dict["reg_tex_tv"] = w_reg_tex_tv * tv.mean()
if self.cfg.w.reg_tex_res_clusters is not None:
mask_sclerae = self.flame_uvmask.get_uvmask_by_region(self.cfg.w.reg_tex_res_for)[None, :, :]
reg_tex_res_clusters = self.tex_extra ** 2 * mask_sclerae
log_dict["reg_tex_res_clusters"] = self.cfg.w.reg_tex_res_clusters * reg_tex_res_clusters.mean()
# lighting parameters regularization
if self.opt_dict['lights']:
if self.cfg.w.reg_light is not None and self.lights is not None:
reg_light = (self.lights - self.lights_uniform) ** 2
log_dict["reg_light"] = self.cfg.w.reg_light * reg_light.mean()
if self.cfg.w.reg_diffuse is not None and self.lights is not None:
diffuse = result_dict['diffuse_detach_normal']
reg_diffuse = F.relu(diffuse.max() - 1) + diffuse.var(dim=1).mean()
log_dict["reg_diffuse"] = self.cfg.w.reg_diffuse * reg_diffuse
# offset regularization
if self.opt_dict['static_offset'] or self.opt_dict['dynamic_offset']:
if self.static_offset is not None or self.dynamic_offset is not None:
offset = 0
if self.static_offset is not None:
offset += self.static_offset
if self.dynamic_offset is not None:
offset += self.to_batch(self.dynamic_offset, indices)
if self.cfg.w.reg_offset_lap is not None:
# laplacian loss
vert_wo_offset = (verts_cano - offset).detach()
reg_offset_lap = self.compute_laplacian_smoothing_loss(
vert_wo_offset, vert_wo_offset + offset
)
if len(self.cfg.w.reg_offset_lap_relax_for) > 0:
w = self.scale_vertex_weights_by_region(
weights=torch.ones_like(verts[:, :, :1]),
scale_factor=self.cfg.w.reg_offset_lap_relax_coef,
region=self.cfg.w.reg_offset_lap_relax_for,
)
reg_offset_lap *= w
log_dict["reg_offset_lap"] = self.cfg.w.reg_offset_lap * reg_offset_lap.mean()
if self.cfg.w.reg_offset is not None:
# norm loss
# reg_offset = offset.norm(dim=-1, keepdim=True)
reg_offset = offset.abs()
if len(self.cfg.w.reg_offset_relax_for) > 0:
w = self.scale_vertex_weights_by_region(
weights=torch.ones_like(verts[:, :, :1]),
scale_factor=self.cfg.w.reg_offset_relax_coef,
region=self.cfg.w.reg_offset_relax_for,
)
reg_offset *= w
log_dict["reg_offset"] = self.cfg.w.reg_offset * reg_offset.mean()
if self.cfg.w.reg_offset_rigid is not None:
reg_offset_rigid = 0
for region in self.cfg.w.reg_offset_rigid_for:
vids = self.flame.mask.get_vid_by_region([region])
reg_offset_rigid += offset[:, vids, :].var(dim=-2).mean()
log_dict["reg_offset_rigid"] = self.cfg.w.reg_offset_rigid * reg_offset_rigid
if self.cfg.w.reg_offset_dynamic is not None and self.dynamic_offset is not None and self.opt_dict['dynamic_offset']:
# The dynamic offset is regularized to be temporally smooth
if frame_idx == 0:
reg_offset_d = torch.zeros_like(self.dynamic_offset[0])
offset_d = self.dynamic_offset[0]
else:
reg_offset_d = torch.stack([self.dynamic_offset[0], self.dynamic_offset[frame_idx - 1]])
offset_d = self.dynamic_offset[frame_idx]
reg_offset_dynamic = ((offset_d - reg_offset_d) ** 2).mean()
log_dict["reg_offset_dynamic"] = self.cfg.w.reg_offset_dynamic * reg_offset_dynamic
return log_dict
def scale_vertex_weights_by_region(self, weights, scale_factor, region):
indices = self.flame.mask.get_vid_by_region(region)
weights[:, indices] *= scale_factor
for _ in range(self.cfg.w.blur_iter):
M = self.flame.laplacian_matrix_negate_diag[None, ...]
weights = M.bmm(weights) / 2
return weights
def compute_pose_smooth_energy(self, frame_idx, use_next_frame=False):
"""
Regularizes the global pose of the flame head model to be temporally smooth
"""
idx = frame_idx
idx_prev = np.clip(idx - 1, 0, self.n_timesteps - 1)
if use_next_frame:
idx_next = np.clip(idx + 1, 0, self.n_timesteps - 1)
ref_indices = [idx_prev, idx_next]
else:
ref_indices = [idx_prev]
E_trans = ((self.translation[[idx]] - self.translation[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_trans
E_rot = ((self.rotation[[idx]] - self.rotation[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_rot
return E_trans + E_rot
def compute_joint_smooth_energy(self, frame_idx, use_next_frame=False):
"""
Regularizes the joints of the flame head model to be temporally smooth
"""
idx = frame_idx
idx_prev = np.clip(idx - 1, 0, self.n_timesteps - 1)
if use_next_frame:
idx_next = np.clip(idx + 1, 0, self.n_timesteps - 1)
ref_indices = [idx_prev, idx_next]
else:
ref_indices = [idx_prev]
E_joint_smooth = 0
E_joint_smooth += ((self.neck_pose[[idx]] - self.neck_pose[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_neck
E_joint_smooth += ((self.jaw_pose[[idx]] - self.jaw_pose[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_jaw
E_joint_smooth += ((self.eyes_pose[[idx]] - self.eyes_pose[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_eyes
return E_joint_smooth
def compute_joint_prior_energy(self, frame_idx):
"""
Regularizes the joints of the flame head model towards neutral joint locations
"""
poses = [
("neck", self.neck_pose[[frame_idx], :]),
("jaw", self.jaw_pose[[frame_idx], :]),
("eyes", self.eyes_pose[[frame_idx], :3]),
("eyes", self.eyes_pose[[frame_idx], 3:]),
]
# Joints should are regularized towards neural
E_joint_prior = 0
for name, pose in poses:
# L2 regularization for each joint
rotmats = batch_rodrigues(torch.cat([torch.zeros_like(pose), pose], dim=0))
diff = ((rotmats[[0]] - rotmats[1:]) ** 2).mean()
# Additional regularization for physical plausibility
if name == 'jaw':
# penalize negative rotation along x axis of jaw
diff += F.relu(-pose[:, 0]).mean() * 10
# penalize rotation along y and z axis of jaw
diff += (pose[:, 1:] ** 2).mean() * 3
elif name == 'eyes':
# penalize the difference between the two eyes
diff += ((self.eyes_pose[[frame_idx], :3] - self.eyes_pose[[frame_idx], 3:]) ** 2).mean()
E_joint_prior += diff * self.cfg.w[f"prior_{name}"]
return E_joint_prior
def compute_laplacian_smoothing_loss(self, verts, offset_verts):
L = self.flame.laplacian_matrix[None, ...].detach() # (1, V, V)
basis_lap = L.bmm(verts).detach() #.norm(dim=-1) * weights
offset_lap = L.bmm(offset_verts) #.norm(dim=-1) # * weights
diff = (offset_lap - basis_lap) ** 2
diff = diff.sum(dim=-1, keepdim=True)
return diff
def compute_energy(
self,
sample,
frame_idx,
include_keyframes=False,
step_i=None,
stage=None,
):
"""
Compute total energy for frame frame_idx
:param sample:
:param frame_idx:
:param include_keyframes: if key frames shall be included when predicting the per
frame energy
:return: loss, log dict, predicted vertices and landmarks
"""
log_dict = {}
gt_rgb = sample["rgb"]
result_dict = {"gt_rgb": gt_rgb}
verts, verts_cano, lmks, albedos = self.forward_flame(frame_idx, include_keyframes)
faces = self.flame.faces
if isinstance(sample["num_cameras"], list):
num_cameras = sample["num_cameras"][0]
else:
num_cameras = sample["num_cameras"]
# albedos = self.repeat_n_times(albedos, num_cameras) # only needed for pytorch3d renderer
if self.cfg.w.landmark is not None:
lmks_n = self.repeat_n_times(lmks, num_cameras)
if not self.cfg.w.always_enable_jawline_landmarks and stage is not None:
disable_jawline_landmarks = self.cfg.pipeline[stage]['disable_jawline_landmarks']
else:
disable_jawline_landmarks = False
E_lmk, _result_dict = self.compute_lmk_energy(sample, lmks_n, disable_jawline_landmarks)
log_dict["lmk"] = self.cfg.w.landmark * E_lmk
result_dict.update(_result_dict)
if stage is None or isinstance(self.cfg.pipeline[stage], PhotometricStageConfig):
if self.cfg.w.photo is not None:
verts_n = self.repeat_n_times(verts, num_cameras)
rast_dict = self.rasterize_flame(
sample, verts_n, self.flame.faces, train_mode=True
)
photo_energy_func = self.compute_photometric_energy
E_photo, _result_dict = photo_energy_func(
sample,
verts,
faces,
albedos,
rast_dict,
step_i,
stage,
include_keyframes,
)
result_dict.update(_result_dict)
log_dict["photo"] = self.cfg.w.photo * E_photo
if stage is not None:
_log_dict = self.compute_regularization_energy(
result_dict, verts, verts_cano, lmks, albedos, frame_idx, include_keyframes, stage
)
log_dict.update(_log_dict)
E_total = torch.stack([v for k, v in log_dict.items()]).sum()
log_dict["total"] = E_total
return E_total, log_dict, verts, faces, lmks, albedos, result_dict
@staticmethod
def to_batch(x, indices):
return torch.stack([x[i] for i in indices])
@staticmethod
def repeat_n_times(x: torch.Tensor, n: int):
"""Expand a tensor from shape [F, ...] to [F*n, ...]"""
return x.unsqueeze(1).repeat_interleave(n, dim=1).reshape(-1, *x.shape[1:])
@torch.no_grad()
def log_scalars(
self,
log_dict,
frame_idx,
session: Literal["train", "eval"] = "train",
stage=None,
frame_step=None,
# step_in_stage=None,
):
"""
Logs scalars in log_dict to tensorboard and self.logger
:param log_dict:
:param frame_idx:
:param step_i:
:return:
"""
if not self.calibrated and stage is not None and 'cam' in self.cfg.pipeline[stage].optimizable_params:
log_dict["focal_length"] = self.focal_length.squeeze(0)
log_msg = ""
if session == "train":
global_step = self.global_step
else:
global_step = frame_idx
for k, v in log_dict.items():
if not k.startswith("decay"):
log_msg += "{}: {:.4f} ".format(k, v)
if self.tb_writer is not None:
self.tb_writer.add_scalar(f"{session}/{k}", v, global_step)
if session == "train":
assert stage is not None
if frame_step is not None:
msg_prefix = f"[{session}-{stage}] frame {frame_idx} step {frame_step}: "
else:
msg_prefix = f"[{session}-{stage}] frame {frame_idx} step {self.global_step}: "
elif session == "eval":
msg_prefix = f"[{session}] frame {frame_idx}: "
self.logger.info(msg_prefix + log_msg)
def save_obj_with_texture(self, vertices, faces, uv_coordinates, uv_indices, albedos, obj_path, mtl_path, texture_path):
# Save the texture image
torchvision.utils.save_image(albedos.squeeze(0), texture_path)
# Create the MTL file
with open(mtl_path, 'w') as f:
f.write(get_mtl_content(texture_path.name))
# Create the obj file
with open(obj_path, 'w') as f:
f.write(get_obj_content(vertices, faces, uv_coordinates, uv_indices, mtl_path.name))
def async_func(func):
"""Decorator to run a function asynchronously"""
def wrapper(*args, **kwargs):
self = args[0]
if self.cfg.async_func:
thread = threading.Thread(target=func, args=args, kwargs=kwargs)
thread.start()
else:
func(*args, **kwargs)
return wrapper
@torch.no_grad()
@async_func
def log_media(
self,
verts: torch.tensor,
faces: torch.tensor,
lmks: torch.tensor,
albedos: torch.tensor,
output_dict: dict,
sample: dict,
frame_idx: int,
session: str,
stage: Optional[str]=None,
frame_step: int=None,
epoch=None,
):
"""
Logs current tracking visualization to tensorboard
:param verts:
:param lmks:
:param sample:
:param frame_idx:
:param frame_step:
:param show_lmks:
:param show_overlay:
:return:
"""
tic = time.time()
prepare_output_path = partial(
self.prepare_output_path,
session=session,
frame_idx=frame_idx,
stage=stage,
step=frame_step,
epoch=epoch,
)
"""images"""
if not self.cfg.w.always_enable_jawline_landmarks and stage is not None:
disable_jawline_landmarks = self.cfg.pipeline[stage]['disable_jawline_landmarks']
else:
disable_jawline_landmarks = False
img = self.visualize_tracking(verts, lmks, albedos, output_dict, sample, disable_jawline_landmarks=disable_jawline_landmarks)
img_path = prepare_output_path(folder_name="image_grid", file_type=self.cfg.log.image_format)
torchvision.utils.save_image(img, img_path)
"""meshes"""
texture_path = prepare_output_path(folder_name="mesh", file_type=self.cfg.log.image_format)
mtl_path = prepare_output_path(folder_name="mesh", file_type="mtl")
obj_path = prepare_output_path(folder_name="mesh", file_type="obj")
vertices = verts.squeeze(0).detach().cpu().numpy()
faces = faces.detach().cpu().numpy()
uv_coordinates = self.flame.verts_uvs.cpu().numpy()
uv_indices = self.flame.textures_idx.cpu().numpy()
self.save_obj_with_texture(vertices, faces, uv_coordinates, uv_indices, albedos, obj_path, mtl_path, texture_path)
""""""
toc = time.time() - tic
if stage is not None:
msg_prefix = f"[{session}-{stage}] frame {frame_idx}"
else:
msg_prefix = f"[{session}] frame {frame_idx}"
if frame_step is not None:
msg_prefix += f" step {frame_step}"
self.logger.info(f"{msg_prefix}: Logging media took {toc:.2f}s")
@torch.no_grad()
def visualize_tracking(
self,
verts,
lmks,
albedos,
output_dict,
sample,
return_imgs_seperately=False,
disable_jawline_landmarks=False,
):
"""
Visualizes the tracking result
"""
if len(self.cfg.log.view_indices) > 0:
view_indices = torch.tensor(self.cfg.log.view_indices)
else:
num_views = sample["rgb"].shape[0]
if num_views > 1:
step = (num_views - 1) // (self.cfg.log.max_num_views - 1)
view_indices = torch.arange(0, num_views, step=step)
else:
view_indices = torch.tensor([0])
num_views_log = len(view_indices)
imgs = []
# rgb
gt_rgb = output_dict["gt_rgb"][view_indices].cpu()
transfm = torchvision.transforms.Resize(gt_rgb.shape[-2:])
imgs += [img[None] for img in gt_rgb]
if "pred_rgb" in output_dict:
pred_rgb = transfm(output_dict["pred_rgb"][view_indices].cpu())
pred_rgb = torch.clip(pred_rgb, min=0, max=1)
imgs += [img[None] for img in pred_rgb]
if "error_rgb" in output_dict:
error_rgb = transfm(output_dict["error_rgb"][view_indices].cpu())
error_rgb = error_rgb.mean(dim=1) / 2 + 0.5
cmap = cm.get_cmap("seismic")
error_rgb = cmap(error_rgb.cpu())
error_rgb = torch.from_numpy(error_rgb[..., :3]).to(gt_rgb).permute(0, 3, 1, 2)
imgs += [img[None] for img in error_rgb]
# cluster id
if "cid" in output_dict:
cid = transfm(output_dict["cid"][view_indices].cpu())
cid = cid / cid.max()
cid = cid.expand(-1, 3, -1, -1).clone()
pred_alpha = transfm(output_dict["pred_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1)
bg = pred_alpha == 0
cid[bg] = 1
imgs += [img[None] for img in cid]
# albedo
if "albedo" in output_dict:
albedo = transfm(output_dict["albedo"][view_indices].cpu())
albedo = torch.clip(albedo, min=0, max=1)
pred_alpha = transfm(output_dict["pred_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1)
bg = pred_alpha == 0
albedo[bg] = 1
imgs += [img[None] for img in albedo]
# normal
if "normal" in output_dict:
normal = transfm(output_dict["normal"][view_indices].cpu())
normal = torch.clip(normal/2+0.5, min=0, max=1)
imgs += [img[None] for img in normal]
# diffuse
diffuse = None
if self.cfg.render.lighting_type != 'constant' and "diffuse" in output_dict:
diffuse = transfm(output_dict["diffuse"][view_indices].cpu())
diffuse = torch.clip(diffuse, min=0, max=1)
imgs += [img[None] for img in diffuse]
# aa
if "aa" in output_dict:
aa = transfm(output_dict["aa"][view_indices].cpu())
aa = torch.clip(aa, min=0, max=1)
imgs += [img[None] for img in aa]
# alpha
if "gt_alpha" in output_dict:
gt_alpha = transfm(output_dict["gt_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1)
imgs += [img[None] for img in gt_alpha]
if "pred_alpha" in output_dict:
pred_alpha = transfm(output_dict["pred_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1)
color_alpha = torch.tensor([0.2, 0.5, 1])[None, :, None, None]
fg_mask = (pred_alpha > 0).float()
if diffuse is not None:
fg_mask *= diffuse
w = 0.7
overlay_alpha = fg_mask * (w * color_alpha * pred_alpha + (1-w) * gt_rgb) \
+ (1 - fg_mask) * gt_rgb
imgs += [img[None] for img in overlay_alpha]
if "error_alpha" in output_dict:
error_alpha = transfm(output_dict["error_alpha"][view_indices].cpu())
error_alpha = error_alpha.mean(dim=1) / 2 + 0.5
cmap = cm.get_cmap("seismic")
error_alpha = cmap(error_alpha.cpu())
error_alpha = (
torch.from_numpy(error_alpha[..., :3]).to(gt_rgb).permute(0, 3, 1, 2)
)
imgs += [img[None] for img in error_alpha]
else:
error_alpha = None
# landmark
vis_lmk = self.visualize_landmarks(gt_rgb, output_dict, view_indices, disable_jawline_landmarks)
if vis_lmk is not None:
imgs += [img[None] for img in vis_lmk]
# ----------------
num_types = len(imgs) // len(view_indices)
if return_imgs_seperately:
return imgs
else:
if self.cfg.log.stack_views_in_rows:
imgs = [imgs[j * num_views_log + i] for i in range(num_views_log) for j in range(num_types)]
imgs = torch.cat(imgs, dim=0).cpu()
return torchvision.utils.make_grid(imgs, nrow=num_types)
else:
imgs = torch.cat(imgs, dim=0).cpu()
return torchvision.utils.make_grid(imgs, nrow=num_views_log)
@torch.no_grad()
def visualize_landmarks(self, gt_rgb, output_dict, view_indices=torch.tensor([0]), disable_jawline_landmarks=False):
h, w = gt_rgb.shape[-2:]
unit = h / 750
wh = torch.tensor([[[w, h]]])
vis_lmk = None
if "gt_lmk2d" in output_dict:
gt_lmk2d = (output_dict['gt_lmk2d'][view_indices].cpu() * 0.5 + 0.5) * wh
if disable_jawline_landmarks:
gt_lmk2d = gt_lmk2d[:, 17:68]
else:
gt_lmk2d = gt_lmk2d[:, :68]
vis_lmk = gt_rgb.clone() if vis_lmk is None else vis_lmk
for i in range(len(view_indices)):
vis_lmk[i] = plot_landmarks_2d(
vis_lmk[i].clone(),
gt_lmk2d[[i]],
colors="green",
unit=unit,
input_float=True,
).to(vis_lmk[i])
if "pred_lmk2d" in output_dict:
pred_lmk2d = (output_dict['pred_lmk2d'][view_indices].cpu() * 0.5 + 0.5) * wh
if disable_jawline_landmarks:
pred_lmk2d = pred_lmk2d[:, 17:68]
else:
pred_lmk2d = pred_lmk2d[:, :68]
vis_lmk = gt_rgb.clone() if vis_lmk is None else vis_lmk
for i in range(len(view_indices)):
vis_lmk[i] = plot_landmarks_2d(
vis_lmk[i].clone(),
pred_lmk2d[[i]],
colors="red",
unit=unit,
input_float=True,
).to(vis_lmk[i])
return vis_lmk
@torch.no_grad()
def evaluate(self, make_visualization=True, epoch=0):
# always save parameters before evaluation
self.save_result(epoch=epoch)
self.logger.info("Started Evaluation")
# vid_frames = []
photo_loss = []
for frame_idx in range(self.n_timesteps):
sample = self.get_current_frame(frame_idx, include_keyframes=False)
self.clear_cache()
self.fill_cam_params_into_sample(sample)
(
E_total,
log_dict,
verts,
faces,
lmks,
albedos,
output_dict,
) = self.compute_energy(sample, frame_idx)
self.log_scalars(log_dict, frame_idx, session="eval")
photo_loss.append(log_dict["photo"].item())
if make_visualization:
self.log_media(
verts,
faces,
lmks,
albedos,
output_dict,
sample,
frame_idx,
session="eval",
epoch=epoch,
)
self.tb_writer.add_scalar(f"eval_mean/photo", np.mean(photo_loss), epoch)
def prepare_output_path(self, session, frame_idx, folder_name, file_type, stage=None, step=None, epoch=None):
if epoch is not None:
output_folder = self.out_dir / f'{session}_{epoch}' / folder_name
else:
output_folder = self.out_dir / session / folder_name
os.makedirs(output_folder, exist_ok=True)
if stage is not None:
assert step is not None
fname = "frame_{:05d}_{:03d}_{}.{}".format(frame_idx, step, stage, file_type)
else:
fname = "frame_{:05d}.{}".format(frame_idx, file_type)
return output_folder / fname
def save_result(self, fname=None, epoch=None):
"""
Saves tracked/optimized flame parameters.
:return:
"""
# save parameters
keys = [
"rotation",
"translation",
"neck_pose",
"jaw_pose",
"eyes_pose",
"shape",
"expr",
"timestep_id",
"n_processed_frames",
]
values = [
self.rotation,
self.translation,
self.neck_pose,
self.jaw_pose,
self.eyes_pose,
self.shape,
self.expr,
np.array(self.dataset.timestep_ids),
self.frame_idx,
]
if not self.calibrated:
keys += ["focal_length"]
values += [self.focal_length]
if not self.cfg.model.tex_painted:
keys += ["tex"]
values += [self.tex_pca]
if self.cfg.model.tex_extra:
keys += ["tex_extra"]
values += [self.tex_extra]
if self.lights is not None:
keys += ["lights"]
values += [self.lights]
if self.cfg.model.use_static_offset:
keys += ["static_offset"]
values += [self.static_offset]
if self.cfg.model.use_dynamic_offset:
keys += ["dynamic_offset"]
values += [self.dynamic_offset]
export_dict = {}
for k, v in zip(keys, values):
if not isinstance(v, np.ndarray):
if isinstance(v, list):
v = torch.stack(v)
if isinstance(v, torch.Tensor):
v = v.detach().cpu().numpy()
export_dict[k] = v
export_dict["image_size"] = np.array(self.image_size)
fname = fname if fname is not None else "tracked_flame_params"
if epoch is not None:
fname = f"{fname}_{epoch}"
np.savez(self.out_dir / f'{fname}.npz', **export_dict)
class GlobalTracker(FlameTracker):
def __init__(self, cfg: BaseTrackingConfig):
super().__init__(cfg)
self.calibrated = cfg.data.calibrated
# logging
out_dir = cfg.exp.output_folder / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
out_dir.mkdir(parents=True,exist_ok=True)
self.frame_idx = self.cfg.begin_frame_idx
self.out_dir = out_dir
self.tb_writer = SummaryWriter(self.out_dir)
self.log_interval_scalar = self.cfg.log.interval_scalar
self.log_interval_media = self.cfg.log.interval_media
config_yaml_path = out_dir / 'config.yml'
config_yaml_path.write_text(yaml.dump(cfg), "utf8")
print(tyro.to_yaml(cfg))
self.logger = get_logger(__name__, root=True, log_dir=out_dir)
# data
self.dataset = import_module(cfg.data._target)(
cfg=cfg.data,
img_to_tensor=True,
batchify_all_views=True, # important to optimized all views together
)
# FlameTracker expects all views of a frame in a batch, which is undertaken by the
# dataset. Therefore batching is disabled for the dataloader
self.image_size = self.dataset[0]["rgb"].shape[-2:]
self.n_timesteps = len(self.dataset)
# parameters
self.init_params()
if self.cfg.model.flame_params_path is not None:
self.load_from_tracked_flame_params(self.cfg.model.flame_params_path)
def init_params(self):
train_tensors = []
# flame model params
self.shape = torch.zeros(self.cfg.model.n_shape).to(self.device)
self.expr = torch.zeros(self.n_timesteps, self.cfg.model.n_expr).to(self.device)
# joint axis angles
self.neck_pose = torch.zeros(self.n_timesteps, 3).to(self.device)
self.jaw_pose = torch.zeros(self.n_timesteps, 3).to(self.device)
self.eyes_pose = torch.zeros(self.n_timesteps, 6).to(self.device)
# rigid pose
self.translation = torch.zeros(self.n_timesteps, 3).to(self.device)
self.rotation = torch.zeros(self.n_timesteps, 3).to(self.device)
# texture and lighting params
self.tex_pca = torch.zeros(self.cfg.model.n_tex).to(self.device)
if self.cfg.model.tex_extra:
res = self.cfg.model.tex_resolution
self.tex_extra = torch.zeros(3, res, res).to(self.device)
if self.cfg.render.lighting_type == 'SH':
self.lights_uniform = torch.zeros(9, 3).to(self.device)
self.lights_uniform[0] = torch.tensor([np.sqrt(4 * np.pi)]).expand(3).float().to(self.device)
self.lights = self.lights_uniform.clone()
else:
self.lights = None
train_tensors += (
[self.shape, self.translation, self.rotation, self.neck_pose, self.jaw_pose, self.eyes_pose, self.expr,]
)
if not self.cfg.model.tex_painted:
train_tensors += [self.tex_pca]
if self.cfg.model.tex_extra:
train_tensors += [self.tex_extra]
if self.lights is not None:
train_tensors += [self.lights]
if self.cfg.model.use_static_offset:
self.static_offset = torch.zeros(1, self.flame.v_template.shape[0], 3).to(self.device)
train_tensors += [self.static_offset]
else:
self.static_offset = None
if self.cfg.model.use_dynamic_offset:
self.dynamic_offset = torch.zeros(self.n_timesteps, self.flame.v_template.shape[0], 3).to(self.device)
train_tensors += self.dynamic_offset
else:
self.dynamic_offset = None
# camera definition
if not self.calibrated:
# K contains focal length and principle point
self.focal_length = torch.tensor([1.5]).to(self.device)
self.RT = torch.eye(3, 4).to(self.device)
self.RT[2, 3] = -1 # (0, 0, -1) in w2c corresponds to (0, 0, 1) in c2w
train_tensors += [self.focal_length]
for t in train_tensors:
t.requires_grad = True
def optimize(self):
"""
Optimizes flame parameters on all frames of the dataset with random rampling
:return:
"""
self.global_step = 0
# first initialize frame either from calibration or previous frame
# with torch.no_grad():
# self.initialize_frame(frame_idx)
# sequential optimization of timesteps
self.logger.info(f"Start sequential tracking FLAME in {self.n_timesteps} frames")
dataloader = DataLoader(self.dataset, batch_size=None, shuffle=False, num_workers=0)
for sample in dataloader:
timestep = sample["timestep_index"][0].item()
if timestep == 0:
self.optimize_stage('lmk_init_rigid', sample)
self.optimize_stage('lmk_init_all', sample)
if self.cfg.exp.photometric:
self.optimize_stage('rgb_init_texture', sample)
self.optimize_stage('rgb_init_all', sample)
if self.cfg.model.use_static_offset:
self.optimize_stage('rgb_init_offset', sample)
if self.cfg.exp.photometric:
self.optimize_stage('rgb_sequential_tracking', sample)
else:
self.optimize_stage('lmk_sequential_tracking', sample)
self.initialize_next_timtestep(timestep)
self.evaluate(make_visualization=False, epoch=0)
self.logger.info(f"Start global optimization of all frames")
# global optimization with random sampling
dataloader = DataLoader(self.dataset, batch_size=None, shuffle=True, num_workers=0)
if self.cfg.exp.photometric:
self.optimize_stage(stage='rgb_global_tracking', dataloader=dataloader, lr_scale=0.1)
else:
self.optimize_stage(stage='lmk_global_tracking', dataloader=dataloader, lr_scale=0.1)
self.logger.info("All done.")
def optimize_stage(
self,
stage: Literal['lmk_init_rigid', 'lmk_init_all', 'rgb_init_texture', 'rgb_init_all', 'rgb_init_offset', 'rgb_sequential_tracking', 'rgb_global_tracking'],
sample = None,
dataloader = None,
lr_scale = 1.0,
):
params = self.get_train_parameters(stage)
optimizer = self.configure_optimizer(params, lr_scale=lr_scale)
if sample is not None:
num_steps = self.cfg.pipeline[stage].num_steps
for step_i in range(num_steps):
self.optimize_iter(sample, optimizer, stage)
else:
assert dataloader is not None
num_epochs = self.cfg.pipeline[stage].num_epochs
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
for epoch_i in range(num_epochs):
self.logger.info(f"EPOCH {epoch_i+1} / {num_epochs}")
for step_i, sample in enumerate(dataloader):
self.optimize_iter(sample, optimizer, stage)
scheduler.step()
if (epoch_i + 1) % 10 == 0:
self.evaluate(make_visualization=True, epoch=epoch_i+1)
def optimize_iter(self, sample, optimizer, stage):
# compute loss and update parameters
self.clear_cache()
timestep_index = sample["timestep_index"][0]
self.fill_cam_params_into_sample(sample)
(
E_total,
log_dict,
verts,
faces,
lmks,
albedos,
output_dict,
) = self.compute_energy(
sample, frame_idx=timestep_index, stage=stage,
)
optimizer.zero_grad()
E_total.backward()
optimizer.step()
# log energy terms and visualize
if (self.global_step+1) % self.log_interval_scalar == 0:
self.log_scalars(
log_dict,
timestep_index,
session="train",
stage=stage,
frame_step=self.global_step,
)
if (self.global_step+1) % self.log_interval_media == 0:
self.log_media(
verts,
faces,
lmks,
albedos,
output_dict,
sample,
timestep_index,
session="train",
stage=stage,
frame_step=self.global_step,
)
del verts, faces, lmks, albedos, output_dict
self.global_step += 1
def get_train_parameters(
self, stage: Literal['lmk_init_rigid', 'lmk_init_all', 'rgb_init_all', 'rgb_init_offset', 'rgb_sequential_tracking', 'rgb_global_tracking'],
):
"""
Collects the parameters to be optimized for the current frame
:return: dict of parameters
"""
self.opt_dict = defaultdict(bool) # dict to keep track of which parameters are optimized
for p in self.cfg.pipeline[stage].optimizable_params:
self.opt_dict[p] = True
params = defaultdict(list) # dict to collect parameters to be optimized
# shared properties
if self.opt_dict["cam"] and not self.calibrated:
params["cam"] = [self.focal_length]
if self.opt_dict["shape"]:
params["shape"] = [self.shape]
if self.opt_dict["texture"]:
if not self.cfg.model.tex_painted:
params["tex"] = [self.tex_pca]
if self.cfg.model.tex_extra:
params["tex_extra"] = [self.tex_extra]
if self.opt_dict["static_offset"] and self.cfg.model.use_static_offset:
params["static_offset"] = [self.static_offset]
if self.opt_dict["lights"] and self.lights is not None:
params["lights"] = [self.lights]
# per-frame properties
if self.opt_dict["pose"]:
params["translation"].append(self.translation)
params["rotation"].append(self.rotation)
if self.opt_dict["joints"]:
params["eyes"].append(self.eyes_pose)
params["neck"].append(self.neck_pose)
params["jaw"].append(self.jaw_pose)
if self.opt_dict["expr"]:
params["expr"].append(self.expr)
if self.opt_dict["dynamic_offset"] and self.cfg.model.use_dynamic_offset:
params["dynamic_offset"].append(self.dynamic_offset)
return params
def initialize_next_timtestep(self, timestep):
if timestep < self.n_timesteps - 1:
self.translation[timestep + 1].data.copy_(self.translation[timestep])
self.rotation[timestep + 1].data.copy_(self.rotation[timestep])
self.neck_pose[timestep + 1].data.copy_(self.neck_pose[timestep])
self.jaw_pose[timestep + 1].data.copy_(self.jaw_pose[timestep])
self.eyes_pose[timestep + 1].data.copy_(self.eyes_pose[timestep])
self.expr[timestep + 1].data.copy_(self.expr[timestep])
if self.cfg.model.use_dynamic_offset:
self.dynamic_offset[timestep + 1].data.copy_(self.dynamic_offset[timestep])