robinwitch commited on
Commit
26b89f7
·
1 Parent(s): 3cc12c0
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -36,6 +36,7 @@ from diffusion.model_util import create_gaussian_diffusion
36
  from diffusion.resample import create_named_schedule_sampler
37
  from models.vq.model import RVQVAE
38
  import spaces
 
39
 
40
  command = ["bash","./demo/install_mfa1.sh"]
41
  result = subprocess.run(command, capture_output=True, text=True)
@@ -267,7 +268,13 @@ class BaseTrainer(object):
267
  for its, batch_data in enumerate(self.test_loader):
268
  # loaded_data = self._load_data(batch_data)
269
  # net_out = self._g_test(loaded_data)
270
- net_out = _warp(self.args,self.model, batch_data,self.joints,self.joint_mask_upper,self.joint_mask_hands,self.joint_mask_lower,self.use_trans,self.mean_upper,self.mean_hands,self.mean_lower,self.std_upper,self.std_hands,self.std_lower,self.trans_mean,self.trans_std)
 
 
 
 
 
 
271
  tar_pose = net_out['tar_pose']
272
  rec_pose = net_out['rec_pose']
273
  tar_exps = net_out['tar_exps']
@@ -341,6 +348,8 @@ def _warp(args,model, batch_data,joints,joint_mask_upper,joint_mask_hands,joint_
341
  batch_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,args,use_trans,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vq_model_upper,vq_model_hands,vq_model_lower
342
  )
343
  net_out = _warp_g_test(loaded_data,diffusion,args,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,model,vqvae_latent_scale,vq_model_upper,vq_model_hands,vq_model_lower,use_trans,trans_std,trans_mean,std_upper,std_hands,std_lower,mean_upper,mean_hands,mean_lower)
 
 
344
  return net_out
345
 
346
  def _warp_inverse_selection_tensor(filtered_t, selection_array, n):
 
36
  from diffusion.resample import create_named_schedule_sampler
37
  from models.vq.model import RVQVAE
38
  import spaces
39
+ import pickle
40
 
41
  command = ["bash","./demo/install_mfa1.sh"]
42
  result = subprocess.run(command, capture_output=True, text=True)
 
268
  for its, batch_data in enumerate(self.test_loader):
269
  # loaded_data = self._load_data(batch_data)
270
  # net_out = self._g_test(loaded_data)
271
+ try:
272
+ net_out = _warp(self.args,self.model, batch_data,self.joints,self.joint_mask_upper,self.joint_mask_hands,self.joint_mask_lower,self.use_trans,self.mean_upper,self.mean_hands,self.mean_lower,self.std_upper,self.std_hands,self.std_lower,self.trans_mean,self.trans_std)
273
+ print("debug8: return try")
274
+ except:
275
+ print("debug9: return fail, use pickle load file")
276
+ with open("tmp_file", "rb") as tmp_file:
277
+ net_out = pickle.loads(tmp_file)
278
  tar_pose = net_out['tar_pose']
279
  rec_pose = net_out['rec_pose']
280
  tar_exps = net_out['tar_exps']
 
348
  batch_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,args,use_trans,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vq_model_upper,vq_model_hands,vq_model_lower
349
  )
350
  net_out = _warp_g_test(loaded_data,diffusion,args,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,model,vqvae_latent_scale,vq_model_upper,vq_model_hands,vq_model_lower,use_trans,trans_std,trans_mean,std_upper,std_hands,std_lower,mean_upper,mean_hands,mean_lower)
351
+ with open("tmp_file", "wb") as tmp_file:
352
+ pickle.dump(net_out, tmp_file)
353
  return net_out
354
 
355
  def _warp_inverse_selection_tensor(filtered_t, selection_array, n):