|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import Optional, Literal, Dict, List |
|
from glob import glob |
|
import concurrent.futures |
|
import multiprocessing |
|
from copy import deepcopy |
|
import yaml |
|
import json |
|
import tyro |
|
from pathlib import Path |
|
from tqdm import tqdm |
|
from PIL import Image |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import DataLoader |
|
import torchvision |
|
|
|
|
|
from vhap.config.base import DataConfig, ModelConfig, import_module |
|
from vhap.data.nerf_dataset import NeRFDataset |
|
from vhap.model.flame import FlameHead |
|
from vhap.util.mesh import get_obj_content |
|
from vhap.util.render_nvdiffrast import NVDiffRenderer |
|
|
|
|
|
import torch.multiprocessing |
|
torch.multiprocessing.set_sharing_strategy('file_system') |
|
|
|
|
|
max_threads = min(multiprocessing.cpu_count(), 8) |
|
|
|
|
|
class NeRFDatasetWriter: |
|
def __init__(self, cfg_data: DataConfig, tgt_folder: Path, subset:Optional[str]=None, scale_factor: Optional[float]=None, background_color: Optional[str]=None): |
|
self.cfg_data = cfg_data |
|
self.tgt_folder = tgt_folder |
|
|
|
print("==== Config: data ====") |
|
print(tyro.to_yaml(cfg_data)) |
|
|
|
cfg_data.target_extrinsic_type = 'c2w' |
|
cfg_data.background_color = 'white' |
|
cfg_data.use_alpha_map = True |
|
dataset = import_module(cfg_data._target)(cfg=cfg_data) |
|
self.dataloader = DataLoader(dataset, shuffle=False, batch_size=None, collate_fn=lambda x: x, num_workers=0) |
|
|
|
def write(self): |
|
if not self.tgt_folder.exists(): |
|
self.tgt_folder.mkdir(parents=True) |
|
|
|
db = { |
|
"frames": [], |
|
} |
|
|
|
print(f"Writing images to {self.tgt_folder}") |
|
worker_args = [] |
|
timestep_indices = set() |
|
camera_indices = set() |
|
for i, item in tqdm(enumerate(self.dataloader), total=len(self.dataloader)): |
|
|
|
|
|
timestep_indices.add(item['timestep_index']) |
|
camera_indices.add(item['camera_index']) |
|
|
|
extrinsic = item['extrinsic'] |
|
transform_matrix = torch.cat([extrinsic, torch.tensor([[0,0,0,1]])], dim=0).numpy() |
|
|
|
intrinsic = item['intrinsic'].double().numpy() |
|
|
|
cx = intrinsic[0, 2] |
|
cy = intrinsic[1, 2] |
|
fl_x = intrinsic[0, 0] |
|
fl_y = intrinsic[1, 1] |
|
h = item['rgb'].shape[0] |
|
w = item['rgb'].shape[1] |
|
angle_x = math.atan(w / (fl_x * 2)) * 2 |
|
angle_y = math.atan(h / (fl_y * 2)) * 2 |
|
|
|
frame_item = { |
|
"timestep_index": item['timestep_index'], |
|
"timestep_index_original": item['timestep_index_original'], |
|
"timestep_id": item['timestep_id'], |
|
"camera_index": item['camera_index'], |
|
"camera_id": item['camera_id'], |
|
|
|
"cx": cx, |
|
"cy": cy, |
|
"fl_x": fl_x, |
|
"fl_y": fl_y, |
|
"h": h, |
|
"w": w, |
|
"camera_angle_x": angle_x, |
|
"camera_angle_y": angle_y, |
|
|
|
"transform_matrix": transform_matrix.tolist(), |
|
|
|
"file_path": f"images/{item['timestep_index']:05d}_{item['camera_index']:02d}.png", |
|
} |
|
|
|
path2data = { |
|
str(self.tgt_folder / frame_item['file_path']): item['rgb'], |
|
} |
|
|
|
if 'alpha_map' in item: |
|
frame_item['fg_mask_path'] = f"fg_masks/{item['timestep_index']:05d}_{item['camera_index']:02d}.png" |
|
path2data[str(self.tgt_folder / frame_item['fg_mask_path'])] = item['alpha_map'] |
|
|
|
db['frames'].append(frame_item) |
|
worker_args.append([path2data]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(worker_args) == max_threads or i == len(self.dataloader)-1: |
|
with concurrent.futures.ThreadPoolExecutor(max_threads) as executor: |
|
futures = [executor.submit(write_data, *args) for args in worker_args] |
|
concurrent.futures.wait(futures) |
|
worker_args = [] |
|
|
|
|
|
db.update({ |
|
"cx": cx, |
|
"cy": cy, |
|
"fl_x": fl_x, |
|
"fl_y": fl_y, |
|
"h": h, |
|
"w": w, |
|
"camera_angle_x": angle_x, |
|
"camera_angle_y": angle_y |
|
}) |
|
|
|
|
|
db['timestep_indices'] = sorted(list(timestep_indices)) |
|
db['camera_indices'] = sorted(list(camera_indices)) |
|
|
|
write_json(db, self.tgt_folder) |
|
write_json(db, self.tgt_folder, division='backup') |
|
|
|
|
|
class TrackedFLAMEDatasetWriter: |
|
def __init__(self, cfg_model: ModelConfig, src_folder: Path, tgt_folder: Path, mode: Literal['mesh', 'param'], epoch: int = -1): |
|
print("---- Config: model ----") |
|
print(tyro.to_yaml(cfg_model)) |
|
|
|
self.cfg_model = cfg_model |
|
self.src_folder = src_folder |
|
self.tgt_folder = tgt_folder |
|
self.mode = mode |
|
|
|
db_backup_path = tgt_folder / "transforms_backup.json" |
|
assert db_backup_path.exists(), f"Could not find {db_backup_path}" |
|
print(f"Loading database from: {db_backup_path}") |
|
self.db = json.load(open(db_backup_path, "r")) |
|
|
|
paths = [Path(p) for p in glob(str(src_folder / "tracked_flame_params*.npz"))] |
|
epochs = [int(p.stem.split('_')[-1]) for p in paths] |
|
if epoch == -1: |
|
index = np.argmax(epochs) |
|
else: |
|
index = epochs.index(epoch) |
|
flame_params_path = paths[index] |
|
|
|
assert flame_params_path.exists(), f"Could not find {flame_params_path}" |
|
print(f"Loading FLAME parameters from: {flame_params_path}") |
|
self.flame_params = dict(np.load(flame_params_path)) |
|
|
|
if "focal_length" in self.flame_params: |
|
self.focal_length = self.flame_params['focal_length'].item() |
|
else: |
|
self.focal_length = None |
|
|
|
|
|
self.M = self.relocate_flame_meshes(self.flame_params) |
|
|
|
print("Initializing FLAME model...") |
|
self.flame_model = FlameHead(cfg_model.n_shape, cfg_model.n_expr, add_teeth=True) |
|
|
|
def relocate_flame_meshes(self, flame_param): |
|
""" Relocate FLAME to the origin and return the transformation matrix to modify camera poses. """ |
|
|
|
Ts = torch.tensor(flame_param['translation']) |
|
|
|
|
|
T_mean = Ts.mean(0) |
|
M = torch.eye(4) |
|
|
|
M[:3, 3] = -T_mean |
|
|
|
|
|
flame_param['translation'] = (M[:3, 3] + Ts).numpy() |
|
return M.numpy() |
|
|
|
def replace_cam_params(self, item): |
|
c2w = np.eye(4) |
|
c2w[2, 3] = 1 |
|
item['transform_matrix'] = c2w |
|
|
|
h = item['h'] |
|
w = item['w'] |
|
fl_x = self.focal_length * max(h, w) |
|
fl_y = self.focal_length * max(h, w) |
|
angle_x = math.atan(w / (fl_x * 2)) * 2 |
|
angle_y = math.atan(h / (fl_y * 2)) * 2 |
|
|
|
item.update({ |
|
"cx": w / 2, |
|
"cy": h / 2, |
|
"fl_x": fl_x, |
|
"fl_y": fl_y, |
|
"camera_angle_x": angle_x, |
|
"camera_angle_y": angle_y, |
|
|
|
"transform_matrix": c2w.tolist(), |
|
}) |
|
|
|
def write(self): |
|
if self.mode == 'mesh': |
|
self.write_canonical_mesh() |
|
indices = self.db['timestep_indices'] |
|
verts = infer_flame_params(self.flame_model, self.flame_params, indices) |
|
|
|
print(f"Writing FLAME expressions and meshes to: {self.tgt_folder}") |
|
elif self.mode == 'param': |
|
self.write_canonical_flame_param() |
|
print(f"Writing FLAME parameters to: {self.tgt_folder}") |
|
|
|
saved = [False] * len(self.db['timestep_indices']) |
|
num_processes = 0 |
|
worker_args = [] |
|
for i, frame in tqdm(enumerate(self.db['frames']), total=len(self.db['frames'])): |
|
if self.focal_length is not None: |
|
self.replace_cam_params(frame) |
|
|
|
frame['transform_matrix'] = (self.M @ np.array(frame['transform_matrix'])).tolist() |
|
|
|
ti_orig = frame['timestep_index_original'] |
|
ti = frame['timestep_index'] |
|
|
|
|
|
if self.mode == 'mesh': |
|
frame['exp_path'] = f"flame/exp/{ti:05d}.txt" |
|
frame['mesh_path'] = f"meshes/{ti:05d}.obj" |
|
if not saved[ti]: |
|
worker_args.append([self.tgt_folder, frame['exp_path'], self.flame_params['expr'][ti_orig], frame['mesh_path'], verts[ti_orig], self.flame_model.faces]) |
|
saved[ti] = True |
|
func = self.write_expr_and_mesh |
|
elif self.mode == 'param': |
|
frame['flame_param_path'] = f"flame_param/{ti:05d}.npz" |
|
if not saved[ti]: |
|
worker_args.append([self.tgt_folder, frame['flame_param_path'], self.flame_params, ti_orig]) |
|
saved[ti] = True |
|
func = self.write_flame_param |
|
|
|
if len(worker_args) > 0: |
|
func(*worker_args.pop()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
write_json(self.db, self.tgt_folder) |
|
write_json(self.db, self.tgt_folder, division='backup_flame') |
|
|
|
def write_canonical_mesh(self): |
|
print(f"Inferencing FLAME in the canonical space...") |
|
if 'static_offset' in self.flame_params: |
|
static_offset = torch.tensor(self.flame_params['static_offset']) |
|
else: |
|
static_offset = None |
|
with torch.no_grad(): |
|
ret = self.flame_model( |
|
torch.tensor(self.flame_params['shape'])[None, ...], |
|
torch.zeros(*self.flame_params['expr'][:1].shape), |
|
torch.zeros(*self.flame_params['rotation'][:1].shape), |
|
torch.zeros(*self.flame_params['neck_pose'][:1].shape), |
|
torch.tensor([[0.3, 0, 0]]), |
|
torch.zeros(*self.flame_params['eyes_pose'][:1].shape), |
|
torch.zeros(*self.flame_params['translation'][:1].shape), |
|
return_verts_cano=False, |
|
static_offset=static_offset, |
|
) |
|
verts = ret[0] |
|
|
|
cano_mesh_path = self.tgt_folder / 'canonical.obj' |
|
print(f"Writing canonical mesh to: {cano_mesh_path}") |
|
obj_data = get_obj_content(verts[0], self.flame_model.faces) |
|
write_data({cano_mesh_path: obj_data}) |
|
|
|
@staticmethod |
|
def write_expr_and_mesh(tgt_folder, exp_path, expr, mesh_path, verts, faces): |
|
path2data = {} |
|
|
|
expr_data = '\n'.join([str(n) for n in expr]) |
|
path2data[tgt_folder / exp_path] = expr_data |
|
|
|
obj_data = get_obj_content(verts, faces) |
|
path2data[tgt_folder / mesh_path] = obj_data |
|
write_data(path2data) |
|
|
|
def write_canonical_flame_param(self): |
|
flame_param = { |
|
'translation': np.zeros_like(self.flame_params['translation'][:1]), |
|
'rotation': np.zeros_like(self.flame_params['rotation'][:1]), |
|
'neck_pose': np.zeros_like(self.flame_params['neck_pose'][:1]), |
|
'jaw_pose': np.array([[0.3, 0, 0]]), |
|
'eyes_pose': np.zeros_like(self.flame_params['eyes_pose'][:1]), |
|
'shape': self.flame_params['shape'], |
|
'expr': np.zeros_like(self.flame_params['expr'][:1]), |
|
} |
|
if 'static_offset' in self.flame_params: |
|
flame_param['static_offset'] = self.flame_params['static_offset'] |
|
|
|
cano_flame_param_path = self.tgt_folder / 'canonical_flame_param.npz' |
|
print(f"Writing canonical FLAME parameters to: {cano_flame_param_path}") |
|
write_data({cano_flame_param_path: flame_param}) |
|
|
|
@staticmethod |
|
def write_flame_param(tgt_folder, flame_param_path, flame_params, tid): |
|
params = { |
|
'translation': flame_params['translation'][[tid]], |
|
'rotation': flame_params['rotation'][[tid]], |
|
'neck_pose': flame_params['neck_pose'][[tid]], |
|
'jaw_pose': flame_params['jaw_pose'][[tid]], |
|
'eyes_pose': flame_params['eyes_pose'][[tid]], |
|
'shape': flame_params['shape'], |
|
'expr': flame_params['expr'][[tid]], |
|
} |
|
|
|
if 'static_offset' in flame_params: |
|
params['static_offset'] = flame_params['static_offset'] |
|
if 'dynamic_offset' in flame_params: |
|
params['dynamic_offset'] = flame_params['dynamic_offset'][[tid]] |
|
|
|
path2data = {tgt_folder / flame_param_path: params} |
|
write_data(path2data) |
|
|
|
class MaskFromFLAME: |
|
def __init__(self, cfg_model: ModelConfig, tgt_folder, background_color: str) -> None: |
|
background_color = self.cfg_data.background_color if background_color is None else background_color |
|
if background_color == 'white': |
|
self.background_tensor = torch.tensor([255, 255, 255]).byte() |
|
elif background_color == 'black': |
|
self.background_tensor = torch.tensor([0, 0, 0]).byte() |
|
else: |
|
raise ValueError(f"Unknown background color: {background_color}") |
|
|
|
dataset = NeRFDataset( |
|
root_folder=tgt_folder, |
|
division=None, |
|
camera_convention_conversion=None, |
|
target_extrinsic_type='w2c', |
|
use_fg_mask=True, |
|
use_flame_param=True, |
|
) |
|
self.dataloader = DataLoader(dataset, shuffle=False, batch_size=None, collate_fn=None, num_workers=0) |
|
|
|
self.flame_model = FlameHead(cfg_model.n_shape, cfg_model.n_expr, add_teeth=True) |
|
|
|
self.mesh_renderer = NVDiffRenderer(use_opengl=False) |
|
|
|
@torch.no_grad() |
|
def write(self): |
|
t2verts = {} |
|
worker_args = [] |
|
print(f"Generating masks from FLAME...") |
|
for i, frame in enumerate(tqdm(self.dataloader)): |
|
|
|
|
|
timestep = frame['timestep_index'] |
|
if timestep not in t2verts: |
|
t2verts[timestep] = infer_flame_params(self.flame_model, frame['flame_param'], [0]).cuda() |
|
verts = t2verts[timestep] |
|
|
|
|
|
RT = frame['extrinsics'].cuda()[None] |
|
K = frame['intrinsics'].cuda()[None] |
|
h = frame['image_height'] |
|
w = frame['image_width'] |
|
|
|
|
|
mask = self.get_mask_tilted_line(verts, RT, K, h, w) |
|
|
|
|
|
img = frame['image'].cuda() |
|
img = img * mask[:, :, None] + self.background_tensor.cuda()[None, None, :] * (1-mask)[:, :, None] |
|
|
|
|
|
path2data = { |
|
str(frame['image_path']): img.byte().cpu().numpy(), |
|
} |
|
|
|
if 'fg_mask_path' in frame and 'fg_mask' in frame: |
|
fg_mask = frame['fg_mask'].cuda() |
|
fg_mask = fg_mask * mask |
|
|
|
|
|
path2data.update({ |
|
str(frame['fg_mask_path']): fg_mask.byte().cpu().numpy(), |
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
write_data(path2data) |
|
worker_args.append([path2data]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(worker_args) == max_threads or i == len(self.dataloader)-1: |
|
with concurrent.futures.ThreadPoolExecutor(max_threads) as executor: |
|
futures = [executor.submit(write_data, *args) for args in worker_args] |
|
concurrent.futures.wait(futures) |
|
worker_args = [] |
|
|
|
def get_mask(self, verts, RT, K, h, w): |
|
faces = self.flame_model.faces.cuda() |
|
out_dict = self.mesh_renderer.render_without_texture(verts, faces, RT, K, (h, w)) |
|
|
|
rgba_mesh = out_dict['rgba'].squeeze(0) |
|
mask_mesh = rgba_mesh[..., 3] |
|
|
|
|
|
verts_clip = out_dict['verts_clip'][0] |
|
verts_ndc = verts_clip[:, :3] / verts_clip[:, -1:] |
|
xy = verts_ndc[:, :2] |
|
xy[:, 1] = -xy[:, 1] |
|
xy = (xy * 0.5 + 0.5) * torch.tensor([[h, w]]).cuda() |
|
vid_ring = self.flame_model.mask.get_vid_by_region(['neck_top']) |
|
xy_ring = xy[vid_ring] |
|
bottom_line = int(xy_ring[:, 1].min().item()) |
|
|
|
mask = mask_mesh.clone() |
|
mask[:bottom_line] = 1 |
|
|
|
|
|
k = int(0.02 * w)//2 * 2 + 1 |
|
blur = torchvision.transforms.GaussianBlur(k, sigma=k) |
|
mask = blur(mask[None])[0] |
|
return mask |
|
|
|
def get_mask_tilted_line(self, verts, RT, K, h, w): |
|
verts_ndc = self.mesh_renderer.world_to_ndc(verts, RT, K, (h, w), flip_y=True) |
|
|
|
verts_xy = verts_ndc[0, :, :2] |
|
verts_xy = (verts_xy * 0.5 + 0.5) * torch.tensor([w, h]).cuda() |
|
|
|
verts_xy_left = verts_xy[self.flame_model.mask.get_vid_by_region(['neck_right_point'])] |
|
verts_xy_right = verts_xy[self.flame_model.mask.get_vid_by_region(['neck_left_point'])] |
|
verts_xy_bottom = verts_xy[self.flame_model.mask.get_vid_by_region(['front_middle_bottom_point_boundary'])] |
|
|
|
delta_xy = verts_xy_left - verts_xy_right |
|
assert (delta_xy[:, 0] != 0).all() |
|
k = delta_xy[:, 1] / delta_xy[:, 0] |
|
b = verts_xy_bottom[:, 1] - k * verts_xy_bottom[:, 0] |
|
|
|
x = torch.arange(w).cuda() |
|
y = torch.arange(h).cuda() |
|
yx = torch.stack(torch.meshgrid(y, x, indexing='ij'), dim=-1) |
|
|
|
mask = ((k * yx[:, :, 1] + b - yx[:, :, 0]) > 0).float() |
|
|
|
|
|
k = int(0.03 * w)//2 * 2 + 1 |
|
blur = torchvision.transforms.GaussianBlur(k, sigma=k) |
|
mask = blur(mask[None])[0] |
|
return mask |
|
|
|
def infer_flame_params(flame_model: FlameHead, flame_params: Dict, indices:List): |
|
if 'static_offset' in flame_params: |
|
static_offset = flame_params['static_offset'] |
|
if isinstance(static_offset, np.ndarray): |
|
static_offset = torch.tensor(static_offset) |
|
else: |
|
static_offset = None |
|
for k in flame_params: |
|
if isinstance(flame_params[k], np.ndarray): |
|
flame_params[k] = torch.tensor(flame_params[k]) |
|
with torch.no_grad(): |
|
ret = flame_model( |
|
flame_params['shape'][None, ...].expand(len(indices), -1), |
|
flame_params['expr'][indices], |
|
flame_params['rotation'][indices], |
|
flame_params['neck_pose'][indices], |
|
flame_params['jaw_pose'][indices], |
|
flame_params['eyes_pose'][indices], |
|
flame_params['translation'][indices], |
|
return_verts_cano=False, |
|
static_offset=static_offset, |
|
) |
|
verts = ret[0] |
|
return verts |
|
|
|
|
|
|
|
def write_json(db, tgt_folder, division=None): |
|
fname = "transforms.json" if division is None else f"transforms_{division}.json" |
|
json_path = tgt_folder / fname |
|
print(f"Writing database: {json_path}") |
|
with open(json_path, "w") as f: |
|
json.dump(db, f, indent=4) |
|
|
|
def write_data(path2data): |
|
for path, data in path2data.items(): |
|
path = Path(path) |
|
if not path.parent.exists(): |
|
path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
if path.suffix in [".png", ".jpg"]: |
|
Image.fromarray(data).save(path) |
|
elif path.suffix in [".obj"]: |
|
with open(path, "w") as f: |
|
f.write(data) |
|
elif path.suffix in [".txt"]: |
|
with open(path, "w") as f: |
|
f.write(data) |
|
elif path.suffix in [".npz"]: |
|
np.savez(path, **data) |
|
else: |
|
raise NotImplementedError(f"Unknown file type: {path.suffix}") |
|
|
|
def split_json(tgt_folder: Path, train_ratio=0.7): |
|
db = json.load(open(tgt_folder / "transforms.json", "r")) |
|
|
|
|
|
db_train = {k: v for k, v in db.items() if k not in ['frames', 'timestep_indices', 'camera_indices']} |
|
db_train['frames'] = [] |
|
db_val = deepcopy(db_train) |
|
db_test = deepcopy(db_train) |
|
|
|
|
|
nt = len(db['timestep_indices']) |
|
assert 0 < train_ratio <= 1 |
|
nt_train = int(np.ceil(nt * train_ratio)) |
|
nt_test = nt - nt_train |
|
|
|
|
|
timestep_indices = sorted(db['timestep_indices']) |
|
db_train['timestep_indices'] = timestep_indices[:nt_train] |
|
db_val['timestep_indices'] = timestep_indices[:nt_train] |
|
db_test['timestep_indices'] = timestep_indices[nt_train:] |
|
|
|
if len(db['camera_indices']) > 1: |
|
|
|
if 8 in db['camera_indices']: |
|
|
|
db_train['camera_indices'] = [i for i in db['camera_indices'] if i != 8] |
|
db_val['camera_indices'] = [8] |
|
db_test['camera_indices'] = db['camera_indices'] |
|
else: |
|
|
|
db_train['camera_indices'] = db['camera_indices'][:-1] |
|
db_val['camera_indices'] = [db['camera_indices'][-1]] |
|
db_test['camera_indices'] = db['camera_indices'] |
|
else: |
|
|
|
db_train['camera_indices'] = db['camera_indices'] |
|
db_val['camera_indices'] = [] |
|
db_test['camera_indices'] = db['camera_indices'] |
|
|
|
|
|
range_train = range(db_train['timestep_indices'][0], db_train['timestep_indices'][-1]+1) if nt_train > 0 else [] |
|
range_test = range(db_test['timestep_indices'][0], db_test['timestep_indices'][-1]+1) if nt_test > 0 else [] |
|
for f in db['frames']: |
|
if f['timestep_index'] in range_train: |
|
if f['camera_index'] in db_train['camera_indices']: |
|
db_train['frames'].append(f) |
|
elif f['camera_index'] in db_val['camera_indices']: |
|
db_val['frames'].append(f) |
|
else: |
|
raise ValueError(f"Unknown camera index: {f['camera_index']}") |
|
elif f['timestep_index'] in range_test: |
|
db_test['frames'].append(f) |
|
assert f['camera_index'] in db_test['camera_indices'], f"Unknown camera index: {f['camera_index']}" |
|
else: |
|
raise ValueError(f"Unknown timestep index: {f['timestep_index']}") |
|
|
|
write_json(db_train, tgt_folder, division='train') |
|
write_json(db_val, tgt_folder, division='val') |
|
write_json(db_test, tgt_folder, division='test') |
|
|
|
def load_config(src_folder: Path): |
|
config_path = src_folder / "config.yml" |
|
if not config_path.exists(): |
|
src_folder = sorted(src_folder.iterdir())[-1] |
|
config_path = src_folder / "config.yml" |
|
assert config_path.exists(), f"File not found: {config_path}" |
|
|
|
cfg = yaml.load(config_path.read_text(), Loader=yaml.Loader) |
|
|
|
return src_folder, cfg |
|
|
|
def check_epoch(src_folder: Path, epoch: int): |
|
paths = [Path(p) for p in glob(str(src_folder / "tracked_flame_params*.npz"))] |
|
epochs = [int(p.stem.split('_')[-1]) for p in paths] |
|
if epoch == -1: |
|
index = np.argmax(epochs) |
|
else: |
|
try: |
|
index = epochs.index(epoch) |
|
except ValueError: |
|
raise ValueError(f"Could not find epoch {epoch} in {src_folder}") |
|
|
|
def main( |
|
src_folder: Path, |
|
tgt_folder: Path, |
|
subset: Optional[str]=None, |
|
scale_factor: Optional[float]=None, |
|
background_color: Optional[str]=None, |
|
flame_mode: Literal['mesh', 'param']='param', |
|
create_mask_from_mesh: bool=False, |
|
epoch: int=-1, |
|
): |
|
print(f"Begin exportation from {src_folder}") |
|
assert src_folder.exists(), f"Folder not found: {src_folder}" |
|
src_folder, cfg = load_config(src_folder) |
|
|
|
check_epoch(src_folder, epoch) |
|
|
|
if epoch != -1: |
|
tgt_folder = Path(str(tgt_folder) + f"_epoch{epoch}") |
|
|
|
nerf_dataset_writer = NeRFDatasetWriter(cfg.data, tgt_folder, subset, scale_factor, background_color) |
|
nerf_dataset_writer.write() |
|
|
|
flame_dataset_writer = TrackedFLAMEDatasetWriter(cfg.model, src_folder, tgt_folder, mode=flame_mode, epoch=epoch) |
|
flame_dataset_writer.write() |
|
|
|
if create_mask_from_mesh: |
|
mask_generator = MaskFromFLAME(cfg.model, tgt_folder, background_color) |
|
mask_generator.write() |
|
|
|
split_json(tgt_folder) |
|
|
|
print("Finshed!") |
|
|
|
|
|
if __name__ == "__main__": |
|
tyro.cli(main) |