Spaces:
Paused
Paused
import os | |
import torch | |
import numpy as np | |
import math | |
import argparse | |
from decord import VideoReader | |
from diffusers import AutoencoderKLCogVideoX | |
from safetensors.torch import save_file | |
import tqdm | |
import random | |
def encode_video(video, vae): | |
video = video[None].permute(0, 2, 1, 3, 4).contiguous() | |
video = video.to(vae.device, dtype=vae.dtype) | |
latent_dist = vae.encode(video).latent_dist | |
latent = latent_dist.sample() * vae.config.scaling_factor | |
return latent | |
def add_dashed_rays_to_video(video_tensor, num_perp_samples=50, density_decay=0.075): | |
T, C, H, W = video_tensor.shape | |
max_length = int((H**2 + W**2) ** 0.5) + 10 | |
center = torch.tensor([W / 2, H / 2]) | |
theta = torch.rand(1).item() * 2 * math.pi | |
direction = torch.tensor([math.cos(theta), math.sin(theta)]) | |
direction = direction / direction.norm() | |
d_perp = torch.tensor([-direction[1], direction[0]]) | |
half_len = max(H, W) // 2 | |
positions = torch.linspace(-half_len, half_len, num_perp_samples) | |
perp_coords = center[None, :] + positions[:, None] * d_perp[None, :] | |
x0, y0 = perp_coords[:, 0], perp_coords[:, 1] | |
steps = [] | |
dist = 0 | |
while dist < max_length: | |
steps.append(dist) | |
dist += 1.0 + density_decay * dist | |
steps = torch.tensor(steps) | |
S = len(steps) | |
dxdy = direction[None, :] * steps[:, None] | |
all_xy = perp_coords[:, None, :] + dxdy[None, :, :] | |
all_xy = all_xy.reshape(-1, 2) | |
all_x = all_xy[:, 0].round().long() | |
all_y = all_xy[:, 1].round().long() | |
valid = (0 <= all_x) & (all_x < W) & (0 <= all_y) & (all_y < H) | |
all_x = all_x[valid] | |
all_y = all_y[valid] | |
x0r = x0.round().long().clamp(0, W - 1) | |
y0r = y0.round().long().clamp(0, H - 1) | |
frame0 = video_tensor[0] | |
base_colors = frame0[:, y0r, x0r] | |
base_colors = base_colors.repeat_interleave(S, dim=1)[:, valid] | |
video_out = video_tensor.clone() | |
offsets = [(0, 0), (0, 1), (1, 0), (1, 1)] | |
for dxo, dyo in offsets: | |
ox = all_x + dxo | |
oy = all_y + dyo | |
inside = (0 <= ox) & (ox < W) & (0 <= oy) & (oy < H) | |
ox = ox[inside] | |
oy = oy[inside] | |
colors = base_colors[:, inside] | |
for c in range(C): | |
video_out[1:, c, oy, ox] = colors[c][None, :].expand(T - 1, -1) | |
return video_out | |
def main(args): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
vae = AutoencoderKLCogVideoX.from_pretrained(args.pretrained_model_path, subfolder="vae") | |
vae.requires_grad_(False) | |
vae = vae.to(device, dtype=torch.float16) | |
masked_video_path = os.path.join(args.video_root, "masked_videos") | |
source_video_path = os.path.join(args.video_root, "videos") | |
joint_latent_path = os.path.join(args.video_root, "joint_latents") | |
os.makedirs(joint_latent_path, exist_ok=True) | |
all_video_names = sorted(os.listdir(source_video_path)) | |
video_names = all_video_names[args.start_idx : args.end_idx] | |
for video_name in tqdm.tqdm(video_names, desc=f"GPU {args.gpu_id}"): | |
masked_video_file = os.path.join(masked_video_path, video_name) | |
source_video_file = os.path.join(source_video_path, video_name) | |
output_file = os.path.join(joint_latent_path, video_name.replace('.mp4', '.safetensors')) | |
if not os.path.exists(masked_video_file): | |
print(f"Skipping {video_name}, masked video not found.") | |
continue | |
if os.path.exists(output_file): | |
continue | |
try: | |
vr = VideoReader(source_video_file) | |
video = torch.from_numpy(vr.get_batch(np.arange(49)).asnumpy()).permute(0, 3, 1, 2).contiguous() | |
video = (video / 255.0) * 2 - 1 | |
source_latent = encode_video(video, vae) | |
vr = VideoReader(masked_video_file) | |
video = torch.from_numpy(vr.get_batch(np.arange(49)).asnumpy()).permute(0, 3, 1, 2).contiguous() | |
video = (video / 255.0) * 2 - 1 | |
video = add_dashed_rays_to_video(video) | |
masked_latent = encode_video(video, vae) | |
source_latent = source_latent.to("cpu") | |
masked_latent = masked_latent.to("cpu") | |
cated_latent = torch.cat([source_latent, masked_latent], dim=2) | |
save_file({'joint_latents': cated_latent}, output_file) | |
except Exception as e: | |
print(f"[GPU {args.gpu_id}] Error processing {video_name}: {e}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--video_root", type=str, required=True) | |
parser.add_argument("--pretrained_model_path", type=str, required=True) | |
parser.add_argument("--start_idx", type=int, required=True) | |
parser.add_argument("--end_idx", type=int, required=True) | |
parser.add_argument("--gpu_id", type=int, required=True) | |
args = parser.parse_args() | |
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) | |
main(args) | |