from sklearn.mixture import GaussianMixture def get_noise_sampler(sample_type='gau'): if sample_type == 'gau': sampler = lambda latnt_sz: torch.randn_like(latnt_sz) elif sample_type == 'gau_offset': sampler = lambda latnt_sz: torch.randn_like(latnt_sz) + (torch.randn_like(latnt_sz)) ... elif sample_type == 'gmm': ... else: ... return if __name__ == "__main__": ...