File size: 14,232 Bytes
17cd746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
# 
# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 
# property and proprietary rights in and to this software and related documentation. 
# Any commercial use, reproduction, disclosure or distribution of this software and 
# related documentation without an express license agreement from Toyota Motor Europe NV/SA 
# is strictly prohibited.
#


from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Literal, Tuple
import tyro
import importlib
from vhap.util.log import get_logger
logger = get_logger(__name__)


def import_module(module_name: str):
    module_name, class_name = module_name.rsplit(".", 1)
    module = getattr(importlib.import_module(module_name), class_name)
    return module


class Config:
    def __getitem__(self, __name: str):
        if hasattr(self, __name):
            return getattr(self, __name)
        else:
            raise AttributeError(f"{self.__class__.__name__} has no attribute '{__name}'")


@dataclass()
class DataConfig(Config):
    root_folder: Path = ''
    """The root folder for the dataset."""
    sequence: str = ''
    """The sequence name"""
    _target: str = "vhap.data.video_dataset.VideoDataset"
    """The target dataset class"""
    division: Optional[str] = None
    subset: Optional[str] = None
    calibrated: bool = False
    """Whether the cameras parameters are available"""
    align_cameras_to_axes: bool = True
    """Adjust how cameras distribute in the space with a global rotation"""
    camera_convention_conversion: str = 'opencv->opengl'
    target_extrinsic_type: Literal['w2c', 'c2w'] = 'w2c'
    n_downsample_rgb: Optional[int] = None
    """Load from downsampled RGB images to save data IO time"""
    scale_factor: float = 1.0
    """Further apply a scaling transformation after the downsampling of RGB"""
    background_color: Optional[Literal['white', 'black']] = 'white'
    use_alpha_map: bool = False
    use_landmark: bool = True
    landmark_source: Optional[Literal['face-alignment', 'star']] = "star"


@dataclass()
class ModelConfig(Config):
    n_shape: int = 300
    n_expr: int = 100
    n_tex: int = 100

    use_static_offset: bool = False
    """Optimize static offsets on top of FLAME vertices in the canonical space"""
    use_dynamic_offset: bool = False
    """Optimize dynamic offsets on top of the FLAME vertices in the canonical space"""
    add_teeth: bool = True
    """Add teeth to the FLAME model"""
    remove_lip_inside: bool = False
    """Remove the inner part of the lips from the FLAME model"""

    tex_resolution: int = 2048
    """The resolution of the extra texture map"""
    tex_painted: bool = True
    """Use a painted texture map instead the pca texture space as the base texture map"""
    tex_extra: bool = True
    """Optimize an extra texture map as the base texture map or the residual texture map"""
    # tex_clusters: tuple[str, ...] = ("skin", "hair", "sclerae", "lips_tight", "boundary")
    tex_clusters: tuple[str, ...] = ("skin", "hair", "boundary", "lips_tight", "teeth", "sclerae", "irises")
    """Regions that are supposed to share a similar color inside"""
    residual_tex: bool = True
    """Use the extra texture map as a residual component on top of the base texture"""
    occluded: tuple[str, ...] = ()  # to be used for updating stage configs in __post_init__
    """The regions that are occluded by the hair or garments"""
    
    flame_params_path: Optional[Path] = None


@dataclass()
class RenderConfig(Config):
    backend: Literal['nvdiffrast', 'pytorch3d'] = 'nvdiffrast'
    """The rendering backend"""
    use_opengl: bool = False
    """Use OpenGL for NVDiffRast"""
    background_train: Literal['white', 'black', 'target'] = 'target'
    """Background color/image for training"""
    disturb_rate_fg: Optional[float] = 0.5
    """The rate of disturbance for the foreground"""
    disturb_rate_bg: Optional[float] = 0.5
    """The rate of disturbance for the background. 0.6 best for multi-view, 0.3 best for single-view"""
    background_eval: Literal['white', 'black', 'target'] = 'target'
    """Background color/image for evaluation"""
    lighting_type: Literal['constant', 'front', 'front-range', 'SH'] = 'SH'
    """The type of lighting"""
    lighting_space: Literal['world', 'camera'] = 'world'
    """The space of lighting"""


@dataclass()
class LearningRateConfig(Config):
    base: float = 5e-3
    """shape, texture, rotation, eyes, neck, jaw"""
    translation: float = 1e-3
    expr: float = 5e-2
    static_offset: float = 5e-4
    dynamic_offset: float = 5e-4
    camera: float = 5e-3
    light: float = 5e-3


