Spaces:
Running
on
Zero
Running
on
Zero
| # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
| # | |
| # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
| # property and proprietary rights in and to this material, related | |
| # documentation and any modifications thereto. Any use, reproduction, | |
| # disclosure or distribution of this material and related documentation | |
| # without an express license agreement from NVIDIA CORPORATION or | |
| # its affiliates is strictly prohibited. | |
| """Streaming images and labels from datasets created with dataset_tool.py.""" | |
| import io | |
| import os | |
| from unittest import skip | |
| import numpy as np | |
| import zipfile | |
| import PIL.Image | |
| import json | |
| import torch | |
| import dnnlib | |
| import cv2 | |
| try: | |
| import pyspng | |
| except ImportError: | |
| pyspng = None | |
| mouth_idx = list(range(22, 52)) | |
| class Dataset(torch.utils.data.Dataset): | |
| def __init__(self, | |
| name, # Name of the dataset. | |
| raw_shape, # Shape of the raw image data (NCHW). | |
| max_size=None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. | |
| use_labels=True, # Enable conditioning labels? False = label dimension is zero. | |
| xflip=False, # Artificially double the size of the dataset via x-flips. Applied after max_size. | |
| load_obj=True, | |
| return_name=False, | |
| random_seed=0, # Random seed to use when applying max_size. | |
| ): | |
| self._name = name | |
| self._raw_shape = list(raw_shape) | |
| self._use_labels = use_labels | |
| self._raw_labels = None | |
| self._label_shape = None | |
| self.load_obj = load_obj | |
| self.return_name = return_name | |
| # Apply max_size. | |
| self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) | |
| if (max_size is not None) and (self._raw_idx.size > max_size): | |
| np.random.RandomState(random_seed).shuffle(self._raw_idx) | |
| self._raw_idx = np.sort(self._raw_idx[:max_size]) | |
| # Apply xflip. | |
| self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) | |
| if xflip: | |
| self._raw_idx = np.tile(self._raw_idx, 2) | |
| self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) | |
| def _get_raw_labels(self): | |
| if self._raw_labels is None: | |
| self._raw_labels = self._load_raw_labels() if self._use_labels else None | |
| if self._raw_labels is None: | |
| self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) | |
| assert isinstance(self._raw_labels, np.ndarray) | |
| assert self._raw_labels.shape[0] == self._raw_shape[0] | |
| assert self._raw_labels.dtype in [np.float32, np.int64] | |
| if self._raw_labels.dtype == np.int64: | |
| assert self._raw_labels.ndim == 1 | |
| assert np.all(self._raw_labels >= 0) | |
| self._raw_labels_std = self._raw_labels.std(0) | |
| return self._raw_labels | |
| def close(self): # to be overridden by subclass | |
| pass | |
| def _load_raw_image(self, raw_idx): # to be overridden by subclass | |
| raise NotImplementedError | |
| def _load_raw_verts_ply(self, raw_idx): # to be overridden by subclass | |
| raise NotImplementedError | |
| def _load_geo(self, raw_idx): # to be overridden by subclass | |
| raise NotImplementedError | |
| def _load_raw_labels(self): # to be overridden by subclass | |
| raise NotImplementedError | |
| def __getstate__(self): | |
| return dict(self.__dict__, _raw_labels=None) | |
| def __del__(self): | |
| try: | |
| self.close() | |
| except: | |
| pass | |
| def __len__(self): | |
| return self._raw_idx.size | |
| def __getitem__(self, idx): | |
| image = self._load_raw_image(self._raw_idx[idx], resolution=self.resolution) | |
| # assert isinstance(image, np.ndarray) | |
| # assert list(image.shape) == self.image_shape | |
| # assert image.dtype == np.uint8 | |
| label_cam = self.get_label(idx) | |
| mesh_cond = self.get_vert(self._raw_idx[idx]) | |
| if self._xflip[idx]: | |
| assert image.ndim == 3 # CHW | |
| image = image[:, :, ::-1] | |
| if self._use_labels: | |
| label_1 = label_cam[0:25] | |
| label_2 = label_cam[25:] | |
| assert label_1.shape == (25,) | |
| assert label_2.shape == (25,) | |
| label_1[[1, 2, 3, 4, 8]] *= -1 | |
| label_2[[1, 2, 3, 4, 8]] *= -1 | |
| label_cam = np.concatenate([label_1, label_2], axis=-1) | |
| if self.return_name: | |
| return self._image_fnames[self._raw_idx[idx]], image.copy(), label_cam, mesh_cond | |
| else: | |
| return image.copy(), label_cam, mesh_cond | |
| def load_random_data(self): | |
| gen_cond_sample_idx = [np.random.randint(self.__len__()) for _ in range(self.random_sample_num)] | |
| all_gen_c = np.stack([self.get_label(i) for i in gen_cond_sample_idx]) | |
| all_gen_v = [self.get_vert(i) for i in gen_cond_sample_idx] | |
| all_gt_img = np.stack([self.get_image(i).astype(np.float32) / 127.5 - 1 for i in gen_cond_sample_idx]) | |
| return all_gen_c, all_gen_v, all_gt_img | |
| def get_by_name(self, name): | |
| raw_idx = self._image_fnames.index(name) | |
| image = self._load_raw_image(raw_idx, resolution=self.resolution) | |
| mesh_cond = self.get_vert(raw_idx) | |
| label = self._get_raw_labels()[raw_idx] | |
| cam = self._raw_cams[raw_idx] | |
| label_cam = np.concatenate([label, cam], axis=-1) | |
| return image.copy(), label_cam, mesh_cond | |
| def get_label(self, idx): | |
| raise NotImplementedError | |
| def get_vert(self, vert_dir, zip_file_select): | |
| raise NotImplementedError | |
| def get_details(self, idx): | |
| d = dnnlib.EasyDict() | |
| d.raw_idx = int(self._raw_idx[idx]) | |
| d.xflip = (int(self._xflip[idx]) != 0) | |
| d.raw_label = self._get_raw_labels()[d.raw_idx].copy() | |
| return d | |
| def get_label_std(self): | |
| return self._raw_labels_std | |
| def name(self): | |
| return self._name | |
| def image_shape(self): | |
| return list(self._raw_shape[1:]) | |
| def num_channels(self): | |
| assert len(self.image_shape) == 3 # CHW | |
| return self.image_shape[0] | |
| def resolution(self): | |
| assert len(self.image_shape) == 3 # CHW | |
| assert self.image_shape[1] == self.image_shape[2] | |
| return self.image_shape[1] | |
| def label_shape(self): | |
| if self._label_shape is None: | |
| raw_labels = self._get_raw_labels() | |
| if raw_labels.dtype == np.int64: | |
| self._label_shape = [int(np.max(raw_labels)) + 1] | |
| else: | |
| self._label_shape = raw_labels.shape[1:] | |
| return list(self._label_shape) | |
| def label_dim(self): | |
| assert len(self.label_shape) == 1 | |
| return self.label_shape[0] | |
| # @property | |
| # def gen_label_dim(self): | |
| # return 25 # 25 for camera params only | |
| def has_labels(self): | |
| return any(x != 0 for x in [25]) | |
| def has_onehot_labels(self): | |
| return self._get_raw_labels().dtype == np.int64 | |
| class ImageFolderDataset(Dataset): | |
| def __init__(self, | |
| path, # Path to directory or zip. | |
| data_label_path, | |
| label_file_vfhq, | |
| label_file_ffhq, | |
| mesh_path_ffhq, | |
| motion_path_ffhq, | |
| mesh_path_vfhq, | |
| motion_path_vfhq, | |
| mesh_path_ffhq_label, | |
| mesh_path_vfhq_label, | |
| resolution=512, | |
| static=False, | |
| **super_kwargs, | |
| ): | |
| self._path = path | |
| self._mesh_ffhq = mesh_path_ffhq | |
| self._motion_ffhq = np.load(motion_path_ffhq) | |
| # self._label_ffhq = np.load(label_file_ffhq) | |
| self._mesh_vfhq = mesh_path_vfhq | |
| self._motion_vfhq = np.load(motion_path_vfhq, allow_pickle=True) | |
| # self._label_vfhq = np.load(label_file_vfhq) | |
| self._data_label_path = data_label_path | |
| self.data_json = json.loads(open(data_label_path).read()) | |
| PIL.Image.init() | |
| self._raw_cams_ffhq = json.loads(open(label_file_ffhq).read())['labels'] | |
| self.mesh_path_ffhq_label = json.loads(open(mesh_path_ffhq_label).read()) | |
| self.mesh_path_vfhq_label = json.loads(open(mesh_path_vfhq_label).read()) | |
| # self._image_fnames = list(dict(json.loads(open(self._raw_cams_ffhq).read())['labels']).keys()) | |
| # self._raw_cams_vfhq = self._load_raw_label(self._mesh_path_vfhq, 'labels') | |
| self.all_input_ids = list(self.data_json.keys()) | |
| self._type = 'zip' | |
| name = os.path.splitext(os.path.basename(self._path))[0] | |
| # self.base_file_zip = zipfile.ZipFile(path) | |
| # self.vfhq_verts_zip = zipfile.ZipFile(mesh_path_vfhq) | |
| # self.ffhq_verts_zip = zipfile.ZipFile(mesh_path_ffhq) | |
| self.base_file_zip = None | |
| self.vfhq_verts_zip = None | |
| self.ffhq_verts_zip = None | |
| # raw_shape = [len(self._image_fnames)] + [3, resolution, resolution] | |
| raw_shape = [len(self.all_input_ids)] + list([3, 512, 512]) | |
| super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) | |
| def __del__(self): | |
| try: | |
| self.close() | |
| except: | |
| pass | |
| def __len__(self): | |
| return len(self.all_input_ids) | |
| # phase_real_z, phase_real_latent, phase_real_c_1_d, phase_real_c_2_d, phase_real_c_3_d, phase_real_v_1_d, phase_real_v_2_d, phase_real_v_s, motion_1, motion_2, motion_ffhq, model_list | |
| def __getitem__(self, idx): | |
| for _ in range(20): | |
| try: | |
| return self.getdata(idx) | |
| except Exception as e: | |
| print(f"Error details: {str(e)}") | |
| idx = np.random.randint(len(self)) | |
| raise RuntimeError('Too many bad data.') | |
| def getdata(self, idx): | |
| base_dir = self.all_input_ids[idx] | |
| model_name = self.data_json[base_dir] | |
| latent_dir = os.path.join('multi_style', base_dir, '0.pt') | |
| latent_dit_dir = os.path.join('multi_style', base_dir, '0_dit.pt') | |
| if self.base_file_zip is None: | |
| self.base_file_zip = zipfile.ZipFile(self._path) | |
| with self._open_file(self.base_file_zip, latent_dir) as f: | |
| phase_real_z = torch.load(f).float() | |
| # with base_file_zip.open(latent_dir, 'r') as f: | |
| # buffer = io.BytesIO(f.read()) | |
| with self._open_file(self.base_file_zip, latent_dit_dir) as f: | |
| phase_real_latent = torch.load(f).float() | |
| # with base_file_zip.open(latent_dit_dir, 'r') as f: | |
| # # buffer = io.BytesIO(f.read()) | |
| # phase_real_latent = torch.load(f).float() | |
| # f.close() | |
| # phase_real_z = torch.load(latent_dir).float() | |
| # phase_real_latent = torch.load(latent_dit_dir).float() | |
| phase_real_c_1_d, phase_real_c_2_d, phase_real_c_3_d = self.get_label(idx) | |
| motion_ffhq, phase_real_v_s = self.get_ffhq_motion() | |
| motion_1, motion_2, phase_real_v_1_d, phase_real_v_2_d = self.get_vfhq_motion() | |
| return { | |
| "model_name": model_name, | |
| "phase_real_z": phase_real_z, | |
| "phase_real_latent": phase_real_latent, | |
| "phase_real_c_1_d": phase_real_c_1_d.unsqueeze(0), | |
| "phase_real_c_2_d": phase_real_c_2_d.unsqueeze(0), | |
| "phase_real_c_3_d": phase_real_c_3_d.unsqueeze(0), | |
| "phase_real_v_s": phase_real_v_s.unsqueeze(0), | |
| "motion_ffhq": motion_ffhq.unsqueeze(0), | |
| "motion_1": motion_1.unsqueeze(0), | |
| "motion_2": motion_2.unsqueeze(0), | |
| "phase_real_v_1_d": phase_real_v_1_d.unsqueeze(0), | |
| "phase_real_v_2_d": phase_real_v_2_d.unsqueeze(0) | |
| } | |
| def _file_ext(fname): | |
| return os.path.splitext(fname)[1].lower() | |
| def _get_zipfile(self): | |
| assert self._type == 'zip' | |
| if self._zipfile is None: | |
| self._zipfile = zipfile.ZipFile(self._path) | |
| return self._zipfile | |
| def _open_file(self, zip_file, fname, path=None): | |
| if not path: | |
| path = self._path | |
| if self._type == 'dir': | |
| return open(os.path.join(path, fname), 'rb') | |
| if self._type == 'zip': | |
| return zip_file.open(fname, 'r') | |
| return None | |
| def close(self): | |
| try: | |
| if self._zipfile is not None: | |
| self._zipfile.close() | |
| finally: | |
| self._zipfile = None | |
| def __getstate__(self): | |
| return dict(super().__getstate__(), _zipfile=None) | |
| def _load_raw_label(self, json_path, sub_key=None): | |
| with open(json_path, 'rb') as f: | |
| labels = json.load(f) | |
| if sub_key is not None: labels = labels[sub_key] | |
| labels = dict(labels) | |
| labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] | |
| return np.array(labels).astype(np.float32) | |
| def _load_raw_image_core(self, fname, path=None, resolution=None): | |
| with self._open_file(fname, path) as f: | |
| image = PIL.Image.open(f) | |
| if resolution: | |
| image = image.resize((resolution, resolution)) | |
| image = np.array(image) # .astype(np.float32) | |
| if image.ndim == 2: | |
| image = image[:, :, np.newaxis] # HW => HWC | |
| image = image.transpose(2, 0, 1) # HWC => CHW | |
| return image | |
| def _load_raw_motion(self, raw_idx, resolution=None): | |
| fname = self._image_fnames[raw_idx] | |
| image = self._load_raw_image_core(fname, resolution=resolution) # [C, H, W] | |
| return image | |
| def _load_vfhq_raw_labels(self): | |
| labels = self._load_raw_label(os.path.join(self._path, self.label_file), 'labels') | |
| return labels | |
| def _load_ffhq_raw_labels(self): | |
| labels = self._load_raw_label(os.path.join(self._path, self.label_file), 'labels') | |
| return labels | |
| def get_vert(self, vert_dir, zip_file_select): | |
| with zip_file_select.open(vert_dir, 'r') as f: | |
| # buffer = io.BytesIO(f.read()) | |
| uvcoords_image = np.load(f)[..., | |
| :3] | |
| # uvcoords_image = np.load(os.path.join(vert_dir))[..., | |
| # :3] # [HW3] 前两维date range(-1, 1),第三维是face_mask,最后一维是render_mask | |
| uvcoords_image[..., -1][uvcoords_image[..., -1] < 0.5] = 0; | |
| uvcoords_image[..., -1][uvcoords_image[..., -1] >= 0.5] = 1 | |
| # out = {'uvcoords_image': torch.tensor(uvcoords_image.copy()).float()} | |
| f.close() | |
| return torch.tensor(uvcoords_image.copy()).float() | |
| def load_random_data(self): | |
| gen_cond_sample_idx = [np.random.randint(self.__len__()) for _ in range(self.random_sample_num)] | |
| all_gen_c = np.stack([self.get_label(i) for i in gen_cond_sample_idx]) | |
| all_gen_v = [self.get_vert(i) for i in gen_cond_sample_idx] | |
| all_gt_img = np.stack([self.get_image(i).astype(np.float32) / 127.5 - 1 for i in gen_cond_sample_idx]) | |
| return all_gen_c, all_gen_v, all_gt_img | |
| def get_label(self, idx): | |
| # all_nums = self._raw_cams_ffhq | |
| gen_cond_sample_idx = [np.random.randint(len(self._raw_cams_ffhq)) for _ in range(3)] | |
| cam = [self._raw_cams_ffhq[i][1] for i in gen_cond_sample_idx] | |
| return torch.tensor(np.array(cam[0]).astype(np.float32)).float(), torch.tensor(np.array(cam[1]).astype(np.float32)).float(), torch.tensor(np.array(cam[2]).astype(np.float32)).float() | |
| def get_ffhq_motion(self): | |
| assert len(self.mesh_path_ffhq_label) == self._motion_ffhq.shape[0] | |
| gen_cond_sample_idx = np.random.randint(self._motion_ffhq.shape[0]) | |
| motion = self._motion_ffhq[gen_cond_sample_idx] | |
| vert_dir = os.path.join('orthRender256x256_face_eye', self.mesh_path_ffhq_label[gen_cond_sample_idx]) | |
| if self.ffhq_verts_zip is None: | |
| self.ffhq_verts_zip = zipfile.ZipFile(self._mesh_ffhq) | |
| vert = self.get_vert(vert_dir, self.ffhq_verts_zip) | |
| return torch.tensor(motion).float(), vert | |
| def get_vfhq_motion(self): | |
| assert len(self.mesh_path_vfhq_label) == self._motion_vfhq.shape[0] | |
| gen_cond_sample_idx_row = np.random.randint(self._motion_vfhq.shape[0]) | |
| motions = self._motion_vfhq[gen_cond_sample_idx_row] | |
| verts = self.mesh_path_vfhq_label[gen_cond_sample_idx_row] | |
| assert motions.shape[0] == len(verts) | |
| gen_cond_sample_idx_col = np.random.randint(motions.shape[0], size=2) | |
| motions_1 = motions[gen_cond_sample_idx_col[0]] | |
| motions_2 = motions[gen_cond_sample_idx_col[1]] | |
| verts_1_dir = os.path.join('orthRender256x256_face_eye', verts[gen_cond_sample_idx_col[0]]) | |
| verts_2_dir = os.path.join('orthRender256x256_face_eye', verts[gen_cond_sample_idx_col[1]]) | |
| if self.vfhq_verts_zip is None: | |
| self.vfhq_verts_zip = zipfile.ZipFile(self._mesh_vfhq) | |
| verts_1 = self.get_vert(verts_1_dir, self.vfhq_verts_zip) | |
| verts_2 = self.get_vert(verts_2_dir, self.vfhq_verts_zip) | |
| return torch.tensor(motions_1).float(), torch.tensor(motions_2).float(), verts_1, verts_2 | |