Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,837 Bytes
8ed2f16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch
"""
import numpy as np
import torch
from .base_model import BaseModel
from . import networks
from .cropping import align_img
from .bfm import ParametricFaceModel
from .util.pytorch3d import MeshRenderer
import trimesh
class FaceReconModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
""" Configures options specific for CUT model
"""
# net structure and parameters
parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure')
parser.add_argument('--init_path', type=str, default='checkpoints/init_model/resnet50-0676ba61.pth')
opt, _ = parser.parse_known_args()
parser.set_defaults(
focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15.
)
return parser
def __init__(self, opt):
"""Initialize this model class.
Parameters:
opt -- training/test options
A few things can be done here.
- (required) call the initialization function of BaseModel
- define loss function, visualization images, model names, and optimizers
"""
BaseModel.__init__(self, opt) # call the initialization method of BaseModel
self.model_names = ['net_recon']
self.parallel_names = self.model_names + ['renderer']
self.net_recon = networks.define_net_recon(
net_recon=opt.net_recon, use_last_fc=opt.use_last_fc, init_path=opt.init_path
)
self.facemodel = ParametricFaceModel(
bfm_folder=opt.bfm_folder, is_train=self.isTrain)
fov = 2 * np.arctan(112 / 1015) * 180 / np.pi
self.renderer = MeshRenderer(
rasterize_fov=fov, znear=0.1, zfar=50, rasterize_size=int(2 * 112)
)
# Our program will automatically call <model.setup> to define schedulers, load networks, and print networks
def preproces_img(self, im, lm, to_tensor=True):
# to RGB
stand_index = np.array([96, 97, 54, 76, 82])
W,H = im.size
lm = lm.reshape([-1, 2])
lm = lm[stand_index,:]
lm[:, -1] = H - 1 - lm[:, -1]
trans_params, im, lm, _ = align_img(im, lm)
if to_tensor:
im = torch.tensor(np.array(im)/255., dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
lm = torch.tensor(lm).unsqueeze(0)
return im, lm, trans_params
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input: a dictionary that contains the data itself and its metadata information.
"""
input_img = input['imgs']
lmdks = input['lms']
align_img, align_lmdks, trans_params = self.preproces_img(input_img, lmdks)
align_img = align_img.to(self.device)
return align_img, trans_params
def split_coeff(self, coeffs):
"""
Return:
coeffs_dict -- a dict of torch.tensors
Parameters:
coeffs -- torch.tensor, size (B, 256)
"""
id_coeffs = coeffs[:, :80]
exp_coeffs = coeffs[:, 80: 144]
tex_coeffs = coeffs[:, 144: 224]
angles = coeffs[:, 224: 227]
gammas = coeffs[:, 227: 254]
translations = coeffs[:, 254:]
return {
'id': id_coeffs,
'exp': exp_coeffs,
'tex': tex_coeffs,
'angle': angles,
'gamma': gammas,
'trans': translations
}
def optimize_parameters(self):
return None
def forward(self, input):
self.input_img, trans_params = self.set_input(input)
output_coeff = self.net_recon(self.input_img)
pred_coeffs_dict = self.split_coeff(output_coeff)
pred_coeffs = {key:pred_coeffs_dict[key].cpu().numpy() for key in pred_coeffs_dict}
self.facemodel.to(self.device)
self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \
self.facemodel.compute_for_render(output_coeff)
self.pred_mask, _, self.pred_face = self.renderer(
self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color)
return pred_coeffs, trans_params
def save_mesh(self, name):
recon_shape = self.pred_vertex # get reconstructed shape
recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space
recon_shape = recon_shape.cpu().numpy()[0]
recon_color = self.pred_color
recon_color = recon_color.cpu().numpy()[0]
tri = self.facemodel.face_buf.cpu().numpy()
mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8))
# mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri)
mesh.export(name)
def compute_visuals(self):
with torch.no_grad():
input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy()
output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img
output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy()
output_vis_numpy = np.concatenate((input_img_numpy,
output_vis_numpy_raw), axis=-2)
output_vis_numpy = np.clip(output_vis_numpy, 0, 255)
# self.output_vis = torch.tensor(
# output_vis_numpy / 255., dtype=torch.float32
# ).permute(0, 3, 1, 2).to(self.device)
self.output_vis = output_vis_numpy
|