@dataclass()
class LossWeightConfig(Config):
    landmark: Optional[float] = 10.
    always_enable_jawline_landmarks: bool = True
    """Always enable the landmark loss for the jawline landmarks. Ignore disable_jawline_landmarks in stages."""

    photo: Optional[float] = 30.

    reg_shape: float = 3e-1
    reg_expr: float = 3e-2
    reg_tex_pca: float = 1e-4  # will make it hard to model hair color when too high
    
    reg_tex_res: Optional[float] = None  # 1e2 (when w/o reg_var)
    """Regularize the residual texture map"""
    reg_tex_res_clusters: Optional[float] = 1e1
    """Regularize the residual texture map inside each texture cluster"""
    reg_tex_res_for: tuple[str, ...] = ("sclerae", "teeth")
    """Regularize the residual texture map for the clusters specified"""
    reg_tex_tv: Optional[float] = 1e4  # important to split regions apart
    """Regularize the total variation of the texture map"""

    reg_light: Optional[float] = None
    """Regularize lighting parameters"""
    reg_diffuse: Optional[float] = 1e2
    """Regularize lighting parameters by the diffuse term"""

    reg_offset: Optional[float] = 3e2
    """Regularize the norm of offsets"""
    reg_offset_relax_coef: float = 1.
    """The coefficient for relaxing reg_offset for the regions specified"""
    reg_offset_relax_for: tuple[str, ...] = ("hair", "ears")
    """Relax the offset loss for the regions specified"""

    reg_offset_lap: Optional[float] = 1e6
    """Regularize the difference of laplacian coordinate caused by offsets"""
    reg_offset_lap_relax_coef: float = 0.1
    """The coefficient for relaxing reg_offset_lap for the regions specified"""
    reg_offset_lap_relax_for: tuple[str, ...] = ("hair", "ears")
    """Relax the offset loss for the regions specified"""

    reg_offset_rigid: Optional[float] = 3e2
    """Regularize the the offsets to be as-rigid-as-possible"""
    reg_offset_rigid_for: tuple[str, ...] = ("left_ear", "right_ear", "neck", "left_eye", "right_eye", "lips_tight")
    """Regularize the the offsets to be as-rigid-as-possible for the regions specified"""

    reg_offset_dynamic: Optional[float] = 3e5
    """Regularize the dynamic offsets to be temporally smooth"""

    blur_iter: int = 0
    """The number of iterations for blurring vertex weights"""
    
    smooth_trans: float = 3e2
    """global translation"""
    smooth_rot: float = 3e1
    """global rotation"""

    smooth_neck: float = 3e1
    """neck joint"""
    smooth_jaw: float = 1e-1
    """jaw joint"""
    smooth_eyes: float = 0
    """eyes joints"""

    prior_neck: float = 3e-1
    """Regularize the neck joint towards neutral"""
    prior_jaw: float = 3e-1
    """Regularize the jaw joint towards neutral"""
    prior_eyes: float = 3e-2
    """Regularize the eyes joints towards neutral"""
    

@dataclass()
class LogConfig(Config):
    interval_scalar: Optional[int] = 100
    """The step interval of scalar logging. Using an interval of stage_tracking.num_steps // 5 unless specified."""
    interval_media: Optional[int] = 500
    """The step interval of media logging. Using an interval of stage_tracking.num_steps unless specified."""
    image_format: Literal['jpg', 'png'] = 'jpg'
    """Output image format"""
    view_indices: Tuple[int, ...] = ()
    """Manually specify the view indices for log"""
    max_num_views: int = 3
    """The maximum number of views for log"""
    stack_views_in_rows: bool = True


@dataclass()
class ExperimentConfig(Config):
    output_folder: Path = Path('output/track')
    reuse_landmarks: bool = True
    keyframes: Tuple[int, ...] = tuple()
    photometric: bool = False
    """enable photometric optimization, otherwise only landmark optimization"""

@dataclass()
class StageConfig(Config):
    disable_jawline_landmarks: bool = False
    """Disable the landmark loss for the jawline landmarks since they are not accurate"""

@dataclass()
class StageLmkInitRigidConfig(StageConfig):
    """The stage for initializing the rigid parameters"""
    num_steps: int = 300
    optimizable_params: tuple[str, ...] = ("cam", "pose")

@dataclass()
class StageLmkInitAllConfig(StageConfig):
    """The stage for initializing all the parameters optimizable with landmark loss"""
    num_steps: int = 300
    optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr")

@dataclass()
class StageLmkSequentialTrackingConfig(StageConfig):
    """The stage for sequential tracking with landmark loss"""
    num_steps: int = 50
    optimizable_params: tuple[str, ...] = ("pose", "joints", "expr")

@dataclass()
class StageLmkGlobalTrackingConfig(StageConfig):
    """The stage for global tracking with landmark loss"""
    num_epochs: int = 0
    optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr")

