Spaces:
Running
on
Zero
Running
on
Zero
File size: 14,232 Bytes
17cd746 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 |
#
# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
# property and proprietary rights in and to this software and related documentation.
# Any commercial use, reproduction, disclosure or distribution of this software and
# related documentation without an express license agreement from Toyota Motor Europe NV/SA
# is strictly prohibited.
#
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Literal, Tuple
import tyro
import importlib
from vhap.util.log import get_logger
logger = get_logger(__name__)
def import_module(module_name: str):
module_name, class_name = module_name.rsplit(".", 1)
module = getattr(importlib.import_module(module_name), class_name)
return module
class Config:
def __getitem__(self, __name: str):
if hasattr(self, __name):
return getattr(self, __name)
else:
raise AttributeError(f"{self.__class__.__name__} has no attribute '{__name}'")
@dataclass()
class DataConfig(Config):
root_folder: Path = ''
"""The root folder for the dataset."""
sequence: str = ''
"""The sequence name"""
_target: str = "vhap.data.video_dataset.VideoDataset"
"""The target dataset class"""
division: Optional[str] = None
subset: Optional[str] = None
calibrated: bool = False
"""Whether the cameras parameters are available"""
align_cameras_to_axes: bool = True
"""Adjust how cameras distribute in the space with a global rotation"""
camera_convention_conversion: str = 'opencv->opengl'
target_extrinsic_type: Literal['w2c', 'c2w'] = 'w2c'
n_downsample_rgb: Optional[int] = None
"""Load from downsampled RGB images to save data IO time"""
scale_factor: float = 1.0
"""Further apply a scaling transformation after the downsampling of RGB"""
background_color: Optional[Literal['white', 'black']] = 'white'
use_alpha_map: bool = False
use_landmark: bool = True
landmark_source: Optional[Literal['face-alignment', 'star']] = "star"
@dataclass()
class ModelConfig(Config):
n_shape: int = 300
n_expr: int = 100
n_tex: int = 100
use_static_offset: bool = False
"""Optimize static offsets on top of FLAME vertices in the canonical space"""
use_dynamic_offset: bool = False
"""Optimize dynamic offsets on top of the FLAME vertices in the canonical space"""
add_teeth: bool = True
"""Add teeth to the FLAME model"""
remove_lip_inside: bool = False
"""Remove the inner part of the lips from the FLAME model"""
tex_resolution: int = 2048
"""The resolution of the extra texture map"""
tex_painted: bool = True
"""Use a painted texture map instead the pca texture space as the base texture map"""
tex_extra: bool = True
"""Optimize an extra texture map as the base texture map or the residual texture map"""
# tex_clusters: tuple[str, ...] = ("skin", "hair", "sclerae", "lips_tight", "boundary")
tex_clusters: tuple[str, ...] = ("skin", "hair", "boundary", "lips_tight", "teeth", "sclerae", "irises")
"""Regions that are supposed to share a similar color inside"""
residual_tex: bool = True
"""Use the extra texture map as a residual component on top of the base texture"""
occluded: tuple[str, ...] = () # to be used for updating stage configs in __post_init__
"""The regions that are occluded by the hair or garments"""
flame_params_path: Optional[Path] = None
@dataclass()
class RenderConfig(Config):
backend: Literal['nvdiffrast', 'pytorch3d'] = 'nvdiffrast'
"""The rendering backend"""
use_opengl: bool = False
"""Use OpenGL for NVDiffRast"""
background_train: Literal['white', 'black', 'target'] = 'target'
"""Background color/image for training"""
disturb_rate_fg: Optional[float] = 0.5
"""The rate of disturbance for the foreground"""
disturb_rate_bg: Optional[float] = 0.5
"""The rate of disturbance for the background. 0.6 best for multi-view, 0.3 best for single-view"""
background_eval: Literal['white', 'black', 'target'] = 'target'
"""Background color/image for evaluation"""
lighting_type: Literal['constant', 'front', 'front-range', 'SH'] = 'SH'
"""The type of lighting"""
lighting_space: Literal['world', 'camera'] = 'world'
"""The space of lighting"""
@dataclass()
class LearningRateConfig(Config):
base: float = 5e-3
"""shape, texture, rotation, eyes, neck, jaw"""
translation: float = 1e-3
expr: float = 5e-2
static_offset: float = 5e-4
dynamic_offset: float = 5e-4
camera: float = 5e-3
light: float = 5e-3
@dataclass()
class LossWeightConfig(Config):
landmark: Optional[float] = 10.
always_enable_jawline_landmarks: bool = True
"""Always enable the landmark loss for the jawline landmarks. Ignore disable_jawline_landmarks in stages."""
photo: Optional[float] = 30.
reg_shape: float = 3e-1
reg_expr: float = 3e-2
reg_tex_pca: float = 1e-4 # will make it hard to model hair color when too high
reg_tex_res: Optional[float] = None # 1e2 (when w/o reg_var)
"""Regularize the residual texture map"""
reg_tex_res_clusters: Optional[float] = 1e1
"""Regularize the residual texture map inside each texture cluster"""
reg_tex_res_for: tuple[str, ...] = ("sclerae", "teeth")
"""Regularize the residual texture map for the clusters specified"""
reg_tex_tv: Optional[float] = 1e4 # important to split regions apart
"""Regularize the total variation of the texture map"""
reg_light: Optional[float] = None
"""Regularize lighting parameters"""
reg_diffuse: Optional[float] = 1e2
"""Regularize lighting parameters by the diffuse term"""
reg_offset: Optional[float] = 3e2
"""Regularize the norm of offsets"""
reg_offset_relax_coef: float = 1.
"""The coefficient for relaxing reg_offset for the regions specified"""
reg_offset_relax_for: tuple[str, ...] = ("hair", "ears")
"""Relax the offset loss for the regions specified"""
reg_offset_lap: Optional[float] = 1e6
"""Regularize the difference of laplacian coordinate caused by offsets"""
reg_offset_lap_relax_coef: float = 0.1
"""The coefficient for relaxing reg_offset_lap for the regions specified"""
reg_offset_lap_relax_for: tuple[str, ...] = ("hair", "ears")
"""Relax the offset loss for the regions specified"""
reg_offset_rigid: Optional[float] = 3e2
"""Regularize the the offsets to be as-rigid-as-possible"""
reg_offset_rigid_for: tuple[str, ...] = ("left_ear", "right_ear", "neck", "left_eye", "right_eye", "lips_tight")
"""Regularize the the offsets to be as-rigid-as-possible for the regions specified"""
reg_offset_dynamic: Optional[float] = 3e5
"""Regularize the dynamic offsets to be temporally smooth"""
blur_iter: int = 0
"""The number of iterations for blurring vertex weights"""
smooth_trans: float = 3e2
"""global translation"""
smooth_rot: float = 3e1
"""global rotation"""
smooth_neck: float = 3e1
"""neck joint"""
smooth_jaw: float = 1e-1
"""jaw joint"""
smooth_eyes: float = 0
"""eyes joints"""
prior_neck: float = 3e-1
"""Regularize the neck joint towards neutral"""
prior_jaw: float = 3e-1
"""Regularize the jaw joint towards neutral"""
prior_eyes: float = 3e-2
"""Regularize the eyes joints towards neutral"""
@dataclass()
class LogConfig(Config):
interval_scalar: Optional[int] = 100
"""The step interval of scalar logging. Using an interval of stage_tracking.num_steps // 5 unless specified."""
interval_media: Optional[int] = 500
"""The step interval of media logging. Using an interval of stage_tracking.num_steps unless specified."""
image_format: Literal['jpg', 'png'] = 'jpg'
"""Output image format"""
view_indices: Tuple[int, ...] = ()
"""Manually specify the view indices for log"""
max_num_views: int = 3
"""The maximum number of views for log"""
stack_views_in_rows: bool = True
@dataclass()
class ExperimentConfig(Config):
output_folder: Path = Path('output/track')
reuse_landmarks: bool = True
keyframes: Tuple[int, ...] = tuple()
photometric: bool = False
"""enable photometric optimization, otherwise only landmark optimization"""
@dataclass()
class StageConfig(Config):
disable_jawline_landmarks: bool = False
"""Disable the landmark loss for the jawline landmarks since they are not accurate"""
@dataclass()
class StageLmkInitRigidConfig(StageConfig):
"""The stage for initializing the rigid parameters"""
num_steps: int = 300
optimizable_params: tuple[str, ...] = ("cam", "pose")
@dataclass()
class StageLmkInitAllConfig(StageConfig):
"""The stage for initializing all the parameters optimizable with landmark loss"""
num_steps: int = 300
optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr")
@dataclass()
class StageLmkSequentialTrackingConfig(StageConfig):
"""The stage for sequential tracking with landmark loss"""
num_steps: int = 50
optimizable_params: tuple[str, ...] = ("pose", "joints", "expr")
@dataclass()
class StageLmkGlobalTrackingConfig(StageConfig):
"""The stage for global tracking with landmark loss"""
num_epochs: int = 0
optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr")
@dataclass()
class PhotometricStageConfig(StageConfig):
align_texture_except: tuple[str, ...] = ()
"""Align the inner region of rendered FLAME to the image, except for the regions specified"""
align_boundary_except: tuple[str, ...] = ("bottomline",) # necessary to avoid the bottomline of FLAME from being stretched to the bottom of the image
"""Align the boundary of FLAME to the image, except for the regions specified"""
@dataclass()
class StageRgbInitTextureConfig(PhotometricStageConfig):
"""The stage for initializing the texture map with photometric loss"""
num_steps: int = 500
optimizable_params: tuple[str, ...] = ("cam", "shape", "texture", "lights")
align_texture_except: tuple[str, ...] = ("hair", "boundary", "neck")
align_boundary_except: tuple[str, ...] = ("hair", "boundary")
@dataclass()
class StageRgbInitAllConfig(PhotometricStageConfig):
"""The stage for initializing all the parameters except the offsets with photometric loss"""
num_steps: int = 500
optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr", "texture", "lights")
disable_jawline_landmarks: bool = True
align_texture_except: tuple[str, ...] = ("hair", "boundary", "neck")
align_boundary_except: tuple[str, ...] = ("hair", "bottomline")
@dataclass()
class StageRgbInitOffsetConfig(PhotometricStageConfig):
"""The stage for initializing the offsets with photometric loss"""
num_steps: int = 500
optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr", "texture", "lights", "static_offset")
disable_jawline_landmarks: bool = True
align_texture_except: tuple[str, ...] = ("hair", "boundary", "neck")
@dataclass()
class StageRgbSequentialTrackingConfig(PhotometricStageConfig):
"""The stage for sequential tracking with photometric loss"""
num_steps: int = 50
optimizable_params: tuple[str, ...] = ("pose", "joints", "expr", "texture", "dynamic_offset")
disable_jawline_landmarks: bool = True
@dataclass()
class StageRgbGlobalTrackingConfig(PhotometricStageConfig):
"""The stage for global tracking with photometric loss"""
num_epochs: int = 30
optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr", "texture", "lights", "static_offset", "dynamic_offset")
disable_jawline_landmarks: bool = True
@dataclass()
class PipelineConfig(Config):
lmk_init_rigid: StageLmkInitRigidConfig
lmk_init_all: StageLmkInitAllConfig
lmk_sequential_tracking: StageLmkSequentialTrackingConfig
lmk_global_tracking: StageLmkGlobalTrackingConfig
rgb_init_texture: StageRgbInitTextureConfig
rgb_init_all: StageRgbInitAllConfig
rgb_init_offset: StageRgbInitOffsetConfig
rgb_sequential_tracking: StageRgbSequentialTrackingConfig
rgb_global_tracking: StageRgbGlobalTrackingConfig
@dataclass()
class BaseTrackingConfig(Config):
data: DataConfig
model: ModelConfig
render: RenderConfig
log: LogConfig
exp: ExperimentConfig
lr: LearningRateConfig
w: LossWeightConfig
pipeline: PipelineConfig
begin_stage: Optional[str] = None
"""Begin from the specified stage for debugging"""
begin_frame_idx: int = 0
"""Begin from the specified frame index for debugging"""
async_func: bool = True
"""Allow asynchronous function calls for speed up"""
device: Literal['cuda', 'cpu'] = 'cuda'
def get_occluded(self):
occluded_table = {
}
if self.data.sequence in occluded_table:
logger.info(f"Automatically setting cfg.model.occluded to {occluded_table[self.data.sequence]}")
self.model.occluded = occluded_table[self.data.sequence]
def __post_init__(self):
self.get_occluded()
if not self.model.use_static_offset and not self.model.use_dynamic_offset:
self.model.occluded = tuple(list(self.model.occluded) + ['hair']) # disable boundary alignment for the hair region if no offset is used
for cfg_stage in self.pipeline.__dict__.values():
if isinstance(cfg_stage, PhotometricStageConfig):
cfg_stage.align_texture_except = tuple(list(cfg_stage.align_texture_except) + list(self.model.occluded))
cfg_stage.align_boundary_except = tuple(list(cfg_stage.align_boundary_except) + list(self.model.occluded))
if self.begin_stage is not None:
skip = True
for cfg_stage in self.pipeline.__dict__.values():
if cfg_stage.__class__.__name__.lower() == self.begin_stage:
skip = False
if skip:
cfg_stage.num_steps = 0
if __name__ == "__main__":
config = tyro.cli(BaseTrackingConfig)
print(tyro.to_yaml(config)) |