|
import hashlib |
|
from pathlib import Path |
|
from typing import TYPE_CHECKING, Any, Dict, List, Tuple |
|
import json |
|
import random |
|
|
|
import torch |
|
from accelerate.logging import get_logger |
|
from safetensors.torch import load_file, save_file |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
from typing_extensions import override |
|
|
|
from finetune.constants import LOG_LEVEL, LOG_NAME |
|
|
|
from .utils import ( |
|
load_images, |
|
load_images_from_videos, |
|
load_prompts, |
|
load_videos, |
|
preprocess_image_with_resize, |
|
preprocess_video_with_buckets, |
|
preprocess_video_with_resize, |
|
load_binary_mask_compressed, |
|
) |
|
|
|
import pdb |
|
|
|
if TYPE_CHECKING: |
|
from finetune.trainer import Trainer |
|
|
|
|
|
|
|
import decord |
|
|
|
decord.bridge.set_bridge("torch") |
|
|
|
logger = get_logger(LOG_NAME, LOG_LEVEL) |
|
|
|
|
|
class I2VFlowDataset(Dataset): |
|
""" |
|
A dataset class for (image,flow)-to-video generation or image-to-flow_video that resizes inputs to fixed dimensions. |
|
|
|
This class preprocesses videos and images by resizing them to specified dimensions: |
|
- Videos are resized to max_num_frames x height x width |
|
- Images are resized to height x width |
|
|
|
Args: |
|
max_num_frames (int): Maximum number of frames to extract from videos |
|
height (int): Target height for resizing videos and images |
|
width (int): Target width for resizing videos and images |
|
""" |
|
|
|
def __init__( |
|
self, |
|
max_num_frames: int, |
|
height: int, |
|
width: int, |
|
data_root: str, |
|
caption_column: str, |
|
video_column: str, |
|
image_column: str | None, |
|
device: torch.device, |
|
trainer: "Trainer" = None, |
|
*args, |
|
**kwargs |
|
) -> None: |
|
data_root = Path(data_root) |
|
metadata_path = data_root / "metadata_revised.jsonl" |
|
assert metadata_path.is_file(), "For this dataset type, you need metadata.jsonl in the root path" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metadata = [] |
|
with open(metadata_path, "r") as f: |
|
for line in f: |
|
metadata.append( json.loads(line) ) |
|
|
|
self.prompts = [x["prompt"] for x in metadata] |
|
if 'curated' in str(data_root).lower(): |
|
self.prompt_embeddings = [data_root / "prompt_embeddings" / (x["hash_code"] + '.safetensors') for x in metadata] |
|
else: |
|
self.prompt_embeddings = [data_root / "prompt_embeddings_revised" / (x["hash_code"] + '.safetensors') for x in metadata] |
|
self.videos = [data_root / "video_latent" / "x".join(str(x) for x in trainer.args.train_resolution) / (x["hash_code"] + '.safetensors') for x in metadata] |
|
self.images = [data_root / "first_frames" / (x["hash_code"] + '.png') for x in metadata] |
|
self.flows = [data_root / "flow_direct_f_latent" / (x["hash_code"] + '.safetensors') for x in metadata] |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.trainer = trainer |
|
|
|
self.device = device |
|
self.encode_video = trainer.encode_video |
|
self.encode_text = trainer.encode_text |
|
|
|
|
|
if not (len(self.videos) == len(self.prompts) == len(self.images) == len(self.flows)): |
|
raise ValueError( |
|
f"Expected length of prompts, videos and images to be the same but found {len(self.prompts)=}, {len(self.videos)=}, {len(self.images)=} and {len(self.flows)=}. Please ensure that the number of caption prompts, videos and images match in your dataset." |
|
) |
|
|
|
self.max_num_frames = max_num_frames |
|
self.height = height |
|
self.width = width |
|
|
|
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]) |
|
self.__image_transforms = self.__frame_transforms |
|
|
|
self.length = len(self.videos) |
|
|
|
print(f"Dataset size: {self.length}") |
|
|
|
def __len__(self) -> int: |
|
return self.length |
|
|
|
def load_data_pair(self, index): |
|
|
|
prompt_embedding_path = self.prompt_embeddings[index] |
|
encoded_video_path = self.videos[index] |
|
encoded_flow_path = self.flows[index] |
|
|
|
|
|
|
|
|
|
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"] |
|
encoded_video = load_file(encoded_video_path)["encoded_video"] |
|
encoded_flow = load_file(encoded_flow_path)["encoded_flow_f"] |
|
|
|
return prompt_embedding, encoded_video, encoded_flow |
|
|
|
def __getitem__(self, index: int) -> Dict[str, Any]: |
|
while True: |
|
try: |
|
prompt_embedding, encoded_video, encoded_flow = self.load_data_pair(index) |
|
break |
|
except Exception as e: |
|
print(f"Error loading {self.prompt_embeddings[index]}: {str(e)}") |
|
index = random.randint(0, self.length - 1) |
|
|
|
image_path = self.images[index] |
|
prompt = self.prompts[index] |
|
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution) |
|
|
|
_, image = self.preprocess(None, image_path) |
|
image = self.image_transform(image) |
|
|
|
|
|
|
|
|
|
return { |
|
"image": image, |
|
"prompt_embedding": prompt_embedding, |
|
"encoded_video": encoded_video, |
|
"encoded_flow": encoded_flow, |
|
"video_metadata": { |
|
"num_frames": encoded_video.shape[1], |
|
"height": encoded_video.shape[2], |
|
"width": encoded_video.shape[3], |
|
}, |
|
} |
|
|
|
@override |
|
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if video_path is not None: |
|
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width) |
|
else: |
|
video = None |
|
if image_path is not None: |
|
image = preprocess_image_with_resize(image_path, self.height, self.width) |
|
else: |
|
image = None |
|
return video, image |
|
|
|
@override |
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor: |
|
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0) |
|
|
|
@override |
|
def image_transform(self, image: torch.Tensor) -> torch.Tensor: |
|
return self.__image_transforms(image) |