|
|
|
|
|
import sys |
|
|
|
|
|
def test_video_dataset(): |
|
from cogvideox.dataset import VideoDataset |
|
|
|
dataset_dirs = VideoDataset( |
|
data_root="assets/tests/", |
|
caption_column="prompts.txt", |
|
video_column="videos.txt", |
|
max_num_frames=49, |
|
id_token=None, |
|
random_flip=None, |
|
) |
|
dataset_csv = VideoDataset( |
|
data_root="assets/tests/", |
|
dataset_file="assets/tests/metadata.csv", |
|
caption_column="caption", |
|
video_column="video", |
|
max_num_frames=49, |
|
id_token=None, |
|
random_flip=None, |
|
) |
|
|
|
assert len(dataset_dirs) == 1 |
|
assert len(dataset_csv) == 1 |
|
assert dataset_dirs[0]["video"].shape == (49, 3, 480, 720) |
|
assert (dataset_dirs[0]["video"] == dataset_csv[0]["video"]).all() |
|
|
|
print(dataset_dirs[0]["video"].shape) |
|
|
|
|
|
def test_video_dataset_with_resizing(): |
|
from cogvideox.dataset import VideoDatasetWithResizing |
|
|
|
dataset_dirs = VideoDatasetWithResizing( |
|
data_root="assets/tests/", |
|
caption_column="prompts.txt", |
|
video_column="videos.txt", |
|
max_num_frames=49, |
|
id_token=None, |
|
random_flip=None, |
|
) |
|
dataset_csv = VideoDatasetWithResizing( |
|
data_root="assets/tests/", |
|
dataset_file="assets/tests/metadata.csv", |
|
caption_column="caption", |
|
video_column="video", |
|
max_num_frames=49, |
|
id_token=None, |
|
random_flip=None, |
|
) |
|
|
|
assert len(dataset_dirs) == 1 |
|
assert len(dataset_csv) == 1 |
|
assert dataset_dirs[0]["video"].shape == (48, 3, 480, 720) |
|
assert (dataset_dirs[0]["video"] == dataset_csv[0]["video"]).all() |
|
|
|
print(dataset_dirs[0]["video"].shape) |
|
|
|
|
|
def test_video_dataset_with_bucket_sampler(): |
|
import torch |
|
from cogvideox.dataset import BucketSampler, VideoDatasetWithResizing |
|
from torch.utils.data import DataLoader |
|
|
|
dataset_dirs = VideoDatasetWithResizing( |
|
data_root="assets/tests/", |
|
caption_column="prompts_multi.txt", |
|
video_column="videos_multi.txt", |
|
max_num_frames=49, |
|
id_token=None, |
|
random_flip=None, |
|
) |
|
sampler = BucketSampler(dataset_dirs, batch_size=8) |
|
|
|
def collate_fn(data): |
|
captions = [x["prompt"] for x in data[0]] |
|
videos = [x["video"] for x in data[0]] |
|
videos = torch.stack(videos) |
|
return captions, videos |
|
|
|
dataloader = DataLoader(dataset_dirs, batch_size=1, sampler=sampler, collate_fn=collate_fn) |
|
first = False |
|
|
|
for captions, videos in dataloader: |
|
if not first: |
|
assert len(captions) == 8 and isinstance(captions[0], str) |
|
assert videos.shape == (8, 48, 3, 480, 720) |
|
first = True |
|
else: |
|
assert len(captions) == 8 and isinstance(captions[0], str) |
|
assert videos.shape == (8, 48, 3, 256, 360) |
|
break |
|
|
|
|
|
if __name__ == "__main__": |
|
sys.path.append("./training") |
|
|
|
test_video_dataset() |
|
test_video_dataset_with_resizing() |
|
test_video_dataset_with_bucket_sampler() |
|
|