Spaces:
Runtime error
Runtime error
| import ast | |
| import json | |
| import logging | |
| import math | |
| import os | |
| import random | |
| import sys | |
| import time | |
| import braceexpand | |
| from dataclasses import dataclass | |
| from multiprocessing import Value | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torchvision.datasets as datasets | |
| import webdataset as wds | |
| from PIL import Image | |
| from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info | |
| from torch.utils.data.distributed import DistributedSampler | |
| from webdataset.filters import _shuffle | |
| from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample | |
| from torchvision import transforms | |
| import io | |
| import PIL | |
| from PIL import ImageFile | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| try: | |
| import horovod.torch as hvd | |
| except ImportError: | |
| hvd = None | |
| import cv2 | |
| import math | |
| import json | |
| import random | |
| import seaborn as sns | |
| def vis_landmark_on_img(img, shape, linewidth=8): | |
| ''' | |
| Visualize landmark on images. | |
| ''' | |
| def draw_curve(idx_list, color=(0, 255, 0), loop=False, lineWidth=linewidth): | |
| for i in idx_list: | |
| cv2.line(img, (shape[i][0], shape[i][1]), (shape[i + 1][0], shape[i + 1][1]), color, lineWidth) | |
| if (loop): | |
| cv2.line(img, (shape[idx_list[0]][0], shape[idx_list[0]][1]), | |
| (shape[idx_list[-1] + 1][0], shape[idx_list[-1] + 1][1]), color, lineWidth) | |
| draw_curve(list(range(0, 16)), color=(255, 144, 25)) # jaw | |
| draw_curve(list(range(17, 21)), color=(50, 205, 50)) # eye brow | |
| draw_curve(list(range(22, 26)), color=(50, 205, 50)) | |
| draw_curve(list(range(27, 35)), color=(208, 224, 63)) # nose | |
| draw_curve(list(range(36, 41)), loop=True, color=(71, 99, 255)) # eyes | |
| draw_curve(list(range(42, 47)), loop=True, color=(71, 99, 255)) | |
| draw_curve(list(range(48, 59)), loop=True, color=(238, 130, 238)) # mouth | |
| draw_curve(list(range(60, 67)), loop=True, color=(238, 130, 238)) | |
| return img.astype("uint8") | |
| def imshow_keypoints(img, | |
| pose_result, | |
| skeleton=None, | |
| kpt_score_thr=0.3, | |
| pose_kpt_color=None, | |
| pose_link_color=None, | |
| radius=4, | |
| thickness=1, | |
| show_keypoint_weight=False, | |
| height=None, | |
| width=None): | |
| """Draw keypoints and links on an image. | |
| Args: | |
| img (str or Tensor): The image to draw poses on. If an image array | |
| is given, id will be modified in-place. | |
| pose_result (list[kpts]): The poses to draw. Each element kpts is | |
| a set of K keypoints as an Kx3 numpy.ndarray, where each | |
| keypoint is represented as x, y, score. | |
| kpt_score_thr (float, optional): Minimum score of keypoints | |
| to be shown. Default: 0.3. | |
| pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, | |
| the keypoint will not be drawn. | |
| pose_link_color (np.array[Mx3]): Color of M links. If None, the | |
| links will not be drawn. | |
| thickness (int): Thickness of lines. | |
| """ | |
| # img = mmcv.imread(img) | |
| # img_h, img_w, _ = img.shape | |
| if img is None: | |
| img = np.zeros((height, width, 3), dtype=np.uint8) | |
| img_h, img_w = height, width | |
| else: | |
| img_h, img_w, _ = img.shape | |
| for kpts in pose_result: | |
| kpts = np.array(kpts, copy=False) | |
| # draw each point on image | |
| if pose_kpt_color is not None: | |
| assert len(pose_kpt_color) == len(kpts) | |
| for kid, kpt in enumerate(kpts): | |
| x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] | |
| if kpt_score > kpt_score_thr: | |
| color = tuple(int(c) for c in pose_kpt_color[kid]) | |
| if show_keypoint_weight: | |
| img_copy = img.copy() | |
| cv2.circle(img_copy, (int(x_coord), int(y_coord)), | |
| radius, color, -1) | |
| transparency = max(0, min(1, kpt_score)) | |
| cv2.addWeighted( | |
| img_copy, | |
| transparency, | |
| img, | |
| 1 - transparency, | |
| 0, | |
| dst=img) | |
| else: | |
| cv2.circle(img, (int(x_coord), int(y_coord)), radius, | |
| color, -1) | |
| # draw links | |
| if skeleton is not None and pose_link_color is not None: | |
| assert len(pose_link_color) == len(skeleton) | |
| for sk_id, sk in enumerate(skeleton): | |
| pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) | |
| pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) | |
| # if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0 | |
| # and pos1[1] < img_h and pos2[0] > 0 and pos2[0] < img_w | |
| # and pos2[1] > 0 and pos2[1] < img_h | |
| # and kpts[sk[0], 2] > kpt_score_thr | |
| # and kpts[sk[1], 2] > kpt_score_thr): | |
| if (kpts[sk[0], 2] > kpt_score_thr | |
| and kpts[sk[1], 2] > kpt_score_thr): | |
| color = tuple(int(c) for c in pose_link_color[sk_id]) | |
| if show_keypoint_weight: | |
| img_copy = img.copy() | |
| X = (pos1[0], pos2[0]) | |
| Y = (pos1[1], pos2[1]) | |
| mX = np.mean(X) | |
| mY = np.mean(Y) | |
| length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 | |
| angle = math.degrees( | |
| math.atan2(Y[0] - Y[1], X[0] - X[1])) | |
| stickwidth = thickness | |
| polygon = cv2.ellipse2Poly( | |
| (int(mX), int(mY)), | |
| (int(length / 2), int(stickwidth)), int(angle), 0, | |
| 360, 1) | |
| cv2.fillConvexPoly(img_copy, polygon, color) | |
| # transparency = max( | |
| # 0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2]))) | |
| transparency = 1 | |
| cv2.addWeighted( | |
| img_copy, | |
| transparency, | |
| img, | |
| 1 - transparency, | |
| 0, | |
| dst=img) | |
| else: | |
| cv2.line(img, pos1, pos2, color, thickness=thickness) | |
| return img | |
| def imshow_keypoints_body(img, | |
| pose_result, | |
| skeleton=None, | |
| kpt_score_thr=0.3, | |
| pose_kpt_color=None, | |
| pose_link_color=None, | |
| radius=4, | |
| thickness=1, | |
| show_keypoint_weight=False, | |
| height=None, | |
| width=None): | |
| """Draw keypoints and links on an image. | |
| Args: | |
| img (str or Tensor): The image to draw poses on. If an image array | |
| is given, id will be modified in-place. | |
| pose_result (list[kpts]): The poses to draw. Each element kpts is | |
| a set of K keypoints as an Kx3 numpy.ndarray, where each | |
| keypoint is represented as x, y, score. | |
| kpt_score_thr (float, optional): Minimum score of keypoints | |
| to be shown. Default: 0.3. | |
| pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, | |
| the keypoint will not be drawn. | |
| pose_link_color (np.array[Mx3]): Color of M links. If None, the | |
| links will not be drawn. | |
| thickness (int): Thickness of lines. | |
| """ | |
| # img = mmcv.imread(img) | |
| # img_h, img_w, _ = img.shape | |
| if img is None: | |
| img = np.zeros((height, width, 3), dtype=np.uint8) | |
| img_h, img_w = height, width | |
| else: | |
| img_h, img_w, _ = img.shape | |
| for kpts in pose_result: | |
| kpts = np.array(kpts, copy=False) | |
| # draw each point on image | |
| if pose_kpt_color is not None: | |
| assert len(pose_kpt_color) == len(kpts) | |
| for kid, kpt in enumerate(kpts): | |
| if kid in [17, 18, 19, 20, 21, 22]: | |
| continue | |
| if kid in [13, 14, 15, 16]: | |
| if kpt[0] > min(kpts[23:91, 0]) and kpt[0] < max(kpts[23:91, 0]) and kpt[1] > min(kpts[23:91, 1]) and kpt[1] < max(kpts[23:91, 1]): | |
| continue | |
| x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] | |
| if kpt_score > kpt_score_thr: | |
| color = tuple(int(c) for c in pose_kpt_color[kid]) | |
| if show_keypoint_weight: | |
| img_copy = img.copy() | |
| cv2.circle(img_copy, (int(x_coord), int(y_coord)), | |
| radius, color, -1) | |
| transparency = max(0, min(1, kpt_score)) | |
| cv2.addWeighted( | |
| img_copy, | |
| transparency, | |
| img, | |
| 1 - transparency, | |
| 0, | |
| dst=img) | |
| else: | |
| cv2.circle(img, (int(x_coord), int(y_coord)), radius, | |
| color, -1) | |
| # draw links | |
| if skeleton is not None and pose_link_color is not None: | |
| assert len(pose_link_color) == len(skeleton) | |
| for sk_id, sk in enumerate(skeleton): | |
| if sk[0] in [17, 18, 19, 20, 21, 22] or sk[1] in [17, 18, 19, 20, 21, 22]: | |
| continue | |
| if sk[0] in [13, 14, 15, 16]: | |
| if kpts[sk[0], 0] > min(kpts[23:91, 0]) and kpts[sk[0], 0] < max(kpts[23:91, 0]) and kpts[sk[0], 1] > min(kpts[23:91, 1]) and kpts[sk[0], 1] < max(kpts[23:91, 1]): | |
| continue | |
| if sk[1] in [13, 14, 15, 16]: | |
| if kpts[sk[1], 0] > min(kpts[23:91, 0]) and kpts[sk[1], 0] < max(kpts[23:91, 0]) and kpts[sk[1], 1] > min(kpts[23:91, 1]) and kpts[sk[1], 1] < max(kpts[23:91, 1]): | |
| continue | |
| pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) | |
| pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) | |
| # if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0 | |
| # and pos1[1] < img_h and pos2[0] > 0 and pos2[0] < img_w | |
| # and pos2[1] > 0 and pos2[1] < img_h | |
| # and kpts[sk[0], 2] > kpt_score_thr | |
| # and kpts[sk[1], 2] > kpt_score_thr): | |
| if (kpts[sk[0], 2] > kpt_score_thr | |
| and kpts[sk[1], 2] > kpt_score_thr): | |
| color = tuple(int(c) for c in pose_link_color[sk_id]) | |
| if show_keypoint_weight: | |
| img_copy = img.copy() | |
| X = (pos1[0], pos2[0]) | |
| Y = (pos1[1], pos2[1]) | |
| mX = np.mean(X) | |
| mY = np.mean(Y) | |
| length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 | |
| angle = math.degrees( | |
| math.atan2(Y[0] - Y[1], X[0] - X[1])) | |
| stickwidth = thickness | |
| polygon = cv2.ellipse2Poly( | |
| (int(mX), int(mY)), | |
| (int(length / 2), int(stickwidth)), int(angle), 0, | |
| 360, 1) | |
| cv2.fillConvexPoly(img_copy, polygon, color) | |
| # transparency = max( | |
| # 0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2]))) | |
| transparency = 1 | |
| cv2.addWeighted( | |
| img_copy, | |
| transparency, | |
| img, | |
| 1 - transparency, | |
| 0, | |
| dst=img) | |
| else: | |
| cv2.line(img, pos1, pos2, color, thickness=thickness) | |
| return img | |
| def imshow_keypoints_whole(img, | |
| pose_result, | |
| skeleton=None, | |
| kpt_score_thr=0.3, | |
| pose_kpt_color=None, | |
| pose_link_color=None, | |
| radius=4, | |
| thickness=1, | |
| show_keypoint_weight=False, | |
| height=None, | |
| width=None): | |
| """Draw keypoints and links on an image. | |
| Args: | |
| img (str or Tensor): The image to draw poses on. If an image array | |
| is given, id will be modified in-place. | |
| pose_result (list[kpts]): The poses to draw. Each element kpts is | |
| a set of K keypoints as an Kx3 numpy.ndarray, where each | |
| keypoint is represented as x, y, score. | |
| kpt_score_thr (float, optional): Minimum score of keypoints | |
| to be shown. Default: 0.3. | |
| pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, | |
| the keypoint will not be drawn. | |
| pose_link_color (np.array[Mx3]): Color of M links. If None, the | |
| links will not be drawn. | |
| thickness (int): Thickness of lines. | |
| """ | |
| # img = mmcv.imread(img) | |
| # img_h, img_w, _ = img.shape | |
| if img is None: | |
| img = np.zeros((height, width, 3), dtype=np.uint8) | |
| img_h, img_w = height, width | |
| else: | |
| img_h, img_w, _ = img.shape | |
| for kpts in pose_result: | |
| kpts = np.array(kpts, copy=False) | |
| # draw each point on image | |
| if pose_kpt_color is not None: | |
| assert len(pose_kpt_color) == len(kpts) | |
| for kid, kpt in enumerate(kpts): | |
| if kid in [17, 18, 19, 20, 21, 22]: | |
| continue | |
| if kid in [13, 14, 15, 16]: | |
| if kpt[0] > min(kpts[23:91, 0]) and kpt[0] < max(kpts[23:91, 0]) and kpt[1] > min(kpts[23:91, 1]) and kpt[1] < max(kpts[23:91, 1]): | |
| continue | |
| x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] | |
| if kpt_score > kpt_score_thr: | |
| color = tuple(int(c) for c in pose_kpt_color[kid]) | |
| if show_keypoint_weight: | |
| img_copy = img.copy() | |
| cv2.circle(img_copy, (int(x_coord), int(y_coord)), | |
| radius, color, -1) | |
| transparency = max(0, min(1, kpt_score)) | |
| cv2.addWeighted( | |
| img_copy, | |
| transparency, | |
| img, | |
| 1 - transparency, | |
| 0, | |
| dst=img) | |
| else: | |
| cv2.circle(img, (int(x_coord), int(y_coord)), radius, | |
| color, -1) | |
| # draw links | |
| if skeleton is not None and pose_link_color is not None: | |
| assert len(pose_link_color) == len(skeleton) | |
| for sk_id, sk in enumerate(skeleton): | |
| if sk[0] in [17, 18, 19, 20, 21, 22] or sk[1] in [17, 18, 19, 20, 21, 22]: | |
| continue | |
| if sk[0] in [13, 14, 15, 16]: | |
| if kpts[sk[0], 0] > min(kpts[23:91, 0]) and kpts[sk[0], 0] < max(kpts[23:91, 0]) and kpts[sk[0], 1] > min(kpts[23:91, 1]) and kpts[sk[0], 1] < max(kpts[23:91, 1]): | |
| continue | |
| if sk[1] in [13, 14, 15, 16]: | |
| if kpts[sk[1], 0] > min(kpts[23:91, 0]) and kpts[sk[1], 0] < max(kpts[23:91, 0]) and kpts[sk[1], 1] > min(kpts[23:91, 1]) and kpts[sk[1], 1] < max(kpts[23:91, 1]): | |
| continue | |
| pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) | |
| pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) | |
| # if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0 | |
| # and pos1[1] < img_h and pos2[0] > 0 and pos2[0] < img_w | |
| # and pos2[1] > 0 and pos2[1] < img_h | |
| # and kpts[sk[0], 2] > kpt_score_thr | |
| # and kpts[sk[1], 2] > kpt_score_thr): | |
| if (kpts[sk[0], 2] > kpt_score_thr | |
| and kpts[sk[1], 2] > kpt_score_thr): | |
| color = tuple(int(c) for c in pose_link_color[sk_id]) | |
| if show_keypoint_weight: | |
| img_copy = img.copy() | |
| X = (pos1[0], pos2[0]) | |
| Y = (pos1[1], pos2[1]) | |
| mX = np.mean(X) | |
| mY = np.mean(Y) | |
| length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 | |
| angle = math.degrees( | |
| math.atan2(Y[0] - Y[1], X[0] - X[1])) | |
| stickwidth = thickness | |
| polygon = cv2.ellipse2Poly( | |
| (int(mX), int(mY)), | |
| (int(length / 2), int(stickwidth)), int(angle), 0, | |
| 360, 1) | |
| cv2.fillConvexPoly(img_copy, polygon, color) | |
| # transparency = max( | |
| # 0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2]))) | |
| transparency = 1 | |
| cv2.addWeighted( | |
| img_copy, | |
| transparency, | |
| img, | |
| 1 - transparency, | |
| 0, | |
| dst=img) | |
| else: | |
| cv2.line(img, pos1, pos2, color, thickness=thickness) | |
| return img | |
| def draw_whole_body_skeleton( | |
| img, | |
| pose, | |
| radius=4, | |
| thickness=1, | |
| kpt_score_thr=0.3, | |
| height=None, | |
| width=None, | |
| ): | |
| palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], | |
| [230, 230, 0], [255, 153, 255], [153, 204, 255], | |
| [255, 102, 255], [255, 51, 255], [102, 178, 255], | |
| [51, 153, 255], [255, 153, 153], [255, 102, 102], | |
| [255, 51, 51], [153, 255, 153], [102, 255, 102], | |
| [51, 255, 51], [0, 255, 0], [0, 0, 255], | |
| [255, 0, 0], [255, 255, 255]]) | |
| # below are for the whole body keypoints | |
| skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], | |
| [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], | |
| [8, 10], [1, 2], [0, 1], [0, 2], | |
| [1, 3], [2, 4], [3, 5], [4, 6], [15, 17], [15, 18], | |
| [15, 19], [16, 20], [16, 21], [16, 22], [91, 92], | |
| [92, 93], [93, 94], [94, 95], [91, 96], [96, 97], | |
| [97, 98], [98, 99], [91, 100], [100, 101], [101, 102], | |
| [102, 103], [91, 104], [104, 105], [105, 106], | |
| [106, 107], [91, 108], [108, 109], [109, 110], | |
| [110, 111], [112, 113], [113, 114], [114, 115], | |
| [115, 116], [112, 117], [117, 118], [118, 119], | |
| [119, 120], [112, 121], [121, 122], [122, 123], | |
| [123, 124], [112, 125], [125, 126], [126, 127], | |
| [127, 128], [112, 129], [129, 130], [130, 131], | |
| [131, 132]] | |
| pose_link_color = palette[[ | |
| 0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16 | |
| ] + [16, 16, 16, 16, 16, 16] + [ | |
| 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, | |
| 16 | |
| ] + [ | |
| 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, | |
| 16 | |
| ]] | |
| pose_kpt_color = palette[ | |
| [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0] + | |
| [0, 0, 0, 0, 0, 0] + [19] * (68 + 42)] | |
| draw = imshow_keypoints_whole(img, pose, skeleton, | |
| kpt_score_thr=0.3, | |
| pose_kpt_color=pose_kpt_color, | |
| pose_link_color=pose_link_color, | |
| radius=radius, | |
| thickness=thickness, | |
| show_keypoint_weight=True, | |
| height=height, | |
| width=width) | |
| return draw | |
| def draw_humansd_skeleton(image, pose, mmpose_detection_thresh=0.3, height=None, width=None, humansd_skeleton_width=10): | |
| humansd_skeleton=[ | |
| [0,0,1], | |
| [1,0,2], | |
| [2,1,3], | |
| [3,2,4], | |
| [4,3,5], | |
| [5,4,6], | |
| [6,5,7], | |
| [7,6,8], | |
| [8,7,9], | |
| [9,8,10], | |
| [10,5,11], | |
| [11,6,12], | |
| [12,11,13], | |
| [13,12,14], | |
| [14,13,15], | |
| [15,14,16], | |
| ] | |
| # humansd_skeleton_width=10 | |
| humansd_color=sns.color_palette("hls", len(humansd_skeleton)) | |
| def plot_kpts(img_draw, kpts, color, edgs,width): | |
| for idx, kpta, kptb in edgs: | |
| if kpts[kpta,2]>mmpose_detection_thresh and \ | |
| kpts[kptb,2]>mmpose_detection_thresh : | |
| line_color = tuple([int(255*color_i) for color_i in color[idx]]) | |
| cv2.line(img_draw, (int(kpts[kpta,0]),int(kpts[kpta,1])), (int(kpts[kptb,0]),int(kpts[kptb,1])), line_color,width) | |
| cv2.circle(img_draw, (int(kpts[kpta,0]),int(kpts[kpta,1])), width//2, line_color, -1) | |
| cv2.circle(img_draw, (int(kpts[kptb,0]),int(kpts[kptb,1])), width//2, line_color, -1) | |
| if image is None: | |
| pose_image = np.zeros((height, width, 3), dtype=np.uint8) | |
| else: | |
| pose_image = np.array(image, dtype=np.uint8) | |
| for person_i in range(len(pose)): | |
| if np.sum(pose[person_i])>0: | |
| plot_kpts(pose_image, pose[person_i],humansd_color,humansd_skeleton,humansd_skeleton_width) | |
| return pose_image | |
| def draw_controlnet_skeleton(image, pose, mmpose_detection_thresh=0.3, height=None, width=None): | |
| if image is None: | |
| canvas = np.zeros((height, width, 3), dtype=np.uint8) | |
| else: | |
| H, W, C = image.shape | |
| canvas = np.array(image, dtype=np.uint8) | |
| for pose_i in range(len(pose)): | |
| present_pose=pose[pose_i] | |
| candidate=[ | |
| [present_pose[0,0],present_pose[0,1],present_pose[0,2],0], | |
| [(present_pose[6,0]+present_pose[5,0])/2,(present_pose[6,1]+present_pose[5,1])/2,(present_pose[6,2]+present_pose[5,2])/2,1] if present_pose[6,2]>mmpose_detection_thresh and present_pose[5,2]>mmpose_detection_thresh else [-1,-1,0,1], | |
| [present_pose[6,0],present_pose[6,1],present_pose[6,2],2], | |
| [present_pose[8,0],present_pose[8,1],present_pose[8,2],3], | |
| [present_pose[10,0],present_pose[10,1],present_pose[10,2],4], | |
| [present_pose[5,0],present_pose[5,1],present_pose[5,2],5], | |
| [present_pose[7,0],present_pose[7,1],present_pose[7,2],6], | |
| [present_pose[9,0],present_pose[9,1],present_pose[9,2],7], | |
| [present_pose[12,0],present_pose[12,1],present_pose[12,2],8], | |
| [present_pose[14,0],present_pose[14,1],present_pose[14,2],9], | |
| [present_pose[16,0],present_pose[16,1],present_pose[16,2],10], | |
| [present_pose[11,0],present_pose[11,1],present_pose[11,2],11], | |
| [present_pose[13,0],present_pose[13,1],present_pose[13,2],12], | |
| [present_pose[15,0],present_pose[15,1],present_pose[15,2],13], | |
| [present_pose[2,0],present_pose[2,1],present_pose[2,2],14], | |
| [present_pose[1,0],present_pose[1,1],present_pose[1,2],15], | |
| [present_pose[4,0],present_pose[4,1],present_pose[4,2],16], | |
| [present_pose[3,0],present_pose[3,1],present_pose[3,2],17], | |
| ] | |
| stickwidth = 4 | |
| limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ | |
| [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ | |
| [1, 16], [16, 18], [3, 17], [6, 18]] | |
| colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ | |
| [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ | |
| [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] | |
| for i in range(17): | |
| if candidate[limbSeq[i][0]-1][2]>mmpose_detection_thresh and candidate[limbSeq[i][1]-1][2]>mmpose_detection_thresh: | |
| Y=[candidate[limbSeq[i][1]-1][0],candidate[limbSeq[i][0]-1][0]] | |
| X=[candidate[limbSeq[i][1]-1][1],candidate[limbSeq[i][0]-1][1]] | |
| mX = np.mean(X) | |
| mY = np.mean(Y) | |
| length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 | |
| angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) | |
| polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) | |
| cur_canvas = canvas.copy() | |
| cv2.fillConvexPoly(cur_canvas, polygon, colors[i]) | |
| canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) | |
| for i in range(18): | |
| if candidate[i][2]>mmpose_detection_thresh: | |
| x, y = candidate[i][0:2] | |
| cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) | |
| return canvas | |
| def draw_body_skeleton( | |
| img, | |
| pose, | |
| radius=4, | |
| thickness=1, | |
| kpt_score_thr=0.3, | |
| height=None, | |
| width=None, | |
| ): | |
| palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], | |
| [230, 230, 0], [255, 153, 255], [153, 204, 255], | |
| [255, 102, 255], [255, 51, 255], [102, 178, 255], | |
| [51, 153, 255], [255, 153, 153], [255, 102, 102], | |
| [255, 51, 51], [153, 255, 153], [102, 255, 102], | |
| [51, 255, 51], [0, 255, 0], [0, 0, 255], | |
| [255, 0, 0], [255, 255, 255]]) | |
| # below are for the body keypoints | |
| # skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], | |
| # [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], | |
| # [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], | |
| # [3, 5], [4, 6]] | |
| skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], | |
| [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], | |
| [8, 10], [3, 4], | |
| [3, 5], [4, 6]] | |
| # pose_link_color = palette[[ | |
| # 0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16 | |
| # ]] | |
| # pose_kpt_color = palette[[ | |
| # 16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0 | |
| # ]] | |
| pose_link_color = palette[[ | |
| 12, 16, 1, 5, 9, 13, 19, 15, 11, 7, 3, 18, 14, 8, 0 | |
| ]] | |
| pose_kpt_color = palette[[ | |
| 19, 15, 11, 7, 3, 18, 14, 10, 6, 2, 17, 13, 9, 5, 1, 16, 12 | |
| ]] | |
| draw = imshow_keypoints_body(img, pose, skeleton, | |
| kpt_score_thr=0.3, | |
| pose_kpt_color=pose_kpt_color, | |
| pose_link_color=pose_link_color, | |
| radius=radius, | |
| thickness=thickness, | |
| show_keypoint_weight=True, | |
| height=height, | |
| width=width) | |
| return draw | |
| def draw_face_skeleton( | |
| img, | |
| pose, | |
| radius=4, | |
| thickness=1, | |
| kpt_score_thr=0.3, | |
| height=None, | |
| width=None, | |
| ): | |
| palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], | |
| [230, 230, 0], [255, 153, 255], [153, 204, 255], | |
| [255, 102, 255], [255, 51, 255], [102, 178, 255], | |
| [51, 153, 255], [255, 153, 153], [255, 102, 102], | |
| [255, 51, 51], [153, 255, 153], [102, 255, 102], | |
| [51, 255, 51], [0, 255, 0], [0, 0, 255], | |
| [255, 0, 0], [255, 255, 255]]) | |
| # below are for the face keypoints | |
| skeleton = [] | |
| pose_link_color = palette[[]] | |
| pose_kpt_color = palette[[19] * 68] | |
| kpt_score_thr = 0 | |
| draw = imshow_keypoints(img, pose, skeleton, | |
| kpt_score_thr=kpt_score_thr, | |
| pose_kpt_color=pose_kpt_color, | |
| pose_link_color=pose_link_color, | |
| radius=radius, | |
| thickness=thickness, | |
| show_keypoint_weight=True, | |
| height=height, | |
| width=width) | |
| return draw | |
| def draw_hand_skeleton( | |
| img, | |
| pose, | |
| radius=4, | |
| thickness=1, | |
| kpt_score_thr=0.3, | |
| height=None, | |
| width=None, | |
| ): | |
| palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], | |
| [230, 230, 0], [255, 153, 255], [153, 204, 255], | |
| [255, 102, 255], [255, 51, 255], [102, 178, 255], | |
| [51, 153, 255], [255, 153, 153], [255, 102, 102], | |
| [255, 51, 51], [153, 255, 153], [102, 255, 102], | |
| [51, 255, 51], [0, 255, 0], [0, 0, 255], | |
| [255, 0, 0], [255, 255, 255]]) | |
| # hand option 1 | |
| skeleton = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], | |
| [7, 8], [0, 9], [9, 10], [10, 11], [11, 12], [0, 13], | |
| [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], | |
| [18, 19], [19, 20]] | |
| pose_link_color = palette[[ | |
| 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, | |
| 16 | |
| ]] | |
| pose_kpt_color = palette[[ | |
| 0, 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, | |
| 16, 16 | |
| ]] | |
| # # hand option 2 | |
| # skeleton = [[0, 1], [1, 2], [2, 3], [4, 5], [5, 6], [6, 7], [8, 9], | |
| # [9, 10], [10, 11], [12, 13], [13, 14], [14, 15], | |
| # [16, 17], [17, 18], [18, 19], [3, 20], [7, 20], | |
| # [11, 20], [15, 20], [19, 20]] | |
| # pose_link_color = palette[[ | |
| # 0, 0, 0, 4, 4, 4, 8, 8, 8, 12, 12, 12, 16, 16, 16, 0, 4, 8, 12, | |
| # 16 | |
| # ]] | |
| # pose_kpt_color = palette[[ | |
| # 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, | |
| # 16, 0 | |
| # ]] | |
| draw = imshow_keypoints(img, pose, skeleton, | |
| kpt_score_thr=kpt_score_thr, | |
| pose_kpt_color=pose_kpt_color, | |
| pose_link_color=pose_link_color, | |
| radius=radius, | |
| thickness=thickness, | |
| show_keypoint_weight=True, | |
| height=height, | |
| width=width) | |
| return draw | |
| class CsvDataset(Dataset): | |
| def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", tokenizer=None): | |
| logging.debug(f'Loading csv data from {input_filename}.') | |
| df = pd.read_csv(input_filename, sep=sep) | |
| self.images = df[img_key].tolist() | |
| self.captions = df[caption_key].tolist() | |
| self.transforms = transforms | |
| logging.debug('Done loading data.') | |
| self.tokenize = tokenizer | |
| def __len__(self): | |
| return len(self.captions) | |
| def __getitem__(self, idx): | |
| images = self.transforms(Image.open(str(self.images[idx]))) | |
| texts = self.tokenize([str(self.captions[idx])])[0] | |
| return images, texts | |
| class SharedEpoch: | |
| def __init__(self, epoch: int = 0): | |
| self.shared_epoch = Value('i', epoch) | |
| def set_value(self, epoch): | |
| self.shared_epoch.value = epoch | |
| def get_value(self): | |
| return self.shared_epoch.value | |
| class DataInfo: | |
| dataloader: DataLoader | |
| sampler: DistributedSampler = None | |
| shared_epoch: SharedEpoch = None | |
| def set_epoch(self, epoch): | |
| if self.shared_epoch is not None: | |
| self.shared_epoch.set_value(epoch) | |
| if self.sampler is not None and isinstance(self.sampler, DistributedSampler): | |
| self.sampler.set_epoch(epoch) | |
| def expand_urls(urls, weights=None): | |
| if weights is None: | |
| expanded_urls = wds.shardlists.expand_urls(urls) | |
| return expanded_urls, None | |
| if isinstance(urls, str): | |
| urllist = urls.split("::") | |
| weights = weights.split('::') | |
| assert len(weights) == len(urllist), f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match." | |
| weights = [float(weight) for weight in weights] | |
| all_urls, all_weights = [], [] | |
| for url, weight in zip(urllist, weights): | |
| expanded_url = list(braceexpand.braceexpand(url)) | |
| expanded_weights = [weight for _ in expanded_url] | |
| all_urls.extend(expanded_url) | |
| all_weights.extend(expanded_weights) | |
| return all_urls, all_weights | |
| else: | |
| all_urls = list(urls) | |
| return all_urls, weights | |
| def get_dataset_size(shards): | |
| shards_list, _ = expand_urls(shards) | |
| dir_path = os.path.dirname(shards_list[0]) | |
| sizes_filename = os.path.join(dir_path, 'sizes.json') | |
| len_filename = os.path.join(dir_path, '__len__') | |
| if os.path.exists(sizes_filename): | |
| sizes = json.load(open(sizes_filename, 'r')) | |
| total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list]) | |
| elif os.path.exists(len_filename): | |
| # FIXME this used to be eval(open(...)) but that seemed rather unsafe | |
| total_size = ast.literal_eval(open(len_filename, 'r').read()) | |
| else: | |
| total_size = None # num samples undefined | |
| # some common dataset sizes (at time of authors last download) | |
| # CC3M (train): 2905954 | |
| # CC12M: 10968539 | |
| # LAION-400M: 407332084 | |
| # LAION-2B (english): 2170337258 | |
| num_shards = len(shards_list) | |
| return total_size, num_shards | |
| def get_imagenet(args, preprocess_fns, split): | |
| assert split in ["train", "val", "v2"] | |
| is_train = split == "train" | |
| preprocess_train, preprocess_val = preprocess_fns | |
| if split == "v2": | |
| from imagenetv2_pytorch import ImageNetV2Dataset | |
| dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) | |
| else: | |
| if is_train: | |
| data_path = args.imagenet_train | |
| preprocess_fn = preprocess_train | |
| else: | |
| data_path = args.imagenet_val | |
| preprocess_fn = preprocess_val | |
| assert data_path | |
| dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) | |
| if is_train: | |
| idxs = np.zeros(len(dataset.targets)) | |
| target_array = np.array(dataset.targets) | |
| k = 50 | |
| for c in range(1000): | |
| m = target_array == c | |
| n = len(idxs[m]) | |
| arr = np.zeros(n) | |
| arr[:k] = 1 | |
| np.random.shuffle(arr) | |
| idxs[m] = arr | |
| idxs = idxs.astype('int') | |
| sampler = SubsetRandomSampler(np.where(idxs)[0]) | |
| else: | |
| sampler = None | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| num_workers=args.workers, | |
| sampler=sampler, | |
| ) | |
| return DataInfo(dataloader=dataloader, sampler=sampler) | |
| def count_samples(dataloader): | |
| os.environ["WDS_EPOCH"] = "0" | |
| n_elements, n_batches = 0, 0 | |
| for images, texts in dataloader: | |
| n_batches += 1 | |
| n_elements += len(images) | |
| assert len(images) == len(texts) | |
| return n_elements, n_batches | |
| def filter_no_caption_or_no_image(sample): | |
| has_caption = ('txt' in sample) | |
| has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) | |
| return has_caption and has_image | |
| def filter_no_image_or_no_ldmk(sample): | |
| has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) | |
| has_ldmk = ('ldmk' in sample) | |
| return has_image and has_ldmk | |
| def log_and_continue(exn): | |
| """Call in an exception handler to ignore any exception, issue a warning, and continue.""" | |
| logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') | |
| return True | |
| def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): | |
| """Return function over iterator that groups key, value pairs into samples. | |
| :param keys: function that splits the key into key and extension (base_plus_ext) | |
| :param lcase: convert suffixes to lower case (Default value = True) | |
| """ | |
| current_sample = None | |
| for filesample in data: | |
| assert isinstance(filesample, dict) | |
| fname, value = filesample["fname"], filesample["data"] | |
| prefix, suffix = keys(fname) | |
| if prefix is None: | |
| continue | |
| if lcase: | |
| suffix = suffix.lower() | |
| # FIXME webdataset version throws if suffix in current_sample, but we have a potential for | |
| # this happening in the current LAION400m dataset if a tar ends with same prefix as the next | |
| # begins, rare, but can happen since prefix aren't unique across tar files in that dataset | |
| if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: | |
| if valid_sample(current_sample): | |
| yield current_sample | |
| current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) | |
| if suffixes is None or suffix in suffixes: | |
| current_sample[suffix] = value | |
| if valid_sample(current_sample): | |
| yield current_sample | |
| def tarfile_to_samples_nothrow(src, handler=log_and_continue): | |
| # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw | |
| streams = url_opener(src, handler=handler) | |
| files = tar_file_expander(streams, handler=handler) | |
| samples = group_by_keys_nothrow(files, handler=handler) | |
| return samples | |
| def pytorch_worker_seed(increment=0): | |
| """get dataloader worker seed from pytorch""" | |
| worker_info = get_worker_info() | |
| if worker_info is not None: | |
| # favour using the seed already created for pytorch dataloader workers if it exists | |
| seed = worker_info.seed | |
| if increment: | |
| # space out seed increments so they can't overlap across workers in different iterations | |
| seed += increment * max(1, worker_info.num_workers) | |
| return seed | |
| # fallback to wds rank based seed | |
| return wds.utils.pytorch_worker_seed() | |
| _SHARD_SHUFFLE_SIZE = 2000 | |
| _SHARD_SHUFFLE_INITIAL = 500 | |
| _SAMPLE_SHUFFLE_SIZE = 5000 | |
| _SAMPLE_SHUFFLE_INITIAL = 1000 | |
| class detshuffle2(wds.PipelineStage): | |
| def __init__( | |
| self, | |
| bufsize=1000, | |
| initial=100, | |
| seed=0, | |
| epoch=-1, | |
| ): | |
| self.bufsize = bufsize | |
| self.initial = initial | |
| self.seed = seed | |
| self.epoch = epoch | |
| def run(self, src): | |
| if isinstance(self.epoch, SharedEpoch): | |
| epoch = self.epoch.get_value() | |
| else: | |
| # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) | |
| # situation as different workers may wrap at different times (or not at all). | |
| self.epoch += 1 | |
| epoch = self.epoch | |
| rng = random.Random() | |
| if self.seed < 0: | |
| # If seed is negative, we use the worker's seed, this will be different across all nodes/workers | |
| seed = pytorch_worker_seed(epoch) | |
| else: | |
| # This seed to be deterministic AND the same across all nodes/workers in each epoch | |
| seed = self.seed + epoch | |
| rng.seed(seed) | |
| return _shuffle(src, self.bufsize, self.initial, rng) | |
| class ResampledShards2(IterableDataset): | |
| """An iterable dataset yielding a list of urls.""" | |
| def __init__( | |
| self, | |
| urls, | |
| weights=None, | |
| nshards=sys.maxsize, | |
| worker_seed=None, | |
| deterministic=False, | |
| epoch=-1, | |
| ): | |
| """Sample shards from the shard list with replacement. | |
| :param urls: a list of URLs as a Python list or brace notation string | |
| """ | |
| super().__init__() | |
| urls, weights = expand_urls(urls, weights) | |
| self.urls = urls | |
| self.weights = weights | |
| if self.weights is not None: | |
| assert len(self.urls) == len(self.weights), f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match." | |
| assert isinstance(self.urls[0], str) | |
| self.nshards = nshards | |
| self.rng = random.Random() | |
| self.worker_seed = worker_seed | |
| self.deterministic = deterministic | |
| self.epoch = epoch | |
| def __iter__(self): | |
| """Return an iterator over the shards.""" | |
| if isinstance(self.epoch, SharedEpoch): | |
| epoch = self.epoch.get_value() | |
| else: | |
| # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) | |
| # situation as different workers may wrap at different times (or not at all). | |
| self.epoch += 1 | |
| epoch = self.epoch | |
| if self.deterministic: | |
| # reset seed w/ epoch if deterministic | |
| if self.worker_seed is None: | |
| # pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id | |
| seed = pytorch_worker_seed(epoch) | |
| else: | |
| seed = self.worker_seed() + epoch | |
| self.rng.seed(seed) | |
| for _ in range(self.nshards): | |
| if self.weights is None: | |
| yield dict(url=self.rng.choice(self.urls)) | |
| else: | |
| yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0]) | |
| def get_wds_dataset_filter(args, preprocess_img): | |
| input_shards = args.train_data | |
| assert input_shards is not None | |
| pipeline = [wds.SimpleShardList(input_shards)] | |
| def replicate_img(sample): | |
| import copy | |
| sample["original"] = copy.copy(sample["image"]) | |
| return sample | |
| def decode_byte_to_rgb(sample): | |
| # import io | |
| # import PIL | |
| # from PIL import ImageFile | |
| # ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| with io.BytesIO(sample["image"]) as stream: | |
| try: | |
| img = PIL.Image.open(stream) | |
| img.load() | |
| img = img.convert("RGB") | |
| sample["image"] = img | |
| return sample | |
| except: | |
| print("A broken image is encountered, replace w/ a placeholder") | |
| image = Image.new('RGB', (512, 512)) | |
| sample["image"] = image | |
| return sample | |
| # at this point we have an iterator over all the shards | |
| pipeline.extend([ | |
| wds.split_by_node, | |
| wds.split_by_worker, | |
| tarfile_to_samples_nothrow, | |
| wds.select(filter_no_caption_or_no_image), | |
| # wds.decode("pilrgb", handler=log_and_continue), | |
| wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
| # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
| wds.map(replicate_img), | |
| wds.map(decode_byte_to_rgb), | |
| # wds.map_dict(image=preprocess_img, text=lambda x: x.encode('utf-8'), \ | |
| # __key__=lambda x: x.encode('utf-8'), __url__=lambda x: x.encode('utf-8')), | |
| wds.map_dict(image=preprocess_img), | |
| wds.to_tuple("original", "image", "text", "__key__", "__url__", "json"), | |
| wds.batched(args.batch_size, partial=True) | |
| ]) | |
| dataset = wds.DataPipeline(*pipeline) | |
| dataloader = wds.WebLoader( | |
| dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=args.workers, | |
| persistent_workers=True, | |
| drop_last=False | |
| ) | |
| return DataInfo(dataloader=dataloader) | |
| def get_wds_dataset_cond_face(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False, filter_lowres=False): | |
| input_shards = args.train_data if is_train else args.val_data | |
| assert input_shards is not None | |
| resampled = getattr(args, 'dataset_resampled', False) and is_train | |
| num_samples, num_shards = get_dataset_size(input_shards) | |
| if not num_samples: | |
| if is_train: | |
| num_samples = args.train_num_samples | |
| if not num_samples: | |
| raise RuntimeError( | |
| 'Currently, number of dataset samples must be specified for training dataset. ' | |
| 'Please specify via `--train-num-samples` if no dataset length info present.') | |
| else: | |
| num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
| shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
| if resampled: | |
| pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
| else: | |
| assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
| pipeline = [wds.SimpleShardList(input_shards)] | |
| # at this point we have an iterator over all the shards | |
| if is_train: | |
| if not resampled: | |
| pipeline.extend([ | |
| detshuffle2( | |
| bufsize=_SHARD_SHUFFLE_SIZE, | |
| initial=_SHARD_SHUFFLE_INITIAL, | |
| seed=args.seed, | |
| epoch=shared_epoch, | |
| ), | |
| wds.split_by_node, | |
| wds.split_by_worker, | |
| ]) | |
| pipeline.extend([ | |
| # at this point, we have an iterator over the shards assigned to each worker at each node | |
| tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
| wds.shuffle( | |
| bufsize=_SAMPLE_SHUFFLE_SIZE, | |
| initial=_SAMPLE_SHUFFLE_INITIAL, | |
| ), | |
| ]) | |
| else: | |
| pipeline.extend([ | |
| wds.split_by_worker, | |
| # at this point, we have an iterator over the shards assigned to each worker | |
| wds.tarfile_to_samples(handler=log_and_continue), | |
| ]) | |
| def preprocess_image(sample): | |
| # print(main_args.resolution, main_args.center_crop, main_args.random_flip) | |
| resize_transform = transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BICUBIC) | |
| sample["image"] = resize_transform(sample["image"]) | |
| sample["ldmk"] = resize_transform(sample["ldmk"]) | |
| transform_list = [] | |
| image_height, image_width = sample["image"].height, sample["image"].width | |
| i = torch.randint(0, image_height - main_args.resolution + 1, size=(1,)).item() | |
| j = torch.randint(0, image_width - main_args.resolution + 1, size=(1,)).item() | |
| if main_args.center_crop or not is_train: | |
| transform_list.append(transforms.CenterCrop(main_args.resolution)) | |
| else: | |
| if image_height < main_args.resolution or image_width < main_args.resolution: | |
| raise ValueError(f"Required crop size {(main_args.resolution, main_args.resolution)} is larger than input image size {(image_height, image_width)}") | |
| elif image_width == main_args.resolution and image_height == main_args.resolution: | |
| i, j = 0, 0 | |
| transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution, main_args.resolution))) | |
| if is_train and torch.rand(1) < 0.5: | |
| transform_list.append(transforms.RandomHorizontalFlip(p=1.)) | |
| transform_list.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) | |
| train_transforms = transforms.Compose(transform_list) | |
| sample["image"] = train_transforms(sample["image"]) | |
| sample["ldmk"] = train_transforms(sample["ldmk"]) | |
| return sample | |
| # def extract_ldmk(sample): | |
| # image_height, image_width = sample["image"].height, sample["image"].width | |
| # preds = fa.get_landmarks(np.array(sample["image"])) | |
| # lands = [] | |
| # if preds is not None: | |
| # for pred in preds: | |
| # land = pred.reshape(-1, 3)[:,:2].astype(int) | |
| # lands.append(land) | |
| # lms_color_map = np.zeros(shape=(image_height, image_width, 3)).astype("uint8") | |
| # if len(lands) > 0: | |
| # for land in lands: | |
| # lms_color_map = vis_landmark_on_img(lms_color_map, land) | |
| # # print(lms_color_map.shape) | |
| # sample["ldmk"] = Image.fromarray(lms_color_map) | |
| # return sample | |
| def visualize_ldmk(sample): | |
| image_height, image_width = sample["image"].height, sample["image"].width | |
| lands = np.frombuffer(sample["ldmk"], dtype=np.float32) | |
| lms_color_map = np.zeros(shape=(image_height, image_width, 3)).astype("uint8") | |
| if len(lands) > 0: | |
| lands = lands.reshape(-1, 68, 3).astype(int) | |
| for i in range(lands.shape[0]): | |
| lms_color_map = vis_landmark_on_img(lms_color_map, lands[i]) | |
| # print(lms_color_map.shape) | |
| sample["ldmk"] = Image.fromarray(lms_color_map) | |
| return sample | |
| def filter_ldmk_none(sample): | |
| return not (sample["ldmk"] == -1).all() | |
| def filter_low_res(sample): | |
| if filter_lowres: | |
| string_json = sample["json"].decode('utf-8') | |
| dict_json = json.loads(string_json) | |
| if "height" in dict_json.keys() and "width" in dict_json.keys(): | |
| min_length = min(dict_json["height"], dict_json["width"]) | |
| return min_length >= main_args.resolution | |
| else: | |
| return True | |
| return True | |
| pipeline.extend([ | |
| wds.select(filter_no_caption_or_no_image), | |
| wds.select(filter_low_res), | |
| wds.decode("pilrgb", handler=log_and_continue), | |
| wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
| # wds.map(extract_ldmk), | |
| wds.map(visualize_ldmk), | |
| wds.map(preprocess_image), | |
| wds.select(filter_ldmk_none), | |
| # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
| wds.map_dict( | |
| text=lambda text: tokenizer(text, \ | |
| max_length=tokenizer.model_max_length, \ | |
| padding="max_length", truncation=True, \ | |
| return_tensors='pt')['input_ids'], | |
| ), | |
| wds.to_tuple("image", "text", "ldmk"), | |
| # wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), | |
| wds.batched(args.batch_size, partial=not is_train) | |
| ]) | |
| dataset = wds.DataPipeline(*pipeline) | |
| if is_train: | |
| if not resampled: | |
| assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
| # roll over and repeat a few samples to get same number of full batches on each node | |
| round_fn = math.floor if floor else math.ceil | |
| global_batch_size = args.batch_size * args.world_size | |
| num_batches = round_fn(num_samples / global_batch_size) | |
| num_workers = max(1, args.workers) | |
| num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
| num_batches = num_worker_batches * num_workers | |
| num_samples = num_batches * global_batch_size | |
| dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
| else: | |
| # last batches are partial, eval is done on single (master) node | |
| num_batches = math.ceil(num_samples / args.batch_size) | |
| dataloader = wds.WebLoader( | |
| dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=args.workers, | |
| persistent_workers=True, | |
| ) | |
| # FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
| # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
| # if is_train: | |
| # # roll over and repeat a few samples to get same number of full batches on each node | |
| # global_batch_size = args.batch_size * args.world_size | |
| # num_batches = math.ceil(num_samples / global_batch_size) | |
| # num_workers = max(1, args.workers) | |
| # num_batches = math.ceil(num_batches / num_workers) * num_workers | |
| # num_samples = num_batches * global_batch_size | |
| # dataloader = dataloader.with_epoch(num_batches) | |
| # else: | |
| # # last batches are partial, eval is done on single (master) node | |
| # num_batches = math.ceil(num_samples / args.batch_size) | |
| # add meta-data to dataloader instance for convenience | |
| dataloader.num_batches = num_batches | |
| dataloader.num_samples = num_samples | |
| return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
| def get_wds_dataset_depth(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False, filter_lowres=False, filter_mface=False, filter_wpose=False): | |
| input_shards = args.train_data if is_train else args.val_data | |
| assert input_shards is not None | |
| resampled = getattr(args, 'dataset_resampled', False) and is_train | |
| num_samples, num_shards = get_dataset_size(input_shards) | |
| if not num_samples: | |
| if is_train: | |
| num_samples = args.train_num_samples | |
| if not num_samples: | |
| raise RuntimeError( | |
| 'Currently, number of dataset samples must be specified for training dataset. ' | |
| 'Please specify via `--train-num-samples` if no dataset length info present.') | |
| else: | |
| num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
| shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
| if resampled: | |
| pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
| else: | |
| assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
| pipeline = [wds.SimpleShardList(input_shards)] | |
| # at this point we have an iterator over all the shards | |
| if is_train: | |
| if not resampled: | |
| pipeline.extend([ | |
| detshuffle2( | |
| bufsize=_SHARD_SHUFFLE_SIZE, | |
| initial=_SHARD_SHUFFLE_INITIAL, | |
| seed=args.seed, | |
| epoch=shared_epoch, | |
| ), | |
| wds.split_by_node, | |
| wds.split_by_worker, | |
| ]) | |
| pipeline.extend([ | |
| # at this point, we have an iterator over the shards assigned to each worker at each node | |
| tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
| wds.shuffle( | |
| bufsize=_SAMPLE_SHUFFLE_SIZE, | |
| initial=_SAMPLE_SHUFFLE_INITIAL, | |
| ), | |
| ]) | |
| else: | |
| pipeline.extend([ | |
| wds.split_by_worker, | |
| # at this point, we have an iterator over the shards assigned to each worker | |
| wds.tarfile_to_samples(handler=log_and_continue), | |
| ]) | |
| def decode_image(sample): | |
| with io.BytesIO(sample["omni_depth"]) as stream: | |
| try: | |
| img = PIL.Image.open(stream) | |
| img.load() | |
| img = img.convert("RGB") | |
| sample["depth"] = img | |
| except: | |
| print("A broken image is encountered, replace w/ a placeholder") | |
| image = Image.new('RGB', (512, 512)) | |
| sample["depth"] = image | |
| return sample | |
| train_transforms = transforms.Compose( | |
| [ | |
| transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), | |
| transforms.CenterCrop(main_args.resolution) if main_args.center_crop else transforms.RandomCrop(main_args.resolution), | |
| transforms.RandomHorizontalFlip() if main_args.random_flip else transforms.Lambda(lambda x: x), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ] | |
| ) | |
| def filter_depth_none(sample): | |
| return not (sample["depth"] == -1).all() | |
| def filter_low_res(sample): | |
| if filter_lowres: | |
| string_json = sample["json"].decode('utf-8') | |
| dict_json = json.loads(string_json) | |
| if "height" in dict_json.keys() and "width" in dict_json.keys(): | |
| min_length = min(dict_json["height"], dict_json["width"]) | |
| return min_length >= main_args.resolution | |
| else: | |
| return True | |
| return True | |
| def filter_multi_face(sample): | |
| if filter_mface: | |
| face_kp = np.frombuffer(sample["face_kp"], dtype=np.float32).reshape(-1, 98, 2) | |
| if face_kp.shape[0] > 1: | |
| return False | |
| return True | |
| def filter_whole_skeleton(sample): | |
| if filter_wpose: | |
| height, width = sample["image"].height, sample["image"].width | |
| body_kp = np.frombuffer(sample["body_kp"], dtype=np.float32).reshape(17, 2) | |
| if (body_kp[:, 0] > 0).all() and (body_kp[:, 0] < width).all() and (body_kp[:, 1] > 0).all() and (body_kp[:, 1] < height).all(): | |
| return True | |
| else: | |
| return False | |
| return True | |
| pipeline.extend([ | |
| wds.select(filter_no_caption_or_no_image), | |
| wds.select(filter_multi_face), | |
| wds.select(filter_low_res), | |
| wds.decode("pilrgb", handler=log_and_continue), | |
| wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
| wds.select(filter_whole_skeleton), | |
| wds.map(decode_image), | |
| wds.map_dict(depth=train_transforms), | |
| wds.select(filter_depth_none), | |
| # wds.map_dict(depth=train_transforms, text=lambda text: tokenizer(text)[0]), | |
| wds.map_dict( | |
| text=lambda text: tokenizer(text, \ | |
| max_length=tokenizer.model_max_length, \ | |
| padding="max_length", truncation=True, \ | |
| return_tensors='pt')['input_ids']), | |
| wds.to_tuple("depth", "text"), | |
| # wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), | |
| wds.batched(args.batch_size, partial=not is_train) | |
| ]) | |
| dataset = wds.DataPipeline(*pipeline) | |
| if is_train: | |
| if not resampled: | |
| assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
| # roll over and repeat a few samples to get same number of full batches on each node | |
| round_fn = math.floor if floor else math.ceil | |
| global_batch_size = args.batch_size * args.world_size | |
| num_batches = round_fn(num_samples / global_batch_size) | |
| num_workers = max(1, args.workers) | |
| num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
| num_batches = num_worker_batches * num_workers | |
| num_samples = num_batches * global_batch_size | |
| dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
| else: | |
| # last batches are partial, eval is done on single (master) node | |
| num_batches = math.ceil(num_samples / args.batch_size) | |
| dataloader = wds.WebLoader( | |
| dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=args.workers, | |
| persistent_workers=True, | |
| ) | |
| # FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
| # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
| # if is_train: | |
| # # roll over and repeat a few samples to get same number of full batches on each node | |
| # global_batch_size = args.batch_size * args.world_size | |
| # num_batches = math.ceil(num_samples / global_batch_size) | |
| # num_workers = max(1, args.workers) | |
| # num_batches = math.ceil(num_batches / num_workers) * num_workers | |
| # num_samples = num_batches * global_batch_size | |
| # dataloader = dataloader.with_epoch(num_batches) | |
| # else: | |
| # # last batches are partial, eval is done on single (master) node | |
| # num_batches = math.ceil(num_samples / args.batch_size) | |
| # add meta-data to dataloader instance for convenience | |
| dataloader.num_batches = num_batches | |
| dataloader.num_samples = num_samples | |
| return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
| def get_wds_dataset_depth2canny(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False): | |
| input_shards = args.train_data if is_train else args.val_data | |
| assert input_shards is not None | |
| resampled = getattr(args, 'dataset_resampled', False) and is_train | |
| num_samples, num_shards = get_dataset_size(input_shards) | |
| if not num_samples: | |
| if is_train: | |
| num_samples = args.train_num_samples | |
| if not num_samples: | |
| raise RuntimeError( | |
| 'Currently, number of dataset samples must be specified for training dataset. ' | |
| 'Please specify via `--train-num-samples` if no dataset length info present.') | |
| else: | |
| num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
| shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
| if resampled: | |
| pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
| else: | |
| assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
| pipeline = [wds.SimpleShardList(input_shards)] | |
| # at this point we have an iterator over all the shards | |
| if is_train: | |
| if not resampled: | |
| pipeline.extend([ | |
| detshuffle2( | |
| bufsize=_SHARD_SHUFFLE_SIZE, | |
| initial=_SHARD_SHUFFLE_INITIAL, | |
| seed=args.seed, | |
| epoch=shared_epoch, | |
| ), | |
| wds.split_by_node, | |
| wds.split_by_worker, | |
| ]) | |
| pipeline.extend([ | |
| # at this point, we have an iterator over the shards assigned to each worker at each node | |
| tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
| wds.shuffle( | |
| bufsize=_SAMPLE_SHUFFLE_SIZE, | |
| initial=_SAMPLE_SHUFFLE_INITIAL, | |
| ), | |
| ]) | |
| else: | |
| pipeline.extend([ | |
| wds.split_by_worker, | |
| # at this point, we have an iterator over the shards assigned to each worker | |
| wds.tarfile_to_samples(handler=log_and_continue), | |
| ]) | |
| def decode_image(sample): | |
| with io.BytesIO(sample["omni_depth"]) as stream: | |
| try: | |
| img = PIL.Image.open(stream) | |
| img.load() | |
| img = img.convert("RGB") | |
| sample["depth"] = img | |
| except: | |
| print("A broken image is encountered, replace w/ a placeholder") | |
| image = Image.new('RGB', (512, 512)) | |
| sample["depth"] = image | |
| return sample | |
| def add_canny(sample): | |
| canny = np.array(sample["image"]) | |
| low_threshold = 100 | |
| high_threshold = 200 | |
| canny = cv2.Canny(canny, low_threshold, high_threshold) | |
| canny = canny[:, :, None] | |
| canny = np.concatenate([canny, canny, canny], axis=2) | |
| sample["canny"] = Image.fromarray(canny) | |
| return sample | |
| def preprocess_image(sample): | |
| # print(main_args.resolution, main_args.center_crop, main_args.random_flip) | |
| if grid_dnc: | |
| resize_transform = transforms.Resize(main_args.resolution // 2, interpolation=transforms.InterpolationMode.BICUBIC) | |
| else: | |
| resize_transform = transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BICUBIC) | |
| sample["image"] = resize_transform(sample["image"]) | |
| sample["canny"] = resize_transform(sample["canny"]) | |
| sample["depth"] = resize_transform(sample["depth"]) | |
| transform_list = [] | |
| image_height, image_width = sample["image"].height, sample["image"].width | |
| if grid_dnc: | |
| i = torch.randint(0, image_height - main_args.resolution // 2 + 1, size=(1,)).item() | |
| j = torch.randint(0, image_width - main_args.resolution // 2 + 1, size=(1,)).item() | |
| else: | |
| i = torch.randint(0, image_height - main_args.resolution + 1, size=(1,)).item() | |
| j = torch.randint(0, image_width - main_args.resolution + 1, size=(1,)).item() | |
| if main_args.center_crop or not is_train: | |
| if grid_dnc: | |
| transform_list.append(transforms.CenterCrop(main_args.resolution // 2)) | |
| else: | |
| transform_list.append(transforms.CenterCrop(main_args.resolution)) | |
| else: | |
| if grid_dnc: | |
| if image_height < main_args.resolution // 2 or image_width < main_args.resolution // 2: | |
| raise ValueError(f"Required crop size {(main_args.resolution // 2, main_args.resolution // 2)} is larger than input image size {(image_height, image_width)}") | |
| elif image_width == main_args.resolution // 2 and image_height == main_args.resolution // 2: | |
| i, j = 0, 0 | |
| transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution // 2, main_args.resolution // 2))) | |
| else: | |
| if image_height < main_args.resolution or image_width < main_args.resolution: | |
| raise ValueError(f"Required crop size {(main_args.resolution, main_args.resolution)} is larger than input image size {(image_height, image_width)}") | |
| elif image_width == main_args.resolution and image_height == main_args.resolution: | |
| i, j = 0, 0 | |
| transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution, main_args.resolution))) | |
| if is_train and torch.rand(1) < 0.5: | |
| transform_list.append(transforms.RandomHorizontalFlip(p=1.)) | |
| transform_list.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) | |
| train_transforms = transforms.Compose(transform_list) | |
| sample["image"] = train_transforms(sample["image"]) | |
| sample["canny"] = train_transforms(sample["canny"]) | |
| sample["depth"] = train_transforms(sample["depth"]) | |
| return sample | |
| def random_mask(sample): | |
| if is_train and dropout: | |
| random_num = torch.rand(1) | |
| if random_num < 0.1: | |
| sample["depth"] = torch.ones_like(sample["depth"]) * (-1) | |
| return sample | |
| pipeline.extend([ | |
| wds.select(filter_no_caption_or_no_image), | |
| wds.decode("pilrgb", handler=log_and_continue), | |
| wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
| wds.map(decode_image), | |
| wds.map(add_canny), | |
| wds.map(preprocess_image), | |
| wds.map(random_mask), | |
| # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
| wds.map_dict( | |
| text=lambda text: tokenizer(text, \ | |
| max_length=tokenizer.model_max_length, \ | |
| padding="max_length", truncation=True, \ | |
| return_tensors='pt')['input_ids'], | |
| ), | |
| wds.to_tuple("canny", "text", "depth"), | |
| # wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), | |
| wds.batched(args.batch_size, partial=not is_train) | |
| ]) | |
| dataset = wds.DataPipeline(*pipeline) | |
| if is_train: | |
| if not resampled: | |
| assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
| # roll over and repeat a few samples to get same number of full batches on each node | |
| round_fn = math.floor if floor else math.ceil | |
| global_batch_size = args.batch_size * args.world_size | |
| num_batches = round_fn(num_samples / global_batch_size) | |
| num_workers = max(1, args.workers) | |
| num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
| num_batches = num_worker_batches * num_workers | |
| num_samples = num_batches * global_batch_size | |
| dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
| else: | |
| # last batches are partial, eval is done on single (master) node | |
| num_batches = math.ceil(num_samples / args.batch_size) | |
| dataloader = wds.WebLoader( | |
| dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=args.workers, | |
| persistent_workers=True, | |
| ) | |
| # FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
| # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
| # if is_train: | |
| # # roll over and repeat a few samples to get same number of full batches on each node | |
| # global_batch_size = args.batch_size * args.world_size | |
| # num_batches = math.ceil(num_samples / global_batch_size) | |
| # num_workers = max(1, args.workers) | |
| # num_batches = math.ceil(num_batches / num_workers) * num_workers | |
| # num_samples = num_batches * global_batch_size | |
| # dataloader = dataloader.with_epoch(num_batches) | |
| # else: | |
| # # last batches are partial, eval is done on single (master) node | |
| # num_batches = math.ceil(num_samples / args.batch_size) | |
| # add meta-data to dataloader instance for convenience | |
| dataloader.num_batches = num_batches | |
| dataloader.num_samples = num_samples | |
| return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
| def get_wds_dataset_depth2normal(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False, filter_lowres=False): | |
| input_shards = args.train_data if is_train else args.val_data | |
| assert input_shards is not None | |
| resampled = getattr(args, 'dataset_resampled', False) and is_train | |
| num_samples, num_shards = get_dataset_size(input_shards) | |
| if not num_samples: | |
| if is_train: | |
| num_samples = args.train_num_samples | |
| if not num_samples: | |
| raise RuntimeError( | |
| 'Currently, number of dataset samples must be specified for training dataset. ' | |
| 'Please specify via `--train-num-samples` if no dataset length info present.') | |
| else: | |
| num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
| shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
| if resampled: | |
| pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
| else: | |
| assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
| pipeline = [wds.SimpleShardList(input_shards)] | |
| # at this point we have an iterator over all the shards | |
| if is_train: | |
| if not resampled: | |
| pipeline.extend([ | |
| detshuffle2( | |
| bufsize=_SHARD_SHUFFLE_SIZE, | |
| initial=_SHARD_SHUFFLE_INITIAL, | |
| seed=args.seed, | |
| epoch=shared_epoch, | |
| ), | |
| wds.split_by_node, | |
| wds.split_by_worker, | |
| ]) | |
| pipeline.extend([ | |
| # at this point, we have an iterator over the shards assigned to each worker at each node | |
| tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
| wds.shuffle( | |
| bufsize=_SAMPLE_SHUFFLE_SIZE, | |
| initial=_SAMPLE_SHUFFLE_INITIAL, | |
| ), | |
| ]) | |
| else: | |
| pipeline.extend([ | |
| wds.split_by_worker, | |
| # at this point, we have an iterator over the shards assigned to each worker | |
| wds.tarfile_to_samples(handler=log_and_continue), | |
| ]) | |
| def decode_image(sample): | |
| with io.BytesIO(sample["omni_normal"]) as stream: | |
| try: | |
| img = PIL.Image.open(stream) | |
| img.load() | |
| img = img.convert("RGB") | |
| sample["normal"] = img | |
| except: | |
| print("A broken image is encountered, replace w/ a placeholder") | |
| image = Image.new('RGB', (512, 512)) | |
| sample["normal"] = image | |
| with io.BytesIO(sample["omni_depth"]) as stream: | |
| try: | |
| img = PIL.Image.open(stream) | |
| img.load() | |
| img = img.convert("RGB") | |
| sample["depth"] = img | |
| except: | |
| print("A broken image is encountered, replace w/ a placeholder") | |
| image = Image.new('RGB', (512, 512)) | |
| sample["depth"] = image | |
| return sample | |
| def preprocess_image(sample): | |
| # print(main_args.resolution, main_args.center_crop, main_args.random_flip) | |
| if grid_dnc: | |
| resize_transform = transforms.Resize(main_args.resolution // 2, interpolation=transforms.InterpolationMode.BICUBIC) | |
| else: | |
| resize_transform = transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BICUBIC) | |
| sample["image"] = resize_transform(sample["image"]) | |
| sample["normal"] = resize_transform(sample["normal"]) | |
| sample["depth"] = resize_transform(sample["depth"]) | |
| transform_list = [] | |
| image_height, image_width = sample["image"].height, sample["image"].width | |
| if grid_dnc: | |
| i = torch.randint(0, image_height - main_args.resolution // 2 + 1, size=(1,)).item() | |
| j = torch.randint(0, image_width - main_args.resolution // 2 + 1, size=(1,)).item() | |
| else: | |
| i = torch.randint(0, image_height - main_args.resolution + 1, size=(1,)).item() | |
| j = torch.randint(0, image_width - main_args.resolution + 1, size=(1,)).item() | |
| if main_args.center_crop or not is_train: | |
| if grid_dnc: | |
| transform_list.append(transforms.CenterCrop(main_args.resolution // 2)) | |
| else: | |
| transform_list.append(transforms.CenterCrop(main_args.resolution)) | |
| else: | |
| if grid_dnc: | |
| if image_height < main_args.resolution // 2 or image_width < main_args.resolution // 2: | |
| raise ValueError(f"Required crop size {(main_args.resolution // 2, main_args.resolution // 2)} is larger than input image size {(image_height, image_width)}") | |
| elif image_width == main_args.resolution // 2 and image_height == main_args.resolution // 2: | |
| i, j = 0, 0 | |
| transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution // 2, main_args.resolution // 2))) | |
| else: | |
| if image_height < main_args.resolution or image_width < main_args.resolution: | |
| raise ValueError(f"Required crop size {(main_args.resolution, main_args.resolution)} is larger than input image size {(image_height, image_width)}") | |
| elif image_width == main_args.resolution and image_height == main_args.resolution: | |
| i, j = 0, 0 | |
| transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution, main_args.resolution))) | |
| if is_train and torch.rand(1) < 0.5: | |
| transform_list.append(transforms.RandomHorizontalFlip(p=1.)) | |
| transform_list.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) | |
| train_transforms = transforms.Compose(transform_list) | |
| sample["image"] = train_transforms(sample["image"]) | |
| sample["normal"] = train_transforms(sample["normal"]) | |
| sample["depth"] = train_transforms(sample["depth"]) | |
| return sample | |
| def random_mask(sample): | |
| if is_train and dropout: | |
| random_num = torch.rand(1) | |
| if random_num < 0.1: | |
| sample["depth"] = torch.ones_like(sample["depth"]) * (-1) | |
| return sample | |
| def filter_low_res(sample): | |
| if filter_lowres: | |
| string_json = sample["json"].decode('utf-8') | |
| dict_json = json.loads(string_json) | |
| if "height" in dict_json.keys() and "width" in dict_json.keys(): | |
| min_length = min(dict_json["height"], dict_json["width"]) | |
| return min_length >= main_args.resolution | |
| else: | |
| return True | |
| return True | |
| pipeline.extend([ | |
| wds.select(filter_no_caption_or_no_image), | |
| wds.select(filter_low_res), | |
| wds.decode("pilrgb", handler=log_and_continue), | |
| wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
| wds.map(decode_image), | |
| wds.map(preprocess_image), | |
| wds.map(random_mask), | |
| # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
| wds.map_dict( | |
| text=lambda text: tokenizer(text, \ | |
| max_length=tokenizer.model_max_length, \ | |
| padding="max_length", truncation=True, \ | |
| return_tensors='pt')['input_ids'], | |
| ), | |
| wds.to_tuple("normal", "text", "depth"), | |
| # wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), | |
| wds.batched(args.batch_size, partial=not is_train) | |
| ]) | |
| dataset = wds.DataPipeline(*pipeline) | |
| if is_train: | |
| if not resampled: | |
| assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
| # roll over and repeat a few samples to get same number of full batches on each node | |
| round_fn = math.floor if floor else math.ceil | |
| global_batch_size = args.batch_size * args.world_size | |
| num_batches = round_fn(num_samples / global_batch_size) | |
| num_workers = max(1, args.workers) | |
| num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
| num_batches = num_worker_batches * num_workers | |
| num_samples = num_batches * global_batch_size | |
| dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
| else: | |
| # last batches are partial, eval is done on single (master) node | |
| num_batches = math.ceil(num_samples / args.batch_size) | |
| dataloader = wds.WebLoader( | |
| dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=args.workers, | |
| persistent_workers=True, | |
| ) | |
| # FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
| # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
| # if is_train: | |
| # # roll over and repeat a few samples to get same number of full batches on each node | |
| # global_batch_size = args.batch_size * args.world_size | |
| # num_batches = math.ceil(num_samples / global_batch_size) | |
| # num_workers = max(1, args.workers) | |
| # num_batches = math.ceil(num_batches / num_workers) * num_workers | |
| # num_samples = num_batches * global_batch_size | |
| # dataloader = dataloader.with_epoch(num_batches) | |
| # else: | |
| # # last batches are partial, eval is done on single (master) node | |
| # num_batches = math.ceil(num_samples / args.batch_size) | |
| # add meta-data to dataloader instance for convenience | |
| dataloader.num_batches = num_batches | |
| dataloader.num_samples = num_samples | |
| return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
| # def get_wds_dataset_cond_sdxl(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False, filter_lowres=False): | |
| # input_shards = args.train_data if is_train else args.val_data | |
| # assert input_shards is not None | |
| # resampled = getattr(args, 'dataset_resampled', False) and is_train | |
| # num_samples, num_shards = get_dataset_size(input_shards) | |
| # if not num_samples: | |
| # if is_train: | |
| # num_samples = args.train_num_samples | |
| # if not num_samples: | |
| # raise RuntimeError( | |
| # 'Currently, number of dataset samples must be specified for training dataset. ' | |
| # 'Please specify via `--train-num-samples` if no dataset length info present.') | |
| # else: | |
| # num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
| # shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
| # if resampled: | |
| # pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
| # else: | |
| # assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
| # pipeline = [wds.SimpleShardList(input_shards)] | |
| # # at this point we have an iterator over all the shards | |
| # if is_train: | |
| # if not resampled: | |
| # pipeline.extend([ | |
| # detshuffle2( | |
| # bufsize=_SHARD_SHUFFLE_SIZE, | |
| # initial=_SHARD_SHUFFLE_INITIAL, | |
| # seed=args.seed, | |
| # epoch=shared_epoch, | |
| # ), | |
| # wds.split_by_node, | |
| # wds.split_by_worker, | |
| # ]) | |
| # pipeline.extend([ | |
| # # at this point, we have an iterator over the shards assigned to each worker at each node | |
| # tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
| # wds.shuffle( | |
| # bufsize=_SAMPLE_SHUFFLE_SIZE, | |
| # initial=_SAMPLE_SHUFFLE_INITIAL, | |
| # ), | |
| # ]) | |
| # else: | |
| # pipeline.extend([ | |
| # wds.split_by_worker, | |
| # # at this point, we have an iterator over the shards assigned to each worker | |
| # wds.tarfile_to_samples(handler=log_and_continue), | |
| # ]) | |
| # def pose2img(sample): | |
| # height, width = sample["image"].height, sample["image"].width | |
| # min_length = min(height, width) | |
| # radius_body = max(int(4. * min_length / main_args.resolution), 4) | |
| # thickness_body = max(int(2. * min_length / main_args.resolution), 2) | |
| # radius_face = max(int(2. * min_length / main_args.resolution), 2) | |
| # thickness_face = max(int(1. * min_length / main_args.resolution), 1) | |
| # radius_hand = max(int(2. * min_length / main_args.resolution), 2) | |
| # thickness_hand = max(int(1. * min_length / main_args.resolution), 1) | |
| # # if "getty" in sample["__url__"]: | |
| # # radius_body *= 4 | |
| # # thickness_body *= 4 | |
| # # radius_face *= 4 | |
| # # thickness_face *= 4 | |
| # # radius_hand *= 4 | |
| # # thickness_hand *= 4 | |
| # body_kp = np.frombuffer(sample["body_kp"], dtype=np.float32).reshape(17, 2) | |
| # body_kpconf = np.frombuffer(sample["body_kpconf"], dtype=np.float32) | |
| # body_all = np.concatenate([body_kp, body_kpconf[:, np.newaxis]], axis=1) | |
| # body_all = body_all[np.newaxis, ...] | |
| # body_draw = draw_body_skeleton( | |
| # img=None, | |
| # pose=body_all, | |
| # radius=radius_body, | |
| # thickness=thickness_body, | |
| # height=height, | |
| # width=width | |
| # ) | |
| # body_draw = Image.fromarray(body_draw) | |
| # face_kp = np.frombuffer(sample["face_kp"], dtype=np.float32).reshape(-1, 98, 2) | |
| # face_kpconf = np.frombuffer(sample["face_kpconf"], dtype=np.float32).reshape(-1, 98) | |
| # face_all = np.concatenate([face_kp, face_kpconf[..., np.newaxis]], axis=2) | |
| # face_draw = draw_face_skeleton( | |
| # # img=np.array(img), | |
| # img=None, | |
| # pose=face_all, | |
| # radius=radius_face, | |
| # thickness=thickness_face, | |
| # height=height, | |
| # width=width | |
| # ) | |
| # face_draw = Image.fromarray(face_draw) | |
| # hand_kp = np.frombuffer(sample["hand_kp"], dtype=np.float32).reshape(-1, 21, 2) | |
| # hand_kpconf = np.frombuffer(sample["hand_kpconf"], dtype=np.float32).reshape(-1, 21) | |
| # hand_all = np.concatenate([hand_kp, hand_kpconf[..., np.newaxis]], axis=2) | |
| # hand_draw = draw_hand_skeleton( | |
| # # img=np.array(img), | |
| # img=None, | |
| # pose=hand_all, | |
| # radius=radius_hand, | |
| # thickness=thickness_hand, | |
| # height=height, | |
| # width=width | |
| # ) | |
| # hand_draw = Image.fromarray(hand_draw) | |
| # sample["body"] = body_draw | |
| # sample["face"] = face_draw | |
| # sample["hand"] = hand_draw | |
| # return sample | |
| # def decode_image(sample): | |
| # with io.BytesIO(sample["omni_normal"]) as stream: | |
| # try: | |
| # img = PIL.Image.open(stream) | |
| # img.load() | |
| # img = img.convert("RGB") | |
| # sample["normal"] = img | |
| # except: | |
| # print("A broken image is encountered, replace w/ a placeholder") | |
| # image = Image.new('RGB', (512, 512)) | |
| # sample["normal"] = image | |
| # with io.BytesIO(sample["omni_depth"]) as stream: | |
| # try: | |
| # img = PIL.Image.open(stream) | |
| # img.load() | |
| # img = img.convert("RGB") | |
| # sample["depth"] = img | |
| # except: | |
| # print("A broken image is encountered, replace w/ a placeholder") | |
| # image = Image.new('RGB', (512, 512)) | |
| # sample["depth"] = image | |
| # return sample | |
| # def add_canny(sample): | |
| # canny = np.array(sample["image"]) | |
| # low_threshold = 100 | |
| # high_threshold = 200 | |
| # canny = cv2.Canny(canny, low_threshold, high_threshold) | |
| # canny = canny[:, :, None] | |
| # canny = np.concatenate([canny, canny, canny], axis=2) | |
| # sample["canny"] = Image.fromarray(canny) | |
| # return sample | |
| # def decode_text(sample): | |
| # sample["blip"] = sample["blip"].decode("utf-8") | |
| # sample["blip_raw"] = sample["blip"] | |
| # sample["text_raw"] = sample["text"] | |
| # return sample | |
| # def augment_text(sample): | |
| # if is_train and string_concat: | |
| # sample["text"] = sample["text"] + " " + sample["blip"] | |
| # if is_train and string_substitute: | |
| # sample["text"] = sample["blip"] | |
| # return sample | |
| # def preprocess_image(sample): | |
| # # print(main_args.resolution, main_args.center_crop, main_args.random_flip) | |
| # if grid_dnc: | |
| # resize_transform = transforms.Resize(main_args.resolution // 2, interpolation=transforms.InterpolationMode.BICUBIC) | |
| # else: | |
| # resize_transform = transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BICUBIC) | |
| # sample["image"] = resize_transform(sample["image"]) | |
| # sample["normal"] = resize_transform(sample["normal"]) | |
| # sample["depth"] = resize_transform(sample["depth"]) | |
| # sample["canny"] = resize_transform(sample["canny"]) | |
| # sample["body"] = resize_transform(sample["body"]) | |
| # sample["face"] = resize_transform(sample["face"]) | |
| # sample["hand"] = resize_transform(sample["hand"]) | |
| # transform_list = [] | |
| # image_height, image_width = sample["image"].height, sample["image"].width | |
| # if grid_dnc: | |
| # i = torch.randint(0, image_height - main_args.resolution // 2 + 1, size=(1,)).item() | |
| # j = torch.randint(0, image_width - main_args.resolution // 2 + 1, size=(1,)).item() | |
| # else: | |
| # i = torch.randint(0, image_height - main_args.resolution + 1, size=(1,)).item() | |
| # j = torch.randint(0, image_width - main_args.resolution + 1, size=(1,)).item() | |
| # if main_args.center_crop or not is_train: | |
| # sample["description"]["crop_tl_h"] = (image_height - main_args.resolution) // 2 | |
| # sample["description"]["crop_tl_w"] = (image_width - main_args.resolution) // 2 | |
| # if grid_dnc: | |
| # transform_list.append(transforms.CenterCrop(main_args.resolution // 2)) | |
| # else: | |
| # transform_list.append(transforms.CenterCrop(main_args.resolution)) | |
| # else: | |
| # if grid_dnc: | |
| # if image_height < main_args.resolution // 2 or image_width < main_args.resolution // 2: | |
| # raise ValueError(f"Required crop size {(main_args.resolution // 2, main_args.resolution // 2)} is larger than input image size {(image_height, image_width)}") | |
| # elif image_width == main_args.resolution // 2 and image_height == main_args.resolution // 2: | |
| # i, j = 0, 0 | |
| # transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution // 2, main_args.resolution // 2))) | |
| # else: | |
| # if image_height < main_args.resolution or image_width < main_args.resolution: | |
| # raise ValueError(f"Required crop size {(main_args.resolution, main_args.resolution)} is larger than input image size {(image_height, image_width)}") | |
| # elif image_width == main_args.resolution and image_height == main_args.resolution: | |
| # i, j = 0, 0 | |
| # transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution, main_args.resolution))) | |
| # sample["description"]["crop_tl_h"] = i | |
| # sample["description"]["crop_tl_w"] = j | |
| # if is_train and torch.rand(1) < 0.5: | |
| # transform_list.append(transforms.RandomHorizontalFlip(p=1.)) | |
| # transform_list.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) | |
| # train_transforms = transforms.Compose(transform_list) | |
| # sample["image"] = train_transforms(sample["image"]) | |
| # sample["normal"] = train_transforms(sample["normal"]) | |
| # sample["depth"] = train_transforms(sample["depth"]) | |
| # sample["canny"] = train_transforms(sample["canny"]) | |
| # sample["body"] = train_transforms(sample["body"]) | |
| # sample["face"] = train_transforms(sample["face"]) | |
| # sample["hand"] = train_transforms(sample["hand"]) | |
| # return sample | |
| # def random_mask(sample): | |
| # if is_train and dropout: | |
| # random_num = torch.rand(1) | |
| # if random_num < 0.1: | |
| # sample["normal"] = torch.ones_like(sample["normal"]) * (-1) | |
| # sample["depth"] = torch.ones_like(sample["depth"]) * (-1) | |
| # sample["canny"] = torch.ones_like(sample["canny"]) * (-1) | |
| # sample["body"] = torch.ones_like(sample["body"]) * (-1) | |
| # sample["face"] = torch.ones_like(sample["face"]) * (-1) | |
| # sample["hand"] = torch.ones_like(sample["hand"]) * (-1) | |
| # elif random_num > 0.9: | |
| # pass | |
| # else: | |
| # if torch.rand(1) < 0.5: | |
| # sample["normal"] = torch.ones_like(sample["normal"]) * (-1) | |
| # if torch.rand(1) < 0.5: | |
| # sample["depth"] = torch.ones_like(sample["depth"]) * (-1) | |
| # if torch.rand(1) < 0.8: | |
| # sample["canny"] = torch.ones_like(sample["canny"]) * (-1) | |
| # if torch.rand(1) < 0.5: | |
| # sample["body"] = torch.ones_like(sample["body"]) * (-1) | |
| # if torch.rand(1) < 0.5: | |
| # sample["face"] = torch.ones_like(sample["face"]) * (-1) | |
| # if torch.rand(1) < 0.2: | |
| # sample["hand"] = torch.ones_like(sample["hand"]) * (-1) | |
| # return sample | |
| # def make_grid_dnc(sample): | |
| # if grid_dnc: | |
| # resized_image = transforms.functional.resize(sample["image"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
| # resized_depth = transforms.functional.resize(sample["depth"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
| # resized_normal = transforms.functional.resize(sample["normal"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
| # resized_canny = transforms.functional.resize(sample["canny"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
| # grid = torch.cat([torch.cat([resized_image, resized_depth], dim=2), | |
| # torch.cat([resized_normal, resized_canny], dim=2)], dim=1) | |
| # assert grid.shape[1] == main_args.resolution and grid.shape[2] == main_args.resolution | |
| # sample["image"] = grid | |
| # return sample | |
| # def filter_low_res(sample): | |
| # if main_args.filter_res is None: | |
| # main_args.filter_res = main_args.resolution | |
| # if filter_lowres: | |
| # string_json = sample["json"].decode('utf-8') | |
| # dict_json = json.loads(string_json) | |
| # if "height" in dict_json.keys() and "width" in dict_json.keys(): | |
| # min_length = min(dict_json["height"], dict_json["width"]) | |
| # return min_length >= main_args.filter_res | |
| # else: | |
| # return True | |
| # return True | |
| # def add_original_hw(sample): | |
| # image_height, image_width = sample["image"].height, sample["image"].width | |
| # sample["description"] = {"h": image_height, "w": image_width} | |
| # return sample | |
| # def add_description(sample): | |
| # # string_json = sample["json"].decode('utf-8') | |
| # # dict_json = json.loads(string_json) | |
| # dict_json = sample["json"] | |
| # if "height" in dict_json.keys() and "width" in dict_json.keys(): | |
| # sample["description"]["h"] = dict_json["height"] | |
| # sample["description"]["w"] = dict_json["width"] | |
| # return sample | |
| # pipeline.extend([ | |
| # wds.select(filter_no_caption_or_no_image), | |
| # wds.select(filter_low_res), | |
| # wds.decode("pilrgb", handler=log_and_continue), | |
| # wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
| # wds.map(add_original_hw), | |
| # wds.map(decode_text), | |
| # wds.map(augment_text), | |
| # wds.map(pose2img), | |
| # wds.map(decode_image), | |
| # wds.map(add_canny), | |
| # wds.map(preprocess_image), | |
| # wds.map(make_grid_dnc), | |
| # wds.map(random_mask), | |
| # wds.map(add_description), | |
| # # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
| # # wds.map_dict( | |
| # # text=lambda text: tokenizer(text, \ | |
| # # max_length=tokenizer.model_max_length, \ | |
| # # padding="max_length", truncation=True, \ | |
| # # return_tensors='pt')['input_ids'], | |
| # # blip=lambda blip: tokenizer(blip, \ | |
| # # max_length=tokenizer.model_max_length, \ | |
| # # padding="max_length", truncation=True, \ | |
| # # return_tensors='pt')['input_ids'] | |
| # # ), | |
| # wds.to_tuple("image", "text", "text_raw", "blip", "blip_raw", "body", "face", "hand", "normal", "depth", "canny", "description"), | |
| # # wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), | |
| # wds.batched(args.batch_size, partial=not is_train) | |
| # ]) | |
| # dataset = wds.DataPipeline(*pipeline) | |
| # if is_train: | |
| # if not resampled: | |
| # assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
| # # roll over and repeat a few samples to get same number of full batches on each node | |
| # round_fn = math.floor if floor else math.ceil | |
| # global_batch_size = args.batch_size * args.world_size | |
| # num_batches = round_fn(num_samples / global_batch_size) | |
| # num_workers = max(1, args.workers) | |
| # num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
| # num_batches = num_worker_batches * num_workers | |
| # num_samples = num_batches * global_batch_size | |
| # dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
| # else: | |
| # # last batches are partial, eval is done on single (master) node | |
| # num_batches = math.ceil(num_samples / args.batch_size) | |
| # dataloader = wds.WebLoader( | |
| # dataset, | |
| # batch_size=None, | |
| # shuffle=False, | |
| # num_workers=args.workers, | |
| # persistent_workers=True, | |
| # ) | |
| # # FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
| # # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
| # # if is_train: | |
| # # # roll over and repeat a few samples to get same number of full batches on each node | |
| # # global_batch_size = args.batch_size * args.world_size | |
| # # num_batches = math.ceil(num_samples / global_batch_size) | |
| # # num_workers = max(1, args.workers) | |
| # # num_batches = math.ceil(num_batches / num_workers) * num_workers | |
| # # num_samples = num_batches * global_batch_size | |
| # # dataloader = dataloader.with_epoch(num_batches) | |
| # # else: | |
| # # # last batches are partial, eval is done on single (master) node | |
| # # num_batches = math.ceil(num_samples / args.batch_size) | |
| # # add meta-data to dataloader instance for convenience | |
| # dataloader.num_batches = num_batches | |
| # dataloader.num_samples = num_samples | |
| # return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
| def get_wds_dataset_cond(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False, filter_lowres=False, filter_res=512, filter_mface=False, filter_wpose=False): | |
| input_shards = args.train_data if is_train else args.val_data | |
| assert input_shards is not None | |
| resampled = getattr(args, 'dataset_resampled', False) and is_train | |
| num_samples, num_shards = get_dataset_size(input_shards) | |
| if not num_samples: | |
| if is_train: | |
| num_samples = args.train_num_samples | |
| if not num_samples: | |
| raise RuntimeError( | |
| 'Currently, number of dataset samples must be specified for training dataset. ' | |
| 'Please specify via `--train-num-samples` if no dataset length info present.') | |
| else: | |
| num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
| shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
| if resampled: | |
| pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
| else: | |
| assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
| pipeline = [wds.SimpleShardList(input_shards)] | |
| # at this point we have an iterator over all the shards | |
| if is_train: | |
| if not resampled: | |
| pipeline.extend([ | |
| detshuffle2( | |
| bufsize=_SHARD_SHUFFLE_SIZE, | |
| initial=_SHARD_SHUFFLE_INITIAL, | |
| seed=args.seed, | |
| epoch=shared_epoch, | |
| ), | |
| wds.split_by_node, | |
| wds.split_by_worker, | |
| ]) | |
| pipeline.extend([ | |
| # at this point, we have an iterator over the shards assigned to each worker at each node | |
| tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
| wds.shuffle( | |
| bufsize=_SAMPLE_SHUFFLE_SIZE, | |
| initial=_SAMPLE_SHUFFLE_INITIAL, | |
| ), | |
| ]) | |
| else: | |
| pipeline.extend([ | |
| wds.split_by_worker, | |
| # at this point, we have an iterator over the shards assigned to each worker | |
| wds.tarfile_to_samples(handler=log_and_continue), | |
| ]) | |
| def pose2img(sample, scale): | |
| height, width = sample["image"].height, sample["image"].width | |
| # min_length = min(height, width) | |
| # radius_body = int(4. * min_length / main_args.resolution) | |
| # thickness_body = int(4. * min_length / main_args.resolution) | |
| # radius_face = int(1.5 * min_length / main_args.resolution) | |
| # thickness_face = int(2. * min_length / main_args.resolution) | |
| # radius_hand = int(1.5 * min_length / main_args.resolution) | |
| # thickness_hand = int(2. * min_length / main_args.resolution) | |
| # if "getty" in sample["__url__"]: | |
| # radius_body *= 4 | |
| # thickness_body *= 4 | |
| # radius_face *= 4 | |
| # thickness_face *= 4 | |
| # radius_hand *= 4 | |
| # thickness_hand *= 4 | |
| try: | |
| location = np.frombuffer(sample["location"], dtype=np.float32) | |
| body_kp = np.frombuffer(sample["new_i_body_kp"], dtype=np.float32).reshape(-1, 17, 2) | |
| x_coord = (body_kp[:, :, 0] - location[0]) / location[2] * location[7] | |
| y_coord = (body_kp[:, :, 1] - location[1]) / location[3] * location[8] | |
| body_kp = np.stack([x_coord, y_coord], axis=2) | |
| body_kp = body_kp * scale | |
| # body_kp[:, :, 0] -= j | |
| # body_kp[:, :, 1] -= i | |
| body_kpconf = np.frombuffer(sample["new_i_body_kp_score"], dtype=np.float32).reshape(-1, 17) | |
| body_all = np.concatenate([body_kp, body_kpconf[..., np.newaxis]], axis=2) | |
| except: | |
| body_kp = np.frombuffer(sample["new_body_kp"], dtype=np.float32).reshape(-1, 17, 2) | |
| body_kp = body_kp * scale | |
| # body_kp[:, :, 0] -= j | |
| # body_kp[:, :, 1] -= i | |
| body_kpconf = np.frombuffer(sample["new_body_kp_score"], dtype=np.float32).reshape(-1, 17) | |
| body_all = np.concatenate([body_kp, body_kpconf[..., np.newaxis]], axis=2) | |
| # body_ratio = 0. | |
| # for i_body in range(body_kp.shape[0]): | |
| # body_ratio = max((np.max(body_kp[i_body, :, 0]) - np.min(body_kp[i_body, :, 0])) / min_length, body_ratio) | |
| # print(body_ratio) | |
| # body_kp = np.frombuffer(sample["new_body_kp"], dtype=np.float32).reshape(-1, 17, 2) | |
| # body_kpconf = np.frombuffer(sample["new_body_kp_score"], dtype=np.float32).reshape(-1, 17) | |
| # body_all = np.concatenate([body_kp, body_kpconf[..., np.newaxis]], axis=2) | |
| # body_draw = draw_controlnet_skeleton(image=None, pose=body_all, height=height, width=width) | |
| # body_draw = draw_humansd_skeleton(image=None, pose=body_all, height=height, width=width, humansd_skeleton_width=int(10. * body_ratio * min_length / main_args.resolution)) | |
| body_draw = draw_humansd_skeleton( | |
| # image=np.array(sample["image"]), | |
| image=None, | |
| pose=body_all, | |
| height=height, | |
| width=width, | |
| humansd_skeleton_width=int(10 * main_args.resolution / 512), | |
| ) | |
| # body_draw = draw_body_skeleton( | |
| # img=None, | |
| # pose=body_all, | |
| # radius=radius_body, | |
| # thickness=thickness_body, | |
| # height=height, | |
| # width=width | |
| # ) | |
| body_draw = Image.fromarray(body_draw) | |
| try: | |
| location = np.frombuffer(sample["location"], dtype=np.float32) | |
| face_kp = np.frombuffer(sample["new_i_face_kp"], dtype=np.float32).reshape(-1, 68, 2) | |
| x_coord = (face_kp[:, :, 0] - location[0]) / location[2] * location[7] | |
| y_coord = (face_kp[:, :, 1] - location[1]) / location[3] * location[8] | |
| face_kp = np.stack([x_coord, y_coord], axis=2) | |
| face_kp = face_kp * scale | |
| # face_kp[:, :, 0] -= j | |
| # face_kp[:, :, 1] -= i | |
| face_kpconf = np.frombuffer(sample["new_i_face_kp_score"], dtype=np.float32).reshape(-1, 68) | |
| face_all = np.concatenate([face_kp, face_kpconf[..., np.newaxis]], axis=2) | |
| except: | |
| face_kp = np.frombuffer(sample["new_face_kp"], dtype=np.float32).reshape(-1, 68, 2) | |
| face_kp = face_kp * scale | |
| # face_kp[:, :, 0] -= j | |
| # face_kp[:, :, 1] -= i | |
| face_kpconf = np.frombuffer(sample["new_face_kp_score"], dtype=np.float32).reshape(-1, 68) | |
| face_all = np.concatenate([face_kp, face_kpconf[..., np.newaxis]], axis=2) | |
| face_draw = draw_face_skeleton( | |
| # img=np.array(sample["image"]), | |
| img=None, | |
| pose=face_all, | |
| # radius=radius_face, | |
| # thickness=thickness_face, | |
| height=height, | |
| width=width, | |
| ) | |
| face_draw = Image.fromarray(face_draw) | |
| try: | |
| location = np.frombuffer(sample["location"], dtype=np.float32) | |
| hand_kp = np.frombuffer(sample["new_i_hand_kp"], dtype=np.float32).reshape(-1, 21, 2) | |
| x_coord = (hand_kp[:, :, 0] - location[0]) / location[2] * location[7] | |
| y_coord = (hand_kp[:, :, 1] - location[1]) / location[3] * location[8] | |
| hand_kp = np.stack([x_coord, y_coord], axis=2) | |
| hand_kp = hand_kp * scale | |
| # hand_kp[:, :, 0] -= j | |
| # hand_kp[:, :, 1] -= i | |
| hand_kpconf = np.frombuffer(sample["new_i_hand_kp_score"], dtype=np.float32).reshape(-1, 21) | |
| hand_all = np.concatenate([hand_kp, hand_kpconf[..., np.newaxis]], axis=2) | |
| except: | |
| hand_kp = np.frombuffer(sample["new_hand_kp"], dtype=np.float32).reshape(-1, 21, 2) | |
| hand_kp = hand_kp * scale | |
| # hand_kp[:, :, 0] -= j | |
| # hand_kp[:, :, 1] -= i | |
| hand_kpconf = np.frombuffer(sample["new_hand_kp_score"], dtype=np.float32).reshape(-1, 21) | |
| hand_all = np.concatenate([hand_kp, hand_kpconf[..., np.newaxis]], axis=2) | |
| hand_draw = draw_hand_skeleton( | |
| # img=np.array(sample["image"]), | |
| img=None, | |
| pose=hand_all, | |
| # radius=radius_hand, | |
| # thickness=thickness_hand, | |
| height=height, | |
| width=width, | |
| ) | |
| hand_draw = Image.fromarray(hand_draw) | |
| # whole_kp = np.frombuffer(sample["new_wholebody_kp"], dtype=np.float32).reshape(-1, 133, 2) | |
| # whole_kpconf = np.frombuffer(sample["new_wholebody_kp_score"], dtype=np.float32).reshape(-1, 133) | |
| # whole_all = np.concatenate([whole_kp, whole_kpconf[..., np.newaxis]], axis=2) | |
| try: | |
| location = np.frombuffer(sample["location"], dtype=np.float32) | |
| whole_kp = np.frombuffer(sample["new_i_wholebody_kp"], dtype=np.float32).reshape(-1, 133, 2) | |
| x_coord = (whole_kp[:, :, 0] - location[0]) / location[2] * location[7] | |
| y_coord = (whole_kp[:, :, 1] - location[1]) / location[3] * location[8] | |
| whole_kp = np.stack([x_coord, y_coord], axis=2) | |
| whole_kp = whole_kp * scale | |
| # whole_kp[:, :, 0] -= j | |
| # whole_kp[:, :, 1] -= i | |
| whole_kpconf = np.frombuffer(sample["new_i_wholebody_kp_score"], dtype=np.float32).reshape(-1, 133) | |
| whole_all = np.concatenate([whole_kp, whole_kpconf[..., np.newaxis]], axis=2) | |
| except: | |
| whole_kp = np.frombuffer(sample["new_wholebody_kp"], dtype=np.float32).reshape(-1, 133, 2) | |
| whole_kp = whole_kp * scale | |
| # whole_kp[:, :, 0] -= j | |
| # whole_kp[:, :, 1] -= i | |
| whole_kpconf = np.frombuffer(sample["new_wholebody_kp_score"], dtype=np.float32).reshape(-1, 133) | |
| whole_all = np.concatenate([whole_kp, whole_kpconf[..., np.newaxis]], axis=2) | |
| whole_draw = draw_whole_body_skeleton( | |
| # img=np.array(sample["image"]), | |
| img=None, | |
| pose=whole_all, | |
| # radius=radius_body, | |
| # thickness=thickness_body, | |
| height=height, | |
| width=width, | |
| ) | |
| whole_draw = Image.fromarray(whole_draw) | |
| sample["body"] = body_draw | |
| sample["face"] = face_draw | |
| sample["hand"] = hand_draw | |
| if main_args.change_whole_to_body: | |
| sample["whole"] = body_draw | |
| else: | |
| sample["whole"] = whole_draw | |
| return sample | |
| def decode_image(sample): | |
| with io.BytesIO(sample["omni_normal"]) as stream: | |
| try: | |
| img = PIL.Image.open(stream) | |
| img.load() | |
| img = img.convert("RGB") | |
| img = transforms.Resize((sample["image"].height, sample["image"].width), interpolation=transforms.InterpolationMode.BICUBIC)(img) | |
| sample["normal"] = img | |
| except: | |
| print("A broken image is encountered, replace w/ a placeholder") | |
| image = Image.new('RGB', (main_args.resolution, main_args.resolution)) | |
| sample["normal"] = image | |
| with io.BytesIO(sample["omni_depth"]) as stream: | |
| try: | |
| img = PIL.Image.open(stream) | |
| img.load() | |
| img = img.convert("RGB") | |
| img = transforms.Resize((sample["image"].height, sample["image"].width), interpolation=transforms.InterpolationMode.BICUBIC)(img) | |
| sample["depth"] = img | |
| except: | |
| print("A broken image is encountered, replace w/ a placeholder") | |
| image = Image.new('RGB', (main_args.resolution, main_args.resolution)) | |
| sample["depth"] = image | |
| with io.BytesIO(sample["midas_depth"]) as stream: | |
| try: | |
| img = PIL.Image.open(stream) | |
| img.load() | |
| img = img.convert("RGB") | |
| img = transforms.Resize((sample["image"].height, sample["image"].width), interpolation=transforms.InterpolationMode.BICUBIC)(img) | |
| sample["midas_depth"] = img | |
| except: | |
| print("A broken image is encountered, replace w/ a placeholder") | |
| image = Image.new('RGB', (main_args.resolution, main_args.resolution)) | |
| sample["midas_depth"] = image | |
| return sample | |
| def add_canny(sample): | |
| canny = np.array(sample["image"]) | |
| low_threshold = 100 | |
| high_threshold = 200 | |
| canny = cv2.Canny(canny, low_threshold, high_threshold) | |
| canny = canny[:, :, None] | |
| canny = np.concatenate([canny, canny, canny], axis=2) | |
| sample["canny"] = Image.fromarray(canny) | |
| return sample | |
| def decode_text(sample): | |
| try: | |
| sample["blip"] = sample["blip"].decode("utf-8") | |
| sample["blip_raw"] = sample["blip"] | |
| except: | |
| sample["blip"] = sample["text"] | |
| sample["blip_raw"] = sample["text"].encode("utf-8") | |
| sample["text_raw"] = sample["text"] | |
| return sample | |
| def augment_text(sample): | |
| if is_train and string_concat: | |
| sample["text"] = sample["text"] + " " + sample["blip"] | |
| if is_train and string_substitute: | |
| if main_args.rv_prompt: | |
| sample["text"] = "RAW photo, " + sample["blip"] + ", 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" | |
| else: | |
| sample["text"] = sample["blip"] | |
| return sample | |
| def dropout_text(sample): | |
| if is_train: | |
| try: | |
| random_num = torch.rand(1) | |
| if random_num < main_args.dropout_text: | |
| sample["text"] = sample["text_raw"] = "" | |
| except: | |
| pass | |
| return sample | |
| def preprocess_image(sample): | |
| # print(main_args.resolution, main_args.center_crop, main_args.random_flip) | |
| if grid_dnc: | |
| resize_transform = transforms.Resize(main_args.resolution // 2, interpolation=transforms.InterpolationMode.BICUBIC) | |
| else: | |
| resize_transform = transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BICUBIC) | |
| scale = main_args.resolution * 1. / min(sample["image"].height, sample["image"].width) | |
| sample["image"] = resize_transform(sample["image"]) | |
| sample["normal"] = resize_transform(sample["normal"]) | |
| sample["depth"] = resize_transform(sample["depth"]) | |
| sample["midas_depth"] = resize_transform(sample["midas_depth"]) | |
| sample["canny"] = resize_transform(sample["canny"]) | |
| # sample["body"] = resize_transform(sample["body"]) | |
| # sample["face"] = resize_transform(sample["face"]) | |
| # sample["hand"] = resize_transform(sample["hand"]) | |
| # sample["whole"] = resize_transform(sample["whole"]) | |
| transform_list = [] | |
| image_height, image_width = sample["image"].height, sample["image"].width | |
| if grid_dnc: | |
| i = torch.randint(0, image_height - main_args.resolution // 2 + 1, size=(1,)).item() | |
| j = torch.randint(0, image_width - main_args.resolution // 2 + 1, size=(1,)).item() | |
| else: | |
| i = torch.randint(0, image_height - main_args.resolution + 1, size=(1,)).item() | |
| j = torch.randint(0, image_width - main_args.resolution + 1, size=(1,)).item() | |
| if main_args.center_crop or not is_train: | |
| sample["description"]["crop_tl_h"] = i = (image_height - main_args.resolution) // 2 | |
| sample["description"]["crop_tl_w"] = j = (image_width - main_args.resolution) // 2 | |
| if grid_dnc: | |
| transform_list.append(transforms.CenterCrop(main_args.resolution // 2)) | |
| else: | |
| transform_list.append(transforms.CenterCrop(main_args.resolution)) | |
| else: | |
| if grid_dnc: | |
| if image_height < main_args.resolution // 2 or image_width < main_args.resolution // 2: | |
| raise ValueError(f"Required crop size {(main_args.resolution // 2, main_args.resolution // 2)} is larger than input image size {(image_height, image_width)}") | |
| elif image_width == main_args.resolution // 2 and image_height == main_args.resolution // 2: | |
| i, j = 0, 0 | |
| transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution // 2, main_args.resolution // 2))) | |
| else: | |
| if image_height < main_args.resolution or image_width < main_args.resolution: | |
| raise ValueError(f"Required crop size {(main_args.resolution, main_args.resolution)} is larger than input image size {(image_height, image_width)}") | |
| elif image_width == main_args.resolution and image_height == main_args.resolution: | |
| i, j = 0, 0 | |
| transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution, main_args.resolution))) | |
| sample["description"]["crop_tl_h"] = i | |
| sample["description"]["crop_tl_w"] = j | |
| sample = pose2img(sample, scale) | |
| if is_train and torch.rand(1) < 0.5: | |
| transform_list.append(transforms.RandomHorizontalFlip(p=1.)) | |
| transform_list.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) | |
| train_transforms = transforms.Compose(transform_list) | |
| sample["image"] = train_transforms(sample["image"]) | |
| sample["normal"] = train_transforms(sample["normal"]) | |
| sample["depth"] = train_transforms(sample["depth"]) | |
| sample["midas_depth"] = train_transforms(sample["midas_depth"]) | |
| sample["canny"] = train_transforms(sample["canny"]) | |
| sample["body"] = train_transforms(sample["body"]) | |
| sample["face"] = train_transforms(sample["face"]) | |
| sample["hand"] = train_transforms(sample["hand"]) | |
| sample["whole"] = train_transforms(sample["whole"]) | |
| return sample | |
| def random_mask(sample): | |
| sample["normal_ori"] = sample["normal"].clone() | |
| sample["depth_ori"] = sample["depth"].clone() | |
| sample["midas_depth_ori"] = sample["midas_depth"].clone() | |
| sample["canny_ori"] = sample["canny"].clone() | |
| sample["body_ori"] = sample["body"].clone() | |
| sample["face_ori"] = sample["face"].clone() | |
| sample["hand_ori"] = sample["hand"].clone() | |
| sample["whole_ori"] = sample["whole"].clone() | |
| mask_list = [] | |
| if is_train and dropout: | |
| random_num = torch.rand(1) | |
| if random_num < 0.15: | |
| sample["normal"] = torch.ones_like(sample["normal"]) * (-1) | |
| sample["depth"] = torch.ones_like(sample["depth"]) * (-1) | |
| sample["midas_depth"] = torch.ones_like(sample["midas_depth"]) * (-1) | |
| sample["canny"] = torch.ones_like(sample["canny"]) * (-1) | |
| sample["body"] = torch.ones_like(sample["body"]) * (-1) | |
| sample["face"] = torch.ones_like(sample["face"]) * (-1) | |
| sample["hand"] = torch.ones_like(sample["hand"]) * (-1) | |
| sample["whole"] = torch.ones_like(sample["whole"]) * (-1) | |
| mask_list = ["normal", "depth", "midas_depth", "canny", "body", "face", "hand", "whole"] | |
| elif random_num > 0.9: | |
| pass | |
| else: | |
| if torch.rand(1) < 0.5: | |
| sample["normal"] = torch.ones_like(sample["normal"]) * (-1) | |
| mask_list.append("normal") | |
| if torch.rand(1) < 0.5: | |
| sample["depth"] = torch.ones_like(sample["depth"]) * (-1) | |
| mask_list.append("depth") | |
| if torch.rand(1) < 0.5: | |
| sample["midas_depth"] = torch.ones_like(sample["midas_depth"]) * (-1) | |
| mask_list.append("midas_depth") | |
| if torch.rand(1) < 0.8: | |
| sample["canny"] = torch.ones_like(sample["canny"]) * (-1) | |
| mask_list.append("canny") | |
| if torch.rand(1) < 0.5: | |
| sample["body"] = torch.ones_like(sample["body"]) * (-1) | |
| mask_list.append("body") | |
| if torch.rand(1) < 0.5: | |
| sample["face"] = torch.ones_like(sample["face"]) * (-1) | |
| mask_list.append("face") | |
| if torch.rand(1) < 0.2: | |
| sample["hand"] = torch.ones_like(sample["hand"]) * (-1) | |
| mask_list.append("hand") | |
| if torch.rand(1) < 0.5: | |
| sample["whole"] = torch.ones_like(sample["whole"]) * (-1) | |
| mask_list.append("whole") | |
| sample["normal_dt"] = sample["normal"].clone() | |
| sample["depth_dt"] = sample["depth"].clone() | |
| sample["midas_depth_dt"] = sample["midas_depth"].clone() | |
| sample["canny_dt"] = sample["canny"].clone() | |
| sample["body_dt"] = sample["body"].clone() | |
| sample["face_dt"] = sample["face"].clone() | |
| sample["hand_dt"] = sample["hand"].clone() | |
| sample["whole_dt"] = sample["whole"].clone() | |
| mask_list = [x for x in mask_list if x in main_args.cond_type] | |
| if len(mask_list) > 0: | |
| target = random.choice(mask_list) | |
| sample[target + "_dt"] = sample[target + "_ori"].clone() | |
| else: | |
| if len(main_args.cond_type) > 0: | |
| target = random.choice(main_args.cond_type) | |
| sample[target + "_dt"] = torch.ones_like(sample[target]) * (-1) | |
| return sample | |
| def make_grid_dnc(sample): | |
| if grid_dnc: | |
| resized_image = transforms.functional.resize(sample["image"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
| resized_depth = transforms.functional.resize(sample["depth"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
| resized_normal = transforms.functional.resize(sample["normal"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
| resized_canny = transforms.functional.resize(sample["body"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
| # resized_canny = transforms.functional.resize(sample["canny"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
| grid = torch.cat([torch.cat([resized_image, resized_depth], dim=2), | |
| torch.cat([resized_normal, resized_canny], dim=2)], dim=1) | |
| assert grid.shape[1] == main_args.resolution and grid.shape[2] == main_args.resolution | |
| sample["image"] = grid | |
| return sample | |
| def filter_low_res(sample): | |
| if main_args.filter_res is None: | |
| main_args.filter_res = main_args.resolution | |
| if filter_lowres: | |
| # string_json = sample["json"].decode('utf-8') | |
| # dict_json = json.loads(string_json) | |
| dict_json = sample["json"] | |
| if "height" in dict_json.keys() and "width" in dict_json.keys(): | |
| min_length = min(dict_json["height"], dict_json["width"]) | |
| return min_length >= main_args.filter_res | |
| else: | |
| min_length = min(sample["image"].height, sample["image"].width) | |
| return min_length >= main_args.filter_res | |
| return True | |
| def filter_watermark(sample): | |
| if main_args.filter_wm: | |
| if sample["description"]["watermark"] >= 100: | |
| return False | |
| return True | |
| def add_original_hw(sample): | |
| image_height, image_width = sample["image"].height, sample["image"].width | |
| sample["description"] = {"h": image_height, "w": image_width} | |
| return sample | |
| def add_description(sample): | |
| # string_json = sample["json"].decode('utf-8') | |
| # dict_json = json.loads(string_json) | |
| try: | |
| dict_json = sample["json"] | |
| if "height" in dict_json.keys() and "width" in dict_json.keys(): | |
| sample["description"]["h"] = dict_json["height"] | |
| sample["description"]["w"] = dict_json["width"] | |
| # try: | |
| if "coyo" in sample["__url__"]: | |
| sample["description"]["aes"] = torch.tensor(sample["json"]["aesthetic_score_laion_v2"] * 1e2) | |
| sample["description"]["watermark"] = torch.tensor(sample["json"]["watermark_score"] * 1e3) | |
| elif "laion" in sample["__url__"]: | |
| sample["description"]["aes"] = torch.tensor(np.frombuffer(sample["aesthetic_score_laion_v2"], dtype=np.float32) * 1e2) | |
| sample["description"]["watermark"] = torch.tensor(np.frombuffer(sample["watermark_score"], dtype=np.float32) * 1e3) | |
| elif "getty" in sample["__url__"]: | |
| sample["description"]["aes"] = torch.tensor(np.frombuffer(sample["aesthetic_score_laion_v2"], dtype=np.float32) * 1e2) | |
| sample["description"]["watermark"] = torch.tensor(float(sample["json"]["display_sizes"][-1]["is_watermarked"] or 0) * 1e3) | |
| elif "fake" in sample["__url__"]: | |
| sample["description"]["aes"] = torch.tensor(random.uniform(5.5, 6.0) * 1e2) | |
| sample["description"]["watermark"] = torch.tensor(random.uniform(0., 0.1) * 1e3) | |
| except: | |
| # sample["description"]["h"] = | |
| # sample["description"]["w"] = | |
| sample["description"]["aes"] = torch.tensor(random.uniform(5.5, 6.0) * 1e2) | |
| sample["description"]["watermark"] = torch.tensor(random.uniform(0., 0.1) * 1e3) | |
| # except: | |
| # sample["description"]["aes"] = 0. | |
| # sample["description"]["watermark"] = 0. | |
| return sample | |
| def filter_multi_face(sample): | |
| if filter_mface: | |
| face_kp = np.frombuffer(sample["new_face_kp"], dtype=np.float32).reshape(-1, 68, 2) | |
| if face_kp.shape[0] > 1: | |
| return False | |
| return True | |
| def filter_whole_skeleton(sample): | |
| if filter_wpose: | |
| height, width = sample["image"].height, sample["image"].width | |
| area = height * width | |
| body_kp = np.frombuffer(sample["new_body_kp"], dtype=np.float32).reshape(-1, 17, 2) | |
| body_kpconf = np.frombuffer(sample["new_body_kp_score"], dtype=np.float32).reshape(-1, 17) | |
| if (body_kp.shape[0] == 1) and (body_kpconf > 0.5).all() and (body_kp[0, :15, 0] > 0).all() \ | |
| and (body_kp[0, :15, 0] < width).all() and (body_kp[0, :15, 1] > 0).all() and \ | |
| (body_kp[0, :15, 1] < height).all(): | |
| x_min = max(np.amin(body_kp[0, :, 0]), 0) | |
| x_max = min(np.amax(body_kp[0, :, 0]), width) | |
| y_min = max(np.amin(body_kp[0, :, 1]), 0) | |
| y_max = min(np.amax(body_kp[0, :, 1]), height) | |
| if (x_max - x_min) * (y_max - y_min) / area > 0.2: | |
| return True | |
| else: | |
| return False | |
| else: | |
| return False | |
| return True | |
| pipeline.extend([ | |
| wds.select(filter_no_caption_or_no_image), | |
| wds.select(filter_multi_face), | |
| wds.decode("pilrgb", handler=log_and_continue), | |
| wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
| wds.select(filter_whole_skeleton), | |
| wds.select(filter_low_res), | |
| wds.map(add_original_hw), | |
| wds.map(decode_text), | |
| wds.map(augment_text), | |
| wds.map(dropout_text), | |
| # wds.map(pose2img), | |
| wds.map(decode_image), | |
| wds.map(add_canny), | |
| wds.map(preprocess_image), | |
| wds.map(make_grid_dnc), | |
| wds.map(random_mask), | |
| wds.map(add_description), | |
| wds.select(filter_watermark), | |
| # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
| wds.map_dict( | |
| text=lambda text: tokenizer(text, \ | |
| max_length=tokenizer.model_max_length, \ | |
| padding="max_length", truncation=True, \ | |
| return_tensors='pt')['input_ids'], | |
| blip=lambda blip: tokenizer(blip, \ | |
| max_length=tokenizer.model_max_length, \ | |
| padding="max_length", truncation=True, \ | |
| return_tensors='pt')['input_ids'] | |
| ), | |
| wds.to_tuple("image", "text", "text_raw", "blip", "blip_raw", \ | |
| "body", "face", "hand", "normal", "depth", "midas_depth", "canny", "whole", "description", \ | |
| "body_ori", "face_ori", "hand_ori", "normal_ori", "depth_ori", "midas_depth_ori", "canny_ori", "whole_ori", \ | |
| "body_dt", "face_dt", "hand_dt", "normal_dt", "depth_dt", "midas_depth_dt", "canny_dt", "whole_dt"), | |
| # wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), | |
| wds.batched(args.batch_size, partial=not is_train) | |
| ]) | |
| dataset = wds.DataPipeline(*pipeline) | |
| if is_train: | |
| if not resampled: | |
| assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
| # roll over and repeat a few samples to get same number of full batches on each node | |
| round_fn = math.floor if floor else math.ceil | |
| global_batch_size = args.batch_size * args.world_size | |
| num_batches = round_fn(num_samples / global_batch_size) | |
| num_workers = max(1, args.workers) | |
| num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
| num_batches = num_worker_batches * num_workers | |
| num_samples = num_batches * global_batch_size | |
| dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
| else: | |
| # last batches are partial, eval is done on single (master) node | |
| num_batches = math.ceil(num_samples / args.batch_size) | |
| dataloader = wds.WebLoader( | |
| dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=args.workers, | |
| persistent_workers=True, | |
| ) | |
| # FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
| # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
| # if is_train: | |
| # # roll over and repeat a few samples to get same number of full batches on each node | |
| # global_batch_size = args.batch_size * args.world_size | |
| # num_batches = math.ceil(num_samples / global_batch_size) | |
| # num_workers = max(1, args.workers) | |
| # num_batches = math.ceil(num_batches / num_workers) * num_workers | |
| # num_samples = num_batches * global_batch_size | |
| # dataloader = dataloader.with_epoch(num_batches) | |
| # else: | |
| # # last batches are partial, eval is done on single (master) node | |
| # num_batches = math.ceil(num_samples / args.batch_size) | |
| # add meta-data to dataloader instance for convenience | |
| dataloader.num_batches = num_batches | |
| dataloader.num_samples = num_samples | |
| return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
| def get_wds_dataset_img(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None): | |
| input_shards = args.train_data if is_train else args.val_data | |
| assert input_shards is not None | |
| resampled = getattr(args, 'dataset_resampled', False) and is_train | |
| num_samples, num_shards = get_dataset_size(input_shards) | |
| if not num_samples: | |
| if is_train: | |
| num_samples = args.train_num_samples | |
| if not num_samples: | |
| raise RuntimeError( | |
| 'Currently, number of dataset samples must be specified for training dataset. ' | |
| 'Please specify via `--train-num-samples` if no dataset length info present.') | |
| else: | |
| num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
| shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
| if resampled: | |
| pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
| else: | |
| assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
| pipeline = [wds.SimpleShardList(input_shards)] | |
| # at this point we have an iterator over all the shards | |
| if is_train: | |
| if not resampled: | |
| pipeline.extend([ | |
| detshuffle2( | |
| bufsize=_SHARD_SHUFFLE_SIZE, | |
| initial=_SHARD_SHUFFLE_INITIAL, | |
| seed=args.seed, | |
| epoch=shared_epoch, | |
| ), | |
| wds.split_by_node, | |
| wds.split_by_worker, | |
| ]) | |
| pipeline.extend([ | |
| # at this point, we have an iterator over the shards assigned to each worker at each node | |
| tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
| wds.shuffle( | |
| bufsize=_SAMPLE_SHUFFLE_SIZE, | |
| initial=_SAMPLE_SHUFFLE_INITIAL, | |
| ), | |
| ]) | |
| else: | |
| pipeline.extend([ | |
| wds.split_by_worker, | |
| # at this point, we have an iterator over the shards assigned to each worker | |
| wds.tarfile_to_samples(handler=log_and_continue), | |
| ]) | |
| pipeline.extend([ | |
| wds.select(filter_no_caption_or_no_image), | |
| wds.decode("pilrgb", handler=log_and_continue), | |
| wds.rename(image="jpg;png;jpeg;webp"), | |
| # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
| wds.map_dict(image=preprocess_img), | |
| wds.to_tuple("image"), | |
| wds.batched(args.batch_size, partial=not is_train) | |
| ]) | |
| dataset = wds.DataPipeline(*pipeline) | |
| if is_train: | |
| if not resampled: | |
| assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
| # roll over and repeat a few samples to get same number of full batches on each node | |
| round_fn = math.floor if floor else math.ceil | |
| global_batch_size = args.batch_size * args.world_size | |
| num_batches = round_fn(num_samples / global_batch_size) | |
| num_workers = max(1, args.workers) | |
| num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
| num_batches = num_worker_batches * num_workers | |
| num_samples = num_batches * global_batch_size | |
| dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
| else: | |
| # last batches are partial, eval is done on single (master) node | |
| num_batches = math.ceil(num_samples / args.batch_size) | |
| dataloader = wds.WebLoader( | |
| dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=args.workers, | |
| persistent_workers=True, | |
| ) | |
| # FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
| # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
| # if is_train: | |
| # # roll over and repeat a few samples to get same number of full batches on each node | |
| # global_batch_size = args.batch_size * args.world_size | |
| # num_batches = math.ceil(num_samples / global_batch_size) | |
| # num_workers = max(1, args.workers) | |
| # num_batches = math.ceil(num_batches / num_workers) * num_workers | |
| # num_samples = num_batches * global_batch_size | |
| # dataloader = dataloader.with_epoch(num_batches) | |
| # else: | |
| # # last batches are partial, eval is done on single (master) node | |
| # num_batches = math.ceil(num_samples / args.batch_size) | |
| # add meta-data to dataloader instance for convenience | |
| dataloader.num_batches = num_batches | |
| dataloader.num_samples = num_samples | |
| return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
| def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None): | |
| input_shards = args.train_data if is_train else args.val_data | |
| assert input_shards is not None | |
| resampled = getattr(args, 'dataset_resampled', False) and is_train | |
| num_samples, num_shards = get_dataset_size(input_shards) | |
| if not num_samples: | |
| if is_train: | |
| num_samples = args.train_num_samples | |
| if not num_samples: | |
| raise RuntimeError( | |
| 'Currently, number of dataset samples must be specified for training dataset. ' | |
| 'Please specify via `--train-num-samples` if no dataset length info present.') | |
| else: | |
| num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
| shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
| if resampled: | |
| pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
| else: | |
| assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
| pipeline = [wds.SimpleShardList(input_shards)] | |
| # at this point we have an iterator over all the shards | |
| if is_train: | |
| if not resampled: | |
| pipeline.extend([ | |
| detshuffle2( | |
| bufsize=_SHARD_SHUFFLE_SIZE, | |
| initial=_SHARD_SHUFFLE_INITIAL, | |
| seed=args.seed, | |
| epoch=shared_epoch, | |
| ), | |
| wds.split_by_node, | |
| wds.split_by_worker, | |
| ]) | |
| pipeline.extend([ | |
| # at this point, we have an iterator over the shards assigned to each worker at each node | |
| tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
| wds.shuffle( | |
| bufsize=_SAMPLE_SHUFFLE_SIZE, | |
| initial=_SAMPLE_SHUFFLE_INITIAL, | |
| ), | |
| ]) | |
| else: | |
| pipeline.extend([ | |
| wds.split_by_worker, | |
| # at this point, we have an iterator over the shards assigned to each worker | |
| wds.tarfile_to_samples(handler=log_and_continue), | |
| ]) | |
| pipeline.extend([ | |
| wds.select(filter_no_caption_or_no_image), | |
| wds.decode("pilrgb", handler=log_and_continue), | |
| wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
| # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
| wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text, | |
| max_length=tokenizer.model_max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors='pt')['input_ids']), | |
| wds.to_tuple("image", "text"), | |
| wds.batched(args.batch_size, partial=not is_train) | |
| ]) | |
| dataset = wds.DataPipeline(*pipeline) | |
| if is_train: | |
| if not resampled: | |
| assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
| # roll over and repeat a few samples to get same number of full batches on each node | |
| round_fn = math.floor if floor else math.ceil | |
| global_batch_size = args.batch_size * args.world_size | |
| num_batches = round_fn(num_samples / global_batch_size) | |
| num_workers = max(1, args.workers) | |
| num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
| num_batches = num_worker_batches * num_workers | |
| num_samples = num_batches * global_batch_size | |
| dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
| else: | |
| # last batches are partial, eval is done on single (master) node | |
| num_batches = math.ceil(num_samples / args.batch_size) | |
| dataloader = wds.WebLoader( | |
| dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=args.workers, | |
| persistent_workers=True, | |
| ) | |
| # FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
| # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
| # if is_train: | |
| # # roll over and repeat a few samples to get same number of full batches on each node | |
| # global_batch_size = args.batch_size * args.world_size | |
| # num_batches = math.ceil(num_samples / global_batch_size) | |
| # num_workers = max(1, args.workers) | |
| # num_batches = math.ceil(num_batches / num_workers) * num_workers | |
| # num_samples = num_batches * global_batch_size | |
| # dataloader = dataloader.with_epoch(num_batches) | |
| # else: | |
| # # last batches are partial, eval is done on single (master) node | |
| # num_batches = math.ceil(num_samples / args.batch_size) | |
| # add meta-data to dataloader instance for convenience | |
| dataloader.num_batches = num_batches | |
| dataloader.num_samples = num_samples | |
| return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
| def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): | |
| input_filename = args.train_data if is_train else args.val_data | |
| assert input_filename | |
| dataset = CsvDataset( | |
| input_filename, | |
| preprocess_fn, | |
| img_key=args.csv_img_key, | |
| caption_key=args.csv_caption_key, | |
| sep=args.csv_separator, | |
| tokenizer=tokenizer | |
| ) | |
| num_samples = len(dataset) | |
| sampler = DistributedSampler(dataset) if args.distributed and is_train else None | |
| shuffle = is_train and sampler is None | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| shuffle=shuffle, | |
| num_workers=args.workers, | |
| pin_memory=True, | |
| sampler=sampler, | |
| drop_last=is_train, | |
| ) | |
| dataloader.num_samples = num_samples | |
| dataloader.num_batches = len(dataloader) | |
| return DataInfo(dataloader, sampler) | |
| class SyntheticDataset(Dataset): | |
| def __init__(self, transform=None, image_size=(224, 224), caption="Dummy caption", dataset_size=100, tokenizer=None): | |
| self.transform = transform | |
| self.image_size = image_size | |
| self.caption = caption | |
| self.image = Image.new('RGB', image_size) | |
| self.dataset_size = dataset_size | |
| self.preprocess_txt = lambda text: tokenizer(text)[0] | |
| def __len__(self): | |
| return self.dataset_size | |
| def __getitem__(self, idx): | |
| if self.transform is not None: | |
| image = self.transform(self.image) | |
| return image, self.preprocess_txt(self.caption) | |
| def get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): | |
| image_size = preprocess_fn.transforms[0].size | |
| dataset = SyntheticDataset( | |
| transform=preprocess_fn, image_size=image_size, dataset_size=args.train_num_samples, tokenizer=tokenizer) | |
| num_samples = len(dataset) | |
| sampler = DistributedSampler(dataset) if args.distributed and is_train else None | |
| shuffle = is_train and sampler is None | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| shuffle=shuffle, | |
| num_workers=args.workers, | |
| pin_memory=True, | |
| sampler=sampler, | |
| drop_last=is_train, | |
| ) | |
| dataloader.num_samples = num_samples | |
| dataloader.num_batches = len(dataloader) | |
| return DataInfo(dataloader, sampler) | |
| def get_dataset_fn(data_path, dataset_type): | |
| if dataset_type == "webdataset": | |
| return get_wds_dataset | |
| elif dataset_type == "csv": | |
| return get_csv_dataset | |
| elif dataset_type == "synthetic": | |
| return get_synthetic_dataset | |
| elif dataset_type == "auto": | |
| ext = data_path.split('.')[-1] | |
| if ext in ['csv', 'tsv']: | |
| return get_csv_dataset | |
| elif ext in ['tar']: | |
| return get_wds_dataset | |
| else: | |
| raise ValueError( | |
| f"Tried to figure out dataset type, but failed for extension {ext}.") | |
| else: | |
| raise ValueError(f"Unsupported dataset type: {dataset_type}") | |
| def get_data(args, preprocess_fns, epoch=0, tokenizer=None): | |
| preprocess_train, preprocess_val = preprocess_fns | |
| data = {} | |
| if args.train_data or args.dataset_type == "synthetic": | |
| data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( | |
| args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) | |
| if args.val_data: | |
| data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( | |
| args, preprocess_val, is_train=False, tokenizer=tokenizer) | |
| if args.imagenet_val is not None: | |
| data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val") | |
| if args.imagenet_v2 is not None: | |
| data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2") | |
| return data | |