Flux9665 commited on
Commit
2889ed0
·
verified ·
1 Parent(s): 4a80e28

Update Modules/ControllabilityGAN/GAN.py

Browse files
Files changed (1) hide show
  1. Modules/ControllabilityGAN/GAN.py +7 -5
Modules/ControllabilityGAN/GAN.py CHANGED
@@ -14,6 +14,8 @@ class GanWrapper:
14
  self.wgan = None
15
  self.normalize = True
16
 
 
 
17
  self.load_model(path_wgan)
18
 
19
  self.U = self.compute_controllability()
@@ -21,12 +23,12 @@ class GanWrapper:
21
  self.z_list = list()
22
 
23
  while len(self.z_list) < num_cached_voices + 2:
24
- z = self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8)
25
- sims = [-1.0]
26
  for other_z in self.z_list:
27
- sims.append(torch.nn.functional.cosine_similarity(z, other_z))
28
- print(max(sims), len(self.z_list))
29
- if max(sims) < 0.25:
30
  self.z_list.append(z)
31
  self.z = self.z_list[0]
32
 
 
14
  self.wgan = None
15
  self.normalize = True
16
 
17
+ torch.manual_seed(160923)
18
+
19
  self.load_model(path_wgan)
20
 
21
  self.U = self.compute_controllability()
 
23
  self.z_list = list()
24
 
25
  while len(self.z_list) < num_cached_voices + 2:
26
+ z = self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.4)
27
+ l1_distances = [100.0]
28
  for other_z in self.z_list:
29
+ l1_distances.append(torch.nn.functional.l1_loss(z, other_z))
30
+ print("dist: ", min(l1_distances), len(self.z_list))
31
+ if min(l1_distances) > 0.5:
32
  self.z_list.append(z)
33
  self.z = self.z_list[0]
34