Spaces:
Runtime error
Runtime error
import ast | |
import json | |
import logging | |
import math | |
import os | |
import random | |
import sys | |
import time | |
import braceexpand | |
from dataclasses import dataclass | |
from multiprocessing import Value | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torchvision.datasets as datasets | |
import webdataset as wds | |
from PIL import Image | |
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info | |
from torch.utils.data.distributed import DistributedSampler | |
from webdataset.filters import _shuffle | |
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample | |
from torchvision import transforms | |
import io | |
import PIL | |
from PIL import ImageFile | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
try: | |
import horovod.torch as hvd | |
except ImportError: | |
hvd = None | |
import cv2 | |
import math | |
import json | |
import random | |
import seaborn as sns | |
def vis_landmark_on_img(img, shape, linewidth=8): | |
''' | |
Visualize landmark on images. | |
''' | |
def draw_curve(idx_list, color=(0, 255, 0), loop=False, lineWidth=linewidth): | |
for i in idx_list: | |
cv2.line(img, (shape[i][0], shape[i][1]), (shape[i + 1][0], shape[i + 1][1]), color, lineWidth) | |
if (loop): | |
cv2.line(img, (shape[idx_list[0]][0], shape[idx_list[0]][1]), | |
(shape[idx_list[-1] + 1][0], shape[idx_list[-1] + 1][1]), color, lineWidth) | |
draw_curve(list(range(0, 16)), color=(255, 144, 25)) # jaw | |
draw_curve(list(range(17, 21)), color=(50, 205, 50)) # eye brow | |
draw_curve(list(range(22, 26)), color=(50, 205, 50)) | |
draw_curve(list(range(27, 35)), color=(208, 224, 63)) # nose | |
draw_curve(list(range(36, 41)), loop=True, color=(71, 99, 255)) # eyes | |
draw_curve(list(range(42, 47)), loop=True, color=(71, 99, 255)) | |
draw_curve(list(range(48, 59)), loop=True, color=(238, 130, 238)) # mouth | |
draw_curve(list(range(60, 67)), loop=True, color=(238, 130, 238)) | |
return img.astype("uint8") | |
def imshow_keypoints(img, | |
pose_result, | |
skeleton=None, | |
kpt_score_thr=0.3, | |
pose_kpt_color=None, | |
pose_link_color=None, | |
radius=4, | |
thickness=1, | |
show_keypoint_weight=False, | |
height=None, | |
width=None): | |
"""Draw keypoints and links on an image. | |
Args: | |
img (str or Tensor): The image to draw poses on. If an image array | |
is given, id will be modified in-place. | |
pose_result (list[kpts]): The poses to draw. Each element kpts is | |
a set of K keypoints as an Kx3 numpy.ndarray, where each | |
keypoint is represented as x, y, score. | |
kpt_score_thr (float, optional): Minimum score of keypoints | |
to be shown. Default: 0.3. | |
pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, | |
the keypoint will not be drawn. | |
pose_link_color (np.array[Mx3]): Color of M links. If None, the | |
links will not be drawn. | |
thickness (int): Thickness of lines. | |
""" | |
# img = mmcv.imread(img) | |
# img_h, img_w, _ = img.shape | |
if img is None: | |
img = np.zeros((height, width, 3), dtype=np.uint8) | |
img_h, img_w = height, width | |
else: | |
img_h, img_w, _ = img.shape | |
for kpts in pose_result: | |
kpts = np.array(kpts, copy=False) | |
# draw each point on image | |
if pose_kpt_color is not None: | |
assert len(pose_kpt_color) == len(kpts) | |
for kid, kpt in enumerate(kpts): | |
x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] | |
if kpt_score > kpt_score_thr: | |
color = tuple(int(c) for c in pose_kpt_color[kid]) | |
if show_keypoint_weight: | |
img_copy = img.copy() | |
cv2.circle(img_copy, (int(x_coord), int(y_coord)), | |
radius, color, -1) | |
transparency = max(0, min(1, kpt_score)) | |
cv2.addWeighted( | |
img_copy, | |
transparency, | |
img, | |
1 - transparency, | |
0, | |
dst=img) | |
else: | |
cv2.circle(img, (int(x_coord), int(y_coord)), radius, | |
color, -1) | |
# draw links | |
if skeleton is not None and pose_link_color is not None: | |
assert len(pose_link_color) == len(skeleton) | |
for sk_id, sk in enumerate(skeleton): | |
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) | |
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) | |
# if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0 | |
# and pos1[1] < img_h and pos2[0] > 0 and pos2[0] < img_w | |
# and pos2[1] > 0 and pos2[1] < img_h | |
# and kpts[sk[0], 2] > kpt_score_thr | |
# and kpts[sk[1], 2] > kpt_score_thr): | |
if (kpts[sk[0], 2] > kpt_score_thr | |
and kpts[sk[1], 2] > kpt_score_thr): | |
color = tuple(int(c) for c in pose_link_color[sk_id]) | |
if show_keypoint_weight: | |
img_copy = img.copy() | |
X = (pos1[0], pos2[0]) | |
Y = (pos1[1], pos2[1]) | |
mX = np.mean(X) | |
mY = np.mean(Y) | |
length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 | |
angle = math.degrees( | |
math.atan2(Y[0] - Y[1], X[0] - X[1])) | |
stickwidth = thickness | |
polygon = cv2.ellipse2Poly( | |
(int(mX), int(mY)), | |
(int(length / 2), int(stickwidth)), int(angle), 0, | |
360, 1) | |
cv2.fillConvexPoly(img_copy, polygon, color) | |
# transparency = max( | |
# 0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2]))) | |
transparency = 1 | |
cv2.addWeighted( | |
img_copy, | |
transparency, | |
img, | |
1 - transparency, | |
0, | |
dst=img) | |
else: | |
cv2.line(img, pos1, pos2, color, thickness=thickness) | |
return img | |
def imshow_keypoints_body(img, | |
pose_result, | |
skeleton=None, | |
kpt_score_thr=0.3, | |
pose_kpt_color=None, | |
pose_link_color=None, | |
radius=4, | |
thickness=1, | |
show_keypoint_weight=False, | |
height=None, | |
width=None): | |
"""Draw keypoints and links on an image. | |
Args: | |
img (str or Tensor): The image to draw poses on. If an image array | |
is given, id will be modified in-place. | |
pose_result (list[kpts]): The poses to draw. Each element kpts is | |
a set of K keypoints as an Kx3 numpy.ndarray, where each | |
keypoint is represented as x, y, score. | |
kpt_score_thr (float, optional): Minimum score of keypoints | |
to be shown. Default: 0.3. | |
pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, | |
the keypoint will not be drawn. | |
pose_link_color (np.array[Mx3]): Color of M links. If None, the | |
links will not be drawn. | |
thickness (int): Thickness of lines. | |
""" | |
# img = mmcv.imread(img) | |
# img_h, img_w, _ = img.shape | |
if img is None: | |
img = np.zeros((height, width, 3), dtype=np.uint8) | |
img_h, img_w = height, width | |
else: | |
img_h, img_w, _ = img.shape | |
for kpts in pose_result: | |
kpts = np.array(kpts, copy=False) | |
# draw each point on image | |
if pose_kpt_color is not None: | |
assert len(pose_kpt_color) == len(kpts) | |
for kid, kpt in enumerate(kpts): | |
if kid in [17, 18, 19, 20, 21, 22]: | |
continue | |
if kid in [13, 14, 15, 16]: | |
if kpt[0] > min(kpts[23:91, 0]) and kpt[0] < max(kpts[23:91, 0]) and kpt[1] > min(kpts[23:91, 1]) and kpt[1] < max(kpts[23:91, 1]): | |
continue | |
x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] | |
if kpt_score > kpt_score_thr: | |
color = tuple(int(c) for c in pose_kpt_color[kid]) | |
if show_keypoint_weight: | |
img_copy = img.copy() | |
cv2.circle(img_copy, (int(x_coord), int(y_coord)), | |
radius, color, -1) | |
transparency = max(0, min(1, kpt_score)) | |
cv2.addWeighted( | |
img_copy, | |
transparency, | |
img, | |
1 - transparency, | |
0, | |
dst=img) | |
else: | |
cv2.circle(img, (int(x_coord), int(y_coord)), radius, | |
color, -1) | |
# draw links | |
if skeleton is not None and pose_link_color is not None: | |
assert len(pose_link_color) == len(skeleton) | |
for sk_id, sk in enumerate(skeleton): | |
if sk[0] in [17, 18, 19, 20, 21, 22] or sk[1] in [17, 18, 19, 20, 21, 22]: | |
continue | |
if sk[0] in [13, 14, 15, 16]: | |
if kpts[sk[0], 0] > min(kpts[23:91, 0]) and kpts[sk[0], 0] < max(kpts[23:91, 0]) and kpts[sk[0], 1] > min(kpts[23:91, 1]) and kpts[sk[0], 1] < max(kpts[23:91, 1]): | |
continue | |
if sk[1] in [13, 14, 15, 16]: | |
if kpts[sk[1], 0] > min(kpts[23:91, 0]) and kpts[sk[1], 0] < max(kpts[23:91, 0]) and kpts[sk[1], 1] > min(kpts[23:91, 1]) and kpts[sk[1], 1] < max(kpts[23:91, 1]): | |
continue | |
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) | |
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) | |
# if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0 | |
# and pos1[1] < img_h and pos2[0] > 0 and pos2[0] < img_w | |
# and pos2[1] > 0 and pos2[1] < img_h | |
# and kpts[sk[0], 2] > kpt_score_thr | |
# and kpts[sk[1], 2] > kpt_score_thr): | |
if (kpts[sk[0], 2] > kpt_score_thr | |
and kpts[sk[1], 2] > kpt_score_thr): | |
color = tuple(int(c) for c in pose_link_color[sk_id]) | |
if show_keypoint_weight: | |
img_copy = img.copy() | |
X = (pos1[0], pos2[0]) | |
Y = (pos1[1], pos2[1]) | |
mX = np.mean(X) | |
mY = np.mean(Y) | |
length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 | |
angle = math.degrees( | |
math.atan2(Y[0] - Y[1], X[0] - X[1])) | |
stickwidth = thickness | |
polygon = cv2.ellipse2Poly( | |
(int(mX), int(mY)), | |
(int(length / 2), int(stickwidth)), int(angle), 0, | |
360, 1) | |
cv2.fillConvexPoly(img_copy, polygon, color) | |
# transparency = max( | |
# 0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2]))) | |
transparency = 1 | |
cv2.addWeighted( | |
img_copy, | |
transparency, | |
img, | |
1 - transparency, | |
0, | |
dst=img) | |
else: | |
cv2.line(img, pos1, pos2, color, thickness=thickness) | |
return img | |
def imshow_keypoints_whole(img, | |
pose_result, | |
skeleton=None, | |
kpt_score_thr=0.3, | |
pose_kpt_color=None, | |
pose_link_color=None, | |
radius=4, | |
thickness=1, | |
show_keypoint_weight=False, | |
height=None, | |
width=None): | |
"""Draw keypoints and links on an image. | |
Args: | |
img (str or Tensor): The image to draw poses on. If an image array | |
is given, id will be modified in-place. | |
pose_result (list[kpts]): The poses to draw. Each element kpts is | |
a set of K keypoints as an Kx3 numpy.ndarray, where each | |
keypoint is represented as x, y, score. | |
kpt_score_thr (float, optional): Minimum score of keypoints | |
to be shown. Default: 0.3. | |
pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, | |
the keypoint will not be drawn. | |
pose_link_color (np.array[Mx3]): Color of M links. If None, the | |
links will not be drawn. | |
thickness (int): Thickness of lines. | |
""" | |
# img = mmcv.imread(img) | |
# img_h, img_w, _ = img.shape | |
if img is None: | |
img = np.zeros((height, width, 3), dtype=np.uint8) | |
img_h, img_w = height, width | |
else: | |
img_h, img_w, _ = img.shape | |
for kpts in pose_result: | |
kpts = np.array(kpts, copy=False) | |
# draw each point on image | |
if pose_kpt_color is not None: | |
assert len(pose_kpt_color) == len(kpts) | |
for kid, kpt in enumerate(kpts): | |
if kid in [17, 18, 19, 20, 21, 22]: | |
continue | |
if kid in [13, 14, 15, 16]: | |
if kpt[0] > min(kpts[23:91, 0]) and kpt[0] < max(kpts[23:91, 0]) and kpt[1] > min(kpts[23:91, 1]) and kpt[1] < max(kpts[23:91, 1]): | |
continue | |
x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] | |
if kpt_score > kpt_score_thr: | |
color = tuple(int(c) for c in pose_kpt_color[kid]) | |
if show_keypoint_weight: | |
img_copy = img.copy() | |
cv2.circle(img_copy, (int(x_coord), int(y_coord)), | |
radius, color, -1) | |
transparency = max(0, min(1, kpt_score)) | |
cv2.addWeighted( | |
img_copy, | |
transparency, | |
img, | |
1 - transparency, | |
0, | |
dst=img) | |
else: | |
cv2.circle(img, (int(x_coord), int(y_coord)), radius, | |
color, -1) | |
# draw links | |
if skeleton is not None and pose_link_color is not None: | |
assert len(pose_link_color) == len(skeleton) | |
for sk_id, sk in enumerate(skeleton): | |
if sk[0] in [17, 18, 19, 20, 21, 22] or sk[1] in [17, 18, 19, 20, 21, 22]: | |
continue | |
if sk[0] in [13, 14, 15, 16]: | |
if kpts[sk[0], 0] > min(kpts[23:91, 0]) and kpts[sk[0], 0] < max(kpts[23:91, 0]) and kpts[sk[0], 1] > min(kpts[23:91, 1]) and kpts[sk[0], 1] < max(kpts[23:91, 1]): | |
continue | |
if sk[1] in [13, 14, 15, 16]: | |
if kpts[sk[1], 0] > min(kpts[23:91, 0]) and kpts[sk[1], 0] < max(kpts[23:91, 0]) and kpts[sk[1], 1] > min(kpts[23:91, 1]) and kpts[sk[1], 1] < max(kpts[23:91, 1]): | |
continue | |
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) | |
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) | |
# if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0 | |
# and pos1[1] < img_h and pos2[0] > 0 and pos2[0] < img_w | |
# and pos2[1] > 0 and pos2[1] < img_h | |
# and kpts[sk[0], 2] > kpt_score_thr | |
# and kpts[sk[1], 2] > kpt_score_thr): | |
if (kpts[sk[0], 2] > kpt_score_thr | |
and kpts[sk[1], 2] > kpt_score_thr): | |
color = tuple(int(c) for c in pose_link_color[sk_id]) | |
if show_keypoint_weight: | |
img_copy = img.copy() | |
X = (pos1[0], pos2[0]) | |
Y = (pos1[1], pos2[1]) | |
mX = np.mean(X) | |
mY = np.mean(Y) | |
length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 | |
angle = math.degrees( | |
math.atan2(Y[0] - Y[1], X[0] - X[1])) | |
stickwidth = thickness | |
polygon = cv2.ellipse2Poly( | |
(int(mX), int(mY)), | |
(int(length / 2), int(stickwidth)), int(angle), 0, | |
360, 1) | |
cv2.fillConvexPoly(img_copy, polygon, color) | |
# transparency = max( | |
# 0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2]))) | |
transparency = 1 | |
cv2.addWeighted( | |
img_copy, | |
transparency, | |
img, | |
1 - transparency, | |
0, | |
dst=img) | |
else: | |
cv2.line(img, pos1, pos2, color, thickness=thickness) | |
return img | |
def draw_whole_body_skeleton( | |
img, | |
pose, | |
radius=4, | |
thickness=1, | |
kpt_score_thr=0.3, | |
height=None, | |
width=None, | |
): | |
palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], | |
[230, 230, 0], [255, 153, 255], [153, 204, 255], | |
[255, 102, 255], [255, 51, 255], [102, 178, 255], | |
[51, 153, 255], [255, 153, 153], [255, 102, 102], | |
[255, 51, 51], [153, 255, 153], [102, 255, 102], | |
[51, 255, 51], [0, 255, 0], [0, 0, 255], | |
[255, 0, 0], [255, 255, 255]]) | |
# below are for the whole body keypoints | |
skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], | |
[5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], | |
[8, 10], [1, 2], [0, 1], [0, 2], | |
[1, 3], [2, 4], [3, 5], [4, 6], [15, 17], [15, 18], | |
[15, 19], [16, 20], [16, 21], [16, 22], [91, 92], | |
[92, 93], [93, 94], [94, 95], [91, 96], [96, 97], | |
[97, 98], [98, 99], [91, 100], [100, 101], [101, 102], | |
[102, 103], [91, 104], [104, 105], [105, 106], | |
[106, 107], [91, 108], [108, 109], [109, 110], | |
[110, 111], [112, 113], [113, 114], [114, 115], | |
[115, 116], [112, 117], [117, 118], [118, 119], | |
[119, 120], [112, 121], [121, 122], [122, 123], | |
[123, 124], [112, 125], [125, 126], [126, 127], | |
[127, 128], [112, 129], [129, 130], [130, 131], | |
[131, 132]] | |
pose_link_color = palette[[ | |
0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16 | |
] + [16, 16, 16, 16, 16, 16] + [ | |
0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, | |
16 | |
] + [ | |
0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, | |
16 | |
]] | |
pose_kpt_color = palette[ | |
[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0] + | |
[0, 0, 0, 0, 0, 0] + [19] * (68 + 42)] | |
draw = imshow_keypoints_whole(img, pose, skeleton, | |
kpt_score_thr=0.3, | |
pose_kpt_color=pose_kpt_color, | |
pose_link_color=pose_link_color, | |
radius=radius, | |
thickness=thickness, | |
show_keypoint_weight=True, | |
height=height, | |
width=width) | |
return draw | |
def draw_humansd_skeleton(image, pose, mmpose_detection_thresh=0.3, height=None, width=None, humansd_skeleton_width=10): | |
humansd_skeleton=[ | |
[0,0,1], | |
[1,0,2], | |
[2,1,3], | |
[3,2,4], | |
[4,3,5], | |
[5,4,6], | |
[6,5,7], | |
[7,6,8], | |
[8,7,9], | |
[9,8,10], | |
[10,5,11], | |
[11,6,12], | |
[12,11,13], | |
[13,12,14], | |
[14,13,15], | |
[15,14,16], | |
] | |
# humansd_skeleton_width=10 | |
humansd_color=sns.color_palette("hls", len(humansd_skeleton)) | |
def plot_kpts(img_draw, kpts, color, edgs,width): | |
for idx, kpta, kptb in edgs: | |
if kpts[kpta,2]>mmpose_detection_thresh and \ | |
kpts[kptb,2]>mmpose_detection_thresh : | |
line_color = tuple([int(255*color_i) for color_i in color[idx]]) | |
cv2.line(img_draw, (int(kpts[kpta,0]),int(kpts[kpta,1])), (int(kpts[kptb,0]),int(kpts[kptb,1])), line_color,width) | |
cv2.circle(img_draw, (int(kpts[kpta,0]),int(kpts[kpta,1])), width//2, line_color, -1) | |
cv2.circle(img_draw, (int(kpts[kptb,0]),int(kpts[kptb,1])), width//2, line_color, -1) | |
if image is None: | |
pose_image = np.zeros((height, width, 3), dtype=np.uint8) | |
else: | |
pose_image = np.array(image, dtype=np.uint8) | |
for person_i in range(len(pose)): | |
if np.sum(pose[person_i])>0: | |
plot_kpts(pose_image, pose[person_i],humansd_color,humansd_skeleton,humansd_skeleton_width) | |
return pose_image | |
def draw_controlnet_skeleton(image, pose, mmpose_detection_thresh=0.3, height=None, width=None): | |
if image is None: | |
canvas = np.zeros((height, width, 3), dtype=np.uint8) | |
else: | |
H, W, C = image.shape | |
canvas = np.array(image, dtype=np.uint8) | |
for pose_i in range(len(pose)): | |
present_pose=pose[pose_i] | |
candidate=[ | |
[present_pose[0,0],present_pose[0,1],present_pose[0,2],0], | |
[(present_pose[6,0]+present_pose[5,0])/2,(present_pose[6,1]+present_pose[5,1])/2,(present_pose[6,2]+present_pose[5,2])/2,1] if present_pose[6,2]>mmpose_detection_thresh and present_pose[5,2]>mmpose_detection_thresh else [-1,-1,0,1], | |
[present_pose[6,0],present_pose[6,1],present_pose[6,2],2], | |
[present_pose[8,0],present_pose[8,1],present_pose[8,2],3], | |
[present_pose[10,0],present_pose[10,1],present_pose[10,2],4], | |
[present_pose[5,0],present_pose[5,1],present_pose[5,2],5], | |
[present_pose[7,0],present_pose[7,1],present_pose[7,2],6], | |
[present_pose[9,0],present_pose[9,1],present_pose[9,2],7], | |
[present_pose[12,0],present_pose[12,1],present_pose[12,2],8], | |
[present_pose[14,0],present_pose[14,1],present_pose[14,2],9], | |
[present_pose[16,0],present_pose[16,1],present_pose[16,2],10], | |
[present_pose[11,0],present_pose[11,1],present_pose[11,2],11], | |
[present_pose[13,0],present_pose[13,1],present_pose[13,2],12], | |
[present_pose[15,0],present_pose[15,1],present_pose[15,2],13], | |
[present_pose[2,0],present_pose[2,1],present_pose[2,2],14], | |
[present_pose[1,0],present_pose[1,1],present_pose[1,2],15], | |
[present_pose[4,0],present_pose[4,1],present_pose[4,2],16], | |
[present_pose[3,0],present_pose[3,1],present_pose[3,2],17], | |
] | |
stickwidth = 4 | |
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ | |
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ | |
[1, 16], [16, 18], [3, 17], [6, 18]] | |
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ | |
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ | |
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] | |
for i in range(17): | |
if candidate[limbSeq[i][0]-1][2]>mmpose_detection_thresh and candidate[limbSeq[i][1]-1][2]>mmpose_detection_thresh: | |
Y=[candidate[limbSeq[i][1]-1][0],candidate[limbSeq[i][0]-1][0]] | |
X=[candidate[limbSeq[i][1]-1][1],candidate[limbSeq[i][0]-1][1]] | |
mX = np.mean(X) | |
mY = np.mean(Y) | |
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 | |
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) | |
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) | |
cur_canvas = canvas.copy() | |
cv2.fillConvexPoly(cur_canvas, polygon, colors[i]) | |
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) | |
for i in range(18): | |
if candidate[i][2]>mmpose_detection_thresh: | |
x, y = candidate[i][0:2] | |
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) | |
return canvas | |
def draw_body_skeleton( | |
img, | |
pose, | |
radius=4, | |
thickness=1, | |
kpt_score_thr=0.3, | |
height=None, | |
width=None, | |
): | |
palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], | |
[230, 230, 0], [255, 153, 255], [153, 204, 255], | |
[255, 102, 255], [255, 51, 255], [102, 178, 255], | |
[51, 153, 255], [255, 153, 153], [255, 102, 102], | |
[255, 51, 51], [153, 255, 153], [102, 255, 102], | |
[51, 255, 51], [0, 255, 0], [0, 0, 255], | |
[255, 0, 0], [255, 255, 255]]) | |
# below are for the body keypoints | |
# skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], | |
# [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], | |
# [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], | |
# [3, 5], [4, 6]] | |
skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], | |
[5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], | |
[8, 10], [3, 4], | |
[3, 5], [4, 6]] | |
# pose_link_color = palette[[ | |
# 0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16 | |
# ]] | |
# pose_kpt_color = palette[[ | |
# 16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0 | |
# ]] | |
pose_link_color = palette[[ | |
12, 16, 1, 5, 9, 13, 19, 15, 11, 7, 3, 18, 14, 8, 0 | |
]] | |
pose_kpt_color = palette[[ | |
19, 15, 11, 7, 3, 18, 14, 10, 6, 2, 17, 13, 9, 5, 1, 16, 12 | |
]] | |
draw = imshow_keypoints_body(img, pose, skeleton, | |
kpt_score_thr=0.3, | |
pose_kpt_color=pose_kpt_color, | |
pose_link_color=pose_link_color, | |
radius=radius, | |
thickness=thickness, | |
show_keypoint_weight=True, | |
height=height, | |
width=width) | |
return draw | |
def draw_face_skeleton( | |
img, | |
pose, | |
radius=4, | |
thickness=1, | |
kpt_score_thr=0.3, | |
height=None, | |
width=None, | |
): | |
palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], | |
[230, 230, 0], [255, 153, 255], [153, 204, 255], | |
[255, 102, 255], [255, 51, 255], [102, 178, 255], | |
[51, 153, 255], [255, 153, 153], [255, 102, 102], | |
[255, 51, 51], [153, 255, 153], [102, 255, 102], | |
[51, 255, 51], [0, 255, 0], [0, 0, 255], | |
[255, 0, 0], [255, 255, 255]]) | |
# below are for the face keypoints | |
skeleton = [] | |
pose_link_color = palette[[]] | |
pose_kpt_color = palette[[19] * 68] | |
kpt_score_thr = 0 | |
draw = imshow_keypoints(img, pose, skeleton, | |
kpt_score_thr=kpt_score_thr, | |
pose_kpt_color=pose_kpt_color, | |
pose_link_color=pose_link_color, | |
radius=radius, | |
thickness=thickness, | |
show_keypoint_weight=True, | |
height=height, | |
width=width) | |
return draw | |
def draw_hand_skeleton( | |
img, | |
pose, | |
radius=4, | |
thickness=1, | |
kpt_score_thr=0.3, | |
height=None, | |
width=None, | |
): | |
palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], | |
[230, 230, 0], [255, 153, 255], [153, 204, 255], | |
[255, 102, 255], [255, 51, 255], [102, 178, 255], | |
[51, 153, 255], [255, 153, 153], [255, 102, 102], | |
[255, 51, 51], [153, 255, 153], [102, 255, 102], | |
[51, 255, 51], [0, 255, 0], [0, 0, 255], | |
[255, 0, 0], [255, 255, 255]]) | |
# hand option 1 | |
skeleton = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], | |
[7, 8], [0, 9], [9, 10], [10, 11], [11, 12], [0, 13], | |
[13, 14], [14, 15], [15, 16], [0, 17], [17, 18], | |
[18, 19], [19, 20]] | |
pose_link_color = palette[[ | |
0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, | |
16 | |
]] | |
pose_kpt_color = palette[[ | |
0, 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, | |
16, 16 | |
]] | |
# # hand option 2 | |
# skeleton = [[0, 1], [1, 2], [2, 3], [4, 5], [5, 6], [6, 7], [8, 9], | |
# [9, 10], [10, 11], [12, 13], [13, 14], [14, 15], | |
# [16, 17], [17, 18], [18, 19], [3, 20], [7, 20], | |
# [11, 20], [15, 20], [19, 20]] | |
# pose_link_color = palette[[ | |
# 0, 0, 0, 4, 4, 4, 8, 8, 8, 12, 12, 12, 16, 16, 16, 0, 4, 8, 12, | |
# 16 | |
# ]] | |
# pose_kpt_color = palette[[ | |
# 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, | |
# 16, 0 | |
# ]] | |
draw = imshow_keypoints(img, pose, skeleton, | |
kpt_score_thr=kpt_score_thr, | |
pose_kpt_color=pose_kpt_color, | |
pose_link_color=pose_link_color, | |
radius=radius, | |
thickness=thickness, | |
show_keypoint_weight=True, | |
height=height, | |
width=width) | |
return draw | |
class CsvDataset(Dataset): | |
def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", tokenizer=None): | |
logging.debug(f'Loading csv data from {input_filename}.') | |
df = pd.read_csv(input_filename, sep=sep) | |
self.images = df[img_key].tolist() | |
self.captions = df[caption_key].tolist() | |
self.transforms = transforms | |
logging.debug('Done loading data.') | |
self.tokenize = tokenizer | |
def __len__(self): | |
return len(self.captions) | |
def __getitem__(self, idx): | |
images = self.transforms(Image.open(str(self.images[idx]))) | |
texts = self.tokenize([str(self.captions[idx])])[0] | |
return images, texts | |
class SharedEpoch: | |
def __init__(self, epoch: int = 0): | |
self.shared_epoch = Value('i', epoch) | |
def set_value(self, epoch): | |
self.shared_epoch.value = epoch | |
def get_value(self): | |
return self.shared_epoch.value | |
class DataInfo: | |
dataloader: DataLoader | |
sampler: DistributedSampler = None | |
shared_epoch: SharedEpoch = None | |
def set_epoch(self, epoch): | |
if self.shared_epoch is not None: | |
self.shared_epoch.set_value(epoch) | |
if self.sampler is not None and isinstance(self.sampler, DistributedSampler): | |
self.sampler.set_epoch(epoch) | |
def expand_urls(urls, weights=None): | |
if weights is None: | |
expanded_urls = wds.shardlists.expand_urls(urls) | |
return expanded_urls, None | |
if isinstance(urls, str): | |
urllist = urls.split("::") | |
weights = weights.split('::') | |
assert len(weights) == len(urllist), f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match." | |
weights = [float(weight) for weight in weights] | |
all_urls, all_weights = [], [] | |
for url, weight in zip(urllist, weights): | |
expanded_url = list(braceexpand.braceexpand(url)) | |
expanded_weights = [weight for _ in expanded_url] | |
all_urls.extend(expanded_url) | |
all_weights.extend(expanded_weights) | |
return all_urls, all_weights | |
else: | |
all_urls = list(urls) | |
return all_urls, weights | |
def get_dataset_size(shards): | |
shards_list, _ = expand_urls(shards) | |
dir_path = os.path.dirname(shards_list[0]) | |
sizes_filename = os.path.join(dir_path, 'sizes.json') | |
len_filename = os.path.join(dir_path, '__len__') | |
if os.path.exists(sizes_filename): | |
sizes = json.load(open(sizes_filename, 'r')) | |
total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list]) | |
elif os.path.exists(len_filename): | |
# FIXME this used to be eval(open(...)) but that seemed rather unsafe | |
total_size = ast.literal_eval(open(len_filename, 'r').read()) | |
else: | |
total_size = None # num samples undefined | |
# some common dataset sizes (at time of authors last download) | |
# CC3M (train): 2905954 | |
# CC12M: 10968539 | |
# LAION-400M: 407332084 | |
# LAION-2B (english): 2170337258 | |
num_shards = len(shards_list) | |
return total_size, num_shards | |
def get_imagenet(args, preprocess_fns, split): | |
assert split in ["train", "val", "v2"] | |
is_train = split == "train" | |
preprocess_train, preprocess_val = preprocess_fns | |
if split == "v2": | |
from imagenetv2_pytorch import ImageNetV2Dataset | |
dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) | |
else: | |
if is_train: | |
data_path = args.imagenet_train | |
preprocess_fn = preprocess_train | |
else: | |
data_path = args.imagenet_val | |
preprocess_fn = preprocess_val | |
assert data_path | |
dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) | |
if is_train: | |
idxs = np.zeros(len(dataset.targets)) | |
target_array = np.array(dataset.targets) | |
k = 50 | |
for c in range(1000): | |
m = target_array == c | |
n = len(idxs[m]) | |
arr = np.zeros(n) | |
arr[:k] = 1 | |
np.random.shuffle(arr) | |
idxs[m] = arr | |
idxs = idxs.astype('int') | |
sampler = SubsetRandomSampler(np.where(idxs)[0]) | |
else: | |
sampler = None | |
dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
num_workers=args.workers, | |
sampler=sampler, | |
) | |
return DataInfo(dataloader=dataloader, sampler=sampler) | |
def count_samples(dataloader): | |
os.environ["WDS_EPOCH"] = "0" | |
n_elements, n_batches = 0, 0 | |
for images, texts in dataloader: | |
n_batches += 1 | |
n_elements += len(images) | |
assert len(images) == len(texts) | |
return n_elements, n_batches | |
def filter_no_caption_or_no_image(sample): | |
has_caption = ('txt' in sample) | |
has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) | |
return has_caption and has_image | |
def filter_no_image_or_no_ldmk(sample): | |
has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) | |
has_ldmk = ('ldmk' in sample) | |
return has_image and has_ldmk | |
def log_and_continue(exn): | |
"""Call in an exception handler to ignore any exception, issue a warning, and continue.""" | |
logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') | |
return True | |
def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): | |
"""Return function over iterator that groups key, value pairs into samples. | |
:param keys: function that splits the key into key and extension (base_plus_ext) | |
:param lcase: convert suffixes to lower case (Default value = True) | |
""" | |
current_sample = None | |
for filesample in data: | |
assert isinstance(filesample, dict) | |
fname, value = filesample["fname"], filesample["data"] | |
prefix, suffix = keys(fname) | |
if prefix is None: | |
continue | |
if lcase: | |
suffix = suffix.lower() | |
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for | |
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next | |
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset | |
if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: | |
if valid_sample(current_sample): | |
yield current_sample | |
current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) | |
if suffixes is None or suffix in suffixes: | |
current_sample[suffix] = value | |
if valid_sample(current_sample): | |
yield current_sample | |
def tarfile_to_samples_nothrow(src, handler=log_and_continue): | |
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw | |
streams = url_opener(src, handler=handler) | |
files = tar_file_expander(streams, handler=handler) | |
samples = group_by_keys_nothrow(files, handler=handler) | |
return samples | |
def pytorch_worker_seed(increment=0): | |
"""get dataloader worker seed from pytorch""" | |
worker_info = get_worker_info() | |
if worker_info is not None: | |
# favour using the seed already created for pytorch dataloader workers if it exists | |
seed = worker_info.seed | |
if increment: | |
# space out seed increments so they can't overlap across workers in different iterations | |
seed += increment * max(1, worker_info.num_workers) | |
return seed | |
# fallback to wds rank based seed | |
return wds.utils.pytorch_worker_seed() | |
_SHARD_SHUFFLE_SIZE = 2000 | |
_SHARD_SHUFFLE_INITIAL = 500 | |
_SAMPLE_SHUFFLE_SIZE = 5000 | |
_SAMPLE_SHUFFLE_INITIAL = 1000 | |
class detshuffle2(wds.PipelineStage): | |
def __init__( | |
self, | |
bufsize=1000, | |
initial=100, | |
seed=0, | |
epoch=-1, | |
): | |
self.bufsize = bufsize | |
self.initial = initial | |
self.seed = seed | |
self.epoch = epoch | |
def run(self, src): | |
if isinstance(self.epoch, SharedEpoch): | |
epoch = self.epoch.get_value() | |
else: | |
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) | |
# situation as different workers may wrap at different times (or not at all). | |
self.epoch += 1 | |
epoch = self.epoch | |
rng = random.Random() | |
if self.seed < 0: | |
# If seed is negative, we use the worker's seed, this will be different across all nodes/workers | |
seed = pytorch_worker_seed(epoch) | |
else: | |
# This seed to be deterministic AND the same across all nodes/workers in each epoch | |
seed = self.seed + epoch | |
rng.seed(seed) | |
return _shuffle(src, self.bufsize, self.initial, rng) | |
class ResampledShards2(IterableDataset): | |
"""An iterable dataset yielding a list of urls.""" | |
def __init__( | |
self, | |
urls, | |
weights=None, | |
nshards=sys.maxsize, | |
worker_seed=None, | |
deterministic=False, | |
epoch=-1, | |
): | |
"""Sample shards from the shard list with replacement. | |
:param urls: a list of URLs as a Python list or brace notation string | |
""" | |
super().__init__() | |
urls, weights = expand_urls(urls, weights) | |
self.urls = urls | |
self.weights = weights | |
if self.weights is not None: | |
assert len(self.urls) == len(self.weights), f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match." | |
assert isinstance(self.urls[0], str) | |
self.nshards = nshards | |
self.rng = random.Random() | |
self.worker_seed = worker_seed | |
self.deterministic = deterministic | |
self.epoch = epoch | |
def __iter__(self): | |
"""Return an iterator over the shards.""" | |
if isinstance(self.epoch, SharedEpoch): | |
epoch = self.epoch.get_value() | |
else: | |
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) | |
# situation as different workers may wrap at different times (or not at all). | |
self.epoch += 1 | |
epoch = self.epoch | |
if self.deterministic: | |
# reset seed w/ epoch if deterministic | |
if self.worker_seed is None: | |
# pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id | |
seed = pytorch_worker_seed(epoch) | |
else: | |
seed = self.worker_seed() + epoch | |
self.rng.seed(seed) | |
for _ in range(self.nshards): | |
if self.weights is None: | |
yield dict(url=self.rng.choice(self.urls)) | |
else: | |
yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0]) | |
def get_wds_dataset_filter(args, preprocess_img): | |
input_shards = args.train_data | |
assert input_shards is not None | |
pipeline = [wds.SimpleShardList(input_shards)] | |
def replicate_img(sample): | |
import copy | |
sample["original"] = copy.copy(sample["image"]) | |
return sample | |
def decode_byte_to_rgb(sample): | |
# import io | |
# import PIL | |
# from PIL import ImageFile | |
# ImageFile.LOAD_TRUNCATED_IMAGES = True | |
with io.BytesIO(sample["image"]) as stream: | |
try: | |
img = PIL.Image.open(stream) | |
img.load() | |
img = img.convert("RGB") | |
sample["image"] = img | |
return sample | |
except: | |
print("A broken image is encountered, replace w/ a placeholder") | |
image = Image.new('RGB', (512, 512)) | |
sample["image"] = image | |
return sample | |
# at this point we have an iterator over all the shards | |
pipeline.extend([ | |
wds.split_by_node, | |
wds.split_by_worker, | |
tarfile_to_samples_nothrow, | |
wds.select(filter_no_caption_or_no_image), | |
# wds.decode("pilrgb", handler=log_and_continue), | |
wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
# wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
wds.map(replicate_img), | |
wds.map(decode_byte_to_rgb), | |
# wds.map_dict(image=preprocess_img, text=lambda x: x.encode('utf-8'), \ | |
# __key__=lambda x: x.encode('utf-8'), __url__=lambda x: x.encode('utf-8')), | |
wds.map_dict(image=preprocess_img), | |
wds.to_tuple("original", "image", "text", "__key__", "__url__", "json"), | |
wds.batched(args.batch_size, partial=True) | |
]) | |
dataset = wds.DataPipeline(*pipeline) | |
dataloader = wds.WebLoader( | |
dataset, | |
batch_size=None, | |
shuffle=False, | |
num_workers=args.workers, | |
persistent_workers=True, | |
drop_last=False | |
) | |
return DataInfo(dataloader=dataloader) | |
def get_wds_dataset_cond_face(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False, filter_lowres=False): | |
input_shards = args.train_data if is_train else args.val_data | |
assert input_shards is not None | |
resampled = getattr(args, 'dataset_resampled', False) and is_train | |
num_samples, num_shards = get_dataset_size(input_shards) | |
if not num_samples: | |
if is_train: | |
num_samples = args.train_num_samples | |
if not num_samples: | |
raise RuntimeError( | |
'Currently, number of dataset samples must be specified for training dataset. ' | |
'Please specify via `--train-num-samples` if no dataset length info present.') | |
else: | |
num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
if resampled: | |
pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
else: | |
assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
pipeline = [wds.SimpleShardList(input_shards)] | |
# at this point we have an iterator over all the shards | |
if is_train: | |
if not resampled: | |
pipeline.extend([ | |
detshuffle2( | |
bufsize=_SHARD_SHUFFLE_SIZE, | |
initial=_SHARD_SHUFFLE_INITIAL, | |
seed=args.seed, | |
epoch=shared_epoch, | |
), | |
wds.split_by_node, | |
wds.split_by_worker, | |
]) | |
pipeline.extend([ | |
# at this point, we have an iterator over the shards assigned to each worker at each node | |
tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
wds.shuffle( | |
bufsize=_SAMPLE_SHUFFLE_SIZE, | |
initial=_SAMPLE_SHUFFLE_INITIAL, | |
), | |
]) | |
else: | |
pipeline.extend([ | |
wds.split_by_worker, | |
# at this point, we have an iterator over the shards assigned to each worker | |
wds.tarfile_to_samples(handler=log_and_continue), | |
]) | |
def preprocess_image(sample): | |
# print(main_args.resolution, main_args.center_crop, main_args.random_flip) | |
resize_transform = transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BICUBIC) | |
sample["image"] = resize_transform(sample["image"]) | |
sample["ldmk"] = resize_transform(sample["ldmk"]) | |
transform_list = [] | |
image_height, image_width = sample["image"].height, sample["image"].width | |
i = torch.randint(0, image_height - main_args.resolution + 1, size=(1,)).item() | |
j = torch.randint(0, image_width - main_args.resolution + 1, size=(1,)).item() | |
if main_args.center_crop or not is_train: | |
transform_list.append(transforms.CenterCrop(main_args.resolution)) | |
else: | |
if image_height < main_args.resolution or image_width < main_args.resolution: | |
raise ValueError(f"Required crop size {(main_args.resolution, main_args.resolution)} is larger than input image size {(image_height, image_width)}") | |
elif image_width == main_args.resolution and image_height == main_args.resolution: | |
i, j = 0, 0 | |
transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution, main_args.resolution))) | |
if is_train and torch.rand(1) < 0.5: | |
transform_list.append(transforms.RandomHorizontalFlip(p=1.)) | |
transform_list.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) | |
train_transforms = transforms.Compose(transform_list) | |
sample["image"] = train_transforms(sample["image"]) | |
sample["ldmk"] = train_transforms(sample["ldmk"]) | |
return sample | |
# def extract_ldmk(sample): | |
# image_height, image_width = sample["image"].height, sample["image"].width | |
# preds = fa.get_landmarks(np.array(sample["image"])) | |
# lands = [] | |
# if preds is not None: | |
# for pred in preds: | |
# land = pred.reshape(-1, 3)[:,:2].astype(int) | |
# lands.append(land) | |
# lms_color_map = np.zeros(shape=(image_height, image_width, 3)).astype("uint8") | |
# if len(lands) > 0: | |
# for land in lands: | |
# lms_color_map = vis_landmark_on_img(lms_color_map, land) | |
# # print(lms_color_map.shape) | |
# sample["ldmk"] = Image.fromarray(lms_color_map) | |
# return sample | |
def visualize_ldmk(sample): | |
image_height, image_width = sample["image"].height, sample["image"].width | |
lands = np.frombuffer(sample["ldmk"], dtype=np.float32) | |
lms_color_map = np.zeros(shape=(image_height, image_width, 3)).astype("uint8") | |
if len(lands) > 0: | |
lands = lands.reshape(-1, 68, 3).astype(int) | |
for i in range(lands.shape[0]): | |
lms_color_map = vis_landmark_on_img(lms_color_map, lands[i]) | |
# print(lms_color_map.shape) | |
sample["ldmk"] = Image.fromarray(lms_color_map) | |
return sample | |
def filter_ldmk_none(sample): | |
return not (sample["ldmk"] == -1).all() | |
def filter_low_res(sample): | |
if filter_lowres: | |
string_json = sample["json"].decode('utf-8') | |
dict_json = json.loads(string_json) | |
if "height" in dict_json.keys() and "width" in dict_json.keys(): | |
min_length = min(dict_json["height"], dict_json["width"]) | |
return min_length >= main_args.resolution | |
else: | |
return True | |
return True | |
pipeline.extend([ | |
wds.select(filter_no_caption_or_no_image), | |
wds.select(filter_low_res), | |
wds.decode("pilrgb", handler=log_and_continue), | |
wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
# wds.map(extract_ldmk), | |
wds.map(visualize_ldmk), | |
wds.map(preprocess_image), | |
wds.select(filter_ldmk_none), | |
# wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
wds.map_dict( | |
text=lambda text: tokenizer(text, \ | |
max_length=tokenizer.model_max_length, \ | |
padding="max_length", truncation=True, \ | |
return_tensors='pt')['input_ids'], | |
), | |
wds.to_tuple("image", "text", "ldmk"), | |
# wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), | |
wds.batched(args.batch_size, partial=not is_train) | |
]) | |
dataset = wds.DataPipeline(*pipeline) | |
if is_train: | |
if not resampled: | |
assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
# roll over and repeat a few samples to get same number of full batches on each node | |
round_fn = math.floor if floor else math.ceil | |
global_batch_size = args.batch_size * args.world_size | |
num_batches = round_fn(num_samples / global_batch_size) | |
num_workers = max(1, args.workers) | |
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
num_batches = num_worker_batches * num_workers | |
num_samples = num_batches * global_batch_size | |
dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
else: | |
# last batches are partial, eval is done on single (master) node | |
num_batches = math.ceil(num_samples / args.batch_size) | |
dataloader = wds.WebLoader( | |
dataset, | |
batch_size=None, | |
shuffle=False, | |
num_workers=args.workers, | |
persistent_workers=True, | |
) | |
# FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
# hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
# if is_train: | |
# # roll over and repeat a few samples to get same number of full batches on each node | |
# global_batch_size = args.batch_size * args.world_size | |
# num_batches = math.ceil(num_samples / global_batch_size) | |
# num_workers = max(1, args.workers) | |
# num_batches = math.ceil(num_batches / num_workers) * num_workers | |
# num_samples = num_batches * global_batch_size | |
# dataloader = dataloader.with_epoch(num_batches) | |
# else: | |
# # last batches are partial, eval is done on single (master) node | |
# num_batches = math.ceil(num_samples / args.batch_size) | |
# add meta-data to dataloader instance for convenience | |
dataloader.num_batches = num_batches | |
dataloader.num_samples = num_samples | |
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
def get_wds_dataset_depth(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False, filter_lowres=False, filter_mface=False, filter_wpose=False): | |
input_shards = args.train_data if is_train else args.val_data | |
assert input_shards is not None | |
resampled = getattr(args, 'dataset_resampled', False) and is_train | |
num_samples, num_shards = get_dataset_size(input_shards) | |
if not num_samples: | |
if is_train: | |
num_samples = args.train_num_samples | |
if not num_samples: | |
raise RuntimeError( | |
'Currently, number of dataset samples must be specified for training dataset. ' | |
'Please specify via `--train-num-samples` if no dataset length info present.') | |
else: | |
num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
if resampled: | |
pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
else: | |
assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
pipeline = [wds.SimpleShardList(input_shards)] | |
# at this point we have an iterator over all the shards | |
if is_train: | |
if not resampled: | |
pipeline.extend([ | |
detshuffle2( | |
bufsize=_SHARD_SHUFFLE_SIZE, | |
initial=_SHARD_SHUFFLE_INITIAL, | |
seed=args.seed, | |
epoch=shared_epoch, | |
), | |
wds.split_by_node, | |
wds.split_by_worker, | |
]) | |
pipeline.extend([ | |
# at this point, we have an iterator over the shards assigned to each worker at each node | |
tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
wds.shuffle( | |
bufsize=_SAMPLE_SHUFFLE_SIZE, | |
initial=_SAMPLE_SHUFFLE_INITIAL, | |
), | |
]) | |
else: | |
pipeline.extend([ | |
wds.split_by_worker, | |
# at this point, we have an iterator over the shards assigned to each worker | |
wds.tarfile_to_samples(handler=log_and_continue), | |
]) | |
def decode_image(sample): | |
with io.BytesIO(sample["omni_depth"]) as stream: | |
try: | |
img = PIL.Image.open(stream) | |
img.load() | |
img = img.convert("RGB") | |
sample["depth"] = img | |
except: | |
print("A broken image is encountered, replace w/ a placeholder") | |
image = Image.new('RGB', (512, 512)) | |
sample["depth"] = image | |
return sample | |
train_transforms = transforms.Compose( | |
[ | |
transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), | |
transforms.CenterCrop(main_args.resolution) if main_args.center_crop else transforms.RandomCrop(main_args.resolution), | |
transforms.RandomHorizontalFlip() if main_args.random_flip else transforms.Lambda(lambda x: x), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
] | |
) | |
def filter_depth_none(sample): | |
return not (sample["depth"] == -1).all() | |
def filter_low_res(sample): | |
if filter_lowres: | |
string_json = sample["json"].decode('utf-8') | |
dict_json = json.loads(string_json) | |
if "height" in dict_json.keys() and "width" in dict_json.keys(): | |
min_length = min(dict_json["height"], dict_json["width"]) | |
return min_length >= main_args.resolution | |
else: | |
return True | |
return True | |
def filter_multi_face(sample): | |
if filter_mface: | |
face_kp = np.frombuffer(sample["face_kp"], dtype=np.float32).reshape(-1, 98, 2) | |
if face_kp.shape[0] > 1: | |
return False | |
return True | |
def filter_whole_skeleton(sample): | |
if filter_wpose: | |
height, width = sample["image"].height, sample["image"].width | |
body_kp = np.frombuffer(sample["body_kp"], dtype=np.float32).reshape(17, 2) | |
if (body_kp[:, 0] > 0).all() and (body_kp[:, 0] < width).all() and (body_kp[:, 1] > 0).all() and (body_kp[:, 1] < height).all(): | |
return True | |
else: | |
return False | |
return True | |
pipeline.extend([ | |
wds.select(filter_no_caption_or_no_image), | |
wds.select(filter_multi_face), | |
wds.select(filter_low_res), | |
wds.decode("pilrgb", handler=log_and_continue), | |
wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
wds.select(filter_whole_skeleton), | |
wds.map(decode_image), | |
wds.map_dict(depth=train_transforms), | |
wds.select(filter_depth_none), | |
# wds.map_dict(depth=train_transforms, text=lambda text: tokenizer(text)[0]), | |
wds.map_dict( | |
text=lambda text: tokenizer(text, \ | |
max_length=tokenizer.model_max_length, \ | |
padding="max_length", truncation=True, \ | |
return_tensors='pt')['input_ids']), | |
wds.to_tuple("depth", "text"), | |
# wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), | |
wds.batched(args.batch_size, partial=not is_train) | |
]) | |
dataset = wds.DataPipeline(*pipeline) | |
if is_train: | |
if not resampled: | |
assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
# roll over and repeat a few samples to get same number of full batches on each node | |
round_fn = math.floor if floor else math.ceil | |
global_batch_size = args.batch_size * args.world_size | |
num_batches = round_fn(num_samples / global_batch_size) | |
num_workers = max(1, args.workers) | |
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
num_batches = num_worker_batches * num_workers | |
num_samples = num_batches * global_batch_size | |
dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
else: | |
# last batches are partial, eval is done on single (master) node | |
num_batches = math.ceil(num_samples / args.batch_size) | |
dataloader = wds.WebLoader( | |
dataset, | |
batch_size=None, | |
shuffle=False, | |
num_workers=args.workers, | |
persistent_workers=True, | |
) | |
# FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
# hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
# if is_train: | |
# # roll over and repeat a few samples to get same number of full batches on each node | |
# global_batch_size = args.batch_size * args.world_size | |
# num_batches = math.ceil(num_samples / global_batch_size) | |
# num_workers = max(1, args.workers) | |
# num_batches = math.ceil(num_batches / num_workers) * num_workers | |
# num_samples = num_batches * global_batch_size | |
# dataloader = dataloader.with_epoch(num_batches) | |
# else: | |
# # last batches are partial, eval is done on single (master) node | |
# num_batches = math.ceil(num_samples / args.batch_size) | |
# add meta-data to dataloader instance for convenience | |
dataloader.num_batches = num_batches | |
dataloader.num_samples = num_samples | |
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
def get_wds_dataset_depth2canny(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False): | |
input_shards = args.train_data if is_train else args.val_data | |
assert input_shards is not None | |
resampled = getattr(args, 'dataset_resampled', False) and is_train | |
num_samples, num_shards = get_dataset_size(input_shards) | |
if not num_samples: | |
if is_train: | |
num_samples = args.train_num_samples | |
if not num_samples: | |
raise RuntimeError( | |
'Currently, number of dataset samples must be specified for training dataset. ' | |
'Please specify via `--train-num-samples` if no dataset length info present.') | |
else: | |
num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
if resampled: | |
pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
else: | |
assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
pipeline = [wds.SimpleShardList(input_shards)] | |
# at this point we have an iterator over all the shards | |
if is_train: | |
if not resampled: | |
pipeline.extend([ | |
detshuffle2( | |
bufsize=_SHARD_SHUFFLE_SIZE, | |
initial=_SHARD_SHUFFLE_INITIAL, | |
seed=args.seed, | |
epoch=shared_epoch, | |
), | |
wds.split_by_node, | |
wds.split_by_worker, | |
]) | |
pipeline.extend([ | |
# at this point, we have an iterator over the shards assigned to each worker at each node | |
tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
wds.shuffle( | |
bufsize=_SAMPLE_SHUFFLE_SIZE, | |
initial=_SAMPLE_SHUFFLE_INITIAL, | |
), | |
]) | |
else: | |
pipeline.extend([ | |
wds.split_by_worker, | |
# at this point, we have an iterator over the shards assigned to each worker | |
wds.tarfile_to_samples(handler=log_and_continue), | |
]) | |
def decode_image(sample): | |
with io.BytesIO(sample["omni_depth"]) as stream: | |
try: | |
img = PIL.Image.open(stream) | |
img.load() | |
img = img.convert("RGB") | |
sample["depth"] = img | |
except: | |
print("A broken image is encountered, replace w/ a placeholder") | |
image = Image.new('RGB', (512, 512)) | |
sample["depth"] = image | |
return sample | |
def add_canny(sample): | |
canny = np.array(sample["image"]) | |
low_threshold = 100 | |
high_threshold = 200 | |
canny = cv2.Canny(canny, low_threshold, high_threshold) | |
canny = canny[:, :, None] | |
canny = np.concatenate([canny, canny, canny], axis=2) | |
sample["canny"] = Image.fromarray(canny) | |
return sample | |
def preprocess_image(sample): | |
# print(main_args.resolution, main_args.center_crop, main_args.random_flip) | |
if grid_dnc: | |
resize_transform = transforms.Resize(main_args.resolution // 2, interpolation=transforms.InterpolationMode.BICUBIC) | |
else: | |
resize_transform = transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BICUBIC) | |
sample["image"] = resize_transform(sample["image"]) | |
sample["canny"] = resize_transform(sample["canny"]) | |
sample["depth"] = resize_transform(sample["depth"]) | |
transform_list = [] | |
image_height, image_width = sample["image"].height, sample["image"].width | |
if grid_dnc: | |
i = torch.randint(0, image_height - main_args.resolution // 2 + 1, size=(1,)).item() | |
j = torch.randint(0, image_width - main_args.resolution // 2 + 1, size=(1,)).item() | |
else: | |
i = torch.randint(0, image_height - main_args.resolution + 1, size=(1,)).item() | |
j = torch.randint(0, image_width - main_args.resolution + 1, size=(1,)).item() | |
if main_args.center_crop or not is_train: | |
if grid_dnc: | |
transform_list.append(transforms.CenterCrop(main_args.resolution // 2)) | |
else: | |
transform_list.append(transforms.CenterCrop(main_args.resolution)) | |
else: | |
if grid_dnc: | |
if image_height < main_args.resolution // 2 or image_width < main_args.resolution // 2: | |
raise ValueError(f"Required crop size {(main_args.resolution // 2, main_args.resolution // 2)} is larger than input image size {(image_height, image_width)}") | |
elif image_width == main_args.resolution // 2 and image_height == main_args.resolution // 2: | |
i, j = 0, 0 | |
transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution // 2, main_args.resolution // 2))) | |
else: | |
if image_height < main_args.resolution or image_width < main_args.resolution: | |
raise ValueError(f"Required crop size {(main_args.resolution, main_args.resolution)} is larger than input image size {(image_height, image_width)}") | |
elif image_width == main_args.resolution and image_height == main_args.resolution: | |
i, j = 0, 0 | |
transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution, main_args.resolution))) | |
if is_train and torch.rand(1) < 0.5: | |
transform_list.append(transforms.RandomHorizontalFlip(p=1.)) | |
transform_list.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) | |
train_transforms = transforms.Compose(transform_list) | |
sample["image"] = train_transforms(sample["image"]) | |
sample["canny"] = train_transforms(sample["canny"]) | |
sample["depth"] = train_transforms(sample["depth"]) | |
return sample | |
def random_mask(sample): | |
if is_train and dropout: | |
random_num = torch.rand(1) | |
if random_num < 0.1: | |
sample["depth"] = torch.ones_like(sample["depth"]) * (-1) | |
return sample | |
pipeline.extend([ | |
wds.select(filter_no_caption_or_no_image), | |
wds.decode("pilrgb", handler=log_and_continue), | |
wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
wds.map(decode_image), | |
wds.map(add_canny), | |
wds.map(preprocess_image), | |
wds.map(random_mask), | |
# wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
wds.map_dict( | |
text=lambda text: tokenizer(text, \ | |
max_length=tokenizer.model_max_length, \ | |
padding="max_length", truncation=True, \ | |
return_tensors='pt')['input_ids'], | |
), | |
wds.to_tuple("canny", "text", "depth"), | |
# wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), | |
wds.batched(args.batch_size, partial=not is_train) | |
]) | |
dataset = wds.DataPipeline(*pipeline) | |
if is_train: | |
if not resampled: | |
assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
# roll over and repeat a few samples to get same number of full batches on each node | |
round_fn = math.floor if floor else math.ceil | |
global_batch_size = args.batch_size * args.world_size | |
num_batches = round_fn(num_samples / global_batch_size) | |
num_workers = max(1, args.workers) | |
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
num_batches = num_worker_batches * num_workers | |
num_samples = num_batches * global_batch_size | |
dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
else: | |
# last batches are partial, eval is done on single (master) node | |
num_batches = math.ceil(num_samples / args.batch_size) | |
dataloader = wds.WebLoader( | |
dataset, | |
batch_size=None, | |
shuffle=False, | |
num_workers=args.workers, | |
persistent_workers=True, | |
) | |
# FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
# hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
# if is_train: | |
# # roll over and repeat a few samples to get same number of full batches on each node | |
# global_batch_size = args.batch_size * args.world_size | |
# num_batches = math.ceil(num_samples / global_batch_size) | |
# num_workers = max(1, args.workers) | |
# num_batches = math.ceil(num_batches / num_workers) * num_workers | |
# num_samples = num_batches * global_batch_size | |
# dataloader = dataloader.with_epoch(num_batches) | |
# else: | |
# # last batches are partial, eval is done on single (master) node | |
# num_batches = math.ceil(num_samples / args.batch_size) | |
# add meta-data to dataloader instance for convenience | |
dataloader.num_batches = num_batches | |
dataloader.num_samples = num_samples | |
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
def get_wds_dataset_depth2normal(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False, filter_lowres=False): | |
input_shards = args.train_data if is_train else args.val_data | |
assert input_shards is not None | |
resampled = getattr(args, 'dataset_resampled', False) and is_train | |
num_samples, num_shards = get_dataset_size(input_shards) | |
if not num_samples: | |
if is_train: | |
num_samples = args.train_num_samples | |
if not num_samples: | |
raise RuntimeError( | |
'Currently, number of dataset samples must be specified for training dataset. ' | |
'Please specify via `--train-num-samples` if no dataset length info present.') | |
else: | |
num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
if resampled: | |
pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
else: | |
assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
pipeline = [wds.SimpleShardList(input_shards)] | |
# at this point we have an iterator over all the shards | |
if is_train: | |
if not resampled: | |
pipeline.extend([ | |
detshuffle2( | |
bufsize=_SHARD_SHUFFLE_SIZE, | |
initial=_SHARD_SHUFFLE_INITIAL, | |
seed=args.seed, | |
epoch=shared_epoch, | |
), | |
wds.split_by_node, | |
wds.split_by_worker, | |
]) | |
pipeline.extend([ | |
# at this point, we have an iterator over the shards assigned to each worker at each node | |
tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
wds.shuffle( | |
bufsize=_SAMPLE_SHUFFLE_SIZE, | |
initial=_SAMPLE_SHUFFLE_INITIAL, | |
), | |
]) | |
else: | |
pipeline.extend([ | |
wds.split_by_worker, | |
# at this point, we have an iterator over the shards assigned to each worker | |
wds.tarfile_to_samples(handler=log_and_continue), | |
]) | |
def decode_image(sample): | |
with io.BytesIO(sample["omni_normal"]) as stream: | |
try: | |
img = PIL.Image.open(stream) | |
img.load() | |
img = img.convert("RGB") | |
sample["normal"] = img | |
except: | |
print("A broken image is encountered, replace w/ a placeholder") | |
image = Image.new('RGB', (512, 512)) | |
sample["normal"] = image | |
with io.BytesIO(sample["omni_depth"]) as stream: | |
try: | |
img = PIL.Image.open(stream) | |
img.load() | |
img = img.convert("RGB") | |
sample["depth"] = img | |
except: | |
print("A broken image is encountered, replace w/ a placeholder") | |
image = Image.new('RGB', (512, 512)) | |
sample["depth"] = image | |
return sample | |
def preprocess_image(sample): | |
# print(main_args.resolution, main_args.center_crop, main_args.random_flip) | |
if grid_dnc: | |
resize_transform = transforms.Resize(main_args.resolution // 2, interpolation=transforms.InterpolationMode.BICUBIC) | |
else: | |
resize_transform = transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BICUBIC) | |
sample["image"] = resize_transform(sample["image"]) | |
sample["normal"] = resize_transform(sample["normal"]) | |
sample["depth"] = resize_transform(sample["depth"]) | |
transform_list = [] | |
image_height, image_width = sample["image"].height, sample["image"].width | |
if grid_dnc: | |
i = torch.randint(0, image_height - main_args.resolution // 2 + 1, size=(1,)).item() | |
j = torch.randint(0, image_width - main_args.resolution // 2 + 1, size=(1,)).item() | |
else: | |
i = torch.randint(0, image_height - main_args.resolution + 1, size=(1,)).item() | |
j = torch.randint(0, image_width - main_args.resolution + 1, size=(1,)).item() | |
if main_args.center_crop or not is_train: | |
if grid_dnc: | |
transform_list.append(transforms.CenterCrop(main_args.resolution // 2)) | |
else: | |
transform_list.append(transforms.CenterCrop(main_args.resolution)) | |
else: | |
if grid_dnc: | |
if image_height < main_args.resolution // 2 or image_width < main_args.resolution // 2: | |
raise ValueError(f"Required crop size {(main_args.resolution // 2, main_args.resolution // 2)} is larger than input image size {(image_height, image_width)}") | |
elif image_width == main_args.resolution // 2 and image_height == main_args.resolution // 2: | |
i, j = 0, 0 | |
transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution // 2, main_args.resolution // 2))) | |
else: | |
if image_height < main_args.resolution or image_width < main_args.resolution: | |
raise ValueError(f"Required crop size {(main_args.resolution, main_args.resolution)} is larger than input image size {(image_height, image_width)}") | |
elif image_width == main_args.resolution and image_height == main_args.resolution: | |
i, j = 0, 0 | |
transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution, main_args.resolution))) | |
if is_train and torch.rand(1) < 0.5: | |
transform_list.append(transforms.RandomHorizontalFlip(p=1.)) | |
transform_list.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) | |
train_transforms = transforms.Compose(transform_list) | |
sample["image"] = train_transforms(sample["image"]) | |
sample["normal"] = train_transforms(sample["normal"]) | |
sample["depth"] = train_transforms(sample["depth"]) | |
return sample | |
def random_mask(sample): | |
if is_train and dropout: | |
random_num = torch.rand(1) | |
if random_num < 0.1: | |
sample["depth"] = torch.ones_like(sample["depth"]) * (-1) | |
return sample | |
def filter_low_res(sample): | |
if filter_lowres: | |
string_json = sample["json"].decode('utf-8') | |
dict_json = json.loads(string_json) | |
if "height" in dict_json.keys() and "width" in dict_json.keys(): | |
min_length = min(dict_json["height"], dict_json["width"]) | |
return min_length >= main_args.resolution | |
else: | |
return True | |
return True | |
pipeline.extend([ | |
wds.select(filter_no_caption_or_no_image), | |
wds.select(filter_low_res), | |
wds.decode("pilrgb", handler=log_and_continue), | |
wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
wds.map(decode_image), | |
wds.map(preprocess_image), | |
wds.map(random_mask), | |
# wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
wds.map_dict( | |
text=lambda text: tokenizer(text, \ | |
max_length=tokenizer.model_max_length, \ | |
padding="max_length", truncation=True, \ | |
return_tensors='pt')['input_ids'], | |
), | |
wds.to_tuple("normal", "text", "depth"), | |
# wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), | |
wds.batched(args.batch_size, partial=not is_train) | |
]) | |
dataset = wds.DataPipeline(*pipeline) | |
if is_train: | |
if not resampled: | |
assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
# roll over and repeat a few samples to get same number of full batches on each node | |
round_fn = math.floor if floor else math.ceil | |
global_batch_size = args.batch_size * args.world_size | |
num_batches = round_fn(num_samples / global_batch_size) | |
num_workers = max(1, args.workers) | |
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
num_batches = num_worker_batches * num_workers | |
num_samples = num_batches * global_batch_size | |
dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
else: | |
# last batches are partial, eval is done on single (master) node | |
num_batches = math.ceil(num_samples / args.batch_size) | |
dataloader = wds.WebLoader( | |
dataset, | |
batch_size=None, | |
shuffle=False, | |
num_workers=args.workers, | |
persistent_workers=True, | |
) | |
# FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
# hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
# if is_train: | |
# # roll over and repeat a few samples to get same number of full batches on each node | |
# global_batch_size = args.batch_size * args.world_size | |
# num_batches = math.ceil(num_samples / global_batch_size) | |
# num_workers = max(1, args.workers) | |
# num_batches = math.ceil(num_batches / num_workers) * num_workers | |
# num_samples = num_batches * global_batch_size | |
# dataloader = dataloader.with_epoch(num_batches) | |
# else: | |
# # last batches are partial, eval is done on single (master) node | |
# num_batches = math.ceil(num_samples / args.batch_size) | |
# add meta-data to dataloader instance for convenience | |
dataloader.num_batches = num_batches | |
dataloader.num_samples = num_samples | |
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
# def get_wds_dataset_cond_sdxl(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False, filter_lowres=False): | |
# input_shards = args.train_data if is_train else args.val_data | |
# assert input_shards is not None | |
# resampled = getattr(args, 'dataset_resampled', False) and is_train | |
# num_samples, num_shards = get_dataset_size(input_shards) | |
# if not num_samples: | |
# if is_train: | |
# num_samples = args.train_num_samples | |
# if not num_samples: | |
# raise RuntimeError( | |
# 'Currently, number of dataset samples must be specified for training dataset. ' | |
# 'Please specify via `--train-num-samples` if no dataset length info present.') | |
# else: | |
# num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
# shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
# if resampled: | |
# pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
# else: | |
# assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
# pipeline = [wds.SimpleShardList(input_shards)] | |
# # at this point we have an iterator over all the shards | |
# if is_train: | |
# if not resampled: | |
# pipeline.extend([ | |
# detshuffle2( | |
# bufsize=_SHARD_SHUFFLE_SIZE, | |
# initial=_SHARD_SHUFFLE_INITIAL, | |
# seed=args.seed, | |
# epoch=shared_epoch, | |
# ), | |
# wds.split_by_node, | |
# wds.split_by_worker, | |
# ]) | |
# pipeline.extend([ | |
# # at this point, we have an iterator over the shards assigned to each worker at each node | |
# tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
# wds.shuffle( | |
# bufsize=_SAMPLE_SHUFFLE_SIZE, | |
# initial=_SAMPLE_SHUFFLE_INITIAL, | |
# ), | |
# ]) | |
# else: | |
# pipeline.extend([ | |
# wds.split_by_worker, | |
# # at this point, we have an iterator over the shards assigned to each worker | |
# wds.tarfile_to_samples(handler=log_and_continue), | |
# ]) | |
# def pose2img(sample): | |
# height, width = sample["image"].height, sample["image"].width | |
# min_length = min(height, width) | |
# radius_body = max(int(4. * min_length / main_args.resolution), 4) | |
# thickness_body = max(int(2. * min_length / main_args.resolution), 2) | |
# radius_face = max(int(2. * min_length / main_args.resolution), 2) | |
# thickness_face = max(int(1. * min_length / main_args.resolution), 1) | |
# radius_hand = max(int(2. * min_length / main_args.resolution), 2) | |
# thickness_hand = max(int(1. * min_length / main_args.resolution), 1) | |
# # if "getty" in sample["__url__"]: | |
# # radius_body *= 4 | |
# # thickness_body *= 4 | |
# # radius_face *= 4 | |
# # thickness_face *= 4 | |
# # radius_hand *= 4 | |
# # thickness_hand *= 4 | |
# body_kp = np.frombuffer(sample["body_kp"], dtype=np.float32).reshape(17, 2) | |
# body_kpconf = np.frombuffer(sample["body_kpconf"], dtype=np.float32) | |
# body_all = np.concatenate([body_kp, body_kpconf[:, np.newaxis]], axis=1) | |
# body_all = body_all[np.newaxis, ...] | |
# body_draw = draw_body_skeleton( | |
# img=None, | |
# pose=body_all, | |
# radius=radius_body, | |
# thickness=thickness_body, | |
# height=height, | |
# width=width | |
# ) | |
# body_draw = Image.fromarray(body_draw) | |
# face_kp = np.frombuffer(sample["face_kp"], dtype=np.float32).reshape(-1, 98, 2) | |
# face_kpconf = np.frombuffer(sample["face_kpconf"], dtype=np.float32).reshape(-1, 98) | |
# face_all = np.concatenate([face_kp, face_kpconf[..., np.newaxis]], axis=2) | |
# face_draw = draw_face_skeleton( | |
# # img=np.array(img), | |
# img=None, | |
# pose=face_all, | |
# radius=radius_face, | |
# thickness=thickness_face, | |
# height=height, | |
# width=width | |
# ) | |
# face_draw = Image.fromarray(face_draw) | |
# hand_kp = np.frombuffer(sample["hand_kp"], dtype=np.float32).reshape(-1, 21, 2) | |
# hand_kpconf = np.frombuffer(sample["hand_kpconf"], dtype=np.float32).reshape(-1, 21) | |
# hand_all = np.concatenate([hand_kp, hand_kpconf[..., np.newaxis]], axis=2) | |
# hand_draw = draw_hand_skeleton( | |
# # img=np.array(img), | |
# img=None, | |
# pose=hand_all, | |
# radius=radius_hand, | |
# thickness=thickness_hand, | |
# height=height, | |
# width=width | |
# ) | |
# hand_draw = Image.fromarray(hand_draw) | |
# sample["body"] = body_draw | |
# sample["face"] = face_draw | |
# sample["hand"] = hand_draw | |
# return sample | |
# def decode_image(sample): | |
# with io.BytesIO(sample["omni_normal"]) as stream: | |
# try: | |
# img = PIL.Image.open(stream) | |
# img.load() | |
# img = img.convert("RGB") | |
# sample["normal"] = img | |
# except: | |
# print("A broken image is encountered, replace w/ a placeholder") | |
# image = Image.new('RGB', (512, 512)) | |
# sample["normal"] = image | |
# with io.BytesIO(sample["omni_depth"]) as stream: | |
# try: | |
# img = PIL.Image.open(stream) | |
# img.load() | |
# img = img.convert("RGB") | |
# sample["depth"] = img | |
# except: | |
# print("A broken image is encountered, replace w/ a placeholder") | |
# image = Image.new('RGB', (512, 512)) | |
# sample["depth"] = image | |
# return sample | |
# def add_canny(sample): | |
# canny = np.array(sample["image"]) | |
# low_threshold = 100 | |
# high_threshold = 200 | |
# canny = cv2.Canny(canny, low_threshold, high_threshold) | |
# canny = canny[:, :, None] | |
# canny = np.concatenate([canny, canny, canny], axis=2) | |
# sample["canny"] = Image.fromarray(canny) | |
# return sample | |
# def decode_text(sample): | |
# sample["blip"] = sample["blip"].decode("utf-8") | |
# sample["blip_raw"] = sample["blip"] | |
# sample["text_raw"] = sample["text"] | |
# return sample | |
# def augment_text(sample): | |
# if is_train and string_concat: | |
# sample["text"] = sample["text"] + " " + sample["blip"] | |
# if is_train and string_substitute: | |
# sample["text"] = sample["blip"] | |
# return sample | |
# def preprocess_image(sample): | |
# # print(main_args.resolution, main_args.center_crop, main_args.random_flip) | |
# if grid_dnc: | |
# resize_transform = transforms.Resize(main_args.resolution // 2, interpolation=transforms.InterpolationMode.BICUBIC) | |
# else: | |
# resize_transform = transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BICUBIC) | |
# sample["image"] = resize_transform(sample["image"]) | |
# sample["normal"] = resize_transform(sample["normal"]) | |
# sample["depth"] = resize_transform(sample["depth"]) | |
# sample["canny"] = resize_transform(sample["canny"]) | |
# sample["body"] = resize_transform(sample["body"]) | |
# sample["face"] = resize_transform(sample["face"]) | |
# sample["hand"] = resize_transform(sample["hand"]) | |
# transform_list = [] | |
# image_height, image_width = sample["image"].height, sample["image"].width | |
# if grid_dnc: | |
# i = torch.randint(0, image_height - main_args.resolution // 2 + 1, size=(1,)).item() | |
# j = torch.randint(0, image_width - main_args.resolution // 2 + 1, size=(1,)).item() | |
# else: | |
# i = torch.randint(0, image_height - main_args.resolution + 1, size=(1,)).item() | |
# j = torch.randint(0, image_width - main_args.resolution + 1, size=(1,)).item() | |
# if main_args.center_crop or not is_train: | |
# sample["description"]["crop_tl_h"] = (image_height - main_args.resolution) // 2 | |
# sample["description"]["crop_tl_w"] = (image_width - main_args.resolution) // 2 | |
# if grid_dnc: | |
# transform_list.append(transforms.CenterCrop(main_args.resolution // 2)) | |
# else: | |
# transform_list.append(transforms.CenterCrop(main_args.resolution)) | |
# else: | |
# if grid_dnc: | |
# if image_height < main_args.resolution // 2 or image_width < main_args.resolution // 2: | |
# raise ValueError(f"Required crop size {(main_args.resolution // 2, main_args.resolution // 2)} is larger than input image size {(image_height, image_width)}") | |
# elif image_width == main_args.resolution // 2 and image_height == main_args.resolution // 2: | |
# i, j = 0, 0 | |
# transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution // 2, main_args.resolution // 2))) | |
# else: | |
# if image_height < main_args.resolution or image_width < main_args.resolution: | |
# raise ValueError(f"Required crop size {(main_args.resolution, main_args.resolution)} is larger than input image size {(image_height, image_width)}") | |
# elif image_width == main_args.resolution and image_height == main_args.resolution: | |
# i, j = 0, 0 | |
# transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution, main_args.resolution))) | |
# sample["description"]["crop_tl_h"] = i | |
# sample["description"]["crop_tl_w"] = j | |
# if is_train and torch.rand(1) < 0.5: | |
# transform_list.append(transforms.RandomHorizontalFlip(p=1.)) | |
# transform_list.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) | |
# train_transforms = transforms.Compose(transform_list) | |
# sample["image"] = train_transforms(sample["image"]) | |
# sample["normal"] = train_transforms(sample["normal"]) | |
# sample["depth"] = train_transforms(sample["depth"]) | |
# sample["canny"] = train_transforms(sample["canny"]) | |
# sample["body"] = train_transforms(sample["body"]) | |
# sample["face"] = train_transforms(sample["face"]) | |
# sample["hand"] = train_transforms(sample["hand"]) | |
# return sample | |
# def random_mask(sample): | |
# if is_train and dropout: | |
# random_num = torch.rand(1) | |
# if random_num < 0.1: | |
# sample["normal"] = torch.ones_like(sample["normal"]) * (-1) | |
# sample["depth"] = torch.ones_like(sample["depth"]) * (-1) | |
# sample["canny"] = torch.ones_like(sample["canny"]) * (-1) | |
# sample["body"] = torch.ones_like(sample["body"]) * (-1) | |
# sample["face"] = torch.ones_like(sample["face"]) * (-1) | |
# sample["hand"] = torch.ones_like(sample["hand"]) * (-1) | |
# elif random_num > 0.9: | |
# pass | |
# else: | |
# if torch.rand(1) < 0.5: | |
# sample["normal"] = torch.ones_like(sample["normal"]) * (-1) | |
# if torch.rand(1) < 0.5: | |
# sample["depth"] = torch.ones_like(sample["depth"]) * (-1) | |
# if torch.rand(1) < 0.8: | |
# sample["canny"] = torch.ones_like(sample["canny"]) * (-1) | |
# if torch.rand(1) < 0.5: | |
# sample["body"] = torch.ones_like(sample["body"]) * (-1) | |
# if torch.rand(1) < 0.5: | |
# sample["face"] = torch.ones_like(sample["face"]) * (-1) | |
# if torch.rand(1) < 0.2: | |
# sample["hand"] = torch.ones_like(sample["hand"]) * (-1) | |
# return sample | |
# def make_grid_dnc(sample): | |
# if grid_dnc: | |
# resized_image = transforms.functional.resize(sample["image"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
# resized_depth = transforms.functional.resize(sample["depth"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
# resized_normal = transforms.functional.resize(sample["normal"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
# resized_canny = transforms.functional.resize(sample["canny"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
# grid = torch.cat([torch.cat([resized_image, resized_depth], dim=2), | |
# torch.cat([resized_normal, resized_canny], dim=2)], dim=1) | |
# assert grid.shape[1] == main_args.resolution and grid.shape[2] == main_args.resolution | |
# sample["image"] = grid | |
# return sample | |
# def filter_low_res(sample): | |
# if main_args.filter_res is None: | |
# main_args.filter_res = main_args.resolution | |
# if filter_lowres: | |
# string_json = sample["json"].decode('utf-8') | |
# dict_json = json.loads(string_json) | |
# if "height" in dict_json.keys() and "width" in dict_json.keys(): | |
# min_length = min(dict_json["height"], dict_json["width"]) | |
# return min_length >= main_args.filter_res | |
# else: | |
# return True | |
# return True | |
# def add_original_hw(sample): | |
# image_height, image_width = sample["image"].height, sample["image"].width | |
# sample["description"] = {"h": image_height, "w": image_width} | |
# return sample | |
# def add_description(sample): | |
# # string_json = sample["json"].decode('utf-8') | |
# # dict_json = json.loads(string_json) | |
# dict_json = sample["json"] | |
# if "height" in dict_json.keys() and "width" in dict_json.keys(): | |
# sample["description"]["h"] = dict_json["height"] | |
# sample["description"]["w"] = dict_json["width"] | |
# return sample | |
# pipeline.extend([ | |
# wds.select(filter_no_caption_or_no_image), | |
# wds.select(filter_low_res), | |
# wds.decode("pilrgb", handler=log_and_continue), | |
# wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
# wds.map(add_original_hw), | |
# wds.map(decode_text), | |
# wds.map(augment_text), | |
# wds.map(pose2img), | |
# wds.map(decode_image), | |
# wds.map(add_canny), | |
# wds.map(preprocess_image), | |
# wds.map(make_grid_dnc), | |
# wds.map(random_mask), | |
# wds.map(add_description), | |
# # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
# # wds.map_dict( | |
# # text=lambda text: tokenizer(text, \ | |
# # max_length=tokenizer.model_max_length, \ | |
# # padding="max_length", truncation=True, \ | |
# # return_tensors='pt')['input_ids'], | |
# # blip=lambda blip: tokenizer(blip, \ | |
# # max_length=tokenizer.model_max_length, \ | |
# # padding="max_length", truncation=True, \ | |
# # return_tensors='pt')['input_ids'] | |
# # ), | |
# wds.to_tuple("image", "text", "text_raw", "blip", "blip_raw", "body", "face", "hand", "normal", "depth", "canny", "description"), | |
# # wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), | |
# wds.batched(args.batch_size, partial=not is_train) | |
# ]) | |
# dataset = wds.DataPipeline(*pipeline) | |
# if is_train: | |
# if not resampled: | |
# assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
# # roll over and repeat a few samples to get same number of full batches on each node | |
# round_fn = math.floor if floor else math.ceil | |
# global_batch_size = args.batch_size * args.world_size | |
# num_batches = round_fn(num_samples / global_batch_size) | |
# num_workers = max(1, args.workers) | |
# num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
# num_batches = num_worker_batches * num_workers | |
# num_samples = num_batches * global_batch_size | |
# dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
# else: | |
# # last batches are partial, eval is done on single (master) node | |
# num_batches = math.ceil(num_samples / args.batch_size) | |
# dataloader = wds.WebLoader( | |
# dataset, | |
# batch_size=None, | |
# shuffle=False, | |
# num_workers=args.workers, | |
# persistent_workers=True, | |
# ) | |
# # FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
# # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
# # if is_train: | |
# # # roll over and repeat a few samples to get same number of full batches on each node | |
# # global_batch_size = args.batch_size * args.world_size | |
# # num_batches = math.ceil(num_samples / global_batch_size) | |
# # num_workers = max(1, args.workers) | |
# # num_batches = math.ceil(num_batches / num_workers) * num_workers | |
# # num_samples = num_batches * global_batch_size | |
# # dataloader = dataloader.with_epoch(num_batches) | |
# # else: | |
# # # last batches are partial, eval is done on single (master) node | |
# # num_batches = math.ceil(num_samples / args.batch_size) | |
# # add meta-data to dataloader instance for convenience | |
# dataloader.num_batches = num_batches | |
# dataloader.num_samples = num_samples | |
# return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
def get_wds_dataset_cond(args, main_args, is_train, epoch=0, floor=False, tokenizer=None, dropout=False, string_concat=False, string_substitute=False, grid_dnc=False, filter_lowres=False, filter_res=512, filter_mface=False, filter_wpose=False): | |
input_shards = args.train_data if is_train else args.val_data | |
assert input_shards is not None | |
resampled = getattr(args, 'dataset_resampled', False) and is_train | |
num_samples, num_shards = get_dataset_size(input_shards) | |
if not num_samples: | |
if is_train: | |
num_samples = args.train_num_samples | |
if not num_samples: | |
raise RuntimeError( | |
'Currently, number of dataset samples must be specified for training dataset. ' | |
'Please specify via `--train-num-samples` if no dataset length info present.') | |
else: | |
num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
if resampled: | |
pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
else: | |
assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
pipeline = [wds.SimpleShardList(input_shards)] | |
# at this point we have an iterator over all the shards | |
if is_train: | |
if not resampled: | |
pipeline.extend([ | |
detshuffle2( | |
bufsize=_SHARD_SHUFFLE_SIZE, | |
initial=_SHARD_SHUFFLE_INITIAL, | |
seed=args.seed, | |
epoch=shared_epoch, | |
), | |
wds.split_by_node, | |
wds.split_by_worker, | |
]) | |
pipeline.extend([ | |
# at this point, we have an iterator over the shards assigned to each worker at each node | |
tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
wds.shuffle( | |
bufsize=_SAMPLE_SHUFFLE_SIZE, | |
initial=_SAMPLE_SHUFFLE_INITIAL, | |
), | |
]) | |
else: | |
pipeline.extend([ | |
wds.split_by_worker, | |
# at this point, we have an iterator over the shards assigned to each worker | |
wds.tarfile_to_samples(handler=log_and_continue), | |
]) | |
def pose2img(sample, scale): | |
height, width = sample["image"].height, sample["image"].width | |
# min_length = min(height, width) | |
# radius_body = int(4. * min_length / main_args.resolution) | |
# thickness_body = int(4. * min_length / main_args.resolution) | |
# radius_face = int(1.5 * min_length / main_args.resolution) | |
# thickness_face = int(2. * min_length / main_args.resolution) | |
# radius_hand = int(1.5 * min_length / main_args.resolution) | |
# thickness_hand = int(2. * min_length / main_args.resolution) | |
# if "getty" in sample["__url__"]: | |
# radius_body *= 4 | |
# thickness_body *= 4 | |
# radius_face *= 4 | |
# thickness_face *= 4 | |
# radius_hand *= 4 | |
# thickness_hand *= 4 | |
try: | |
location = np.frombuffer(sample["location"], dtype=np.float32) | |
body_kp = np.frombuffer(sample["new_i_body_kp"], dtype=np.float32).reshape(-1, 17, 2) | |
x_coord = (body_kp[:, :, 0] - location[0]) / location[2] * location[7] | |
y_coord = (body_kp[:, :, 1] - location[1]) / location[3] * location[8] | |
body_kp = np.stack([x_coord, y_coord], axis=2) | |
body_kp = body_kp * scale | |
# body_kp[:, :, 0] -= j | |
# body_kp[:, :, 1] -= i | |
body_kpconf = np.frombuffer(sample["new_i_body_kp_score"], dtype=np.float32).reshape(-1, 17) | |
body_all = np.concatenate([body_kp, body_kpconf[..., np.newaxis]], axis=2) | |
except: | |
body_kp = np.frombuffer(sample["new_body_kp"], dtype=np.float32).reshape(-1, 17, 2) | |
body_kp = body_kp * scale | |
# body_kp[:, :, 0] -= j | |
# body_kp[:, :, 1] -= i | |
body_kpconf = np.frombuffer(sample["new_body_kp_score"], dtype=np.float32).reshape(-1, 17) | |
body_all = np.concatenate([body_kp, body_kpconf[..., np.newaxis]], axis=2) | |
# body_ratio = 0. | |
# for i_body in range(body_kp.shape[0]): | |
# body_ratio = max((np.max(body_kp[i_body, :, 0]) - np.min(body_kp[i_body, :, 0])) / min_length, body_ratio) | |
# print(body_ratio) | |
# body_kp = np.frombuffer(sample["new_body_kp"], dtype=np.float32).reshape(-1, 17, 2) | |
# body_kpconf = np.frombuffer(sample["new_body_kp_score"], dtype=np.float32).reshape(-1, 17) | |
# body_all = np.concatenate([body_kp, body_kpconf[..., np.newaxis]], axis=2) | |
# body_draw = draw_controlnet_skeleton(image=None, pose=body_all, height=height, width=width) | |
# body_draw = draw_humansd_skeleton(image=None, pose=body_all, height=height, width=width, humansd_skeleton_width=int(10. * body_ratio * min_length / main_args.resolution)) | |
body_draw = draw_humansd_skeleton( | |
# image=np.array(sample["image"]), | |
image=None, | |
pose=body_all, | |
height=height, | |
width=width, | |
humansd_skeleton_width=int(10 * main_args.resolution / 512), | |
) | |
# body_draw = draw_body_skeleton( | |
# img=None, | |
# pose=body_all, | |
# radius=radius_body, | |
# thickness=thickness_body, | |
# height=height, | |
# width=width | |
# ) | |
body_draw = Image.fromarray(body_draw) | |
try: | |
location = np.frombuffer(sample["location"], dtype=np.float32) | |
face_kp = np.frombuffer(sample["new_i_face_kp"], dtype=np.float32).reshape(-1, 68, 2) | |
x_coord = (face_kp[:, :, 0] - location[0]) / location[2] * location[7] | |
y_coord = (face_kp[:, :, 1] - location[1]) / location[3] * location[8] | |
face_kp = np.stack([x_coord, y_coord], axis=2) | |
face_kp = face_kp * scale | |
# face_kp[:, :, 0] -= j | |
# face_kp[:, :, 1] -= i | |
face_kpconf = np.frombuffer(sample["new_i_face_kp_score"], dtype=np.float32).reshape(-1, 68) | |
face_all = np.concatenate([face_kp, face_kpconf[..., np.newaxis]], axis=2) | |
except: | |
face_kp = np.frombuffer(sample["new_face_kp"], dtype=np.float32).reshape(-1, 68, 2) | |
face_kp = face_kp * scale | |
# face_kp[:, :, 0] -= j | |
# face_kp[:, :, 1] -= i | |
face_kpconf = np.frombuffer(sample["new_face_kp_score"], dtype=np.float32).reshape(-1, 68) | |
face_all = np.concatenate([face_kp, face_kpconf[..., np.newaxis]], axis=2) | |
face_draw = draw_face_skeleton( | |
# img=np.array(sample["image"]), | |
img=None, | |
pose=face_all, | |
# radius=radius_face, | |
# thickness=thickness_face, | |
height=height, | |
width=width, | |
) | |
face_draw = Image.fromarray(face_draw) | |
try: | |
location = np.frombuffer(sample["location"], dtype=np.float32) | |
hand_kp = np.frombuffer(sample["new_i_hand_kp"], dtype=np.float32).reshape(-1, 21, 2) | |
x_coord = (hand_kp[:, :, 0] - location[0]) / location[2] * location[7] | |
y_coord = (hand_kp[:, :, 1] - location[1]) / location[3] * location[8] | |
hand_kp = np.stack([x_coord, y_coord], axis=2) | |
hand_kp = hand_kp * scale | |
# hand_kp[:, :, 0] -= j | |
# hand_kp[:, :, 1] -= i | |
hand_kpconf = np.frombuffer(sample["new_i_hand_kp_score"], dtype=np.float32).reshape(-1, 21) | |
hand_all = np.concatenate([hand_kp, hand_kpconf[..., np.newaxis]], axis=2) | |
except: | |
hand_kp = np.frombuffer(sample["new_hand_kp"], dtype=np.float32).reshape(-1, 21, 2) | |
hand_kp = hand_kp * scale | |
# hand_kp[:, :, 0] -= j | |
# hand_kp[:, :, 1] -= i | |
hand_kpconf = np.frombuffer(sample["new_hand_kp_score"], dtype=np.float32).reshape(-1, 21) | |
hand_all = np.concatenate([hand_kp, hand_kpconf[..., np.newaxis]], axis=2) | |
hand_draw = draw_hand_skeleton( | |
# img=np.array(sample["image"]), | |
img=None, | |
pose=hand_all, | |
# radius=radius_hand, | |
# thickness=thickness_hand, | |
height=height, | |
width=width, | |
) | |
hand_draw = Image.fromarray(hand_draw) | |
# whole_kp = np.frombuffer(sample["new_wholebody_kp"], dtype=np.float32).reshape(-1, 133, 2) | |
# whole_kpconf = np.frombuffer(sample["new_wholebody_kp_score"], dtype=np.float32).reshape(-1, 133) | |
# whole_all = np.concatenate([whole_kp, whole_kpconf[..., np.newaxis]], axis=2) | |
try: | |
location = np.frombuffer(sample["location"], dtype=np.float32) | |
whole_kp = np.frombuffer(sample["new_i_wholebody_kp"], dtype=np.float32).reshape(-1, 133, 2) | |
x_coord = (whole_kp[:, :, 0] - location[0]) / location[2] * location[7] | |
y_coord = (whole_kp[:, :, 1] - location[1]) / location[3] * location[8] | |
whole_kp = np.stack([x_coord, y_coord], axis=2) | |
whole_kp = whole_kp * scale | |
# whole_kp[:, :, 0] -= j | |
# whole_kp[:, :, 1] -= i | |
whole_kpconf = np.frombuffer(sample["new_i_wholebody_kp_score"], dtype=np.float32).reshape(-1, 133) | |
whole_all = np.concatenate([whole_kp, whole_kpconf[..., np.newaxis]], axis=2) | |
except: | |
whole_kp = np.frombuffer(sample["new_wholebody_kp"], dtype=np.float32).reshape(-1, 133, 2) | |
whole_kp = whole_kp * scale | |
# whole_kp[:, :, 0] -= j | |
# whole_kp[:, :, 1] -= i | |
whole_kpconf = np.frombuffer(sample["new_wholebody_kp_score"], dtype=np.float32).reshape(-1, 133) | |
whole_all = np.concatenate([whole_kp, whole_kpconf[..., np.newaxis]], axis=2) | |
whole_draw = draw_whole_body_skeleton( | |
# img=np.array(sample["image"]), | |
img=None, | |
pose=whole_all, | |
# radius=radius_body, | |
# thickness=thickness_body, | |
height=height, | |
width=width, | |
) | |
whole_draw = Image.fromarray(whole_draw) | |
sample["body"] = body_draw | |
sample["face"] = face_draw | |
sample["hand"] = hand_draw | |
if main_args.change_whole_to_body: | |
sample["whole"] = body_draw | |
else: | |
sample["whole"] = whole_draw | |
return sample | |
def decode_image(sample): | |
with io.BytesIO(sample["omni_normal"]) as stream: | |
try: | |
img = PIL.Image.open(stream) | |
img.load() | |
img = img.convert("RGB") | |
img = transforms.Resize((sample["image"].height, sample["image"].width), interpolation=transforms.InterpolationMode.BICUBIC)(img) | |
sample["normal"] = img | |
except: | |
print("A broken image is encountered, replace w/ a placeholder") | |
image = Image.new('RGB', (main_args.resolution, main_args.resolution)) | |
sample["normal"] = image | |
with io.BytesIO(sample["omni_depth"]) as stream: | |
try: | |
img = PIL.Image.open(stream) | |
img.load() | |
img = img.convert("RGB") | |
img = transforms.Resize((sample["image"].height, sample["image"].width), interpolation=transforms.InterpolationMode.BICUBIC)(img) | |
sample["depth"] = img | |
except: | |
print("A broken image is encountered, replace w/ a placeholder") | |
image = Image.new('RGB', (main_args.resolution, main_args.resolution)) | |
sample["depth"] = image | |
with io.BytesIO(sample["midas_depth"]) as stream: | |
try: | |
img = PIL.Image.open(stream) | |
img.load() | |
img = img.convert("RGB") | |
img = transforms.Resize((sample["image"].height, sample["image"].width), interpolation=transforms.InterpolationMode.BICUBIC)(img) | |
sample["midas_depth"] = img | |
except: | |
print("A broken image is encountered, replace w/ a placeholder") | |
image = Image.new('RGB', (main_args.resolution, main_args.resolution)) | |
sample["midas_depth"] = image | |
return sample | |
def add_canny(sample): | |
canny = np.array(sample["image"]) | |
low_threshold = 100 | |
high_threshold = 200 | |
canny = cv2.Canny(canny, low_threshold, high_threshold) | |
canny = canny[:, :, None] | |
canny = np.concatenate([canny, canny, canny], axis=2) | |
sample["canny"] = Image.fromarray(canny) | |
return sample | |
def decode_text(sample): | |
try: | |
sample["blip"] = sample["blip"].decode("utf-8") | |
sample["blip_raw"] = sample["blip"] | |
except: | |
sample["blip"] = sample["text"] | |
sample["blip_raw"] = sample["text"].encode("utf-8") | |
sample["text_raw"] = sample["text"] | |
return sample | |
def augment_text(sample): | |
if is_train and string_concat: | |
sample["text"] = sample["text"] + " " + sample["blip"] | |
if is_train and string_substitute: | |
if main_args.rv_prompt: | |
sample["text"] = "RAW photo, " + sample["blip"] + ", 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" | |
else: | |
sample["text"] = sample["blip"] | |
return sample | |
def dropout_text(sample): | |
if is_train: | |
try: | |
random_num = torch.rand(1) | |
if random_num < main_args.dropout_text: | |
sample["text"] = sample["text_raw"] = "" | |
except: | |
pass | |
return sample | |
def preprocess_image(sample): | |
# print(main_args.resolution, main_args.center_crop, main_args.random_flip) | |
if grid_dnc: | |
resize_transform = transforms.Resize(main_args.resolution // 2, interpolation=transforms.InterpolationMode.BICUBIC) | |
else: | |
resize_transform = transforms.Resize(main_args.resolution, interpolation=transforms.InterpolationMode.BICUBIC) | |
scale = main_args.resolution * 1. / min(sample["image"].height, sample["image"].width) | |
sample["image"] = resize_transform(sample["image"]) | |
sample["normal"] = resize_transform(sample["normal"]) | |
sample["depth"] = resize_transform(sample["depth"]) | |
sample["midas_depth"] = resize_transform(sample["midas_depth"]) | |
sample["canny"] = resize_transform(sample["canny"]) | |
# sample["body"] = resize_transform(sample["body"]) | |
# sample["face"] = resize_transform(sample["face"]) | |
# sample["hand"] = resize_transform(sample["hand"]) | |
# sample["whole"] = resize_transform(sample["whole"]) | |
transform_list = [] | |
image_height, image_width = sample["image"].height, sample["image"].width | |
if grid_dnc: | |
i = torch.randint(0, image_height - main_args.resolution // 2 + 1, size=(1,)).item() | |
j = torch.randint(0, image_width - main_args.resolution // 2 + 1, size=(1,)).item() | |
else: | |
i = torch.randint(0, image_height - main_args.resolution + 1, size=(1,)).item() | |
j = torch.randint(0, image_width - main_args.resolution + 1, size=(1,)).item() | |
if main_args.center_crop or not is_train: | |
sample["description"]["crop_tl_h"] = i = (image_height - main_args.resolution) // 2 | |
sample["description"]["crop_tl_w"] = j = (image_width - main_args.resolution) // 2 | |
if grid_dnc: | |
transform_list.append(transforms.CenterCrop(main_args.resolution // 2)) | |
else: | |
transform_list.append(transforms.CenterCrop(main_args.resolution)) | |
else: | |
if grid_dnc: | |
if image_height < main_args.resolution // 2 or image_width < main_args.resolution // 2: | |
raise ValueError(f"Required crop size {(main_args.resolution // 2, main_args.resolution // 2)} is larger than input image size {(image_height, image_width)}") | |
elif image_width == main_args.resolution // 2 and image_height == main_args.resolution // 2: | |
i, j = 0, 0 | |
transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution // 2, main_args.resolution // 2))) | |
else: | |
if image_height < main_args.resolution or image_width < main_args.resolution: | |
raise ValueError(f"Required crop size {(main_args.resolution, main_args.resolution)} is larger than input image size {(image_height, image_width)}") | |
elif image_width == main_args.resolution and image_height == main_args.resolution: | |
i, j = 0, 0 | |
transform_list.append(transforms.Lambda(lambda img: transforms.functional.crop(img, i, j, main_args.resolution, main_args.resolution))) | |
sample["description"]["crop_tl_h"] = i | |
sample["description"]["crop_tl_w"] = j | |
sample = pose2img(sample, scale) | |
if is_train and torch.rand(1) < 0.5: | |
transform_list.append(transforms.RandomHorizontalFlip(p=1.)) | |
transform_list.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) | |
train_transforms = transforms.Compose(transform_list) | |
sample["image"] = train_transforms(sample["image"]) | |
sample["normal"] = train_transforms(sample["normal"]) | |
sample["depth"] = train_transforms(sample["depth"]) | |
sample["midas_depth"] = train_transforms(sample["midas_depth"]) | |
sample["canny"] = train_transforms(sample["canny"]) | |
sample["body"] = train_transforms(sample["body"]) | |
sample["face"] = train_transforms(sample["face"]) | |
sample["hand"] = train_transforms(sample["hand"]) | |
sample["whole"] = train_transforms(sample["whole"]) | |
return sample | |
def random_mask(sample): | |
sample["normal_ori"] = sample["normal"].clone() | |
sample["depth_ori"] = sample["depth"].clone() | |
sample["midas_depth_ori"] = sample["midas_depth"].clone() | |
sample["canny_ori"] = sample["canny"].clone() | |
sample["body_ori"] = sample["body"].clone() | |
sample["face_ori"] = sample["face"].clone() | |
sample["hand_ori"] = sample["hand"].clone() | |
sample["whole_ori"] = sample["whole"].clone() | |
mask_list = [] | |
if is_train and dropout: | |
random_num = torch.rand(1) | |
if random_num < 0.15: | |
sample["normal"] = torch.ones_like(sample["normal"]) * (-1) | |
sample["depth"] = torch.ones_like(sample["depth"]) * (-1) | |
sample["midas_depth"] = torch.ones_like(sample["midas_depth"]) * (-1) | |
sample["canny"] = torch.ones_like(sample["canny"]) * (-1) | |
sample["body"] = torch.ones_like(sample["body"]) * (-1) | |
sample["face"] = torch.ones_like(sample["face"]) * (-1) | |
sample["hand"] = torch.ones_like(sample["hand"]) * (-1) | |
sample["whole"] = torch.ones_like(sample["whole"]) * (-1) | |
mask_list = ["normal", "depth", "midas_depth", "canny", "body", "face", "hand", "whole"] | |
elif random_num > 0.9: | |
pass | |
else: | |
if torch.rand(1) < 0.5: | |
sample["normal"] = torch.ones_like(sample["normal"]) * (-1) | |
mask_list.append("normal") | |
if torch.rand(1) < 0.5: | |
sample["depth"] = torch.ones_like(sample["depth"]) * (-1) | |
mask_list.append("depth") | |
if torch.rand(1) < 0.5: | |
sample["midas_depth"] = torch.ones_like(sample["midas_depth"]) * (-1) | |
mask_list.append("midas_depth") | |
if torch.rand(1) < 0.8: | |
sample["canny"] = torch.ones_like(sample["canny"]) * (-1) | |
mask_list.append("canny") | |
if torch.rand(1) < 0.5: | |
sample["body"] = torch.ones_like(sample["body"]) * (-1) | |
mask_list.append("body") | |
if torch.rand(1) < 0.5: | |
sample["face"] = torch.ones_like(sample["face"]) * (-1) | |
mask_list.append("face") | |
if torch.rand(1) < 0.2: | |
sample["hand"] = torch.ones_like(sample["hand"]) * (-1) | |
mask_list.append("hand") | |
if torch.rand(1) < 0.5: | |
sample["whole"] = torch.ones_like(sample["whole"]) * (-1) | |
mask_list.append("whole") | |
sample["normal_dt"] = sample["normal"].clone() | |
sample["depth_dt"] = sample["depth"].clone() | |
sample["midas_depth_dt"] = sample["midas_depth"].clone() | |
sample["canny_dt"] = sample["canny"].clone() | |
sample["body_dt"] = sample["body"].clone() | |
sample["face_dt"] = sample["face"].clone() | |
sample["hand_dt"] = sample["hand"].clone() | |
sample["whole_dt"] = sample["whole"].clone() | |
mask_list = [x for x in mask_list if x in main_args.cond_type] | |
if len(mask_list) > 0: | |
target = random.choice(mask_list) | |
sample[target + "_dt"] = sample[target + "_ori"].clone() | |
else: | |
if len(main_args.cond_type) > 0: | |
target = random.choice(main_args.cond_type) | |
sample[target + "_dt"] = torch.ones_like(sample[target]) * (-1) | |
return sample | |
def make_grid_dnc(sample): | |
if grid_dnc: | |
resized_image = transforms.functional.resize(sample["image"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
resized_depth = transforms.functional.resize(sample["depth"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
resized_normal = transforms.functional.resize(sample["normal"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
resized_canny = transforms.functional.resize(sample["body"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
# resized_canny = transforms.functional.resize(sample["canny"], (main_args.resolution // 2, main_args.resolution // 2), interpolation=transforms.InterpolationMode.BICUBIC) | |
grid = torch.cat([torch.cat([resized_image, resized_depth], dim=2), | |
torch.cat([resized_normal, resized_canny], dim=2)], dim=1) | |
assert grid.shape[1] == main_args.resolution and grid.shape[2] == main_args.resolution | |
sample["image"] = grid | |
return sample | |
def filter_low_res(sample): | |
if main_args.filter_res is None: | |
main_args.filter_res = main_args.resolution | |
if filter_lowres: | |
# string_json = sample["json"].decode('utf-8') | |
# dict_json = json.loads(string_json) | |
dict_json = sample["json"] | |
if "height" in dict_json.keys() and "width" in dict_json.keys(): | |
min_length = min(dict_json["height"], dict_json["width"]) | |
return min_length >= main_args.filter_res | |
else: | |
min_length = min(sample["image"].height, sample["image"].width) | |
return min_length >= main_args.filter_res | |
return True | |
def filter_watermark(sample): | |
if main_args.filter_wm: | |
if sample["description"]["watermark"] >= 100: | |
return False | |
return True | |
def add_original_hw(sample): | |
image_height, image_width = sample["image"].height, sample["image"].width | |
sample["description"] = {"h": image_height, "w": image_width} | |
return sample | |
def add_description(sample): | |
# string_json = sample["json"].decode('utf-8') | |
# dict_json = json.loads(string_json) | |
try: | |
dict_json = sample["json"] | |
if "height" in dict_json.keys() and "width" in dict_json.keys(): | |
sample["description"]["h"] = dict_json["height"] | |
sample["description"]["w"] = dict_json["width"] | |
# try: | |
if "coyo" in sample["__url__"]: | |
sample["description"]["aes"] = torch.tensor(sample["json"]["aesthetic_score_laion_v2"] * 1e2) | |
sample["description"]["watermark"] = torch.tensor(sample["json"]["watermark_score"] * 1e3) | |
elif "laion" in sample["__url__"]: | |
sample["description"]["aes"] = torch.tensor(np.frombuffer(sample["aesthetic_score_laion_v2"], dtype=np.float32) * 1e2) | |
sample["description"]["watermark"] = torch.tensor(np.frombuffer(sample["watermark_score"], dtype=np.float32) * 1e3) | |
elif "getty" in sample["__url__"]: | |
sample["description"]["aes"] = torch.tensor(np.frombuffer(sample["aesthetic_score_laion_v2"], dtype=np.float32) * 1e2) | |
sample["description"]["watermark"] = torch.tensor(float(sample["json"]["display_sizes"][-1]["is_watermarked"] or 0) * 1e3) | |
elif "fake" in sample["__url__"]: | |
sample["description"]["aes"] = torch.tensor(random.uniform(5.5, 6.0) * 1e2) | |
sample["description"]["watermark"] = torch.tensor(random.uniform(0., 0.1) * 1e3) | |
except: | |
# sample["description"]["h"] = | |
# sample["description"]["w"] = | |
sample["description"]["aes"] = torch.tensor(random.uniform(5.5, 6.0) * 1e2) | |
sample["description"]["watermark"] = torch.tensor(random.uniform(0., 0.1) * 1e3) | |
# except: | |
# sample["description"]["aes"] = 0. | |
# sample["description"]["watermark"] = 0. | |
return sample | |
def filter_multi_face(sample): | |
if filter_mface: | |
face_kp = np.frombuffer(sample["new_face_kp"], dtype=np.float32).reshape(-1, 68, 2) | |
if face_kp.shape[0] > 1: | |
return False | |
return True | |
def filter_whole_skeleton(sample): | |
if filter_wpose: | |
height, width = sample["image"].height, sample["image"].width | |
area = height * width | |
body_kp = np.frombuffer(sample["new_body_kp"], dtype=np.float32).reshape(-1, 17, 2) | |
body_kpconf = np.frombuffer(sample["new_body_kp_score"], dtype=np.float32).reshape(-1, 17) | |
if (body_kp.shape[0] == 1) and (body_kpconf > 0.5).all() and (body_kp[0, :15, 0] > 0).all() \ | |
and (body_kp[0, :15, 0] < width).all() and (body_kp[0, :15, 1] > 0).all() and \ | |
(body_kp[0, :15, 1] < height).all(): | |
x_min = max(np.amin(body_kp[0, :, 0]), 0) | |
x_max = min(np.amax(body_kp[0, :, 0]), width) | |
y_min = max(np.amin(body_kp[0, :, 1]), 0) | |
y_max = min(np.amax(body_kp[0, :, 1]), height) | |
if (x_max - x_min) * (y_max - y_min) / area > 0.2: | |
return True | |
else: | |
return False | |
else: | |
return False | |
return True | |
pipeline.extend([ | |
wds.select(filter_no_caption_or_no_image), | |
wds.select(filter_multi_face), | |
wds.decode("pilrgb", handler=log_and_continue), | |
wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
wds.select(filter_whole_skeleton), | |
wds.select(filter_low_res), | |
wds.map(add_original_hw), | |
wds.map(decode_text), | |
wds.map(augment_text), | |
wds.map(dropout_text), | |
# wds.map(pose2img), | |
wds.map(decode_image), | |
wds.map(add_canny), | |
wds.map(preprocess_image), | |
wds.map(make_grid_dnc), | |
wds.map(random_mask), | |
wds.map(add_description), | |
wds.select(filter_watermark), | |
# wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
wds.map_dict( | |
text=lambda text: tokenizer(text, \ | |
max_length=tokenizer.model_max_length, \ | |
padding="max_length", truncation=True, \ | |
return_tensors='pt')['input_ids'], | |
blip=lambda blip: tokenizer(blip, \ | |
max_length=tokenizer.model_max_length, \ | |
padding="max_length", truncation=True, \ | |
return_tensors='pt')['input_ids'] | |
), | |
wds.to_tuple("image", "text", "text_raw", "blip", "blip_raw", \ | |
"body", "face", "hand", "normal", "depth", "midas_depth", "canny", "whole", "description", \ | |
"body_ori", "face_ori", "hand_ori", "normal_ori", "depth_ori", "midas_depth_ori", "canny_ori", "whole_ori", \ | |
"body_dt", "face_dt", "hand_dt", "normal_dt", "depth_dt", "midas_depth_dt", "canny_dt", "whole_dt"), | |
# wds.to_tuple("image", "text", "blip", "normal", "depth", "canny"), | |
wds.batched(args.batch_size, partial=not is_train) | |
]) | |
dataset = wds.DataPipeline(*pipeline) | |
if is_train: | |
if not resampled: | |
assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
# roll over and repeat a few samples to get same number of full batches on each node | |
round_fn = math.floor if floor else math.ceil | |
global_batch_size = args.batch_size * args.world_size | |
num_batches = round_fn(num_samples / global_batch_size) | |
num_workers = max(1, args.workers) | |
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
num_batches = num_worker_batches * num_workers | |
num_samples = num_batches * global_batch_size | |
dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
else: | |
# last batches are partial, eval is done on single (master) node | |
num_batches = math.ceil(num_samples / args.batch_size) | |
dataloader = wds.WebLoader( | |
dataset, | |
batch_size=None, | |
shuffle=False, | |
num_workers=args.workers, | |
persistent_workers=True, | |
) | |
# FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
# hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
# if is_train: | |
# # roll over and repeat a few samples to get same number of full batches on each node | |
# global_batch_size = args.batch_size * args.world_size | |
# num_batches = math.ceil(num_samples / global_batch_size) | |
# num_workers = max(1, args.workers) | |
# num_batches = math.ceil(num_batches / num_workers) * num_workers | |
# num_samples = num_batches * global_batch_size | |
# dataloader = dataloader.with_epoch(num_batches) | |
# else: | |
# # last batches are partial, eval is done on single (master) node | |
# num_batches = math.ceil(num_samples / args.batch_size) | |
# add meta-data to dataloader instance for convenience | |
dataloader.num_batches = num_batches | |
dataloader.num_samples = num_samples | |
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
def get_wds_dataset_img(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None): | |
input_shards = args.train_data if is_train else args.val_data | |
assert input_shards is not None | |
resampled = getattr(args, 'dataset_resampled', False) and is_train | |
num_samples, num_shards = get_dataset_size(input_shards) | |
if not num_samples: | |
if is_train: | |
num_samples = args.train_num_samples | |
if not num_samples: | |
raise RuntimeError( | |
'Currently, number of dataset samples must be specified for training dataset. ' | |
'Please specify via `--train-num-samples` if no dataset length info present.') | |
else: | |
num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
if resampled: | |
pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
else: | |
assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
pipeline = [wds.SimpleShardList(input_shards)] | |
# at this point we have an iterator over all the shards | |
if is_train: | |
if not resampled: | |
pipeline.extend([ | |
detshuffle2( | |
bufsize=_SHARD_SHUFFLE_SIZE, | |
initial=_SHARD_SHUFFLE_INITIAL, | |
seed=args.seed, | |
epoch=shared_epoch, | |
), | |
wds.split_by_node, | |
wds.split_by_worker, | |
]) | |
pipeline.extend([ | |
# at this point, we have an iterator over the shards assigned to each worker at each node | |
tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
wds.shuffle( | |
bufsize=_SAMPLE_SHUFFLE_SIZE, | |
initial=_SAMPLE_SHUFFLE_INITIAL, | |
), | |
]) | |
else: | |
pipeline.extend([ | |
wds.split_by_worker, | |
# at this point, we have an iterator over the shards assigned to each worker | |
wds.tarfile_to_samples(handler=log_and_continue), | |
]) | |
pipeline.extend([ | |
wds.select(filter_no_caption_or_no_image), | |
wds.decode("pilrgb", handler=log_and_continue), | |
wds.rename(image="jpg;png;jpeg;webp"), | |
# wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
wds.map_dict(image=preprocess_img), | |
wds.to_tuple("image"), | |
wds.batched(args.batch_size, partial=not is_train) | |
]) | |
dataset = wds.DataPipeline(*pipeline) | |
if is_train: | |
if not resampled: | |
assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
# roll over and repeat a few samples to get same number of full batches on each node | |
round_fn = math.floor if floor else math.ceil | |
global_batch_size = args.batch_size * args.world_size | |
num_batches = round_fn(num_samples / global_batch_size) | |
num_workers = max(1, args.workers) | |
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
num_batches = num_worker_batches * num_workers | |
num_samples = num_batches * global_batch_size | |
dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
else: | |
# last batches are partial, eval is done on single (master) node | |
num_batches = math.ceil(num_samples / args.batch_size) | |
dataloader = wds.WebLoader( | |
dataset, | |
batch_size=None, | |
shuffle=False, | |
num_workers=args.workers, | |
persistent_workers=True, | |
) | |
# FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
# hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
# if is_train: | |
# # roll over and repeat a few samples to get same number of full batches on each node | |
# global_batch_size = args.batch_size * args.world_size | |
# num_batches = math.ceil(num_samples / global_batch_size) | |
# num_workers = max(1, args.workers) | |
# num_batches = math.ceil(num_batches / num_workers) * num_workers | |
# num_samples = num_batches * global_batch_size | |
# dataloader = dataloader.with_epoch(num_batches) | |
# else: | |
# # last batches are partial, eval is done on single (master) node | |
# num_batches = math.ceil(num_samples / args.batch_size) | |
# add meta-data to dataloader instance for convenience | |
dataloader.num_batches = num_batches | |
dataloader.num_samples = num_samples | |
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None): | |
input_shards = args.train_data if is_train else args.val_data | |
assert input_shards is not None | |
resampled = getattr(args, 'dataset_resampled', False) and is_train | |
num_samples, num_shards = get_dataset_size(input_shards) | |
if not num_samples: | |
if is_train: | |
num_samples = args.train_num_samples | |
if not num_samples: | |
raise RuntimeError( | |
'Currently, number of dataset samples must be specified for training dataset. ' | |
'Please specify via `--train-num-samples` if no dataset length info present.') | |
else: | |
num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified | |
shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
if resampled: | |
pipeline = [ResampledShards2(input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch)] | |
else: | |
assert args.train_data_upsampling_factors is None, "--train_data_upsampling_factors is only supported when sampling with replacement (together with --dataset-resampled)." | |
pipeline = [wds.SimpleShardList(input_shards)] | |
# at this point we have an iterator over all the shards | |
if is_train: | |
if not resampled: | |
pipeline.extend([ | |
detshuffle2( | |
bufsize=_SHARD_SHUFFLE_SIZE, | |
initial=_SHARD_SHUFFLE_INITIAL, | |
seed=args.seed, | |
epoch=shared_epoch, | |
), | |
wds.split_by_node, | |
wds.split_by_worker, | |
]) | |
pipeline.extend([ | |
# at this point, we have an iterator over the shards assigned to each worker at each node | |
tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
wds.shuffle( | |
bufsize=_SAMPLE_SHUFFLE_SIZE, | |
initial=_SAMPLE_SHUFFLE_INITIAL, | |
), | |
]) | |
else: | |
pipeline.extend([ | |
wds.split_by_worker, | |
# at this point, we have an iterator over the shards assigned to each worker | |
wds.tarfile_to_samples(handler=log_and_continue), | |
]) | |
pipeline.extend([ | |
wds.select(filter_no_caption_or_no_image), | |
wds.decode("pilrgb", handler=log_and_continue), | |
wds.rename(image="jpg;png;jpeg;webp", text="txt"), | |
# wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), | |
wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text, | |
max_length=tokenizer.model_max_length, | |
padding="max_length", | |
truncation=True, | |
return_tensors='pt')['input_ids']), | |
wds.to_tuple("image", "text"), | |
wds.batched(args.batch_size, partial=not is_train) | |
]) | |
dataset = wds.DataPipeline(*pipeline) | |
if is_train: | |
if not resampled: | |
assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' | |
# roll over and repeat a few samples to get same number of full batches on each node | |
round_fn = math.floor if floor else math.ceil | |
global_batch_size = args.batch_size * args.world_size | |
num_batches = round_fn(num_samples / global_batch_size) | |
num_workers = max(1, args.workers) | |
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
num_batches = num_worker_batches * num_workers | |
num_samples = num_batches * global_batch_size | |
dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
else: | |
# last batches are partial, eval is done on single (master) node | |
num_batches = math.ceil(num_samples / args.batch_size) | |
dataloader = wds.WebLoader( | |
dataset, | |
batch_size=None, | |
shuffle=False, | |
num_workers=args.workers, | |
persistent_workers=True, | |
) | |
# FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
# hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
# if is_train: | |
# # roll over and repeat a few samples to get same number of full batches on each node | |
# global_batch_size = args.batch_size * args.world_size | |
# num_batches = math.ceil(num_samples / global_batch_size) | |
# num_workers = max(1, args.workers) | |
# num_batches = math.ceil(num_batches / num_workers) * num_workers | |
# num_samples = num_batches * global_batch_size | |
# dataloader = dataloader.with_epoch(num_batches) | |
# else: | |
# # last batches are partial, eval is done on single (master) node | |
# num_batches = math.ceil(num_samples / args.batch_size) | |
# add meta-data to dataloader instance for convenience | |
dataloader.num_batches = num_batches | |
dataloader.num_samples = num_samples | |
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) | |
def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): | |
input_filename = args.train_data if is_train else args.val_data | |
assert input_filename | |
dataset = CsvDataset( | |
input_filename, | |
preprocess_fn, | |
img_key=args.csv_img_key, | |
caption_key=args.csv_caption_key, | |
sep=args.csv_separator, | |
tokenizer=tokenizer | |
) | |
num_samples = len(dataset) | |
sampler = DistributedSampler(dataset) if args.distributed and is_train else None | |
shuffle = is_train and sampler is None | |
dataloader = DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
shuffle=shuffle, | |
num_workers=args.workers, | |
pin_memory=True, | |
sampler=sampler, | |
drop_last=is_train, | |
) | |
dataloader.num_samples = num_samples | |
dataloader.num_batches = len(dataloader) | |
return DataInfo(dataloader, sampler) | |
class SyntheticDataset(Dataset): | |
def __init__(self, transform=None, image_size=(224, 224), caption="Dummy caption", dataset_size=100, tokenizer=None): | |
self.transform = transform | |
self.image_size = image_size | |
self.caption = caption | |
self.image = Image.new('RGB', image_size) | |
self.dataset_size = dataset_size | |
self.preprocess_txt = lambda text: tokenizer(text)[0] | |
def __len__(self): | |
return self.dataset_size | |
def __getitem__(self, idx): | |
if self.transform is not None: | |
image = self.transform(self.image) | |
return image, self.preprocess_txt(self.caption) | |
def get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): | |
image_size = preprocess_fn.transforms[0].size | |
dataset = SyntheticDataset( | |
transform=preprocess_fn, image_size=image_size, dataset_size=args.train_num_samples, tokenizer=tokenizer) | |
num_samples = len(dataset) | |
sampler = DistributedSampler(dataset) if args.distributed and is_train else None | |
shuffle = is_train and sampler is None | |
dataloader = DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
shuffle=shuffle, | |
num_workers=args.workers, | |
pin_memory=True, | |
sampler=sampler, | |
drop_last=is_train, | |
) | |
dataloader.num_samples = num_samples | |
dataloader.num_batches = len(dataloader) | |
return DataInfo(dataloader, sampler) | |
def get_dataset_fn(data_path, dataset_type): | |
if dataset_type == "webdataset": | |
return get_wds_dataset | |
elif dataset_type == "csv": | |
return get_csv_dataset | |
elif dataset_type == "synthetic": | |
return get_synthetic_dataset | |
elif dataset_type == "auto": | |
ext = data_path.split('.')[-1] | |
if ext in ['csv', 'tsv']: | |
return get_csv_dataset | |
elif ext in ['tar']: | |
return get_wds_dataset | |
else: | |
raise ValueError( | |
f"Tried to figure out dataset type, but failed for extension {ext}.") | |
else: | |
raise ValueError(f"Unsupported dataset type: {dataset_type}") | |
def get_data(args, preprocess_fns, epoch=0, tokenizer=None): | |
preprocess_train, preprocess_val = preprocess_fns | |
data = {} | |
if args.train_data or args.dataset_type == "synthetic": | |
data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( | |
args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) | |
if args.val_data: | |
data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( | |
args, preprocess_val, is_train=False, tokenizer=tokenizer) | |
if args.imagenet_val is not None: | |
data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val") | |
if args.imagenet_v2 is not None: | |
data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2") | |
return data | |