Spaces:
Running
on
T4
Running
on
T4
Update Modules/ControllabilityGAN/wgan/wgan_qc.py
Browse files
Modules/ControllabilityGAN/wgan/wgan_qc.py
CHANGED
@@ -237,9 +237,9 @@ class WassersteinGanQuadraticCost:
|
|
237 |
def sample_generator(self, num_samples, nograd=False, return_intermediate=False):
|
238 |
self.G.eval()
|
239 |
if isinstance(self.G, torch.nn.parallel.DataParallel):
|
240 |
-
latent_samples = self.G.module.sample_latent(num_samples, self.G.module.z_dim)
|
241 |
else:
|
242 |
-
latent_samples = self.G.sample_latent(num_samples, self.G.z_dim)
|
243 |
latent_samples = latent_samples.to(self.device)
|
244 |
if nograd:
|
245 |
with torch.no_grad():
|
|
|
237 |
def sample_generator(self, num_samples, nograd=False, return_intermediate=False):
|
238 |
self.G.eval()
|
239 |
if isinstance(self.G, torch.nn.parallel.DataParallel):
|
240 |
+
latent_samples = self.G.module.sample_latent(num_samples, self.G.module.z_dim, 1.0)
|
241 |
else:
|
242 |
+
latent_samples = self.G.sample_latent(num_samples, self.G.z_dim, 1.0)
|
243 |
latent_samples = latent_samples.to(self.device)
|
244 |
if nograd:
|
245 |
with torch.no_grad():
|