Spaces:
Paused
Paused
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import csv | |
import glob | |
import json | |
import numpy as np | |
import os.path as osp | |
import pickle | |
import random | |
import decord | |
import pandas as pd | |
import torch | |
def datetime2sec(str): | |
hh, mm, ss = str.split(':') | |
return int(hh) * 3600 + int(mm) * 60 + float(ss) | |
def video_loader(root, vid, second, end_second=None, chunk_len=300, fps=30, clip_length=32, jitter=False): | |
if chunk_len == -1: | |
vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid))) | |
second_offset = second | |
if end_second is not None: | |
end_second = min(end_second, len(vr) / vr.get_avg_fps()) | |
else: | |
end_second = len(vr) / vr.get_avg_fps() | |
else: | |
chunk_start = int(second) // chunk_len * chunk_len | |
second_offset = second - chunk_start | |
vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start))) | |
if fps == -1: | |
fps = vr.get_avg_fps() | |
# calculate frame_ids | |
frame_offset = int(np.round(second_offset * fps)) | |
total_duration = max(int((end_second - second) * fps), clip_length) | |
if chunk_len == -1: | |
if end_second <= second: | |
raise ValueError("end_second should be greater than second") | |
else: | |
frame_ids = get_frame_ids(frame_offset, min(frame_offset + total_duration, len(vr)), num_segments=clip_length, jitter=jitter) | |
else: | |
frame_ids = get_frame_ids(frame_offset, frame_offset + total_duration, num_segments=clip_length, jitter=jitter) | |
# load frames | |
if max(frame_ids) < len(vr): | |
try: | |
frames = vr.get_batch(frame_ids).asnumpy() | |
except decord.DECORDError as error: | |
print(error) | |
frames = vr.get_batch([0] * len(frame_ids)).asnumpy() | |
else: | |
# find the remaining frames in the next chunk | |
try: | |
frame_ids_part1 = list(filter(lambda frame_id: frame_id < len(vr), frame_ids)) | |
frames_part1 = vr.get_batch(frame_ids_part1).asnumpy() | |
vr2 = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start + chunk_len))) | |
frame_ids_part2 = list(filter(lambda frame_id: frame_id >= len(vr), frame_ids)) | |
frame_ids_part2 = [min(frame_id % len(vr), len(vr2) - 1) for frame_id in frame_ids_part2] | |
frames_part2 = vr2.get_batch(frame_ids_part2).asnumpy() | |
frames = np.concatenate([frames_part1, frames_part2], axis=0) | |
# the next chunk does not exist; the current chunk is the last one | |
except (RuntimeError, decord.DECORDError) as error: | |
print(error) | |
frame_ids = get_frame_ids(min(frame_offset, len(vr) - 1), len(vr), num_segments=clip_length, jitter=jitter) | |
frames = vr.get_batch(frame_ids).asnumpy() | |
frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames] | |
return torch.stack(frames, dim=0) | |
def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True): | |
seg_size = float(end_frame - start_frame - 1) / num_segments | |
seq = [] | |
for i in range(num_segments): | |
start = int(np.round(seg_size * i) + start_frame) | |
end = int(np.round(seg_size * (i + 1)) + start_frame) | |
end = min(end, end_frame) | |
if jitter: | |
frame_id = np.random.randint(low=start, high=(end + 1)) | |
else: | |
frame_id = (start + end) // 2 | |
seq.append(frame_id) | |
return seq | |
def video_loader_by_frames(root, vid, frame_ids): | |
vr = decord.VideoReader(osp.join(root, vid)) | |
try: | |
frames = vr.get_batch(frame_ids).asnumpy() | |
frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames] | |
except (IndexError, decord.DECORDError) as error: | |
print(error) | |
print("Erroneous video: ", vid) | |
frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))] | |
return torch.stack(frames, dim=0) | |
class VideoCaptionDatasetBase(torch.utils.data.Dataset): | |
def __init__(self, dataset, root, metadata, is_trimmed=True): | |
self.dataset = dataset | |
self.root = root | |
self.is_trimmed = is_trimmed | |
if self.dataset == 'ego4d': | |
with open(metadata, 'rb') as f: | |
self.samples = pickle.load(f) | |
elif self.dataset == 'ego4d_mcq': | |
with open(metadata, 'r') as f: | |
self.samples = json.load(f) | |
elif self.dataset in ['ek100_cls', 'ek100_mir']: | |
video_list = glob.glob(osp.join(self.root, '*/*.MP4')) | |
fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list} | |
self.samples = [] | |
with open(metadata) as f: | |
csv_reader = csv.reader(f) | |
_ = next(csv_reader) # skip the header | |
for row in csv_reader: | |
pid, vid = row[1:3] | |
# start_frame, end_frame = int(row[6]), int(row[7]) | |
# Deprecated: some videos might have fps mismatch issue | |
start_timestamp, end_timestamp = datetime2sec(row[4]), datetime2sec(row[5]) | |
narration = row[8] | |
verb, noun = int(row[10]), int(row[12]) | |
vid_path = '{}/{}.MP4'.format(pid, vid) | |
fps = fps_dict[osp.join(self.root, vid_path)] | |
start_frame = int(np.round(fps * start_timestamp)) | |
end_frame = int(np.ceil(fps * end_timestamp)) | |
self.samples.append((vid_path, start_frame, end_frame, narration, verb, noun)) | |
if self.dataset == 'ek100_mir': | |
self.metadata_sentence = pd.read_csv(metadata[:metadata.index('.csv')] + '_sentence.csv') | |
if 'train' in metadata: | |
self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_train.pkl'), 'rb')) | |
elif 'test' in metadata: | |
self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_test.pkl'), 'rb')) | |
else: | |
raise ValueError('{} should contain either "train" or "test"!'.format(metadata)) | |
self.relevancy = .1 | |
elif self.dataset == 'egtea': | |
video_list = glob.glob(osp.join(self.root, '*/*')) | |
len_dict = {video: len(decord.VideoReader(video)) for video in video_list} | |
vn_list, labels = [], [] | |
for row in open(osp.join(osp.dirname(metadata), 'action_idx.txt')): | |
row = row.strip() | |
vn = int(row.split(' ')[-1]) | |
vn_list.append(vn) | |
narration = ' '.join(row.split(' ')[:-1]) | |
labels.append(narration.replace('_', ' ').lower()) | |
# labels.append(narration) | |
mapping_act2narration = {vn: narration for vn, narration in zip(vn_list, labels)} | |
self.samples = [] | |
with open(metadata) as f: | |
for row in f: | |
clip_id, action_idx = row.strip().split(' ')[:2] | |
video_id = '-'.join(clip_id.split('-')[:3]) | |
vid_relpath = osp.join(video_id, '{}.mp4'.format(clip_id)) | |
vid_fullpath = osp.join(self.root, video_id, '{}.mp4'.format(clip_id)) | |
self.samples.append((vid_relpath, 0, len_dict[vid_fullpath], mapping_act2narration[int(action_idx)])) | |
elif self.dataset == 'charades_ego': | |
video_list = glob.glob(osp.join(self.root, '*.mp4')) | |
fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list} | |
self.samples = [] | |
with open(metadata) as f: | |
csv_reader = csv.reader(f) | |
_ = next(csv_reader) # skip the header | |
for row in csv_reader: | |
video_id = row[0] | |
if self.is_trimmed: | |
for action_tuple in row[9].split(';'): | |
if not action_tuple: | |
continue | |
action, start_timestamp, end_timestamp = action_tuple.split(' ') | |
start_timestamp, end_timestamp = float(start_timestamp), float(end_timestamp) | |
vid_path = '{}.mp4'.format(video_id) | |
fps = fps_dict[osp.join(self.root, vid_path)] | |
start_frame = int(np.round(fps * start_timestamp)) | |
end_frame = int(np.ceil(fps * end_timestamp)) | |
self.samples.append((vid_path, start_frame, end_frame, action)) | |
else: | |
if not row[9]: | |
action_list = [] | |
else: | |
action_list = [action_tuple.split(' ')[0] for action_tuple in row[9].split(';')] | |
vid_path = '{}.mp4'.format(video_id) | |
fps = fps_dict[osp.join(self.root, vid_path)] | |
duration = fps * float(row[10]) | |
self.samples.append((vid_path, 0, duration, action_list)) | |
elif self.dataset == 'charades_ego_trimmed': | |
with open(metadata, 'rb') as f: | |
self.samples = pickle.load(f) | |
else: | |
raise NotImplementedError | |
def get_raw_item(self, i, is_training=True, num_clips=1, clip_length=32, clip_stride=2, sparse_sample=False, | |
narration_selection='random'): | |
if self.dataset == 'ego4d': | |
if len(self.samples[i]) == 4: | |
vid, start_second, end_second, narration = self.samples[i] | |
frames = video_loader(self.root, vid, start_second, | |
end_second=end_second, | |
clip_length=clip_length, | |
jitter=is_training) | |
if isinstance(narration, list): | |
if narration_selection == 'random': | |
narration = random.choice(narration) | |
elif narration_selection == 'concat': | |
narration = '. '.join(narration) | |
elif narration_selection == 'list': | |
narration = narration | |
else: | |
raise ValueError | |
return frames, narration | |
elif len(self.samples[i]) == 5: | |
# TODO: need better filtering strategy based on nll | |
vid, start_second, end_second, narration, _ = self.samples[i] | |
frames = video_loader(self.root, vid, start_second, | |
end_second=end_second, | |
clip_length=clip_length, | |
jitter=is_training) | |
if isinstance(narration, list): | |
if narration_selection == 'random': | |
narration = random.choice(narration) | |
elif narration_selection == 'concat': | |
narration = '. '.join(narration) | |
elif narration_selection == 'list': | |
narration = narration | |
else: | |
raise ValueError | |
return frames, narration | |
elif self.dataset == 'ego4d_mcq': | |
itemMCQ = self.samples[str(i)] | |
answerIndex = itemMCQ['answer'] | |
textQuery = itemMCQ['query']['clip_text'] | |
sampleOptions = itemMCQ['choices'] | |
frames_options = [] | |
narration_options = [] | |
for option_id in range(len(sampleOptions)): | |
option = sampleOptions[str(option_id)] | |
frames = video_loader(self.root, option['video_uid'], | |
float(option['clip_start']), end_second=float(option['clip_end']), | |
clip_length=clip_length, | |
jitter=is_training) | |
frames_options.append(frames) | |
narration_options.append(option['clip_text']) | |
return textQuery, frames_options, narration_options, answerIndex, itemMCQ['types'] | |
elif self.dataset == 'ek100_mir': | |
vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i] | |
# from third_party.EgoVLP.base.base_dataset import sample_frames_start_end | |
# frame_ids = sample_frames_start_end(clip_length, start_frame, end_frame, sample='uniform', fix_start=None) | |
frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training) | |
frames = video_loader_by_frames(self.root, vid_path, frame_ids) | |
if is_training: | |
positive_list = np.where(self.relevancy_mat[i] > self.relevancy)[0].tolist() | |
if positive_list != []: | |
pos = random.sample(positive_list, min(len(positive_list), 1))[0] | |
if pos < len(self.metadata_sentence) and pos < self.relevancy_mat.shape[1]: | |
return frames, (self.metadata_sentence.iloc[pos][1], self.relevancy_mat[i][pos]) | |
else: | |
return frames, (narration, 1) | |
elif self.dataset == 'ek100_cls': | |
vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i] | |
frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training) | |
frames = video_loader_by_frames(self.root, vid_path, frame_ids) | |
return frames, '{}:{}'.format(verb, noun) | |
elif self.dataset == 'egtea': | |
vid_path, start_frame, end_frame, sentence = self.samples[i] | |
if is_training: | |
assert num_clips == 1 | |
if end_frame < clip_length * clip_stride: | |
frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame))) | |
zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:])) | |
frames = torch.cat((frames, zeros), dim=0) | |
frames = frames[::clip_stride] | |
else: | |
start_id = np.random.randint(0, end_frame - clip_length * clip_stride + 1) | |
frame_ids = np.arange(start_id, start_id + clip_length * clip_stride, clip_stride) | |
frames = video_loader_by_frames(self.root, vid_path, frame_ids) | |
else: | |
if end_frame < clip_length * clip_stride: | |
frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame))) | |
zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:])) | |
frames = torch.cat((frames, zeros), dim=0) | |
frames = frames[::clip_stride] | |
frames = frames.repeat(num_clips, 1, 1, 1) | |
else: | |
frame_ids = [] | |
for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int): | |
frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride)) | |
frames = video_loader_by_frames(self.root, vid_path, frame_ids) | |
return frames, sentence | |
elif self.dataset == 'charades_ego': | |
vid_path, start_frame, end_frame, action_list = self.samples[i] | |
if sparse_sample: | |
frame_ids = get_frame_ids(start_frame, end_frame, num_segments=num_clips * clip_length, jitter=is_training) | |
frames = video_loader_by_frames(self.root, vid_path, frame_ids) | |
else: | |
if end_frame < clip_length * clip_stride: | |
frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame))) | |
zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:])) | |
frames = torch.cat((frames, zeros), dim=0) | |
frames = frames[::clip_stride] | |
frames = frames.repeat(num_clips, 1, 1, 1) | |
else: | |
frame_ids = [] | |
for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int): | |
frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride)) | |
#print('frame_ids:', frame_ids) | |
frames = video_loader_by_frames(self.root, vid_path, frame_ids) | |
return frames, action_list, vid_path | |
elif self.dataset == 'charades_ego_trimmed': | |
vid, start_second, end_second, narration = self.samples[i] | |
frames = video_loader(self.root, vid, start_second, | |
end_second=end_second, | |
chunk_len=-1, # no chunk for CharadesEgo | |
fps=-1, # could be variable fps | |
clip_length=clip_length, | |
jitter=is_training) | |
return frames, narration | |
else: | |
raise NotImplementedError | |
def __getitem__(self, i): | |
raise NotImplementedError | |
def __len__(self): | |
return len(self.samples) | |
class VideoCaptionDatasetCLIP(VideoCaptionDatasetBase): | |
def __init__(self, dataset, root, metadata, transform=None, | |
is_training=True, tokenizer=None, | |
clip_length=32, clip_stride=2, sparse_sample=False, | |
narration_selection='random', | |
num_hard_negatives=0, | |
subsample_stride=None): | |
super().__init__(dataset, root, metadata) | |
self.full_samples = self.samples.copy() | |
if isinstance(subsample_stride, int): | |
self.samples = self.samples[::subsample_stride] | |
self.transform = transform | |
self.is_training = is_training | |
self.tokenizer = tokenizer | |
self.clip_length = clip_length | |
self.clip_stride = clip_stride | |
self.sparse_sample = sparse_sample | |
self.narration_selection = narration_selection | |
self.num_hard_negatives = num_hard_negatives | |
if num_hard_negatives > 0: | |
assert self.dataset == 'htm_aa' | |
def __getitem__(self, i): | |
frames, caption = self.get_raw_item( | |
i, is_training=self.is_training, | |
clip_length=self.clip_length, | |
clip_stride=self.clip_stride, | |
sparse_sample=self.sparse_sample, | |
narration_selection=self.narration_selection, | |
) | |
# ek100_mir will also output relevancy value | |
if isinstance(caption, tuple): | |
caption, relevancy = caption | |
else: | |
relevancy = 0. | |
# apply transformation | |
if self.transform is not None: | |
frames = self.transform(frames) | |
# tokenize caption | |
if self.tokenizer is not None: | |
caption = self.tokenizer(caption) | |
if isinstance(caption, tuple): | |
caption, mask = caption | |
return frames, caption, mask, relevancy | |
else: | |
return frames, caption, relevancy | |
class VideoCaptionDatasetMCQ(VideoCaptionDatasetBase): | |
def __init__(self, dataset, root, metadata, transform=None, | |
is_training=True, tokenizer=None, | |
clip_length=32, clip_stride=2, sparse_sample=False, | |
narration_selection='random'): | |
super().__init__(dataset, root, metadata) | |
self.full_samples = self.samples.copy() | |
self.transform = transform | |
self.is_training = is_training | |
self.tokenizer = tokenizer | |
self.clip_length = clip_length | |
self.clip_stride = clip_stride | |
self.sparse_sample = sparse_sample | |
self.narration_selection = narration_selection | |
def __getitem__(self, i): | |
textQuery, frames_options, narration_options, answerIndex, q_type = self.get_raw_item( | |
i, is_training=self.is_training, | |
clip_length=self.clip_length, | |
clip_stride=self.clip_stride, | |
sparse_sample=self.sparse_sample, | |
narration_selection=self.narration_selection, | |
) | |
# apply transformation | |
if self.transform is not None: | |
frames_options = [self.transform(frames) for frames in frames_options] | |
# tokenize caption | |
if self.tokenizer is not None: | |
textQuery = self.tokenizer(textQuery) | |
narration_options = self.tokenizer(narration_options) | |
if isinstance(textQuery, tuple): | |
textQuery, mask_query = textQuery | |
narration_options, mask_options = narration_options | |
return ( | |
textQuery, torch.stack(frames_options, dim=0), | |
narration_options, answerIndex, q_type, | |
mask_query, mask_options | |
) | |
else: | |
return textQuery, torch.stack(frames_options, dim=0), narration_options, answerIndex, q_type | |
class VideoClassyDataset(VideoCaptionDatasetBase): | |
def __init__( | |
self, dataset, root, metadata, transform=None, | |
is_training=True, label_mapping=None, | |
num_clips=1, | |
clip_length=32, clip_stride=2, | |
sparse_sample=False, | |
is_trimmed=True, | |
): | |
super().__init__(dataset, root, metadata, is_trimmed=is_trimmed) | |
self.transform = transform | |
self.is_training = is_training | |
self.label_mapping = label_mapping | |
self.num_clips = num_clips | |
self.clip_length = clip_length | |
self.clip_stride = clip_stride | |
self.sparse_sample = sparse_sample | |
def __getitem__(self, i): | |
frames, label, vid_path = self.get_raw_item( | |
i, is_training=self.is_training, | |
num_clips=self.num_clips, | |
clip_length=self.clip_length, | |
clip_stride=self.clip_stride, | |
sparse_sample=self.sparse_sample, | |
) | |
# apply transformation | |
if self.transform is not None: | |
frames = self.transform(frames) | |
if self.label_mapping is not None: | |
if isinstance(label, list): | |
# multi-label case | |
res_array = np.zeros(len(self.label_mapping)) | |
for lbl in label: | |
res_array[self.label_mapping[lbl]] = 1. | |
label = res_array | |
else: | |
label = self.label_mapping[label] | |
return frames, label, vid_path | |
def get_dataset(train_transform, tokenizer, cfg, is_training=True): | |
narration_selection = cfg.get('narration_selection', 'random') | |
num_hard_neg = cfg.get('num_hard_neg', 0) | |
data_cfg = cfg['data'] | |
if cfg['model']['arch'].startswith('CLIP') or cfg['model']['arch'].startswith('VCLM'): | |
if is_training: | |
metadata = data_cfg['metadata'] | |
else: | |
metadata = data_cfg['metadata_val'] | |
return VideoCaptionDatasetCLIP( | |
data_cfg['dataset'], data_cfg['root'], metadata, train_transform, | |
is_training=is_training, | |
tokenizer=tokenizer, | |
clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'], | |
sparse_sample=data_cfg['sparse_sample'], | |
narration_selection=narration_selection, | |
num_hard_negatives=num_hard_neg | |
) | |
else: | |
raise NotImplementedError | |
def get_downstream_dataset(transform, tokenizer, cfg, is_training=True, num_clips=0, label_mapping=None): | |
data_cfg = cfg['data'] | |
n_clips = num_clips if num_clips > 0 else data_cfg['num_clips'] | |
if is_training: | |
metadata = data_cfg['metadata'] | |
return VideoClassyDataset( | |
data_cfg['dataset'], data_cfg['root'], metadata, transform, | |
is_training=True, label_mapping=label_mapping, | |
num_clips=n_clips, | |
clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'], | |
sparse_sample=data_cfg['sparse_sample'], | |
) | |
else: | |
metadata = data_cfg['metadata_val'] | |
return VideoClassyDataset( | |
data_cfg['dataset'], data_cfg['root'], metadata, transform, | |
is_training=False, label_mapping=label_mapping, | |
num_clips=n_clips, | |
clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'], | |
sparse_sample=data_cfg['sparse_sample'], | |
is_trimmed=not data_cfg['dataset'] == 'charades_ego' | |
) | |