Spaces:
Running
on
Zero
Running
on
Zero
# | |
# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual | |
# property and proprietary rights in and to this software and related documentation. | |
# Any commercial use, reproduction, disclosure or distribution of this software and | |
# related documentation without an express license agreement from Toyota Motor Europe NV/SA | |
# is strictly prohibited. | |
# | |
import math | |
from typing import Optional, Literal, Dict, List | |
from glob import glob | |
import concurrent.futures | |
import multiprocessing | |
from copy import deepcopy | |
import yaml | |
import json | |
import tyro | |
from pathlib import Path | |
from tqdm import tqdm | |
from PIL import Image | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader | |
import torchvision | |
# from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle | |
from vhap.config.base import DataConfig, ModelConfig, import_module | |
from vhap.data.nerf_dataset import NeRFDataset | |
from vhap.model.flame import FlameHead | |
from vhap.util.mesh import get_obj_content | |
from vhap.util.render_nvdiffrast import NVDiffRenderer | |
# to prevent "OSError: [Errno 24] Too many open files" | |
import torch.multiprocessing | |
torch.multiprocessing.set_sharing_strategy('file_system') | |
max_threads = min(multiprocessing.cpu_count(), 8) | |
class NeRFDatasetWriter: | |
def __init__(self, cfg_data: DataConfig, tgt_folder: Path, subset:Optional[str]=None, scale_factor: Optional[float]=None, background_color: Optional[str]=None): | |
self.cfg_data = cfg_data | |
self.tgt_folder = tgt_folder | |
print("==== Config: data ====") | |
print(tyro.to_yaml(cfg_data)) | |
cfg_data.target_extrinsic_type = 'c2w' | |
cfg_data.background_color = 'white' | |
cfg_data.use_alpha_map = True | |
dataset = import_module(cfg_data._target)(cfg=cfg_data) | |
self.dataloader = DataLoader(dataset, shuffle=False, batch_size=None, collate_fn=lambda x: x, num_workers=0) | |
def write(self): | |
if not self.tgt_folder.exists(): | |
self.tgt_folder.mkdir(parents=True) | |
db = { | |
"frames": [], | |
} | |
print(f"Writing images to {self.tgt_folder}") | |
worker_args = [] | |
timestep_indices = set() | |
camera_indices = set() | |
for i, item in tqdm(enumerate(self.dataloader), total=len(self.dataloader)): | |
# print(item.keys()) | |
timestep_indices.add(item['timestep_index']) | |
camera_indices.add(item['camera_index']) | |
extrinsic = item['extrinsic'] | |
transform_matrix = torch.cat([extrinsic, torch.tensor([[0,0,0,1]])], dim=0).numpy() | |
intrinsic = item['intrinsic'].double().numpy() | |
cx = intrinsic[0, 2] | |
cy = intrinsic[1, 2] | |
fl_x = intrinsic[0, 0] | |
fl_y = intrinsic[1, 1] | |
h = item['rgb'].shape[0] | |
w = item['rgb'].shape[1] | |
angle_x = math.atan(w / (fl_x * 2)) * 2 | |
angle_y = math.atan(h / (fl_y * 2)) * 2 | |
frame_item = { | |
"timestep_index": item['timestep_index'], | |
"timestep_index_original": item['timestep_index_original'], | |
"timestep_id": item['timestep_id'], | |
"camera_index": item['camera_index'], | |
"camera_id": item['camera_id'], | |
"cx": cx, | |
"cy": cy, | |
"fl_x": fl_x, | |
"fl_y": fl_y, | |
"h": h, | |
"w": w, | |
"camera_angle_x": angle_x, | |
"camera_angle_y": angle_y, | |
"transform_matrix": transform_matrix.tolist(), | |
"file_path": f"images/{item['timestep_index']:05d}_{item['camera_index']:02d}.png", | |
} | |
path2data = { | |
str(self.tgt_folder / frame_item['file_path']): item['rgb'], | |
} | |
if 'alpha_map' in item: | |
frame_item['fg_mask_path'] = f"fg_masks/{item['timestep_index']:05d}_{item['camera_index']:02d}.png" | |
path2data[str(self.tgt_folder / frame_item['fg_mask_path'])] = item['alpha_map'] | |
db['frames'].append(frame_item) | |
worker_args.append([path2data]) | |
#--- no threading | |
# if len(worker_args) > 0: | |
# write_data(path2data) | |
#--- threading | |
if len(worker_args) == max_threads or i == len(self.dataloader)-1: | |
with concurrent.futures.ThreadPoolExecutor(max_threads) as executor: | |
futures = [executor.submit(write_data, *args) for args in worker_args] | |
concurrent.futures.wait(futures) | |
worker_args = [] | |
# add shared intrinsic parameters to be compatible with other nerf libraries | |
db.update({ | |
"cx": cx, | |
"cy": cy, | |
"fl_x": fl_x, | |
"fl_y": fl_y, | |
"h": h, | |
"w": w, | |
"camera_angle_x": angle_x, | |
"camera_angle_y": angle_y | |
}) | |
# add indices to ease filtering | |
db['timestep_indices'] = sorted(list(timestep_indices)) | |
db['camera_indices'] = sorted(list(camera_indices)) | |
write_json(db, self.tgt_folder) | |
write_json(db, self.tgt_folder, division='backup') | |
class TrackedFLAMEDatasetWriter: | |
def __init__(self, cfg_model: ModelConfig, src_folder: Path, tgt_folder: Path, mode: Literal['mesh', 'param'], epoch: int = -1): | |
print("---- Config: model ----") | |
print(tyro.to_yaml(cfg_model)) | |
self.cfg_model = cfg_model | |
self.src_folder = src_folder | |
self.tgt_folder = tgt_folder | |
self.mode = mode | |
db_backup_path = tgt_folder / "transforms_backup.json" | |
assert db_backup_path.exists(), f"Could not find {db_backup_path}" | |
print(f"Loading database from: {db_backup_path}") | |
self.db = json.load(open(db_backup_path, "r")) | |
paths = [Path(p) for p in glob(str(src_folder / "tracked_flame_params*.npz"))] | |
epochs = [int(p.stem.split('_')[-1]) for p in paths] | |
if epoch == -1: | |
index = np.argmax(epochs) | |
else: | |
index = epochs.index(epoch) | |
flame_params_path = paths[index] | |
assert flame_params_path.exists(), f"Could not find {flame_params_path}" | |
print(f"Loading FLAME parameters from: {flame_params_path}") | |
self.flame_params = dict(np.load(flame_params_path)) | |
if "focal_length" in self.flame_params: | |
self.focal_length = self.flame_params['focal_length'].item() | |
else: | |
self.focal_length = None | |
# Relocate FLAME to the origin and return the transformation matrix to modify camera poses. | |
self.M = self.relocate_flame_meshes(self.flame_params) | |
print("Initializing FLAME model...") | |
self.flame_model = FlameHead(cfg_model.n_shape, cfg_model.n_expr, add_teeth=True) | |
def relocate_flame_meshes(self, flame_param): | |
""" Relocate FLAME to the origin and return the transformation matrix to modify camera poses. """ | |
# Rs = torch.tensor(flame_param['rotation']) | |
Ts = torch.tensor(flame_param['translation']) | |
# R_mean = axis_angle_to_matrix(Rs.mean(0)) | |
T_mean = Ts.mean(0) | |
M = torch.eye(4) | |
# M[:3, :3] = R_mean.transpose(-1, -2) | |
M[:3, 3] = -T_mean | |
# flame_param['rotation'] = (matrix_to_axis_angle(M[None, :3, :3] @ axis_angle_to_matrix(Rs))).numpy() | |
flame_param['translation'] = (M[:3, 3] + Ts).numpy() | |
return M.numpy() | |
def replace_cam_params(self, item): | |
c2w = np.eye(4) | |
c2w[2, 3] = 1 # place the camera at (0, 0, 1) in the world coordinate by default | |
item['transform_matrix'] = c2w | |
h = item['h'] | |
w = item['w'] | |
fl_x = self.focal_length * max(h, w) | |
fl_y = self.focal_length * max(h, w) | |
angle_x = math.atan(w / (fl_x * 2)) * 2 | |
angle_y = math.atan(h / (fl_y * 2)) * 2 | |
item.update({ | |
"cx": w / 2, | |
"cy": h / 2, | |
"fl_x": fl_x, | |
"fl_y": fl_y, | |
"camera_angle_x": angle_x, | |
"camera_angle_y": angle_y, | |
"transform_matrix": c2w.tolist(), | |
}) | |
def write(self): | |
if self.mode == 'mesh': | |
self.write_canonical_mesh() | |
indices = self.db['timestep_indices'] | |
verts = infer_flame_params(self.flame_model, self.flame_params, indices) | |
print(f"Writing FLAME expressions and meshes to: {self.tgt_folder}") | |
elif self.mode == 'param': | |
self.write_canonical_flame_param() | |
print(f"Writing FLAME parameters to: {self.tgt_folder}") | |
saved = [False] * len(self.db['timestep_indices']) # avoid writing the same mesh multiple times | |
num_processes = 0 | |
worker_args = [] | |
for i, frame in tqdm(enumerate(self.db['frames']), total=len(self.db['frames'])): | |
if self.focal_length is not None: | |
self.replace_cam_params(frame) | |
# modify the camera extrinsics to place the tracked FLAME at the origin | |
frame['transform_matrix'] = (self.M @ np.array(frame['transform_matrix'])).tolist() | |
ti_orig = frame['timestep_index_original'] # use ti_orig when loading FLAME parameters | |
ti = frame['timestep_index'] # use ti when saving files | |
# write FLAME mesh or parameters | |
if self.mode == 'mesh': | |
frame['exp_path'] = f"flame/exp/{ti:05d}.txt" | |
frame['mesh_path'] = f"meshes/{ti:05d}.obj" | |
if not saved[ti]: | |
worker_args.append([self.tgt_folder, frame['exp_path'], self.flame_params['expr'][ti_orig], frame['mesh_path'], verts[ti_orig], self.flame_model.faces]) | |
saved[ti] = True | |
func = self.write_expr_and_mesh | |
elif self.mode == 'param': | |
frame['flame_param_path'] = f"flame_param/{ti:05d}.npz" | |
if not saved[ti]: | |
worker_args.append([self.tgt_folder, frame['flame_param_path'], self.flame_params, ti_orig]) | |
saved[ti] = True | |
func = self.write_flame_param | |
#--- no multiprocessing | |
if len(worker_args) > 0: | |
func(*worker_args.pop()) | |
#--- multiprocessing | |
# if len(worker_args) == num_processes or i == len(self.db['frames'])-1: | |
# pool = multiprocessing.Pool(processes=num_processes) | |
# pool.starmap(func, worker_args) | |
# pool.close() | |
# pool.join() | |
# worker_args = [] | |
write_json(self.db, self.tgt_folder) | |
write_json(self.db, self.tgt_folder, division='backup_flame') | |
def write_canonical_mesh(self): | |
print(f"Inferencing FLAME in the canonical space...") | |
if 'static_offset' in self.flame_params: | |
static_offset = torch.tensor(self.flame_params['static_offset']) | |
else: | |
static_offset = None | |
with torch.no_grad(): | |
ret = self.flame_model( | |
torch.tensor(self.flame_params['shape'])[None, ...], | |
torch.zeros(*self.flame_params['expr'][:1].shape), | |
torch.zeros(*self.flame_params['rotation'][:1].shape), | |
torch.zeros(*self.flame_params['neck_pose'][:1].shape), | |
torch.tensor([[0.3, 0, 0]]), | |
torch.zeros(*self.flame_params['eyes_pose'][:1].shape), | |
torch.zeros(*self.flame_params['translation'][:1].shape), | |
return_verts_cano=False, | |
static_offset=static_offset, | |
) | |
verts = ret[0] | |
cano_mesh_path = self.tgt_folder / 'canonical.obj' | |
print(f"Writing canonical mesh to: {cano_mesh_path}") | |
obj_data = get_obj_content(verts[0], self.flame_model.faces) | |
write_data({cano_mesh_path: obj_data}) | |
def write_expr_and_mesh(tgt_folder, exp_path, expr, mesh_path, verts, faces): | |
path2data = {} | |
expr_data = '\n'.join([str(n) for n in expr]) | |
path2data[tgt_folder / exp_path] = expr_data | |
obj_data = get_obj_content(verts, faces) | |
path2data[tgt_folder / mesh_path] = obj_data | |
write_data(path2data) | |
def write_canonical_flame_param(self): | |
flame_param = { | |
'translation': np.zeros_like(self.flame_params['translation'][:1]), | |
'rotation': np.zeros_like(self.flame_params['rotation'][:1]), | |
'neck_pose': np.zeros_like(self.flame_params['neck_pose'][:1]), | |
'jaw_pose': np.array([[0.3, 0, 0]]), # open mouth | |
'eyes_pose': np.zeros_like(self.flame_params['eyes_pose'][:1]), | |
'shape': self.flame_params['shape'], | |
'expr': np.zeros_like(self.flame_params['expr'][:1]), | |
} | |
if 'static_offset' in self.flame_params: | |
flame_param['static_offset'] = self.flame_params['static_offset'] | |
cano_flame_param_path = self.tgt_folder / 'canonical_flame_param.npz' | |
print(f"Writing canonical FLAME parameters to: {cano_flame_param_path}") | |
write_data({cano_flame_param_path: flame_param}) | |
def write_flame_param(tgt_folder, flame_param_path, flame_params, tid): | |
params = { | |
'translation': flame_params['translation'][[tid]], | |
'rotation': flame_params['rotation'][[tid]], | |
'neck_pose': flame_params['neck_pose'][[tid]], | |
'jaw_pose': flame_params['jaw_pose'][[tid]], | |
'eyes_pose': flame_params['eyes_pose'][[tid]], | |
'shape': flame_params['shape'], | |
'expr': flame_params['expr'][[tid]], | |
} | |
if 'static_offset' in flame_params: | |
params['static_offset'] = flame_params['static_offset'] | |
if 'dynamic_offset' in flame_params: | |
params['dynamic_offset'] = flame_params['dynamic_offset'][[tid]] | |
path2data = {tgt_folder / flame_param_path: params} | |
write_data(path2data) | |
class MaskFromFLAME: | |
def __init__(self, cfg_model: ModelConfig, tgt_folder, background_color: str) -> None: | |
background_color = self.cfg_data.background_color if background_color is None else background_color | |
if background_color == 'white': | |
self.background_tensor = torch.tensor([255, 255, 255]).byte() | |
elif background_color == 'black': | |
self.background_tensor = torch.tensor([0, 0, 0]).byte() | |
else: | |
raise ValueError(f"Unknown background color: {background_color}") | |
dataset = NeRFDataset( | |
root_folder=tgt_folder, | |
division=None, | |
camera_convention_conversion=None, | |
target_extrinsic_type='w2c', | |
use_fg_mask=True, | |
use_flame_param=True, | |
) | |
self.dataloader = DataLoader(dataset, shuffle=False, batch_size=None, collate_fn=None, num_workers=0) | |
self.flame_model = FlameHead(cfg_model.n_shape, cfg_model.n_expr, add_teeth=True) | |
self.mesh_renderer = NVDiffRenderer(use_opengl=False) | |
def write(self): | |
t2verts = {} | |
worker_args = [] | |
print(f"Generating masks from FLAME...") | |
for i, frame in enumerate(tqdm(self.dataloader)): | |
# get FLAME vertices | |
timestep = frame['timestep_index'] | |
if timestep not in t2verts: | |
t2verts[timestep] = infer_flame_params(self.flame_model, frame['flame_param'], [0]).cuda() | |
verts = t2verts[timestep] | |
# render to get forground mask | |
RT = frame['extrinsics'].cuda()[None] | |
K = frame['intrinsics'].cuda()[None] | |
h = frame['image_height'] | |
w = frame['image_width'] | |
# mask = self.get_mask(verts, RT, K, h, w) | |
mask = self.get_mask_tilted_line(verts, RT, K, h, w) | |
# edit the image and mask with dilated FLAME mask | |
img = frame['image'].cuda() | |
img = img * mask[:, :, None] + self.background_tensor.cuda()[None, None, :] * (1-mask)[:, :, None] | |
# overwrite the original images | |
path2data = { | |
str(frame['image_path']): img.byte().cpu().numpy(), | |
} | |
if 'fg_mask_path' in frame and 'fg_mask' in frame: | |
fg_mask = frame['fg_mask'].cuda() | |
fg_mask = fg_mask * mask | |
# overwrite the original masks | |
path2data.update({ | |
str(frame['fg_mask_path']): fg_mask.byte().cpu().numpy(), | |
}) | |
# # write to new folder | |
# path2data.update({ | |
# str(frame['fg_mask_path']).replace('fg_masks', 'fg_masks_'): fg_mask.byte().cpu().numpy(), | |
# }) | |
write_data(path2data) | |
worker_args.append([path2data]) | |
#--- no threading | |
# if len(worker_args) > 0: | |
# write_data(path2data) | |
#--- threading | |
if len(worker_args) == max_threads or i == len(self.dataloader)-1: | |
with concurrent.futures.ThreadPoolExecutor(max_threads) as executor: | |
futures = [executor.submit(write_data, *args) for args in worker_args] | |
concurrent.futures.wait(futures) | |
worker_args = [] | |
def get_mask(self, verts, RT, K, h, w): | |
faces = self.flame_model.faces.cuda() | |
out_dict = self.mesh_renderer.render_without_texture(verts, faces, RT, K, (h, w)) | |
rgba_mesh = out_dict['rgba'].squeeze(0) # (H, W, C) | |
mask_mesh = rgba_mesh[..., 3] # (H, W) | |
# get the bottom line of the neck and disable mask for the upper part | |
verts_clip = out_dict['verts_clip'][0] | |
verts_ndc = verts_clip[:, :3] / verts_clip[:, -1:] | |
xy = verts_ndc[:, :2] | |
xy[:, 1] = -xy[:, 1] | |
xy = (xy * 0.5 + 0.5) * torch.tensor([[h, w]]).cuda() | |
vid_ring = self.flame_model.mask.get_vid_by_region(['neck_top']) | |
xy_ring = xy[vid_ring] | |
bottom_line = int(xy_ring[:, 1].min().item()) | |
mask = mask_mesh.clone() | |
mask[:bottom_line] = 1 | |
# anti-aliasing with gaussian kernel | |
k = int(0.02 * w)//2 * 2 + 1 | |
blur = torchvision.transforms.GaussianBlur(k, sigma=k) | |
mask = blur(mask[None])[0] #.clamp(0, 1) | |
return mask | |
def get_mask_tilted_line(self, verts, RT, K, h, w): | |
verts_ndc = self.mesh_renderer.world_to_ndc(verts, RT, K, (h, w), flip_y=True) | |
verts_xy = verts_ndc[0, :, :2] | |
verts_xy = (verts_xy * 0.5 + 0.5) * torch.tensor([w, h]).cuda() | |
verts_xy_left = verts_xy[self.flame_model.mask.get_vid_by_region(['neck_right_point'])] | |
verts_xy_right = verts_xy[self.flame_model.mask.get_vid_by_region(['neck_left_point'])] | |
verts_xy_bottom = verts_xy[self.flame_model.mask.get_vid_by_region(['front_middle_bottom_point_boundary'])] | |
delta_xy = verts_xy_left - verts_xy_right | |
assert (delta_xy[:, 0] != 0).all() | |
k = delta_xy[:, 1] / delta_xy[:, 0] | |
b = verts_xy_bottom[:, 1] - k * verts_xy_bottom[:, 0] | |
x = torch.arange(w).cuda() | |
y = torch.arange(h).cuda() | |
yx = torch.stack(torch.meshgrid(y, x, indexing='ij'), dim=-1) | |
mask = ((k * yx[:, :, 1] + b - yx[:, :, 0]) > 0).float() | |
# anti-aliasing with gaussian kernel | |
k = int(0.03 * w)//2 * 2 + 1 | |
blur = torchvision.transforms.GaussianBlur(k, sigma=k) | |
mask = blur(mask[None])[0] #.clamp(0, 1) | |
return mask | |
def infer_flame_params(flame_model: FlameHead, flame_params: Dict, indices:List): | |
if 'static_offset' in flame_params: | |
static_offset = flame_params['static_offset'] | |
if isinstance(static_offset, np.ndarray): | |
static_offset = torch.tensor(static_offset) | |
else: | |
static_offset = None | |
for k in flame_params: | |
if isinstance(flame_params[k], np.ndarray): | |
flame_params[k] = torch.tensor(flame_params[k]) | |
with torch.no_grad(): | |
ret = flame_model( | |
flame_params['shape'][None, ...].expand(len(indices), -1), | |
flame_params['expr'][indices], | |
flame_params['rotation'][indices], | |
flame_params['neck_pose'][indices], | |
flame_params['jaw_pose'][indices], | |
flame_params['eyes_pose'][indices], | |
flame_params['translation'][indices], | |
return_verts_cano=False, | |
static_offset=static_offset, | |
) | |
verts = ret[0] | |
return verts | |
def write_json(db, tgt_folder, division=None): | |
fname = "transforms.json" if division is None else f"transforms_{division}.json" | |
json_path = tgt_folder / fname | |
print(f"Writing database: {json_path}") | |
with open(json_path, "w") as f: | |
json.dump(db, f, indent=4) | |
def write_data(path2data): | |
for path, data in path2data.items(): | |
path = Path(path) | |
if not path.parent.exists(): | |
path.parent.mkdir(parents=True, exist_ok=True) | |
if path.suffix in [".png", ".jpg"]: | |
Image.fromarray(data).save(path) | |
elif path.suffix in [".obj"]: | |
with open(path, "w") as f: | |
f.write(data) | |
elif path.suffix in [".txt"]: | |
with open(path, "w") as f: | |
f.write(data) | |
elif path.suffix in [".npz"]: | |
np.savez(path, **data) | |
else: | |
raise NotImplementedError(f"Unknown file type: {path.suffix}") | |
def split_json(tgt_folder: Path, train_ratio=0.7): | |
db = json.load(open(tgt_folder / "transforms.json", "r")) | |
# init db for each division | |
db_train = {k: v for k, v in db.items() if k not in ['frames', 'timestep_indices', 'camera_indices']} | |
db_train['frames'] = [] | |
db_val = deepcopy(db_train) | |
db_test = deepcopy(db_train) | |
# divide timesteps | |
nt = len(db['timestep_indices']) | |
assert 0 < train_ratio <= 1 | |
nt_train = int(np.ceil(nt * train_ratio)) | |
nt_test = nt - nt_train | |
# record number of timesteps | |
timestep_indices = sorted(db['timestep_indices']) | |
db_train['timestep_indices'] = timestep_indices[:nt_train] | |
db_val['timestep_indices'] = timestep_indices[:nt_train] # validation set share the same timesteps with training set | |
db_test['timestep_indices'] = timestep_indices[nt_train:] | |
if len(db['camera_indices']) > 1: | |
# when having multiple cameras, leave one camera for validation (novel-view sythesis) | |
if 8 in db['camera_indices']: | |
# use camera 8 for validation (front-view of the NeRSemble dataset) | |
db_train['camera_indices'] = [i for i in db['camera_indices'] if i != 8] | |
db_val['camera_indices'] = [8] | |
db_test['camera_indices'] = db['camera_indices'] | |
else: | |
# use the last camera for validation | |
db_train['camera_indices'] = db['camera_indices'][:-1] | |
db_val['camera_indices'] = [db['camera_indices'][-1]] | |
db_test['camera_indices'] = db['camera_indices'] | |
else: | |
# when only having one camera, we create an empty validation set | |
db_train['camera_indices'] = db['camera_indices'] | |
db_val['camera_indices'] = [] | |
db_test['camera_indices'] = db['camera_indices'] | |
# fill data by timestep index | |
range_train = range(db_train['timestep_indices'][0], db_train['timestep_indices'][-1]+1) if nt_train > 0 else [] | |
range_test = range(db_test['timestep_indices'][0], db_test['timestep_indices'][-1]+1) if nt_test > 0 else [] | |
for f in db['frames']: | |
if f['timestep_index'] in range_train: | |
if f['camera_index'] in db_train['camera_indices']: | |
db_train['frames'].append(f) | |
elif f['camera_index'] in db_val['camera_indices']: | |
db_val['frames'].append(f) | |
else: | |
raise ValueError(f"Unknown camera index: {f['camera_index']}") | |
elif f['timestep_index'] in range_test: | |
db_test['frames'].append(f) | |
assert f['camera_index'] in db_test['camera_indices'], f"Unknown camera index: {f['camera_index']}" | |
else: | |
raise ValueError(f"Unknown timestep index: {f['timestep_index']}") | |
write_json(db_train, tgt_folder, division='train') | |
write_json(db_val, tgt_folder, division='val') | |
write_json(db_test, tgt_folder, division='test') | |
def load_config(src_folder: Path): | |
config_path = src_folder / "config.yml" | |
if not config_path.exists(): | |
src_folder = sorted(src_folder.iterdir())[-1] | |
config_path = src_folder / "config.yml" | |
assert config_path.exists(), f"File not found: {config_path}" | |
cfg = yaml.load(config_path.read_text(), Loader=yaml.Loader) | |
# assert isinstance(cfg, BaseTrackingConfig) | |
return src_folder, cfg | |
def check_epoch(src_folder: Path, epoch: int): | |
paths = [Path(p) for p in glob(str(src_folder / "tracked_flame_params*.npz"))] | |
epochs = [int(p.stem.split('_')[-1]) for p in paths] | |
if epoch == -1: | |
index = np.argmax(epochs) | |
else: | |
try: | |
index = epochs.index(epoch) | |
except ValueError: | |
raise ValueError(f"Could not find epoch {epoch} in {src_folder}") | |
def main( | |
src_folder: Path, | |
tgt_folder: Path, | |
subset: Optional[str]=None, | |
scale_factor: Optional[float]=None, | |
background_color: Optional[str]=None, | |
flame_mode: Literal['mesh', 'param']='param', | |
create_mask_from_mesh: bool=False, | |
epoch: int=-1, | |
): | |
print(f"Begin exportation from {src_folder}") | |
assert src_folder.exists(), f"Folder not found: {src_folder}" | |
src_folder, cfg = load_config(src_folder) | |
check_epoch(src_folder, epoch) | |
if epoch != -1: | |
tgt_folder = Path(str(tgt_folder) + f"_epoch{epoch}") | |
nerf_dataset_writer = NeRFDatasetWriter(cfg.data, tgt_folder, subset, scale_factor, background_color) | |
nerf_dataset_writer.write() | |
flame_dataset_writer = TrackedFLAMEDatasetWriter(cfg.model, src_folder, tgt_folder, mode=flame_mode, epoch=epoch) | |
flame_dataset_writer.write() | |
if create_mask_from_mesh: | |
mask_generator = MaskFromFLAME(cfg.model, tgt_folder, background_color) | |
mask_generator.write() | |
split_json(tgt_folder) | |
print("Finshed!") | |
if __name__ == "__main__": | |
tyro.cli(main) |