Spaces:
Running
on
Zero
Running
on
Zero
import importlib | |
import numpy as np | |
import cv2 | |
import torch | |
import torch.distributed as dist | |
import os | |
from einops import rearrange | |
import imageio | |
import torchvision | |
from PIL import Image | |
import io | |
from matplotlib import pyplot as plt | |
RY = 15 | |
YG = 6 | |
GC = 4 | |
CB = 11 | |
BM = 13 | |
MR = 6 | |
COLORWHEEL = torch.zeros((RY + YG + GC + CB + BM + MR, 3)) | |
col = 0 | |
# RY | |
COLORWHEEL[0:RY, 0] = 255 | |
COLORWHEEL[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY) | |
col = col + RY | |
# YG | |
COLORWHEEL[col:col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG) | |
COLORWHEEL[col:col + YG, 1] = 255 | |
col = col + YG | |
# GC | |
COLORWHEEL[col:col + GC, 1] = 255 | |
COLORWHEEL[col:col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC) | |
col = col + GC | |
# CB | |
COLORWHEEL[col:col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB) | |
COLORWHEEL[col:col + CB, 2] = 255 | |
col = col + CB | |
# BM | |
COLORWHEEL[col:col + BM, 2] = 255 | |
COLORWHEEL[col:col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM) | |
col = col + BM | |
# MR | |
COLORWHEEL[col:col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR) | |
COLORWHEEL[col:col + MR, 0] = 255 | |
def count_params(model, verbose=False): | |
total_params = sum(p.numel() for p in model.parameters()) | |
if verbose: | |
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") | |
return total_params | |
def check_istarget(name, para_list): | |
""" | |
name: full name of source para | |
para_list: partial name of target para | |
""" | |
istarget=False | |
for para in para_list: | |
if para in name: | |
return True | |
return istarget | |
def instantiate_from_config(config): | |
if not "target" in config: | |
if config == '__is_first_stage__': | |
return None | |
elif config == "__is_unconditional__": | |
return None | |
raise KeyError("Expected key `target` to instantiate.") | |
return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
def get_obj_from_str(string, reload=False): | |
module, cls = string.rsplit(".", 1) | |
if reload: | |
module_imp = importlib.import_module(module) | |
importlib.reload(module_imp) | |
return getattr(importlib.import_module(module, package=None), cls) | |
def load_npz_from_dir(data_dir): | |
data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)] | |
data = np.concatenate(data, axis=0) | |
return data | |
def load_npz_from_paths(data_paths): | |
data = [np.load(data_path)['arr_0'] for data_path in data_paths] | |
data = np.concatenate(data, axis=0) | |
return data | |
def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None): | |
h, w = image.shape[:2] | |
if resize_short_edge is not None: | |
k = resize_short_edge / min(h, w) | |
else: | |
k = max_resolution / (h * w) | |
k = k**0.5 | |
h = int(np.round(h * k / 64)) * 64 | |
w = int(np.round(w * k / 64)) * 64 | |
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) | |
return image | |
def setup_dist(args): | |
if dist.is_initialized(): | |
return | |
torch.cuda.set_device(args.local_rank) | |
torch.distributed.init_process_group( | |
'nccl', | |
init_method='env://' | |
) | |
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
outputs = [] | |
for x in videos: | |
x = torchvision.utils.make_grid(x, nrow=n_rows) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = (x * 255).numpy().astype(np.uint8) | |
outputs.append(x) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
imageio.mimsave(path, outputs, fps=fps) | |
def save_images_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6): | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
os.makedirs(path, exist_ok=True) | |
for time_idx, x in enumerate(videos): | |
x = torchvision.utils.make_grid(x, nrow=n_rows) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = (x * 255).numpy().astype(np.uint8) | |
image = Image.fromarray(x) | |
image.save(os.path.join(path, f"{time_idx:04d}.png")) | |
def save_image_with_mask(image: torch.Tensor, masks: torch.Tensor, path: str, rescale=False, alpha=0.6): | |
# image: [C, H, W], mask: [N, H, W] | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
image = rearrange(image, "c h w -> h w c") | |
if rescale: | |
image = (image + 1.0) / 2.0 # -1,1 -> 0,1 | |
image = (image * 255).numpy().astype(np.uint8) | |
final_image = Image.fromarray(image).convert("RGBA") | |
cmap = plt.get_cmap("tab20c") | |
masks = masks.cpu().numpy().astype(np.float32) | |
for i, img in enumerate(masks): | |
mask_color = np.array([*cmap(i * 4 + 2)[:3], alpha]) | |
mask = img[:,:,None] * mask_color[None,None,:] * 255 | |
mask = mask.astype(np.uint8) | |
mask = Image.fromarray(mask).convert("RGBA") | |
final_image = Image.alpha_composite(final_image, mask) | |
final_image.save(path) | |
def save_videos_with_heatmap(videos: torch.Tensor, trajectory: torch.Tensor, path: str, n_rows=6, fps=8): | |
# use Image RGBA and alpha_composite to combine video and trajectory | |
# use imageio to save video | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
trajectory = rearrange(trajectory, "b c t h w -> t b c h w") | |
outputs = [] | |
for x, y in zip(videos, trajectory): | |
x = torchvision.utils.make_grid(x, nrow=6) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
x = (x * 255).numpy().astype(np.uint8) | |
y = torchvision.utils.make_grid(y, nrow=6) | |
y = y.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
y = torch.cat([y, torch.mean(y, dim=-1, keepdim=True)], dim=-1) | |
y = (y * 255).numpy().astype(np.uint8) | |
x = Image.fromarray(x).convert("RGBA") | |
y = Image.fromarray(y) | |
x = Image.alpha_composite(x, y) | |
outputs.append(x) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
imageio.mimsave(path, outputs, fps=fps) | |
def save_videos_with_traj(videos: torch.Tensor, trajectory: torch.Tensor, path: str, rescale=False, fps=8, line_width=3, circle_radius=5): | |
# videos: [C, F, H, W] | |
# trajectory: [F, N, 2] | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
videos = rearrange(videos, "c f h w -> f h w c") | |
if rescale: | |
videos = (videos + 1) / 2 | |
videos = (videos * 255).numpy().astype(np.uint8) | |
outputs = [] | |
for frame_idx, img in enumerate(videos): | |
# img: [H, W, C], traj: [N, 2] | |
# draw trajectory use cv2.line | |
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
for traj_idx in range(trajectory.shape[1]): | |
for history_idx in range(frame_idx): | |
cv2.line(img, tuple(trajectory[history_idx, traj_idx].int().tolist()), tuple(trajectory[history_idx+1, traj_idx].int().tolist()), (0, 0, 255), line_width) | |
cv2.circle(img, tuple(trajectory[frame_idx, traj_idx].int().tolist()), circle_radius, (100, 230, 160), -1) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
outputs.append(img) | |
imageio.mimsave(path, outputs, fps=fps) | |
def save_layer_prompts_video(videos, layer_masks, motion_scores, flow_maps, path, alpha=0.6, fps=8, flow_step=10, flow_scale=1.0): | |
# videos: [F, C, H, W] | |
# layer_masks: [N, F, H, W] | |
# motion_scores: [N, ] | |
# flow_maps: [F, 2, H, W] | |
frame_length = videos.shape[0] | |
h, w = videos.shape[-2:] | |
n_keyframes = layer_masks.shape[1] | |
if n_keyframes == 1: | |
keyframe_indices = [0] | |
elif n_keyframes == 2: | |
keyframe_indices = [0, frame_length - 1] | |
else: | |
keyframe_indices = list(range(n_keyframes)) | |
videos = rearrange(videos, "t c h w -> t h w c") | |
videos = ((videos + 1) / 2 * 255).clamp(0, 255).numpy().astype(np.uint8) | |
layer_masks = layer_masks.numpy() | |
flow_maps = flow_maps.float().numpy() | |
frame_list = [] | |
cmap = plt.get_cmap("tab10") | |
for frame_idx in range(frame_length): | |
output_frame = Image.new("RGBA", (w * 2, h * 2)) | |
frame = Image.fromarray(videos[frame_idx]).convert("RGBA") | |
frame_mask = None | |
output_frame.paste(frame, (0, 0)) | |
for layer_idx, layer_mask in enumerate(layer_masks): | |
if frame_idx in keyframe_indices: | |
layer_color = (np.array([*cmap(layer_idx)[:3], alpha]) * 255).astype(np.uint8) | |
if frame_idx == frame_length - 1: | |
mask_with_color = Image.fromarray(layer_mask[-1, :, :, np.newaxis] * layer_color[np.newaxis, np.newaxis, :]) | |
else: | |
mask_with_color = Image.fromarray(layer_mask[frame_idx, :, :, np.newaxis] * layer_color[np.newaxis, np.newaxis, :]) | |
else: | |
mask_with_color = Image.fromarray(np.zeros((h, w, 4), dtype=np.uint8)) | |
frame = Image.alpha_composite(frame, mask_with_color) | |
frame_mask = Image.alpha_composite(frame_mask, mask_with_color) if frame_mask is not None else mask_with_color | |
output_frame.paste(frame, (w, 0)) | |
output_frame.paste(frame_mask, (0, h)) | |
flow_x = flow_maps[frame_idx, 0] * flow_scale | |
flow_y = flow_maps[frame_idx, 1] * flow_scale | |
x, y = np.arange(0, w, step=flow_step), np.arange(0, h, step=flow_step) | |
X, Y = np.meshgrid(x, y) | |
U, V = flow_x[::flow_step, ::flow_step], flow_y[::flow_step, ::flow_step] | |
plt.figure() | |
plt.gca().set_facecolor('white') | |
plt.quiver(X, Y, U, V, color='black', angles='xy', scale_units='xy', scale=1) | |
plt.xlim(0, w) | |
plt.ylim(h, 0) | |
plt.gca().set_xticks([]) | |
plt.gca().set_yticks([]) | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) | |
buf.seek(0) | |
flow = Image.open(buf).convert("RGBA") | |
output_frame.paste(flow, (w, h)) | |
plt.close() | |
frame_list.append(output_frame) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
imageio.mimsave(path, frame_list, fps=fps) | |
def flow_uv_to_colors(u, v, rad, convert_to_bgr=False): | |
""" | |
Applies the flow color wheel to (possibly clipped) flow components u and v. | |
According to the C++ source code of Daniel Scharstein | |
According to the Matlab source code of Deqing Sun | |
Args: | |
u (torch.tensor): Input horizontal flow of shape [N,H,W] | |
v (torch.tensor): Input vertical flow of shape [N,H,W] | |
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. | |
Returns: | |
torch.tensor: Flow visualization image of shape [N,3,H,W] | |
""" | |
flow_image = torch.zeros((u.shape[0], 3, u.shape[1], u.shape[2]), dtype=torch.uint8, device=u.device) | |
colorwheel = COLORWHEEL.to(u.device) | |
ncols = colorwheel.shape[0] | |
a = torch.arctan2(-v, -u) / np.pi | |
fk = (a + 1) / 2 * (ncols - 1) | |
k0 = torch.floor(fk).int() | |
k1 = k0 + 1 | |
k1[k1 == ncols] = 0 | |
f = fk - k0 | |
for i in range(colorwheel.shape[1]): | |
tmp = colorwheel[:, i] | |
col0 = tmp[k0] / 255.0 | |
col1 = tmp[k1] / 255.0 | |
col = (1 - f) * col0 + f * col1 | |
idx = rad <= 1 | |
col[idx] = 1 - rad[idx] * (1 - col[idx]) | |
col[~idx] = col[~idx] * 0.75 # out of range | |
# Note the 2-i => BGR instead of RGB | |
ch_idx = 2 - i if convert_to_bgr else i | |
flow_image[:, ch_idx, :, :] = torch.floor(255 * col) | |
return flow_image | |
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): | |
""" | |
Adapted from Tora: https://github.com/alibaba/Tora/blob/14db1b0a074284a6c265564eef07f5320911dc00/sat/utils/flow_utils.py#L120 | |
Expects a two dimensional flow image of shape. | |
Args: | |
flow_uv (torch.Tensor): Flow UV image of shape [N,2,H,W] | |
clip_flow (float, optional): Clip maximum of flow values. Defaults to None. | |
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. | |
Returns: | |
torch.Tensor: Flow visualization image of shape [N,3,H,W] | |
""" | |
if clip_flow is not None: | |
flow_uv = torch.clamp(flow_uv, 0, clip_flow) | |
u = flow_uv[:, 0] | |
v = flow_uv[:, 1] | |
rad = torch.sqrt(u**2 + v**2) | |
rad_max = torch.max(rad) | |
epsilon = 1e-5 | |
u = u / (rad_max + epsilon) | |
v = v / (rad_max + epsilon) | |
flow_image = flow_uv_to_colors(u, v, rad, convert_to_bgr) | |
return flow_image | |
def generate_gaussian_template(imgSize=200): | |
""" Adapted from DragAnything: https://github.com/showlab/DragAnything/blob/79355363218a7eb9b3437a31b8604b6d436d9337/dataset/dataset.py#L110""" | |
circle_img = np.zeros((imgSize, imgSize), np.float32) | |
circle_mask = cv2.circle(circle_img, (imgSize//2, imgSize//2), imgSize//2, 1, -1) | |
isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32) | |
# Guass Map | |
for i in range(imgSize): | |
for j in range(imgSize): | |
isotropicGrayscaleImage[i, j] = 1 / 2 / np.pi / (40 ** 2) * np.exp( | |
-1 / 2 * ((i - imgSize / 2) ** 2 / (40 ** 2) + (j - imgSize / 2) ** 2 / (40 ** 2))) | |
isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask | |
isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32) | |
isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)*255).astype(np.uint8) | |
# isotropicGrayscaleImage = cv2.resize(isotropicGrayscaleImage, (40, 40)) | |
return isotropicGrayscaleImage | |
def generate_gaussian_heatmap(tracks, width, height, layer_index, layer_capacity, side=20, offset=True): | |
heatmap_template = generate_gaussian_template() | |
num_frames, num_points = tracks.shape[:2] | |
if isinstance(tracks, torch.Tensor): | |
tracks = tracks.cpu().numpy() | |
if offset: | |
offset_kernel = cv2.resize(heatmap_template / 255, (2 * side + 1, 2 * side + 1)) | |
offset_kernel /= np.sum(offset_kernel) | |
offset_kernel /= offset_kernel[side, side] | |
heatmaps = [] | |
for frame_idx in range(num_frames): | |
if offset: | |
layer_imgs = np.zeros((layer_capacity, height, width, 3), dtype=np.float32) | |
else: | |
layer_imgs = np.zeros((layer_capacity, height, width, 1), dtype=np.float32) | |
layer_heatmaps = [] | |
for point_idx in range(num_points): | |
x, y = tracks[frame_idx, point_idx] | |
layer_id = layer_index[point_idx] | |
if x < 0 or y < 0 or x >= width or y >= height: | |
continue | |
x1 = int(max(x - side, 0)) | |
x2 = int(min(x + side, width - 1)) | |
y1 = int(max(y - side, 0)) | |
y2 = int(min(y + side, height - 1)) | |
if (x2 - x1) < 1 or (y2 - y1) < 1: | |
continue | |
temp_map = cv2.resize(heatmap_template, (x2-x1, y2-y1)) | |
layer_imgs[layer_id, y1:y2,x1:x2, 0] = np.maximum(layer_imgs[layer_id, y1:y2,x1:x2, 0], temp_map) | |
if offset: | |
if frame_idx < num_frames - 1: | |
next_x, next_y = tracks[frame_idx + 1, point_idx] | |
else: | |
next_x, next_y = x, y | |
layer_imgs[layer_id, int(y), int(x), 1] = next_x - x | |
layer_imgs[layer_id, int(y), int(x), 2] = next_y - y | |
for img in layer_imgs: | |
if offset: | |
img[:, :, 1:] = cv2.filter2D(img[:, :, 1:], -1, offset_kernel) | |
else: | |
img = cv2.cvtColor(img[:, :, 0].astype(np.uint8), cv2.COLOR_GRAY2RGB) | |
layer_heatmaps.append(img) | |
heatmaps.append(np.stack(layer_heatmaps, axis=0)) | |
heatmaps = np.stack(heatmaps, axis=0) | |
return torch.from_numpy(heatmaps).permute(0, 1, 4, 2, 3).contiguous().float() # [F, N_layer, C, H, W] | |