Spaces:
Paused
Paused
import os | |
import random | |
import json | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms.functional as F | |
import numpy as np | |
from torch.utils.data.dataset import Dataset | |
from packaging import version as pver | |
from decord import VideoReader | |
from safetensors.torch import load_file | |
class RealEstate10KPCDRenderLatentCapEmbDataset(Dataset): | |
def __init__( | |
self, | |
video_root_dir, | |
text_embedding_path | |
): | |
root_path = video_root_dir | |
self.root_path = root_path | |
self.latent_root = os.path.join(self.root_path, 'joint_latents') | |
self.source_video_root = os.path.join(self.root_path, 'videos') | |
self.captions_root = os.path.join(self.root_path, 'captions') | |
self.dataset = sorted([n.replace('.safetensors','') for n in os.listdir(self.latent_root)]) | |
self.length = len(self.dataset) | |
self.text_embedding_path = text_embedding_path | |
self.mask_root = os.path.join(self.root_path, 'masks') | |
def get_batch(self, idx): | |
clip_name = self.dataset[idx] | |
cap_emb_path = os.path.join(self.text_embedding_path, clip_name + '.pt') | |
video_caption_emb = torch.load(cap_emb_path, weights_only=True) | |
joint_latent_path = os.path.join(self.latent_root, clip_name + '.safetensors') | |
joint_latent = load_file(joint_latent_path, device='cpu')['joint_latent'] | |
video_reader = VideoReader(os.path.join(self.source_video_root, clip_name + '.mp4')) | |
indices = [0] | |
first_frame = torch.from_numpy(video_reader.get_batch(indices).asnumpy()).permute(0, 3, 1, 2).contiguous() | |
first_frame = (first_frame / 255.)*2-1 | |
T = joint_latent.shape[2] // 2 | |
source_latent = joint_latent[:, :, :T] | |
anchor_latent = joint_latent[:, :, T:] | |
masks = np.load(os.path.join(self.mask_root, clip_name + '.npz'))['mask']*1.0 | |
masks = torch.from_numpy(masks).unsqueeze(1) | |
return source_latent, anchor_latent, first_frame, masks, video_caption_emb, clip_name | |
def __len__(self): | |
return self.length | |
def __getitem__(self, idx): | |
while True: | |
try: | |
source_latent, anchor_latent, image, mask, video_caption_emb, clip_name = self.get_batch(idx) | |
break | |
except Exception as e: | |
idx = random.randint(0, self.length - 1) | |
data = { | |
'source_latent': source_latent, | |
'anchor_latent': anchor_latent, | |
'image': image, | |
'caption_emb': video_caption_emb, | |
'mask': mask | |
} | |
return data |