File size: 433 Bytes
15acbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from sklearn.mixture import GaussianMixture

def get_noise_sampler(sample_type='gau'):
    if sample_type == 'gau':
        sampler = lambda latnt_sz: torch.randn_like(latnt_sz)
    elif sample_type == 'gau_offset':
        sampler = lambda latnt_sz: torch.randn_like(latnt_sz) + (torch.randn_like(latnt_sz))
        ...
    elif sample_type == 'gmm':
        ...
    else:
        ...
    return 

if __name__ == "__main__":
    ...