Spaces:
Configuration error
Configuration error
import os | |
import numpy as np | |
from PIL import Image | |
from einops import rearrange | |
from pathlib import Path | |
import torch | |
from torch.utils.data import Dataset | |
from .transform import short_size_scale, random_crop, center_crop, offset_crop | |
from ..common.image_util import IMAGE_EXTENSION | |
import cv2 | |
class ImageSequenceDataset(Dataset): | |
def __init__( | |
self, | |
path: str, | |
layout_mask_dir: str, | |
layout_mask_order: list, | |
prompt_ids: torch.Tensor, | |
prompt: str, | |
start_sample_frame: int=0, | |
n_sample_frame: int = 8, | |
sampling_rate: int = 1, | |
stride: int = -1, # only used during tuning to sample a long video | |
image_mode: str = "RGB", | |
image_size: int = 512, | |
crop: str = "center", | |
class_data_root: str = None, | |
class_prompt_ids: torch.Tensor = None, | |
offset: dict = { | |
"left": 0, | |
"right": 0, | |
"top": 0, | |
"bottom": 0 | |
}, | |
**args | |
): | |
self.path = path | |
self.images = self.get_image_list(path) | |
# | |
self.layout_mask_dir = layout_mask_dir | |
self.layout_mask_order = list(layout_mask_order) | |
layout_mask_dir0 = os.path.join(self.layout_mask_dir,self.layout_mask_order[0]) | |
self.masks_index = self.get_image_list(layout_mask_dir0) | |
# | |
self.n_images = len(self.images) | |
self.offset = offset | |
self.start_sample_frame = start_sample_frame | |
if n_sample_frame < 0: | |
n_sample_frame = len(self.images) | |
self.n_sample_frame = n_sample_frame | |
# local sampling rate from the video | |
self.sampling_rate = sampling_rate | |
self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1 | |
if self.n_images < self.sequence_length: | |
raise ValueError(f"self.n_images {self.n_images } < self.sequence_length {self.sequence_length}: Required number of frames {self.sequence_length} larger than total frames in the dataset {self.n_images }") | |
# During tuning if video is too long, we sample the long video every self.stride globally | |
self.stride = stride if stride > 0 else (self.n_images+1) | |
self.video_len = (self.n_images - self.sequence_length) // self.stride + 1 | |
self.image_mode = image_mode | |
self.image_size = image_size | |
crop_methods = { | |
"center": center_crop, | |
"random": random_crop, | |
} | |
if crop not in crop_methods: | |
raise ValueError | |
self.crop = crop_methods[crop] | |
self.prompt = prompt | |
self.prompt_ids = prompt_ids | |
# Negative prompt for regularization to avoid overfitting during one-shot tuning | |
if class_data_root is not None: | |
self.class_data_root = Path(class_data_root) | |
self.class_images_path = sorted(list(self.class_data_root.iterdir())) | |
self.num_class_images = len(self.class_images_path) | |
self.class_prompt_ids = class_prompt_ids | |
def __len__(self): | |
max_len = (self.n_images - self.sequence_length) // self.stride + 1 | |
if hasattr(self, 'num_class_images'): | |
max_len = max(max_len, self.num_class_images) | |
return max_len | |
def __getitem__(self, index): | |
return_batch = {} | |
frame_indices = self.get_frame_indices(index%self.video_len) | |
frames = [self.load_frame(i) for i in frame_indices] | |
frames = self.transform(frames) | |
layout_ = [] | |
for layout_name in self.layout_mask_order: | |
frame_indices = self.get_frame_indices(index%self.video_len) | |
layout_mask_dir = os.path.join(self.layout_mask_dir,layout_name) | |
mask = [self._read_mask(layout_mask_dir,i) for i in frame_indices] | |
masks = np.stack(mask) | |
layout_.append(masks) | |
layout_ = np.stack(layout_) | |
merged_masks = [] | |
for i in range(int(self.n_sample_frame)): | |
merged_mask_frame = np.sum(layout_[:,i,:,:,:], axis=0) | |
merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8) | |
merged_masks.append(merged_mask_frame) | |
masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w") | |
masks = torch.from_numpy(masks).half() | |
layouts = rearrange(layout_,"s f c h w -> f s c h w" ) | |
layouts = torch.from_numpy(layouts).half() | |
return_batch.update( | |
{ | |
"images": frames, | |
"masks":masks, | |
"layouts":layouts, | |
"prompt_ids": self.prompt_ids, | |
} | |
) | |
if hasattr(self, 'class_data_root'): | |
class_index = index % (self.num_class_images - self.n_sample_frame) | |
class_indices = self.get_class_indices(class_index) | |
frames = [self.load_class_frame(i) for i in class_indices] | |
return_batch["class_images"] = self.tensorize_frames(frames) | |
return_batch["class_prompt_ids"] = self.class_prompt_ids | |
return return_batch | |
def transform(self, frames): | |
frames = self.tensorize_frames(frames) | |
frames = offset_crop(frames, **self.offset) | |
frames = short_size_scale(frames, size=self.image_size) | |
frames = self.crop(frames, height=self.image_size, width=self.image_size) | |
return frames | |
def tensorize_frames(frames): | |
frames = rearrange(np.stack(frames), "f h w c -> c f h w") | |
return torch.from_numpy(frames).div(255) * 2 - 1 | |
def _read_mask(self, mask_path,index: int): | |
### read mask by pil | |
mask_path = os.path.join(mask_path,f"{index:05d}.png") | |
### read mask by cv2 | |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
mask = (mask > 0).astype(np.uint8) | |
# Determine dynamic destination size | |
height, width = mask.shape | |
dest_size = (width // 8, height // 8) | |
# Resize using nearest neighbor interpolation | |
mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST) #cv2.INTER_CUBIC | |
mask = mask[np.newaxis, ...] | |
return mask | |
def load_frame(self, index): | |
image_path = os.path.join(self.path, self.images[index]) | |
return Image.open(image_path).convert(self.image_mode) | |
def load_class_frame(self, index): | |
image_path = self.class_images_path[index] | |
return Image.open(image_path).convert(self.image_mode) | |
def get_frame_indices(self, index): | |
if self.start_sample_frame is not None: | |
frame_start = self.start_sample_frame + self.stride * index | |
else: | |
frame_start = self.stride * index | |
return (frame_start + i * self.sampling_rate for i in range(self.n_sample_frame)) | |
def get_class_indices(self, index): | |
frame_start = index | |
return (frame_start + i for i in range(self.n_sample_frame)) | |
def get_image_list(path): | |
images = [] | |
for file in sorted(os.listdir(path)): | |
if file.endswith(IMAGE_EXTENSION): | |
images.append(file) | |
return images |