@dataclass()
class PhotometricStageConfig(StageConfig):
    align_texture_except: tuple[str, ...] = ()
    """Align the inner region of rendered FLAME to the image, except for the regions specified"""
    align_boundary_except: tuple[str, ...] = ("bottomline",)  # necessary to avoid the bottomline of FLAME from being stretched to the bottom of the image
    """Align the boundary of FLAME to the image, except for the regions specified"""

@dataclass()
class StageRgbInitTextureConfig(PhotometricStageConfig):
    """The stage for initializing the texture map with photometric loss"""
    num_steps: int = 500
    optimizable_params: tuple[str, ...] = ("cam", "shape", "texture", "lights")
    align_texture_except: tuple[str, ...] = ("hair", "boundary", "neck")
    align_boundary_except: tuple[str, ...] = ("hair", "boundary")

@dataclass()
class StageRgbInitAllConfig(PhotometricStageConfig):
    """The stage for initializing all the parameters except the offsets with photometric loss"""
    num_steps: int = 500
    optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr", "texture", "lights")
    disable_jawline_landmarks: bool = True
    align_texture_except: tuple[str, ...] = ("hair", "boundary", "neck")
    align_boundary_except: tuple[str, ...] = ("hair", "bottomline")

@dataclass()
class StageRgbInitOffsetConfig(PhotometricStageConfig):
    """The stage for initializing the offsets with photometric loss"""
    num_steps: int = 500
    optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr", "texture", "lights", "static_offset")
    disable_jawline_landmarks: bool = True
    align_texture_except: tuple[str, ...] = ("hair", "boundary", "neck")

@dataclass()
class StageRgbSequentialTrackingConfig(PhotometricStageConfig):
    """The stage for sequential tracking with photometric loss"""
    num_steps: int = 50
    optimizable_params: tuple[str, ...] = ("pose", "joints", "expr", "texture", "dynamic_offset")
    disable_jawline_landmarks: bool = True

@dataclass()
class StageRgbGlobalTrackingConfig(PhotometricStageConfig):
    """The stage for global tracking with photometric loss"""
    num_epochs: int = 30
    optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr", "texture", "lights", "static_offset", "dynamic_offset")
    disable_jawline_landmarks: bool = True

@dataclass()
class PipelineConfig(Config):
    lmk_init_rigid: StageLmkInitRigidConfig
    lmk_init_all: StageLmkInitAllConfig
    lmk_sequential_tracking: StageLmkSequentialTrackingConfig
    lmk_global_tracking: StageLmkGlobalTrackingConfig
    rgb_init_texture: StageRgbInitTextureConfig
    rgb_init_all: StageRgbInitAllConfig
    rgb_init_offset: StageRgbInitOffsetConfig
    rgb_sequential_tracking: StageRgbSequentialTrackingConfig
    rgb_global_tracking: StageRgbGlobalTrackingConfig

    
@dataclass()
class BaseTrackingConfig(Config):
    data: DataConfig
    model: ModelConfig
    render: RenderConfig
    log: LogConfig
    exp: ExperimentConfig
    lr: LearningRateConfig
    w: LossWeightConfig
    pipeline: PipelineConfig

    begin_stage: Optional[str] = None
    """Begin from the specified stage for debugging"""
    begin_frame_idx: int = 0
    """Begin from the specified frame index for debugging"""
    async_func: bool = True
    """Allow asynchronous function calls for speed up"""
    device: Literal['cuda', 'cpu'] = 'cuda'

    def get_occluded(self):
        occluded_table = {
        }
        if self.data.sequence in occluded_table:
            logger.info(f"Automatically setting cfg.model.occluded to {occluded_table[self.data.sequence]}")
            self.model.occluded = occluded_table[self.data.sequence]

    def __post_init__(self):
        self.get_occluded()

        if not self.model.use_static_offset and not self.model.use_dynamic_offset:
            self.model.occluded = tuple(list(self.model.occluded) + ['hair'])  # disable boundary alignment for the hair region if no offset is used

        for cfg_stage in self.pipeline.__dict__.values():
            if isinstance(cfg_stage, PhotometricStageConfig):
                cfg_stage.align_texture_except = tuple(list(cfg_stage.align_texture_except) + list(self.model.occluded))
                cfg_stage.align_boundary_except = tuple(list(cfg_stage.align_boundary_except) + list(self.model.occluded))

        if self.begin_stage is not None:
            skip = True
            for cfg_stage in self.pipeline.__dict__.values():
                if cfg_stage.__class__.__name__.lower() == self.begin_stage:
                    skip = False
                if skip:
                    cfg_stage.num_steps = 0


if __name__ == "__main__":
    config = tyro.cli(BaseTrackingConfig)
    print(tyro.to_yaml(config))