robinwitch commited on
Commit
1317804
·
1 Parent(s): cd50369
Files changed (2) hide show
  1. app.py +84 -83
  2. models/vq/quantizer.py +1 -1
app.py CHANGED
@@ -172,84 +172,7 @@ class BaseTrainer(object):
172
  self.args.vae_layer = 4
173
  self.vq_model_lower = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank)
174
  other_tools.load_checkpoints(self.vq_model_lower, args.vqvae_lower_path, args.e_name)
175
-
176
- elif vq_type=="rvqvae":
177
-
178
- args.num_quantizers = 6
179
- args.shared_codebook = False
180
- args.quantize_dropout_prob = 0.2
181
- args.mu = 0.99
182
-
183
- args.nb_code = 512
184
- args.code_dim = 512
185
- args.code_dim = 512
186
- args.down_t = 2
187
- args.stride_t = 2
188
- args.width = 512
189
- args.depth = 3
190
- args.dilation_growth_rate = 3
191
- args.vq_act = "relu"
192
- args.vq_norm = None
193
-
194
- dim_pose = 78
195
- args.body_part = "upper"
196
- self.vq_model_upper = RVQVAE(args,
197
- dim_pose,
198
- args.nb_code,
199
- args.code_dim,
200
- args.code_dim,
201
- args.down_t,
202
- args.stride_t,
203
- args.width,
204
- args.depth,
205
- args.dilation_growth_rate,
206
- args.vq_act,
207
- args.vq_norm)
208
-
209
- dim_pose = 180
210
- args.body_part = "hands"
211
- self.vq_model_hands = RVQVAE(args,
212
- dim_pose,
213
- args.nb_code,
214
- args.code_dim,
215
- args.code_dim,
216
- args.down_t,
217
- args.stride_t,
218
- args.width,
219
- args.depth,
220
- args.dilation_growth_rate,
221
- args.vq_act,
222
- args.vq_norm)
223
-
224
- dim_pose = 54
225
- if args.use_trans:
226
- dim_pose = 57
227
- self.args.vqvae_lower_path = self.args.vqvae_lower_trans_path
228
- args.body_part = "lower"
229
- self.vq_model_lower = RVQVAE(args,
230
- dim_pose,
231
- args.nb_code,
232
- args.code_dim,
233
- args.code_dim,
234
- args.down_t,
235
- args.stride_t,
236
- args.width,
237
- args.depth,
238
- args.dilation_growth_rate,
239
- args.vq_act,
240
- args.vq_norm)
241
-
242
- self.vq_model_upper.load_state_dict(torch.load(self.args.vqvae_upper_path)['net'])
243
- self.vq_model_hands.load_state_dict(torch.load(self.args.vqvae_hands_path)['net'])
244
- self.vq_model_lower.load_state_dict(torch.load(self.args.vqvae_lower_path)['net'])
245
-
246
- self.vqvae_latent_scale = self.args.vqvae_latent_scale
247
 
248
- self.vq_model_upper.eval().to(self.rank)
249
- self.vq_model_hands.eval().to(self.rank)
250
- self.vq_model_lower.eval().to(self.rank)
251
-
252
-
253
 
254
 
255
 
@@ -260,10 +183,7 @@ class BaseTrainer(object):
260
  self.args.vae_length = 240
261
 
262
 
263
- # self.vq_model_face.eval()
264
- self.vq_model_upper.eval()
265
- self.vq_model_hands.eval()
266
- self.vq_model_lower.eval()
267
 
268
  self.cls_loss = nn.NLLLoss().to(self.rank)
269
  self.reclatent_loss = nn.MSELoss().to(self.rank)
@@ -609,8 +529,91 @@ class BaseTrainer(object):
609
  'rec_exps': rec_exps,
610
  }
611
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  @spaces.GPU(duration=149)
613
  def _warp(self, batch_data):
 
 
 
614
  loaded_data = self._load_data(batch_data)
615
  net_out = self._g_test(loaded_data)
616
  return net_out
@@ -634,8 +637,6 @@ class BaseTrainer(object):
634
  latent_ori = []
635
  l2_all = 0
636
  lvel = 0
637
- self.model = self.model.cuda()
638
- self.model.eval()
639
  # self.eval_copy.eval()
640
  with torch.no_grad():
641
  for its, batch_data in enumerate(self.test_loader):
 
172
  self.args.vae_layer = 4
173
  self.vq_model_lower = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank)
