File size: 433 Bytes
15acbf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
from sklearn.mixture import GaussianMixture
def get_noise_sampler(sample_type='gau'):
if sample_type == 'gau':
sampler = lambda latnt_sz: torch.randn_like(latnt_sz)
elif sample_type == 'gau_offset':
sampler = lambda latnt_sz: torch.randn_like(latnt_sz) + (torch.randn_like(latnt_sz))
...
elif sample_type == 'gmm':
...
else:
...
return
if __name__ == "__main__":
... |