Spaces:
Runtime error
Runtime error
| """This script defines the custom dataset for Deep3DFaceRecon_pytorch | |
| """ | |
| import os.path | |
| from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine | |
| from data.image_folder import make_dataset | |
| from PIL import Image | |
| import random | |
| import util.util as util | |
| import numpy as np | |
| import json | |
| import torch | |
| from scipy.io import loadmat, savemat | |
| import pickle | |
| from util.preprocess import align_img, estimate_norm | |
| from util.load_mats import load_lm3d | |
| def default_flist_reader(flist): | |
| """ | |
| flist format: impath label\nimpath label\n ...(same to caffe's filelist) | |
| """ | |
| imlist = [] | |
| with open(flist, 'r') as rf: | |
| for line in rf.readlines(): | |
| impath = line.strip() | |
| imlist.append(impath) | |
| return imlist | |
| def jason_flist_reader(flist): | |
| with open(flist, 'r') as fp: | |
| info = json.load(fp) | |
| return info | |
| def parse_label(label): | |
| return torch.tensor(np.array(label).astype(np.float32)) | |
| class FlistDataset(BaseDataset): | |
| """ | |
| It requires one directories to host training images '/path/to/data/train' | |
| You can train the model with the dataset flag '--dataroot /path/to/data'. | |
| """ | |
| def __init__(self, opt): | |
| """Initialize this dataset class. | |
| Parameters: | |
| opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions | |
| """ | |
| BaseDataset.__init__(self, opt) | |
| self.lm3d_std = load_lm3d(opt.bfm_folder) | |
| msk_names = default_flist_reader(opt.flist) | |
| self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] | |
| self.size = len(self.msk_paths) | |
| self.opt = opt | |
| self.name = 'train' if opt.isTrain else 'val' | |
| if '_' in opt.flist: | |
| self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0] | |
| def __getitem__(self, index): | |
| """Return a data point and its metadata information. | |
| Parameters: | |
| index (int) -- a random integer for data indexing | |
| Returns a dictionary that contains A, B, A_paths and B_paths | |
| img (tensor) -- an image in the input domain | |
| msk (tensor) -- its corresponding attention mask | |
| lm (tensor) -- its corresponding 3d landmarks | |
| im_paths (str) -- image paths | |
| aug_flag (bool) -- a flag used to tell whether its raw or augmented | |
| """ | |
| msk_path = self.msk_paths[index % self.size] # make sure index is within then range | |
| img_path = msk_path.replace('mask/', '') | |
| lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt' | |
| raw_img = Image.open(img_path).convert('RGB') | |
| raw_msk = Image.open(msk_path).convert('RGB') | |
| raw_lm = np.loadtxt(lm_path).astype(np.float32) | |
| _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) | |
| aug_flag = self.opt.use_aug and self.opt.isTrain | |
| if aug_flag: | |
| img, lm, msk = self._augmentation(img, lm, self.opt, msk) | |
| _, H = img.size | |
| M = estimate_norm(lm, H) | |
| transform = get_transform() | |
| img_tensor = transform(img) | |
| msk_tensor = transform(msk)[:1, ...] | |
| lm_tensor = parse_label(lm) | |
| M_tensor = parse_label(M) | |
| return {'imgs': img_tensor, | |
| 'lms': lm_tensor, | |
| 'msks': msk_tensor, | |
| 'M': M_tensor, | |
| 'im_paths': img_path, | |
| 'aug_flag': aug_flag, | |
| 'dataset': self.name} | |
| def _augmentation(self, img, lm, opt, msk=None): | |
| affine, affine_inv, flip = get_affine_mat(opt, img.size) | |
| img = apply_img_affine(img, affine_inv) | |
| lm = apply_lm_affine(lm, affine, flip, img.size) | |
| if msk is not None: | |
| msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) | |
| return img, lm, msk | |
| def __len__(self): | |
| """Return the total number of images in the dataset. | |
| """ | |
| return self.size | |