|
from torch.utils.data import Sampler |
|
import numpy as np |
|
import logging |
|
from collections import defaultdict |
|
from pathlib import Path |
|
import torch |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class MultilabelStratifiedSampler(Sampler): |
|
def __init__(self, labels, groups, batch_size, cached_size=None): |
|
super().__init__(None) |
|
self.labels = np.array(labels) |
|
self.groups = np.array(groups) |
|
self.batch_size = batch_size |
|
self.num_samples = len(labels) |
|
|
|
|
|
if len(self.labels) != len(self.groups): |
|
raise ValueError("Length mismatch between labels and groups") |
|
|
|
|
|
self.group_indices = {} |
|
unique_groups = np.unique(self.groups) |
|
|
|
for group in unique_groups: |
|
indices = np.where(self.groups == group)[0] |
|
if len(indices) > 0: |
|
self.group_indices[group] = indices |
|
|
|
|
|
group_sizes = np.array([len(indices) for indices in self.group_indices.values()]) |
|
self.group_probs = group_sizes / group_sizes.sum() |
|
self.valid_groups = list(self.group_indices.keys()) |
|
|
|
|
|
self.num_batches = self.num_samples // self.batch_size |
|
if self.num_batches == 0: |
|
self.num_batches = 1 |
|
self.total_samples = self.num_batches * self.batch_size |
|
|
|
def __iter__(self): |
|
indices = [] |
|
for _ in range(self.num_batches): |
|
batch = [] |
|
for _ in range(self.batch_size): |
|
|
|
group = np.random.choice(self.valid_groups, p=self.group_probs) |
|
idx = np.random.choice(self.group_indices[group]) |
|
batch.append(idx) |
|
indices.extend(batch) |
|
|
|
return iter(indices) |
|
|
|
def __len__(self): |
|
return self.total_samples |