File size: 3,500 Bytes
17cd746 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
#
# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
# property and proprietary rights in and to this software and related documentation.
# Any commercial use, reproduction, disclosure or distribution of this software and
# related documentation without an express license agreement from Toyota Motor Europe NV/SA
# is strictly prohibited.
#
from typing import Optional, Literal
from dataclasses import dataclass
import tyro
from vhap.config.base import (
StageRgbSequentialTrackingConfig, StageRgbGlobalTrackingConfig, PipelineConfig,
DataConfig, LossWeightConfig, BaseTrackingConfig,
)
from vhap.util.log import get_logger
logger = get_logger(__name__)
@dataclass()
class NersembleDataConfig(DataConfig):
_target: str = "vhap.data.nersemble_dataset.NeRSembleDataset"
calibrated: bool = True
image_size_during_calibration: Optional[tuple[int, int]] = (3208, 2200)
"""(height, width). Will be use to convert principle points when the image size is not included in the camera parameters."""
background_color: Optional[Literal['white', 'black']] = None
landmark_source: Optional[Literal["face-alignment", 'star']] = "star"
subject: str = ""
"""Subject ID. Such as 018, 218, 251, 253"""
use_color_correction: bool = True
"""Whether to use color correction to harmonize the color of the input images."""
@dataclass()
class NersembleLossWeightConfig(LossWeightConfig):
landmark: Optional[float] = 3. # should not be lower to avoid collapse
always_enable_jawline_landmarks: bool = False # allow disable_jawline_landmarks in StageConfig to work
reg_expr: float = 1e-2 # for best expressivness
reg_tex_tv: Optional[float] = 1e5 # 10x of the base value
@dataclass()
class NersembleStageRgbSequentialTrackingConfig(StageRgbSequentialTrackingConfig):
optimizable_params: tuple[str, ...] = ("pose", "joints", "expr", "dynamic_offset")
align_texture_except: tuple[str, ...] = ("boundary",)
align_boundary_except: tuple[str, ...] = ("boundary",)
"""Due to the limited flexibility in the lower neck region of FLAME, we relax the
alignment constraints for better alignment in the face region.
"""
@dataclass()
class NersembleStageRgbGlobalTrackingConfig(StageRgbGlobalTrackingConfig):
align_texture_except: tuple[str, ...] = ("boundary",)
align_boundary_except: tuple[str, ...] = ("boundary",)
"""Due to the limited flexibility in the lower neck region of FLAME, we relax the
alignment constraints for better alignment in the face region.
"""
@dataclass()
class NersemblePipelineConfig(PipelineConfig):
rgb_sequential_tracking: NersembleStageRgbSequentialTrackingConfig
rgb_global_tracking: NersembleStageRgbGlobalTrackingConfig
@dataclass()
class NersembleTrackingConfig(BaseTrackingConfig):
data: NersembleDataConfig
w: NersembleLossWeightConfig
pipeline: NersemblePipelineConfig
def get_occluded(self):
occluded_table = {
'018': ('neck_lower',),
'218': ('neck_lower',),
'251': ('neck_lower', 'boundary'),
'253': ('neck_lower',),
}
if self.data.subject in occluded_table:
logger.info(f"Automatically setting cfg.model.occluded to {occluded_table[self.data.subject]}")
self.model.occluded = occluded_table[self.data.subject]
if __name__ == "__main__":
config = tyro.cli(NersembleTrackingConfig)
print(tyro.to_yaml(config)) |