|  | import numpy as np | 
					
						
						|  | import cv2, os, sys,torch | 
					
						
						|  | from tqdm import tqdm | 
					
						
						|  | from PIL import Image | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from src.face3d.util.preprocess import align_img | 
					
						
						|  | from src.face3d.util.load_mats import load_lm3d | 
					
						
						|  | from src.face3d.models import networks | 
					
						
						|  | from src.face3d.extract_kp_videos import KeypointExtractor | 
					
						
						|  |  | 
					
						
						|  | from scipy.io import loadmat, savemat | 
					
						
						|  | from src.utils.croper import Croper | 
					
						
						|  |  | 
					
						
						|  | import warnings | 
					
						
						|  | warnings.filterwarnings("ignore") | 
					
						
						|  |  | 
					
						
						|  | def split_coeff(coeffs): | 
					
						
						|  | """ | 
					
						
						|  | Return: | 
					
						
						|  | coeffs_dict     -- a dict of torch.tensors | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | coeffs          -- torch.tensor, size (B, 256) | 
					
						
						|  | """ | 
					
						
						|  | id_coeffs = coeffs[:, :80] | 
					
						
						|  | exp_coeffs = coeffs[:, 80: 144] | 
					
						
						|  | tex_coeffs = coeffs[:, 144: 224] | 
					
						
						|  | angles = coeffs[:, 224: 227] | 
					
						
						|  | gammas = coeffs[:, 227: 254] | 
					
						
						|  | translations = coeffs[:, 254:] | 
					
						
						|  | return { | 
					
						
						|  | 'id': id_coeffs, | 
					
						
						|  | 'exp': exp_coeffs, | 
					
						
						|  | 'tex': tex_coeffs, | 
					
						
						|  | 'angle': angles, | 
					
						
						|  | 'gamma': gammas, | 
					
						
						|  | 'trans': translations | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CropAndExtract(): | 
					
						
						|  | def __init__(self, path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device): | 
					
						
						|  |  | 
					
						
						|  | self.croper = Croper(path_of_lm_croper) | 
					
						
						|  | self.kp_extractor = KeypointExtractor(device) | 
					
						
						|  | self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device) | 
					
						
						|  | checkpoint = torch.load(path_of_net_recon_model, map_location=torch.device(device)) | 
					
						
						|  | self.net_recon.load_state_dict(checkpoint['net_recon']) | 
					
						
						|  | self.net_recon.eval() | 
					
						
						|  | self.lm3d_std = load_lm3d(dir_of_BFM_fitting) | 
					
						
						|  | self.device = device | 
					
						
						|  |  | 
					
						
						|  | def generate(self, input_path, save_dir): | 
					
						
						|  |  | 
					
						
						|  | pic_size = 256 | 
					
						
						|  | pic_name = os.path.splitext(os.path.split(input_path)[-1])[0] | 
					
						
						|  |  | 
					
						
						|  | landmarks_path =  os.path.join(save_dir, pic_name+'_landmarks.txt') | 
					
						
						|  | coeff_path =  os.path.join(save_dir, pic_name+'.mat') | 
					
						
						|  | png_path =  os.path.join(save_dir, pic_name+'.png') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not os.path.isfile(input_path): | 
					
						
						|  | raise ValueError('input_path must be a valid path to video/image file') | 
					
						
						|  | elif input_path.split('.')[1] in ['jpg', 'png', 'jpeg']: | 
					
						
						|  |  | 
					
						
						|  | full_frames = [cv2.imread(input_path)] | 
					
						
						|  | fps = 25 | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | video_stream = cv2.VideoCapture(input_path) | 
					
						
						|  | fps = video_stream.get(cv2.CAP_PROP_FPS) | 
					
						
						|  | full_frames = [] | 
					
						
						|  | while 1: | 
					
						
						|  | still_reading, frame = video_stream.read() | 
					
						
						|  | if not still_reading: | 
					
						
						|  | video_stream.release() | 
					
						
						|  | break | 
					
						
						|  | full_frames.append(frame) | 
					
						
						|  | break | 
					
						
						|  | x_full_frames = [cv2.cvtColor(full_frames[0], cv2.COLOR_BGR2RGB) ] | 
					
						
						|  |  | 
					
						
						|  | if True: | 
					
						
						|  | x_full_frames, crop, quad = self.croper.crop(x_full_frames, xsize=pic_size) | 
					
						
						|  | clx, cly, crx, cry = crop | 
					
						
						|  | lx, ly, rx, ry = quad | 
					
						
						|  | lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) | 
					
						
						|  | oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx | 
					
						
						|  | original_size = (ox2 - ox1, oy2 - oy1) | 
					
						
						|  | else: | 
					
						
						|  | oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1] | 
					
						
						|  | frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size,pic_size))) for frame in x_full_frames] | 
					
						
						|  | if len(frames_pil) == 0: | 
					
						
						|  | print('No face is detected in the input file') | 
					
						
						|  | return None, None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for frame in frames_pil: | 
					
						
						|  | cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not os.path.isfile(landmarks_path): | 
					
						
						|  | lm = self.kp_extractor.extract_keypoint(frames_pil, landmarks_path) | 
					
						
						|  | else: | 
					
						
						|  | print(' Using saved landmarks.') | 
					
						
						|  | lm = np.loadtxt(landmarks_path).astype(np.float32) | 
					
						
						|  | lm = lm.reshape([len(x_full_frames), -1, 2]) | 
					
						
						|  |  | 
					
						
						|  | if not os.path.isfile(coeff_path): | 
					
						
						|  |  | 
					
						
						|  | video_coeffs, full_coeffs = [],  [] | 
					
						
						|  | for idx in tqdm(range(len(frames_pil)), desc=' 3DMM Extraction In Video:'): | 
					
						
						|  | frame = frames_pil[idx] | 
					
						
						|  | W,H = frame.size | 
					
						
						|  | lm1 = lm[idx].reshape([-1, 2]) | 
					
						
						|  |  | 
					
						
						|  | if np.mean(lm1) == -1: | 
					
						
						|  | lm1 = (self.lm3d_std[:, :2]+1)/2. | 
					
						
						|  | lm1 = np.concatenate( | 
					
						
						|  | [lm1[:, :1]*W, lm1[:, 1:2]*H], 1 | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | lm1[:, -1] = H - 1 - lm1[:, -1] | 
					
						
						|  |  | 
					
						
						|  | trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std) | 
					
						
						|  |  | 
					
						
						|  | trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]).astype(np.float32) | 
					
						
						|  | im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0) | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | full_coeff = self.net_recon(im_t) | 
					
						
						|  | coeffs = split_coeff(full_coeff) | 
					
						
						|  |  | 
					
						
						|  | pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs} | 
					
						
						|  |  | 
					
						
						|  | pred_coeff = np.concatenate([ | 
					
						
						|  | pred_coeff['exp'], | 
					
						
						|  | pred_coeff['angle'], | 
					
						
						|  | pred_coeff['trans'], | 
					
						
						|  | trans_params[2:][None], | 
					
						
						|  | ], 1) | 
					
						
						|  | video_coeffs.append(pred_coeff) | 
					
						
						|  | full_coeffs.append(full_coeff.cpu().numpy()) | 
					
						
						|  |  | 
					
						
						|  | semantic_npy = np.array(video_coeffs)[:,0] | 
					
						
						|  |  | 
					
						
						|  | savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0]}) | 
					
						
						|  |  | 
					
						
						|  | return coeff_path, png_path |