Spaces:
Runtime error
Runtime error
| import copy | |
| import csv | |
| import json | |
| import numpy as np | |
| import os | |
| import pickle | |
| import random | |
| import torch | |
| from torch.utils.data.sampler import Sampler | |
| import pdb | |
| class ASMRSampler(Sampler): | |
| """ | |
| Total videos: 2794. The sampler ends when last $BATCH_SIZE videos are left. | |
| """ | |
| def __init__(self, list_sample, batch_size, rand_per_epoch=True): | |
| self.list_sample = list_sample | |
| self.batch_size = batch_size | |
| if not rand_per_epoch: | |
| random.seed(1234) | |
| self.N = len(self.list_sample) | |
| self.sample_class_dict = self.generate_vid_dict() | |
| # self.indexes = self.gen_index_batchwise() | |
| # pdb.set_trace() | |
| def generate_vid_dict(self): | |
| _ = [self.list_sample[i].append(i) for i in range(len(self.list_sample))] | |
| sample_class_dict = {} | |
| for i in range(len(self.list_sample)): | |
| video_name = self.list_sample[i][0] | |
| if video_name not in sample_class_dict: | |
| sample_class_dict[video_name] = [] | |
| sample_class_dict[video_name].append(self.list_sample[i]) | |
| return sample_class_dict | |
| def gen_index_batchwise(self): | |
| indexes = [] | |
| scd_copy = copy.deepcopy(self.sample_class_dict) | |
| for i in range(self.N // self.batch_size): | |
| if len(list(scd_copy.keys())) <= self.batch_size: | |
| break | |
| batch_vid = random.sample(scd_copy.keys(), self.batch_size) | |
| for vid in batch_vid: | |
| rand_clip = random.choice(scd_copy[vid]) | |
| indexes.append(rand_clip[-1]) | |
| scd_copy[vid].remove(rand_clip) # removed added element | |
| # remove dict if empty | |
| if len(scd_copy[vid]) == 0: | |
| del scd_copy[vid] | |
| # add remain items to indexes | |
| # for k, v in scd_copy.items(): | |
| # for item in v: | |
| # indexes.append(item[-1]) | |
| return indexes | |
| def __iter__(self): | |
| return iter(self.gen_index_batchwise()) | |
| def __len__(self): | |
| return self.N | |
| class VoxcelebSampler(Sampler): | |
| def __init__(self, list_sample, batch_size, rand_per_epoch=True): | |
| self.list_sample = list_sample | |
| self.batch_size = batch_size | |
| if not rand_per_epoch: | |
| random.seed(1234) | |
| self.N = len(self.list_sample) | |
| self.sample_class_dict = self.generate_vid_dict() | |
| def generate_vid_dict(self): | |
| _ = [self.sample[i].append(i) for i in range(len(self.list_sample))] | |
| sample_class_dict = {} | |
| pdb.set_trace() | |
| for i in range(len(self.list_sample)): | |
| video_name = self.list_sample[i][0] | |
| if video_name in batch_vid: | |
| pdb.set_trace() |