Spaces:
Runtime error
Runtime error
| import os, io, csv, math, random | |
| import numpy as np | |
| from einops import rearrange | |
| from decord import VideoReader | |
| import torch | |
| import torchvision.transforms as transforms | |
| from torch.utils.data.dataset import Dataset | |
| from PIA.utils.util import zero_rank_print, detect_edges | |
| import cv2 | |
| def get_score(video_data, | |
| cond_frame_idx, | |
| weight=[1.0, 1.0, 1.0, 1.0], | |
| use_edge=True): | |
| """ | |
| Similar to get_score under utils/util.py/detect_edges | |
| """ | |
| """ | |
| the shape of video_data is f c h w, np.ndarray | |
| """ | |
| h, w = video_data.shape[1], video_data.shape[2] | |
| cond_frame = video_data[cond_frame_idx] | |
| cond_hsv_list = list( | |
| cv2.split( | |
| cv2.cvtColor(cond_frame.astype(np.float32), cv2.COLOR_RGB2HSV))) | |
| if use_edge: | |
| cond_frame_lum = cond_hsv_list[-1] | |
| cond_frame_edge = detect_edges(cond_frame_lum.astype(np.uint8)) | |
| cond_hsv_list.append(cond_frame_edge) | |
| score_sum = [] | |
| for frame_idx in range(video_data.shape[0]): | |
| frame = video_data[frame_idx] | |
| hsv_list = list( | |
| cv2.split(cv2.cvtColor(frame.astype(np.float32), | |
| cv2.COLOR_RGB2HSV))) | |
| if use_edge: | |
| frame_img_lum = hsv_list[-1] | |
| frame_img_edge = detect_edges(lum=frame_img_lum.astype(np.uint8)) | |
| hsv_list.append(frame_img_edge) | |
| hsv_diff = [ | |
| np.abs(hsv_list[c] - cond_hsv_list[c]) for c in range(len(weight)) | |
| ] | |
| hsv_mse = [np.sum(hsv_diff[c]) * weight[c] for c in range(len(weight))] | |
| score_sum.append(sum(hsv_mse) / (h * w) / (sum(weight))) | |
| return score_sum | |
| class WebVid10M(Dataset): | |
| def __init__( | |
| self, | |
| csv_path, video_folder, | |
| sample_size=256, sample_stride=4, sample_n_frames=16, | |
| is_image=False, | |
| ): | |
| zero_rank_print(f"loading annotations from {csv_path} ...") | |
| with open(csv_path, 'r') as csvfile: | |
| self.dataset = list(csv.DictReader(csvfile)) | |
| self.length = len(self.dataset) | |
| zero_rank_print(f"data scale: {self.length}") | |
| self.video_folder = video_folder | |
| self.sample_stride = sample_stride | |
| self.sample_n_frames = sample_n_frames | |
| self.is_image = is_image | |
| sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) | |
| self.pixel_transforms = transforms.Compose([ | |
| transforms.RandomHorizontalFlip(), | |
| transforms.Resize(sample_size[0]), | |
| transforms.CenterCrop(sample_size), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ]) | |
| def get_batch(self, idx): | |
| video_dict = self.dataset[idx] | |
| videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] | |
| video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") | |
| video_reader = VideoReader(video_dir) | |
| video_length = len(video_reader) | |
| total_frames = len(video_reader) | |
| clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) | |
| start_idx = random.randint(0, video_length - clip_length) | |
| batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) | |
| frame_indice = [random.randint(0, total_frames - 1)] | |
| pixel_values_np = video_reader.get_batch(frame_indice).asnumpy() | |
| cond_frames = random.randint(0, self.sample_n_frames - 1) | |
| # f h w c -> f c h w | |
| pixel_values = torch.from_numpy(pixel_values_np).permute(0, 3, 1, 2).contiguous() | |
| pixel_values = pixel_values / 255. | |
| del video_reader | |
| if self.is_image: | |
| pixel_values = pixel_values[0] | |
| return pixel_values, name, cond_frames, videoid | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| while True: | |
| try: | |
| video, name, cond_frames, videoid = self.get_batch(idx) | |
| break | |
| except Exception as e: | |
| # zero_rank_print(e) | |
| idx = random.randint(0, self.length-1) | |
| video = self.pixel_transforms(video) | |
| video_ = video.clone().permute(0, 2, 3, 1).numpy() / 2 + 0.5 | |
| video_ = video_ * 255 | |
| #video_ = video_.astype(np.uint8) | |
| score = get_score(video_, cond_frame_idx=cond_frames) | |
| del video_ | |
| sample = dict(pixel_values=video, text=name, score=score, cond_frames=cond_frames, vid=videoid) | |
| return sample | |