File size: 1,582 Bytes
93c029f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from torch.utils.data.sampler import RandomSampler, Sampler
import numpy as np

class FixedLenRandomSampler(RandomSampler):
    """
    Code from mnpinto - Miguel
    https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10
    """
    def __init__(self, data_source, bs, epoch_size, *args, **kwargs):
        super().__init__(data_source)
        self.epoch_size = epoch_size
        self.bs = bs
        self.not_sampled = np.array([True]*len(data_source))
        self.size_to_sample = self.epoch_size * self.bs

    @property
    def _reset_state(self):
        self.not_sampled[:] = True

    def __iter__(self):
        ns = sum(self.not_sampled)
        idx_last = []
        if ns >= self.size_to_sample:
            idx = np.random.choice(np.where(self.not_sampled)[0], size=self.size_to_sample, replace=False).tolist()
            if ns == self.size_to_sample:
                self._reset_state
        else:
            idx_last = np.where(self.not_sampled)[0].tolist()
            self._reset_state
            idx = np.random.choice(np.where(self.not_sampled)[0], size=self.size_to_sample-len(idx_last), replace=False).tolist()
        self.not_sampled[idx] = False
        idx = [*idx_last, *idx]
        # print(ns, len(idx), len(idx_last)) # debug
        out = []
        i_idx = 0
        for i in range(self.epoch_size):
            batch = []
            for j in range(self.bs):
                batch.append(idx[i_idx])
                i_idx += 1
            out.append(batch)
        return iter(out)

    def __len__(self):
        return self.epoch_size