Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import csv | |
| import glob | |
| import json | |
| import numpy as np | |
| import os.path as osp | |
| import pickle | |
| import random | |
| import decord | |
| import pandas as pd | |
| import torch | |
| def datetime2sec(str): | |
| hh, mm, ss = str.split(':') | |
| return int(hh) * 3600 + int(mm) * 60 + float(ss) | |
| def video_loader(root, vid, second, end_second=None, chunk_len=300, fps=30, clip_length=32, jitter=False): | |
| if chunk_len == -1: | |
| vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid))) | |
| second_offset = second | |
| if end_second is not None: | |
| end_second = min(end_second, len(vr) / vr.get_avg_fps()) | |
| else: | |
| end_second = len(vr) / vr.get_avg_fps() | |
| else: | |
| chunk_start = int(second) // chunk_len * chunk_len | |
| second_offset = second - chunk_start | |
| vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start))) | |
| if fps == -1: | |
| fps = vr.get_avg_fps() | |
| # calculate frame_ids | |
| frame_offset = int(np.round(second_offset * fps)) | |
| total_duration = max(int((end_second - second) * fps), clip_length) | |
| if chunk_len == -1: | |
| if end_second <= second: | |
| raise ValueError("end_second should be greater than second") | |
| else: | |
| frame_ids = get_frame_ids(frame_offset, min(frame_offset + total_duration, len(vr)), num_segments=clip_length, jitter=jitter) | |
| else: | |
| frame_ids = get_frame_ids(frame_offset, frame_offset + total_duration, num_segments=clip_length, jitter=jitter) | |
| # load frames | |
| if max(frame_ids) < len(vr): | |
| try: | |
| frames = vr.get_batch(frame_ids).asnumpy() | |
| except decord.DECORDError as error: | |
| print(error) | |
| frames = vr.get_batch([0] * len(frame_ids)).asnumpy() | |
| else: | |
| # find the remaining frames in the next chunk | |
| try: | |
| frame_ids_part1 = list(filter(lambda frame_id: frame_id < len(vr), frame_ids)) | |
| frames_part1 = vr.get_batch(frame_ids_part1).asnumpy() | |
| vr2 = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start + chunk_len))) | |
| frame_ids_part2 = list(filter(lambda frame_id: frame_id >= len(vr), frame_ids)) | |
| frame_ids_part2 = [min(frame_id % len(vr), len(vr2) - 1) for frame_id in frame_ids_part2] | |
| frames_part2 = vr2.get_batch(frame_ids_part2).asnumpy() | |
| frames = np.concatenate([frames_part1, frames_part2], axis=0) | |
| # the next chunk does not exist; the current chunk is the last one | |
| except (RuntimeError, decord.DECORDError) as error: | |
| print(error) | |
| frame_ids = get_frame_ids(min(frame_offset, len(vr) - 1), len(vr), num_segments=clip_length, jitter=jitter) | |
| frames = vr.get_batch(frame_ids).asnumpy() | |
| frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames] | |
| return torch.stack(frames, dim=0) | |
| def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True): | |
| seg_size = float(end_frame - start_frame - 1) / num_segments | |
| seq = [] | |
| for i in range(num_segments): | |
| start = int(np.round(seg_size * i) + start_frame) | |
| end = int(np.round(seg_size * (i + 1)) + start_frame) | |
| end = min(end, end_frame) | |
| if jitter: | |
| frame_id = np.random.randint(low=start, high=(end + 1)) | |
| else: | |
| frame_id = (start + end) // 2 | |
| seq.append(frame_id) | |
| return seq | |
| def video_loader_by_frames(root, vid, frame_ids): | |
| vr = decord.VideoReader(osp.join(root, vid)) | |
| try: | |
| frames = vr.get_batch(frame_ids).asnumpy() | |
| frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames] | |
| except (IndexError, decord.DECORDError) as error: | |
| print(error) | |
| print("Erroneous video: ", vid) | |
| frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))] | |
| return torch.stack(frames, dim=0) | |
| class VideoCaptionDatasetBase(torch.utils.data.Dataset): | |
| def __init__(self, dataset, root, metadata, is_trimmed=True): | |
| self.dataset = dataset | |
| self.root = root | |
| self.is_trimmed = is_trimmed | |
| if self.dataset == 'ego4d': | |
| with open(metadata, 'rb') as f: | |
| self.samples = pickle.load(f) | |
| elif self.dataset == 'ego4d_mcq': | |
| with open(metadata, 'r') as f: | |
| self.samples = json.load(f) | |
| elif self.dataset in ['ek100_cls', 'ek100_mir']: | |
| video_list = glob.glob(osp.join(self.root, '*/*.MP4')) | |
| fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list} | |
| self.samples = [] | |
| with open(metadata) as f: | |
| csv_reader = csv.reader(f) | |
| _ = next(csv_reader) # skip the header | |
| for row in csv_reader: | |
| pid, vid = row[1:3] | |
| # start_frame, end_frame = int(row[6]), int(row[7]) | |
| # Deprecated: some videos might have fps mismatch issue | |
| start_timestamp, end_timestamp = datetime2sec(row[4]), datetime2sec(row[5]) | |
| narration = row[8] | |
| verb, noun = int(row[10]), int(row[12]) | |
| vid_path = '{}/{}.MP4'.format(pid, vid) | |
| fps = fps_dict[osp.join(self.root, vid_path)] | |
| start_frame = int(np.round(fps * start_timestamp)) | |
| end_frame = int(np.ceil(fps * end_timestamp)) | |
| self.samples.append((vid_path, start_frame, end_frame, narration, verb, noun)) | |
| if self.dataset == 'ek100_mir': | |
| self.metadata_sentence = pd.read_csv(metadata[:metadata.index('.csv')] + '_sentence.csv') | |
| if 'train' in metadata: | |
| self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_train.pkl'), 'rb')) | |
| elif 'test' in metadata: | |
| self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_test.pkl'), 'rb')) | |
| else: | |
| raise ValueError('{} should contain either "train" or "test"!'.format(metadata)) | |
| self.relevancy = .1 | |
| elif self.dataset == 'egtea': | |
| video_list = glob.glob(osp.join(self.root, '*/*')) | |
| len_dict = {video: len(decord.VideoReader(video)) for video in video_list} | |
| vn_list, labels = [], [] | |
| for row in open(osp.join(osp.dirname(metadata), 'action_idx.txt')): | |
| row = row.strip() | |
| vn = int(row.split(' ')[-1]) | |
| vn_list.append(vn) | |
| narration = ' '.join(row.split(' ')[:-1]) | |
| labels.append(narration.replace('_', ' ').lower()) | |
| # labels.append(narration) | |
| mapping_act2narration = {vn: narration for vn, narration in zip(vn_list, labels)} | |
| self.samples = [] | |
| with open(metadata) as f: | |
| for row in f: | |
| clip_id, action_idx = row.strip().split(' ')[:2] | |
| video_id = '-'.join(clip_id.split('-')[:3]) | |
| vid_relpath = osp.join(video_id, '{}.mp4'.format(clip_id)) | |
| vid_fullpath = osp.join(self.root, video_id, '{}.mp4'.format(clip_id)) | |
| self.samples.append((vid_relpath, 0, len_dict[vid_fullpath], mapping_act2narration[int(action_idx)])) | |
| elif self.dataset == 'charades_ego': | |
| video_list = glob.glob(osp.join(self.root, '*.mp4')) | |
| fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list} | |
| self.samples = [] | |
| with open(metadata) as f: | |
| csv_reader = csv.reader(f) | |
| _ = next(csv_reader) # skip the header | |
| for row in csv_reader: | |
| video_id = row[0] | |
| if self.is_trimmed: | |
| for action_tuple in row[9].split(';'): | |
| if not action_tuple: | |
| continue | |
| action, start_timestamp, end_timestamp = action_tuple.split(' ') | |
| start_timestamp, end_timestamp = float(start_timestamp), float(end_timestamp) | |
| vid_path = '{}.mp4'.format(video_id) | |
| fps = fps_dict[osp.join(self.root, vid_path)] | |
| start_frame = int(np.round(fps * start_timestamp)) | |
| end_frame = int(np.ceil(fps * end_timestamp)) | |
| self.samples.append((vid_path, start_frame, end_frame, action)) | |
| else: | |
| if not row[9]: | |
| action_list = [] | |
| else: | |
| action_list = [action_tuple.split(' ')[0] for action_tuple in row[9].split(';')] | |
| vid_path = '{}.mp4'.format(video_id) | |
| fps = fps_dict[osp.join(self.root, vid_path)] | |
| duration = fps * float(row[10]) | |
| self.samples.append((vid_path, 0, duration, action_list)) | |
| elif self.dataset == 'charades_ego_trimmed': | |
| with open(metadata, 'rb') as f: | |
| self.samples = pickle.load(f) | |
| else: | |
| raise NotImplementedError | |
| def get_raw_item(self, i, is_training=True, num_clips=1, clip_length=32, clip_stride=2, sparse_sample=False, | |
| narration_selection='random'): | |
| if self.dataset == 'ego4d': | |
| if len(self.samples[i]) == 4: | |
| vid, start_second, end_second, narration = self.samples[i] | |
| frames = video_loader(self.root, vid, start_second, | |
| end_second=end_second, | |
| clip_length=clip_length, | |
| jitter=is_training) | |
| if isinstance(narration, list): | |
| if narration_selection == 'random': | |
| narration = random.choice(narration) | |
| elif narration_selection == 'concat': | |
| narration = '. '.join(narration) | |
| elif narration_selection == 'list': | |
| narration = narration | |
| else: | |
| raise ValueError | |
| return frames, narration | |
| elif len(self.samples[i]) == 5: | |
| # TODO: need better filtering strategy based on nll | |
| vid, start_second, end_second, narration, _ = self.samples[i] | |
| frames = video_loader(self.root, vid, start_second, | |
| end_second=end_second, | |
| clip_length=clip_length, | |
| jitter=is_training) | |
| if isinstance(narration, list): | |
| if narration_selection == 'random': | |
| narration = random.choice(narration) | |
| elif narration_selection == 'concat': | |
| narration = '. '.join(narration) | |
| elif narration_selection == 'list': | |
| narration = narration | |
| else: | |
| raise ValueError | |
| return frames, narration | |
| elif self.dataset == 'ego4d_mcq': | |
| itemMCQ = self.samples[str(i)] | |
| answerIndex = itemMCQ['answer'] | |
| textQuery = itemMCQ['query']['clip_text'] | |
| sampleOptions = itemMCQ['choices'] | |
| frames_options = [] | |
| narration_options = [] | |
| for option_id in range(len(sampleOptions)): | |
| option = sampleOptions[str(option_id)] | |
| frames = video_loader(self.root, option['video_uid'], | |
| float(option['clip_start']), end_second=float(option['clip_end']), | |
| clip_length=clip_length, | |
| jitter=is_training) | |
| frames_options.append(frames) | |
| narration_options.append(option['clip_text']) | |
| return textQuery, frames_options, narration_options, answerIndex, itemMCQ['types'] | |
| elif self.dataset == 'ek100_mir': | |
| vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i] | |
| # from third_party.EgoVLP.base.base_dataset import sample_frames_start_end | |
| # frame_ids = sample_frames_start_end(clip_length, start_frame, end_frame, sample='uniform', fix_start=None) | |
| frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training) | |
| frames = video_loader_by_frames(self.root, vid_path, frame_ids) | |
| if is_training: | |
| positive_list = np.where(self.relevancy_mat[i] > self.relevancy)[0].tolist() | |
| if positive_list != []: | |
| pos = random.sample(positive_list, min(len(positive_list), 1))[0] | |
| if pos < len(self.metadata_sentence) and pos < self.relevancy_mat.shape[1]: | |
| return frames, (self.metadata_sentence.iloc[pos][1], self.relevancy_mat[i][pos]) | |
| else: | |
| return frames, (narration, 1) | |
| elif self.dataset == 'ek100_cls': | |
| vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i] | |
| frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training) | |
| frames = video_loader_by_frames(self.root, vid_path, frame_ids) | |
| return frames, '{}:{}'.format(verb, noun) | |
| elif self.dataset == 'egtea': | |
| vid_path, start_frame, end_frame, sentence = self.samples[i] | |
| if is_training: | |
| assert num_clips == 1 | |
| if end_frame < clip_length * clip_stride: | |
| frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame))) | |
| zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:])) | |
| frames = torch.cat((frames, zeros), dim=0) | |
| frames = frames[::clip_stride] | |
| else: | |
| start_id = np.random.randint(0, end_frame - clip_length * clip_stride + 1) | |
| frame_ids = np.arange(start_id, start_id + clip_length * clip_stride, clip_stride) | |
| frames = video_loader_by_frames(self.root, vid_path, frame_ids) | |
| else: | |
| if end_frame < clip_length * clip_stride: | |
| frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame))) | |
| zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:])) | |
| frames = torch.cat((frames, zeros), dim=0) | |
| frames = frames[::clip_stride] | |
| frames = frames.repeat(num_clips, 1, 1, 1) | |
| else: | |
| frame_ids = [] | |
| for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int): | |
| frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride)) | |
| frames = video_loader_by_frames(self.root, vid_path, frame_ids) | |
| return frames, sentence | |
| elif self.dataset == 'charades_ego': | |
| vid_path, start_frame, end_frame, action_list = self.samples[i] | |
| if sparse_sample: | |
| frame_ids = get_frame_ids(start_frame, end_frame, num_segments=num_clips * clip_length, jitter=is_training) | |
| frames = video_loader_by_frames(self.root, vid_path, frame_ids) | |
| else: | |
| if end_frame < clip_length * clip_stride: | |
| frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame))) | |
| zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:])) | |
| frames = torch.cat((frames, zeros), dim=0) | |
| frames = frames[::clip_stride] | |
| frames = frames.repeat(num_clips, 1, 1, 1) | |
| else: | |
| frame_ids = [] | |
| for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int): | |
| frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride)) | |
| #print('frame_ids:', frame_ids) | |
| frames = video_loader_by_frames(self.root, vid_path, frame_ids) | |
| return frames, action_list, vid_path | |
| elif self.dataset == 'charades_ego_trimmed': | |
| vid, start_second, end_second, narration = self.samples[i] | |
| frames = video_loader(self.root, vid, start_second, | |
| end_second=end_second, | |
| chunk_len=-1, # no chunk for CharadesEgo | |
| fps=-1, # could be variable fps | |
| clip_length=clip_length, | |
| jitter=is_training) | |
| return frames, narration | |
| else: | |
| raise NotImplementedError | |
| def __getitem__(self, i): | |
| raise NotImplementedError | |
| def __len__(self): | |
| return len(self.samples) | |
| class VideoCaptionDatasetCLIP(VideoCaptionDatasetBase): | |
| def __init__(self, dataset, root, metadata, transform=None, | |
| is_training=True, tokenizer=None, | |
| clip_length=32, clip_stride=2, sparse_sample=False, | |
| narration_selection='random', | |
| num_hard_negatives=0, | |
| subsample_stride=None): | |
| super().__init__(dataset, root, metadata) | |
| self.full_samples = self.samples.copy() | |
| if isinstance(subsample_stride, int): | |
| self.samples = self.samples[::subsample_stride] | |
| self.transform = transform | |
| self.is_training = is_training | |
| self.tokenizer = tokenizer | |
| self.clip_length = clip_length | |
| self.clip_stride = clip_stride | |
| self.sparse_sample = sparse_sample | |
| self.narration_selection = narration_selection | |
| self.num_hard_negatives = num_hard_negatives | |
| if num_hard_negatives > 0: | |
| assert self.dataset == 'htm_aa' | |
| def __getitem__(self, i): | |
| frames, caption = self.get_raw_item( | |
| i, is_training=self.is_training, | |
| clip_length=self.clip_length, | |
| clip_stride=self.clip_stride, | |
| sparse_sample=self.sparse_sample, | |
| narration_selection=self.narration_selection, | |
| ) | |
| # ek100_mir will also output relevancy value | |
| if isinstance(caption, tuple): | |
| caption, relevancy = caption | |
| else: | |
| relevancy = 0. | |
| # apply transformation | |
| if self.transform is not None: | |
| frames = self.transform(frames) | |
| # tokenize caption | |
| if self.tokenizer is not None: | |
| caption = self.tokenizer(caption) | |
| if isinstance(caption, tuple): | |
| caption, mask = caption | |
| return frames, caption, mask, relevancy | |
| else: | |
| return frames, caption, relevancy | |
| class VideoCaptionDatasetMCQ(VideoCaptionDatasetBase): | |
| def __init__(self, dataset, root, metadata, transform=None, | |
| is_training=True, tokenizer=None, | |
| clip_length=32, clip_stride=2, sparse_sample=False, | |
| narration_selection='random'): | |
| super().__init__(dataset, root, metadata) | |
| self.full_samples = self.samples.copy() | |
| self.transform = transform | |
| self.is_training = is_training | |
| self.tokenizer = tokenizer | |
| self.clip_length = clip_length | |
| self.clip_stride = clip_stride | |
| self.sparse_sample = sparse_sample | |
| self.narration_selection = narration_selection | |
| def __getitem__(self, i): | |
| textQuery, frames_options, narration_options, answerIndex, q_type = self.get_raw_item( | |
| i, is_training=self.is_training, | |
| clip_length=self.clip_length, | |
| clip_stride=self.clip_stride, | |
| sparse_sample=self.sparse_sample, | |
| narration_selection=self.narration_selection, | |
| ) | |
| # apply transformation | |
| if self.transform is not None: | |
| frames_options = [self.transform(frames) for frames in frames_options] | |
| # tokenize caption | |
| if self.tokenizer is not None: | |
| textQuery = self.tokenizer(textQuery) | |
| narration_options = self.tokenizer(narration_options) | |
| if isinstance(textQuery, tuple): | |
| textQuery, mask_query = textQuery | |
| narration_options, mask_options = narration_options | |
| return ( | |
| textQuery, torch.stack(frames_options, dim=0), | |
| narration_options, answerIndex, q_type, | |
| mask_query, mask_options | |
| ) | |
| else: | |
| return textQuery, torch.stack(frames_options, dim=0), narration_options, answerIndex, q_type | |
| class VideoClassyDataset(VideoCaptionDatasetBase): | |
| def __init__( | |
| self, dataset, root, metadata, transform=None, | |
| is_training=True, label_mapping=None, | |
| num_clips=1, | |
| clip_length=32, clip_stride=2, | |
| sparse_sample=False, | |
| is_trimmed=True, | |
| ): | |
| super().__init__(dataset, root, metadata, is_trimmed=is_trimmed) | |
| self.transform = transform | |
| self.is_training = is_training | |
| self.label_mapping = label_mapping | |
| self.num_clips = num_clips | |
| self.clip_length = clip_length | |
| self.clip_stride = clip_stride | |
| self.sparse_sample = sparse_sample | |
| def __getitem__(self, i): | |
| frames, label, vid_path = self.get_raw_item( | |
| i, is_training=self.is_training, | |
| num_clips=self.num_clips, | |
| clip_length=self.clip_length, | |
| clip_stride=self.clip_stride, | |
| sparse_sample=self.sparse_sample, | |
| ) | |
| # apply transformation | |
| if self.transform is not None: | |
| frames = self.transform(frames) | |
| if self.label_mapping is not None: | |
| if isinstance(label, list): | |
| # multi-label case | |
| res_array = np.zeros(len(self.label_mapping)) | |
| for lbl in label: | |
| res_array[self.label_mapping[lbl]] = 1. | |
| label = res_array | |
| else: | |
| label = self.label_mapping[label] | |
| return frames, label, vid_path | |
| def get_dataset(train_transform, tokenizer, cfg, is_training=True): | |
| narration_selection = cfg.get('narration_selection', 'random') | |
| num_hard_neg = cfg.get('num_hard_neg', 0) | |
| data_cfg = cfg['data'] | |
| if cfg['model']['arch'].startswith('CLIP') or cfg['model']['arch'].startswith('VCLM'): | |
| if is_training: | |
| metadata = data_cfg['metadata'] | |
| else: | |
| metadata = data_cfg['metadata_val'] | |
| return VideoCaptionDatasetCLIP( | |
| data_cfg['dataset'], data_cfg['root'], metadata, train_transform, | |
| is_training=is_training, | |
| tokenizer=tokenizer, | |
| clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'], | |
| sparse_sample=data_cfg['sparse_sample'], | |
| narration_selection=narration_selection, | |
| num_hard_negatives=num_hard_neg | |
| ) | |
| else: | |
| raise NotImplementedError | |
| def get_downstream_dataset(transform, tokenizer, cfg, is_training=True, num_clips=0, label_mapping=None): | |
| data_cfg = cfg['data'] | |
| n_clips = num_clips if num_clips > 0 else data_cfg['num_clips'] | |
| if is_training: | |
| metadata = data_cfg['metadata'] | |
| return VideoClassyDataset( | |
| data_cfg['dataset'], data_cfg['root'], metadata, transform, | |
| is_training=True, label_mapping=label_mapping, | |
| num_clips=n_clips, | |
| clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'], | |
| sparse_sample=data_cfg['sparse_sample'], | |
| ) | |
| else: | |
| metadata = data_cfg['metadata_val'] | |
| return VideoClassyDataset( | |
| data_cfg['dataset'], data_cfg['root'], metadata, transform, | |
| is_training=False, label_mapping=label_mapping, | |
| num_clips=n_clips, | |
| clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'], | |
| sparse_sample=data_cfg['sparse_sample'], | |
| is_trimmed=not data_cfg['dataset'] == 'charades_ego' | |
| ) | |