Spaces:
Runtime error
Runtime error
File size: 2,481 Bytes
ef1c94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from torch.utils.data import Dataset
import csv
import os
import numpy as np
class HighlightDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, root_dir, transform=None):
"""
Arguments:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all datas including videos and annotations.
"""
self.root_dir = root_dir
self.video_dir = os.path.join(root_dir, "videos")
self.anno_path = os.path.join(root_dir, "ydata-tvsum50-anno.tsv")
#read annotations
with open(self.anno_path, newline='') as f:
reader = csv.reader(f, delimiter='\t')
raw_annotations = list(reader)
self.num_annotator = 20
self.annotations = self.parse_annotations(raw_annotations) # {video_id: [importance scores]}
#get list of videos
self.video_list = os.listdir(self.video_dir)
def parse_annotations(self, annotations):
'''
format of annotations:
[[video_id, video_category, importance score], ...]
'''
#separate annotations into chunks of length 20
parsed_annotations = {}
annotations_per_video = [annotations[i:i + self.num_annotator] for i in range(0, len(annotations), self.num_annotator)]
for anno_video in annotations_per_video:
video_id = anno_video[0][0]
video_category = anno_video[0][1]
#get importance score
#anno[2] is a string of scores separated by commas
importance_score = []
for anno in anno_video:
anno[2] = anno[2].split(',')
anno[2] = [float(score) for score in anno[2]]
importance_score.append(anno[2])
importance_score = np.array(importance_score)
#get average importance score
parsed_annotations[video_id] = np.mean(importance_score, axis=0)
return parsed_annotations
def __len__(self):
return len(self.video_list)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
#should return frames and scores
video_name = self.video_list[idx]
video_id = video_name.split('.')[0]
video_path = os.path.join(self.video_dir, video_name)
#get annotations
annotations = self.annotations[video_id]
return video_path, annotations |