Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import csv | |
import json | |
import numpy as np | |
import os | |
import pickle | |
import random | |
import torch | |
from torch.utils.data.sampler import Sampler | |
import pdb | |
class ASMRSampler(Sampler): | |
""" | |
Total videos: 2794. The sampler ends when last $BATCH_SIZE videos are left. | |
""" | |
def __init__(self, list_sample, batch_size, rand_per_epoch=True): | |
self.list_sample = list_sample | |
self.batch_size = batch_size | |
if not rand_per_epoch: | |
random.seed(1234) | |
self.N = len(self.list_sample) | |
self.sample_class_dict = self.generate_vid_dict() | |
# self.indexes = self.gen_index_batchwise() | |
# pdb.set_trace() | |
def generate_vid_dict(self): | |
_ = [self.list_sample[i].append(i) for i in range(len(self.list_sample))] | |
sample_class_dict = {} | |
for i in range(len(self.list_sample)): | |
video_name = self.list_sample[i][0] | |
if video_name not in sample_class_dict: | |
sample_class_dict[video_name] = [] | |
sample_class_dict[video_name].append(self.list_sample[i]) | |
return sample_class_dict | |
def gen_index_batchwise(self): | |
indexes = [] | |
scd_copy = copy.deepcopy(self.sample_class_dict) | |
for i in range(self.N // self.batch_size): | |
if len(list(scd_copy.keys())) <= self.batch_size: | |
break | |
batch_vid = random.sample(scd_copy.keys(), self.batch_size) | |
for vid in batch_vid: | |
rand_clip = random.choice(scd_copy[vid]) | |
indexes.append(rand_clip[-1]) | |
scd_copy[vid].remove(rand_clip) # removed added element | |
# remove dict if empty | |
if len(scd_copy[vid]) == 0: | |
del scd_copy[vid] | |
# add remain items to indexes | |
# for k, v in scd_copy.items(): | |
# for item in v: | |
# indexes.append(item[-1]) | |
return indexes | |
def __iter__(self): | |
return iter(self.gen_index_batchwise()) | |
def __len__(self): | |
return self.N | |
class VoxcelebSampler(Sampler): | |
def __init__(self, list_sample, batch_size, rand_per_epoch=True): | |
self.list_sample = list_sample | |
self.batch_size = batch_size | |
if not rand_per_epoch: | |
random.seed(1234) | |
self.N = len(self.list_sample) | |
self.sample_class_dict = self.generate_vid_dict() | |
def generate_vid_dict(self): | |
_ = [self.sample[i].append(i) for i in range(len(self.list_sample))] | |
sample_class_dict = {} | |
pdb.set_trace() | |
for i in range(len(self.list_sample)): | |
video_name = self.list_sample[i][0] | |
if video_name in batch_vid: | |
pdb.set_trace() |