Spaces:
Runtime error
Runtime error
import numpy as np | |
import random | |
import copy | |
import time | |
import warnings | |
import random | |
from torch.utils.data import Sampler | |
from torch._six import int_classes as _int_classes | |
class CustomGCSampler(Sampler): | |
"""Wraps another sampler to yield a mini-batch of indices. | |
The structure of this sampler is way to complicated because it is a shorter/simplified version of | |
CustomBatchSampler. The relations between breeds are not relevant for the cvpr 2022 paper, but we kept | |
this structure which we were using for the experiments with clade related losses. ToDo: restructure | |
this sampler. | |
Args: | |
data_sampler_info (dict): a dictionnary, containing information about the dataset and breeds. | |
batch_size (int): Size of mini-batch. | |
""" | |
def __init__(self, data_sampler_info_gc, batch_size, add_nonflat=False, more_standing=False): | |
if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ | |
batch_size <= 0: | |
assert (batch_size == 12 and add_nonflat==False) or (batch_size == 14 and add_nonflat==True) | |
raise ValueError("batch_size should be a positive integer value, " | |
"but got batch_size={}".format(batch_size)) | |
self.data_sampler_info_gc = data_sampler_info_gc | |
self.batch_size = batch_size | |
self.add_nonflat = add_nonflat | |
self.more_standing = more_standing | |
self.n_images_tot = len(self.data_sampler_info_gc['name_list']) # 4305 | |
# get full sorted image list | |
self.pose_dict = {} | |
self.dict_name_to_idx = {} | |
for ind_img, img in enumerate(self.data_sampler_info_gc['name_list']): | |
self.dict_name_to_idx[img] = ind_img | |
pose = self.data_sampler_info_gc['gc_annots_categories'][img]['pose'] | |
if pose in self.pose_dict.keys(): | |
self.pose_dict[pose].append(img) | |
else: | |
self.pose_dict[pose] = [img] | |
# prepare non-flat images | |
if self.add_nonflat: | |
self.n_images_nonflat_tot = len(self.data_sampler_info_gc['name_list_nonflat']) | |
# self.n_desired_batches = int(np.floor(len(self.data_sampler_info_gc['name_list']) / batch_size)) # 157 | |
self.n_desired_batches = int(np.ceil(len(self.get_list_for_group_index(ind_g=1, n_groups=5, shuffle=True, more_standing=self.more_standing)) / 3)) | |
def get_description(self): | |
description = "\ | |
This sampler returns stanext data such that poses are more balanced. \n\ | |
-> works on top of stanext24_withgc_v2" | |
return description | |
def get_nonflat_idx_list(self, shuffle=True): | |
all_nonflat_idxs = list(range(self.n_images_tot, self.n_images_tot + self.n_images_nonflat_tot)) | |
if shuffle: | |
random.shuffle(all_nonflat_idxs) | |
return all_nonflat_idxs | |
def get_list_for_group_index(self, ind_g, n_groups=5, shuffle=True, return_info=False, more_standing=False): | |
# availabe poses | |
# sitting_sym: 561 | |
# lying_sym: 199 | |
# jumping_touching: 21 | |
# standing_4paws: 1999 | |
# running: 132 | |
# sitting_comp: 306 | |
# onhindlegs: 16 | |
# walking: 325 | |
# lying_comp: 596 | |
# standing_fewpaws: 98 | |
# otherpose: 22 | |
# downwardfacingdog: 14 | |
# jumping_nottouching: 16 | |
# | |
# available groups (7 groups) | |
# 89: 'otherpose', 'downwardfacingdog', 'jumping_nottouching', 'onhindlegs', 'jumping_touching' | |
# 561: 'sitting_sym' | |
# 306: 'sitting_comp' | |
# 199: 'lying_sym' | |
# 596: 'lying_comp' | |
# 555: 'standing_fewpaws', 'running', 'walking' | |
# 1999: 'standing_4paws' | |
# -> sample: 2, 1.5, 1.5, 1.5, 1.5, 2, 2 | |
# | |
# available groups (5 groups) | |
# 89: 'otherpose', 'downwardfacingdog', 'jumping_nottouching', 'onhindlegs', 'jumping_touching' | |
# 867: 'sitting_sym', 'sitting_comp' | |
# 795: 'lying_sym', 'lying_comp' | |
# 555: 'standing_fewpaws', 'running', 'walking' | |
# 1999: 'standing_4paws' | |
# -> sample: 2, 3, 3, 2, 2 | |
assert (n_groups == 5) | |
if more_standing: | |
if ind_g == 0: | |
n_samples_per_batch = 2 | |
pose_names = ['otherpose', 'downwardfacingdog', 'jumping_nottouching', 'onhindlegs', 'jumping_touching'] | |
elif ind_g == 1: | |
n_samples_per_batch = 2 | |
pose_names = ['sitting_sym', 'sitting_comp'] | |
elif ind_g == 2: | |
n_samples_per_batch = 2 | |
pose_names = ['lying_sym', 'lying_comp'] | |
elif ind_g == 3: | |
n_samples_per_batch = 2 | |
pose_names = ['standing_fewpaws', 'running', 'walking'] | |
elif ind_g == 4: | |
n_samples_per_batch = 4 | |
pose_names = ['standing_4paws'] | |
else: | |
raise ValueError | |
else: | |
if ind_g == 0: | |
n_samples_per_batch = 2 | |
pose_names = ['otherpose', 'downwardfacingdog', 'jumping_nottouching', 'onhindlegs', 'jumping_touching'] | |
elif ind_g == 1: | |
n_samples_per_batch = 3 | |
pose_names = ['sitting_sym', 'sitting_comp'] | |
elif ind_g == 2: | |
n_samples_per_batch = 3 | |
pose_names = ['lying_sym', 'lying_comp'] | |
elif ind_g == 3: | |
n_samples_per_batch = 2 | |
pose_names = ['standing_fewpaws', 'running', 'walking'] | |
elif ind_g == 4: | |
n_samples_per_batch = 2 | |
pose_names = ['standing_4paws'] | |
else: | |
raise ValueError | |
all_imgs_this_group = [] | |
for pose_name in pose_names: | |
all_imgs_this_group.extend(self.pose_dict[pose_name]) | |
if shuffle: | |
random.shuffle(all_imgs_this_group) | |
if return_info: | |
return all_imgs_this_group, pose_names, n_samples_per_batch | |
else: | |
return all_imgs_this_group | |
def __iter__(self): | |
n_groups = 5 | |
group_lists = {} | |
n_samples_per_batch = {} | |
for ind_g in range(n_groups): | |
group_lists[ind_g], pose_names, n_samples_per_batch[ind_g] = self.get_list_for_group_index(ind_g, n_groups=5, shuffle=True, return_info=True, more_standing=self.more_standing) | |
if self.add_nonflat: | |
nonflat_idx_list = self.get_nonflat_idx_list() | |
# we want to sample all sitting poses at least once per batch (and ths all other | |
# images except standing on 4 paws) | |
all_batches = [] | |
for ind in range(self.n_desired_batches): | |
batch_with_idxs = [] | |
for ind_g in range(n_groups): | |
for ind_s in range(n_samples_per_batch[ind_g]): | |
if len(group_lists[ind_g]) == 0: | |
group_lists[ind_g] = self.get_list_for_group_index(ind_g, n_groups=5, shuffle=True, more_standing=self.more_standing) | |
name = group_lists[ind_g].pop(0) | |
idx = self.dict_name_to_idx[name] | |
batch_with_idxs.append(idx) | |
if self.add_nonflat: | |
for ind_x in range(2): | |
if len(nonflat_idx_list) == 0: | |
nonflat_idx_list = self.get_nonflat_idx_list() | |
idx = nonflat_idx_list.pop(0) | |
batch_with_idxs.append(idx) | |
all_batches.append(batch_with_idxs) | |
for batch in all_batches: | |
yield batch | |
def __len__(self): | |
# Since we are sampling pairs of dogs and not each breed has an even number of dogs, we can not | |
# guarantee to show each dog exacly once. What we do instead, is returning the same amount of | |
# batches as we would return with a standard sampler which is not based on dog pairs. | |
'''if self.drop_last: | |
return len(self.sampler) // self.batch_size # type: ignore | |
else: | |
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore''' | |
return self.n_desired_batches | |