robinwitch commited on
Commit
828cd89
·
1 Parent(s): 7c6f179
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -432,7 +432,7 @@ class BaseTrainer(object):
432
  "tar_contact": tar_contact,
433
  "style_feature":style_feature,
434
  }
435
- @spaces.GPU(duration=149)
436
  def _g_test(self, loaded_data):
437
  sample_fn = self.diffusion.p_sample_loop
438
  if self.args.use_ddim:
@@ -617,7 +617,12 @@ class BaseTrainer(object):
617
  'rec_exps': rec_exps,
618
  }
619
 
620
-
 
 
 
 
 
621
  def test_demo(self, epoch):
622
  '''
623
  input audio and text, output motion
@@ -642,8 +647,9 @@ class BaseTrainer(object):
642
  # self.eval_copy.eval()
643
  with torch.no_grad():
644
  for its, batch_data in enumerate(self.test_loader):
645
- loaded_data = self._load_data(batch_data)
646
- net_out = self._g_test(loaded_data)
 
647
  tar_pose = net_out['tar_pose']
648
  rec_pose = net_out['rec_pose']
649
  tar_exps = net_out['tar_exps']
 
432
  "tar_contact": tar_contact,
433
  "style_feature":style_feature,
434
  }
435
+
436
  def _g_test(self, loaded_data):
437
  sample_fn = self.diffusion.p_sample_loop
438
  if self.args.use_ddim:
 
617
  'rec_exps': rec_exps,
618
  }
619
 
620
+ @spaces.GPU(duration=149)
621
+ def _warp(self, batch_data):
622
+ loaded_data = self._load_data(batch_data)
623
+ net_out = self._g_test(loaded_data)
624
+ return net_out
625
+
626
  def test_demo(self, epoch):
627
  '''
628
  input audio and text, output motion
 
647
  # self.eval_copy.eval()
648
  with torch.no_grad():
649
  for its, batch_data in enumerate(self.test_loader):
650
+ # loaded_data = self._load_data(batch_data)
651
+ # net_out = self._g_test(loaded_data)
652
+ net_out = self._warp(batch_data)
653
  tar_pose = net_out['tar_pose']
654
  rec_pose = net_out['rec_pose']
655
  tar_exps = net_out['tar_exps']