Flux9665 commited on
Commit
8b95ccf
·
verified ·
1 Parent(s): d5f2d74

Update Modules/ControllabilityGAN/wgan/wgan_qc.py

Browse files
Modules/ControllabilityGAN/wgan/wgan_qc.py CHANGED
@@ -237,9 +237,9 @@ class WassersteinGanQuadraticCost:
237
  def sample_generator(self, num_samples, nograd=False, return_intermediate=False):
238
  self.G.eval()
239
  if isinstance(self.G, torch.nn.parallel.DataParallel):
240
- latent_samples = self.G.module.sample_latent(num_samples, self.G.module.z_dim)
241
  else:
242
- latent_samples = self.G.sample_latent(num_samples, self.G.z_dim)
243
  latent_samples = latent_samples.to(self.device)
244
  if nograd:
245
  with torch.no_grad():
 
237
  def sample_generator(self, num_samples, nograd=False, return_intermediate=False):
238
  self.G.eval()
239
  if isinstance(self.G, torch.nn.parallel.DataParallel):
240
+ latent_samples = self.G.module.sample_latent(num_samples, self.G.module.z_dim, 1.0)
241
  else:
242
+ latent_samples = self.G.sample_latent(num_samples, self.G.z_dim, 1.0)
243
  latent_samples = latent_samples.to(self.device)
244
  if nograd:
245
  with torch.no_grad():