Spaces:
Runtime error
Runtime error
File size: 6,049 Bytes
f844f44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import os
from skimage import io, img_as_float32
from skimage.color import gray2rgb
from sklearn.model_selection import train_test_split
from imageio import mimread
from skimage.transform import resize
import numpy as np
from torch.utils.data import Dataset
from augmentation import AllAugmentationTransform
import glob
from functools import partial
def read_video(name, frame_shape):
"""
Read video which can be:
- an image of concatenated frames
- '.mp4' and'.gif'
- folder with videos
"""
if os.path.isdir(name):
frames = sorted(os.listdir(name))
num_frames = len(frames)
video_array = np.array(
[img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
image = io.imread(name)
if len(image.shape) == 2 or image.shape[2] == 1:
image = gray2rgb(image)
if image.shape[2] == 4:
image = image[..., :3]
image = img_as_float32(image)
video_array = np.moveaxis(image, 1, 0)
video_array = video_array.reshape((-1,) + frame_shape)
video_array = np.moveaxis(video_array, 1, 2)
elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
video = mimread(name)
if len(video[0].shape) == 2:
video = [gray2rgb(frame) for frame in video]
if frame_shape is not None:
video = np.array([resize(frame, frame_shape) for frame in video])
video = np.array(video)
if video.shape[-1] == 4:
video = video[..., :3]
video_array = img_as_float32(video)
else:
raise Exception("Unknown file extensions %s" % name)
return video_array
class FramesDataset(Dataset):
"""
Dataset of videos, each video can be represented as:
- an image of concatenated frames
- '.mp4' or '.gif'
- folder with all frames
"""
def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
random_seed=0, pairs_list=None, augmentation_params=None):
self.root_dir = root_dir
self.videos = os.listdir(root_dir)
self.frame_shape = frame_shape
print(self.frame_shape)
self.pairs_list = pairs_list
self.id_sampling = id_sampling
if os.path.exists(os.path.join(root_dir, 'train')):
assert os.path.exists(os.path.join(root_dir, 'test'))
print("Use predefined train-test split.")
if id_sampling:
train_videos = {os.path.basename(video).split('#')[0] for video in
os.listdir(os.path.join(root_dir, 'train'))}
train_videos = list(train_videos)
else:
train_videos = os.listdir(os.path.join(root_dir, 'train'))
test_videos = os.listdir(os.path.join(root_dir, 'test'))
self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
else:
print("Use random train-test split.")
train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
if is_train:
self.videos = train_videos
else:
self.videos = test_videos
self.is_train = is_train
if self.is_train:
self.transform = AllAugmentationTransform(**augmentation_params)
else:
self.transform = None
def __len__(self):
return len(self.videos)
def __getitem__(self, idx):
if self.is_train and self.id_sampling:
name = self.videos[idx]
path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
else:
name = self.videos[idx]
path = os.path.join(self.root_dir, name)
video_name = os.path.basename(path)
if self.is_train and os.path.isdir(path):
frames = os.listdir(path)
num_frames = len(frames)
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
if self.frame_shape is not None:
resize_fn = partial(resize, output_shape=self.frame_shape)
else:
resize_fn = img_as_float32
if type(frames[0]) is bytes:
video_array = [resize_fn(io.imread(os.path.join(path, frames[idx].decode('utf-8')))) for idx in
frame_idx]
else:
video_array = [resize_fn(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]
else:
video_array = read_video(path, frame_shape=self.frame_shape)
num_frames = len(video_array)
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range(
num_frames)
video_array = video_array[frame_idx]
if self.transform is not None:
video_array = self.transform(video_array)
out = {}
if self.is_train:
source = np.array(video_array[0], dtype='float32')
driving = np.array(video_array[1], dtype='float32')
out['driving'] = driving.transpose((2, 0, 1))
out['source'] = source.transpose((2, 0, 1))
else:
video = np.array(video_array, dtype='float32')
out['video'] = video.transpose((3, 0, 1, 2))
out['name'] = video_name
return out
class DatasetRepeater(Dataset):
"""
Pass several times over the same dataset for better i/o performance
"""
def __init__(self, dataset, num_repeats=100):
self.dataset = dataset
self.num_repeats = num_repeats
def __len__(self):
return self.num_repeats * self.dataset.__len__()
def __getitem__(self, idx):
return self.dataset[idx % self.dataset.__len__()]
|