| def forward( | |
| self, | |
| sample, | |
| sample_posterior=False, | |
| return_dict=True, | |
| generator=None, | |
| ): | |
| r""" | |
| Args: | |
| sample (`torch.Tensor`): Input sample. | |
| sample_posterior (`bool`, *optional*, defaults to `False`): | |
| Whether to sample from the posterior. | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`DecoderOutput`] instead of a plain tuple. | |
| """ | |
| x = sample | |
| posterior = self.encode(x).latent_dist | |
| if sample_posterior: | |
| z = posterior.sample(generator=generator) | |
| else: | |
| z = posterior.mode() | |
| dec = self.decode(z).sample | |
| return dec, None, None | |