174
  other_tools.load_checkpoints(self.vq_model_lower, args.vqvae_lower_path, args.e_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
 
 
 
 
 
176
 
177
 
178
 
 
183
  self.args.vae_length = 240
184
 
185
 
186
+
 
 
 
187
 
188
  self.cls_loss = nn.NLLLoss().to(self.rank)
189
  self.reclatent_loss = nn.MSELoss().to(self.rank)
 
529
  'rec_exps': rec_exps,
530
  }
531
 
532
+
533
+ def _create_cuda_model(self):
534
+ args = self.args
535
+ args.num_quantizers = 6
536
+ args.shared_codebook = False
537
+ args.quantize_dropout_prob = 0.2
538
+ args.mu = 0.99
539
+
540
+ args.nb_code = 512
541
+ args.code_dim = 512
542
+ args.code_dim = 512
543
+ args.down_t = 2
544
+ args.stride_t = 2
545
+ args.width = 512
546
+ args.depth = 3
547
+ args.dilation_growth_rate = 3
548
+ args.vq_act = "relu"
549
+ args.vq_norm = None
550
+
551
+ dim_pose = 78
552
+ args.body_part = "upper"
553
+ self.vq_model_upper = RVQVAE(args,
554
+ dim_pose,
555
+ args.nb_code,
556
+ args.code_dim,
557
+ args.code_dim,
558
+ args.down_t,
559
+ args.stride_t,
560
+ args.width,
561
+ args.depth,
562
+ args.dilation_growth_rate,
563
+ args.vq_act,
564
+ args.vq_norm)
565
+
566
+ dim_pose = 180
567
+ args.body_part = "hands"
568
+ self.vq_model_hands = RVQVAE(args,
569
+ dim_pose,
570
+ args.nb_code,
571
+ args.code_dim,
572
+ args.code_dim,
573
+ args.down_t,
574
+ args.stride_t,
575
+ args.width,
576
+ args.depth,
577
+ args.dilation_growth_rate,
578
+ args.vq_act,
579
+ args.vq_norm)
580
+
581
+ dim_pose = 54
582
+ if args.use_trans:
583
+ dim_pose = 57
584
+ self.args.vqvae_lower_path = self.args.vqvae_lower_trans_path
585
+ args.body_part = "lower"
586
+ self.vq_model_lower = RVQVAE(args,
587
+ dim_pose,
588
+ args.nb_code,
589
+ args.code_dim,
590
+ args.code_dim,
591
+ args.down_t,
592
+ args.stride_t,
593
+ args.width,
594
+ args.depth,
595
+ args.dilation_growth_rate,
596
+ args.vq_act,
597
+ args.vq_norm)
598
+
599
+ self.vq_model_upper.load_state_dict(torch.load(self.args.vqvae_upper_path)['net'])
600
+ self.vq_model_hands.load_state_dict(torch.load(self.args.vqvae_hands_path)['net'])
601
+ self.vq_model_lower.load_state_dict(torch.load(self.args.vqvae_lower_path)['net'])
602
+
603
+ self.vqvae_latent_scale = self.args.vqvae_latent_scale
604
+
605
+ self.vq_model_upper.eval().to(self.rank)
606
+ self.vq_model_hands.eval().to(self.rank)
607
+ self.vq_model_lower.eval().to(self.rank)
608
+
609
+ self.model = self.model.cuda()
610
+ self.model.eval()
611
+
612
  @spaces.GPU(duration=149)
613
  def _warp(self, batch_data):
614
+ self._create_cuda_model()
615
+
616
+
617
  loaded_data = self._load_data(batch_data)
618
  net_out = self._g_test(loaded_data)
619
  return net_out
 
637
  latent_ori = []
638
  l2_all = 0
639
  lvel = 0
 
 
640
  # self.eval_copy.eval()
641
  with torch.no_grad():
642
  for its, batch_data in enumerate(self.test_loader):
models/vq/quantizer.py CHANGED
@@ -44,7 +44,7 @@ class QuantizeEMAReset(nn.Module):
44
  self.init = False
45
  self.code_sum = None
46
  self.code_count = None
47
- self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim, requires_grad=False))
48
 
49
  def _tile(self, x):
50
  nb_code_x, code_dim = x.shape
 
44
  self.init = False
45
  self.code_sum = None
46
  self.code_count = None
47
+ self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim, requires_grad=False).cuda())
48
 
49
  def _tile(self, x):
50
  nb_code_x, code_dim = x.shape