Spaces:
Sleeping
Sleeping
Commit
·
1317804
1
Parent(s):
cd50369
add
Browse files- app.py +84 -83
- models/vq/quantizer.py +1 -1
app.py
CHANGED
@@ -172,84 +172,7 @@ class BaseTrainer(object):
|
|
172 |
self.args.vae_layer = 4
|
173 |
self.vq_model_lower = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank)
|
174 |
other_tools.load_checkpoints(self.vq_model_lower, args.vqvae_lower_path, args.e_name)
|
175 |
-
|
176 |
-
elif vq_type=="rvqvae":
|
177 |
-
|
178 |
-
args.num_quantizers = 6
|
179 |
-
args.shared_codebook = False
|
180 |
-
args.quantize_dropout_prob = 0.2
|
181 |
-
args.mu = 0.99
|
182 |
-
|
183 |
-
args.nb_code = 512
|
184 |
-
args.code_dim = 512
|
185 |
-
args.code_dim = 512
|
186 |
-
args.down_t = 2
|
187 |
-
args.stride_t = 2
|
188 |
-
args.width = 512
|
189 |
-
args.depth = 3
|
190 |
-
args.dilation_growth_rate = 3
|
191 |
-
args.vq_act = "relu"
|
192 |
-
args.vq_norm = None
|
193 |
-
|
194 |
-
dim_pose = 78
|
195 |
-
args.body_part = "upper"
|
196 |
-
self.vq_model_upper = RVQVAE(args,
|
197 |
-
dim_pose,
|
198 |
-
args.nb_code,
|
199 |
-
args.code_dim,
|
200 |
-
args.code_dim,
|
201 |
-
args.down_t,
|
202 |
-
args.stride_t,
|
203 |
-
args.width,
|
204 |
-
args.depth,
|
205 |
-
args.dilation_growth_rate,
|
206 |
-
args.vq_act,
|
207 |
-
args.vq_norm)
|
208 |
-
|
209 |
-
dim_pose = 180
|
210 |
-
args.body_part = "hands"
|
211 |
-
self.vq_model_hands = RVQVAE(args,
|
212 |
-
dim_pose,
|
213 |
-
args.nb_code,
|
214 |
-
args.code_dim,
|
215 |
-
args.code_dim,
|
216 |
-
args.down_t,
|
217 |
-
args.stride_t,
|
218 |
-
args.width,
|
219 |
-
args.depth,
|
220 |
-
args.dilation_growth_rate,
|
221 |
-
args.vq_act,
|
222 |
-
args.vq_norm)
|
223 |
-
|
224 |
-
dim_pose = 54
|
225 |
-
if args.use_trans:
|
226 |
-
dim_pose = 57
|
227 |
-
self.args.vqvae_lower_path = self.args.vqvae_lower_trans_path
|
228 |
-
args.body_part = "lower"
|
229 |
-
self.vq_model_lower = RVQVAE(args,
|
230 |
-
dim_pose,
|
231 |
-
args.nb_code,
|
232 |
-
args.code_dim,
|
233 |
-
args.code_dim,
|
234 |
-
args.down_t,
|
235 |
-
args.stride_t,
|
236 |
-
args.width,
|
237 |
-
args.depth,
|
238 |
-
args.dilation_growth_rate,
|
239 |
-
args.vq_act,
|
240 |
-
args.vq_norm)
|
241 |
-
|
242 |
-
self.vq_model_upper.load_state_dict(torch.load(self.args.vqvae_upper_path)['net'])
|
243 |
-
self.vq_model_hands.load_state_dict(torch.load(self.args.vqvae_hands_path)['net'])
|
244 |
-
self.vq_model_lower.load_state_dict(torch.load(self.args.vqvae_lower_path)['net'])
|
245 |
-
|
246 |
-
self.vqvae_latent_scale = self.args.vqvae_latent_scale
|
247 |
|
248 |
-
self.vq_model_upper.eval().to(self.rank)
|
249 |
-
self.vq_model_hands.eval().to(self.rank)
|
250 |
-
self.vq_model_lower.eval().to(self.rank)
|
251 |
-
|
252 |
-
|
253 |
|
254 |
|
255 |
|
@@ -260,10 +183,7 @@ class BaseTrainer(object):
|
|
260 |
self.args.vae_length = 240
|
261 |
|
262 |
|
263 |
-
|
264 |
-
self.vq_model_upper.eval()
|
265 |
-
self.vq_model_hands.eval()
|
266 |
-
self.vq_model_lower.eval()
|
267 |
|
268 |
self.cls_loss = nn.NLLLoss().to(self.rank)
|
269 |
self.reclatent_loss = nn.MSELoss().to(self.rank)
|
@@ -609,8 +529,91 @@ class BaseTrainer(object):
|
|
609 |
'rec_exps': rec_exps,
|
610 |
}
|
611 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
612 |
@spaces.GPU(duration=149)
|
613 |
def _warp(self, batch_data):
|
|
|
|
|
|
|
614 |
loaded_data = self._load_data(batch_data)
|
615 |
net_out = self._g_test(loaded_data)
|
616 |
return net_out
|
@@ -634,8 +637,6 @@ class BaseTrainer(object):
|
|
634 |
latent_ori = []
|
635 |
l2_all = 0
|
636 |
lvel = 0
|
637 |
-
self.model = self.model.cuda()
|
638 |
-
self.model.eval()
|
639 |
# self.eval_copy.eval()
|
640 |
with torch.no_grad():
|
641 |
for its, batch_data in enumerate(self.test_loader):
|
|
|
172 |
self.args.vae_layer = 4
|
173 |
self.vq_model_lower = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank)
|
174 |
other_tools.load_checkpoints(self.vq_model_lower, args.vqvae_lower_path, args.e_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
|
178 |
|
|
|
183 |
self.args.vae_length = 240
|
184 |
|
185 |
|
186 |
+
|
|
|
|
|
|
|
187 |
|
188 |
self.cls_loss = nn.NLLLoss().to(self.rank)
|
189 |
self.reclatent_loss = nn.MSELoss().to(self.rank)
|
|
|
529 |
'rec_exps': rec_exps,
|
530 |
}
|
531 |
|
532 |
+
|
533 |
+
def _create_cuda_model(self):
|
534 |
+
args = self.args
|
535 |
+
args.num_quantizers = 6
|
536 |
+
args.shared_codebook = False
|
537 |
+
args.quantize_dropout_prob = 0.2
|
538 |
+
args.mu = 0.99
|
539 |
+
|
540 |
+
args.nb_code = 512
|
541 |
+
args.code_dim = 512
|
542 |
+
args.code_dim = 512
|
543 |
+
args.down_t = 2
|
544 |
+
args.stride_t = 2
|
545 |
+
args.width = 512
|
546 |
+
args.depth = 3
|
547 |
+
args.dilation_growth_rate = 3
|
548 |
+
args.vq_act = "relu"
|
549 |
+
args.vq_norm = None
|
550 |
+
|
551 |
+
dim_pose = 78
|
552 |
+
args.body_part = "upper"
|
553 |
+
self.vq_model_upper = RVQVAE(args,
|
554 |
+
dim_pose,
|
555 |
+
args.nb_code,
|
556 |
+
args.code_dim,
|
557 |
+
args.code_dim,
|
558 |
+
args.down_t,
|
559 |
+
args.stride_t,
|
560 |
+
args.width,
|
561 |
+
args.depth,
|
562 |
+
args.dilation_growth_rate,
|
563 |
+
args.vq_act,
|
564 |
+
args.vq_norm)
|
565 |
+
|
566 |
+
dim_pose = 180
|
567 |
+
args.body_part = "hands"
|
568 |
+
self.vq_model_hands = RVQVAE(args,
|
569 |
+
dim_pose,
|
570 |
+
args.nb_code,
|
571 |
+
args.code_dim,
|
572 |
+
args.code_dim,
|
573 |
+
args.down_t,
|
574 |
+
args.stride_t,
|
575 |
+
args.width,
|
576 |
+
args.depth,
|
577 |
+
args.dilation_growth_rate,
|
578 |
+
args.vq_act,
|
579 |
+
args.vq_norm)
|
580 |
+
|
581 |
+
dim_pose = 54
|
582 |
+
if args.use_trans:
|
583 |
+
dim_pose = 57
|
584 |
+
self.args.vqvae_lower_path = self.args.vqvae_lower_trans_path
|
585 |
+
args.body_part = "lower"
|
586 |
+
self.vq_model_lower = RVQVAE(args,
|
587 |
+
dim_pose,
|
588 |
+
args.nb_code,
|
589 |
+
args.code_dim,
|
590 |
+
args.code_dim,
|
591 |
+
args.down_t,
|
592 |
+
args.stride_t,
|
593 |
+
args.width,
|
594 |
+
args.depth,
|
595 |
+
args.dilation_growth_rate,
|
596 |
+
args.vq_act,
|
597 |
+
args.vq_norm)
|
598 |
+
|
599 |
+
self.vq_model_upper.load_state_dict(torch.load(self.args.vqvae_upper_path)['net'])
|
600 |
+
self.vq_model_hands.load_state_dict(torch.load(self.args.vqvae_hands_path)['net'])
|
601 |
+
self.vq_model_lower.load_state_dict(torch.load(self.args.vqvae_lower_path)['net'])
|
602 |
+
|
603 |
+
self.vqvae_latent_scale = self.args.vqvae_latent_scale
|
604 |
+
|
605 |
+
self.vq_model_upper.eval().to(self.rank)
|
606 |
+
self.vq_model_hands.eval().to(self.rank)
|
607 |
+
self.vq_model_lower.eval().to(self.rank)
|
608 |
+
|
609 |
+
self.model = self.model.cuda()
|
610 |
+
self.model.eval()
|
611 |
+
|
612 |
@spaces.GPU(duration=149)
|
613 |
def _warp(self, batch_data):
|
614 |
+
self._create_cuda_model()
|
615 |
+
|
616 |
+
|
617 |
loaded_data = self._load_data(batch_data)
|
618 |
net_out = self._g_test(loaded_data)
|
619 |
return net_out
|
|
|
637 |
latent_ori = []
|
638 |
l2_all = 0
|
639 |
lvel = 0
|
|
|
|
|
640 |
# self.eval_copy.eval()
|
641 |
with torch.no_grad():
|
642 |
for its, batch_data in enumerate(self.test_loader):
|
models/vq/quantizer.py
CHANGED
@@ -44,7 +44,7 @@ class QuantizeEMAReset(nn.Module):
|
|
44 |
self.init = False
|
45 |
self.code_sum = None
|
46 |
self.code_count = None
|
47 |
-
self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim, requires_grad=False))
|
48 |
|
49 |
def _tile(self, x):
|
50 |
nb_code_x, code_dim = x.shape
|
|
|
44 |
self.init = False
|
45 |
self.code_sum = None
|
46 |
self.code_count = None
|
47 |
+
self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim, requires_grad=False).cuda())
|
48 |
|
49 |
def _tile(self, x):
|
50 |
nb_code_x, code_dim = x.shape
|