saurav4455 commited on
Commit
ae31928
·
verified ·
1 Parent(s): aa5b197

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -119,20 +119,24 @@ DEFAULT_TEXT = "A person is "
119
  if not os.path.exists("checkpoints/t2m"):
120
  os.system("bash prepare/download_models_demo.sh")
121
 
 
 
 
122
  ##########################
123
- ######Preparing demo######
124
  ##########################
125
  parser = EvalT2MOptions()
126
  opt = parser.parse()
127
- fixseed(opt.seed)
128
- opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id))
 
129
  dim_pose = 263
130
  root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
131
  model_dir = pjoin(root_dir, 'model')
132
  model_opt_path = pjoin(root_dir, 'opt.txt')
133
  model_opt = get_opt(model_opt_path, device=opt.device)
134
 
135
- ######Loading RVQ######
136
  vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt')
137
  vq_opt = get_opt(vq_opt_path, device=opt.device)
138
  vq_opt.dim_pose = dim_pose
@@ -142,19 +146,20 @@ model_opt.num_tokens = vq_opt.nb_code
142
  model_opt.num_quantizers = vq_opt.num_quantizers
143
  model_opt.code_dim = vq_opt.code_dim
144
 
145
- ######Loading R-Transformer######
146
  res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt')
147
  res_opt = get_opt(res_opt_path, device=opt.device)
148
  res_model = load_res_model(res_opt, vq_opt, opt)
149
 
150
  assert res_opt.vq_name == model_opt.vq_name
151
 
152
- ######Loading M-Transformer######
153
  t2m_transformer = load_trans_model(model_opt, opt, 'latest.tar')
154
 
155
- #####Loading Length Predictor#####
156
  length_estimator = load_len_estimator(model_opt)
157
 
 
158
  t2m_transformer.eval()
159
  vq_model.eval()
160
  res_model.eval()
@@ -168,6 +173,7 @@ length_estimator.to(opt.device)
168
  opt.nb_joints = 22
169
  mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'mean.npy'))
170
  std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'std.npy'))
 
171
  def inv_transform(data):
172
  return data * std + mean
173
 
 
119
  if not os.path.exists("checkpoints/t2m"):
120
  os.system("bash prepare/download_models_demo.sh")
121
 
122
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
123
+ print("CUDA is not required; running on CPU.")
124
+
125
  ##########################
126
+ ###### Preparing Demo ####
127
  ##########################
128
  parser = EvalT2MOptions()
129
  opt = parser.parse()
130
+
131
+ # Force use of CPU
132
+ opt.device = torch.device("cpu")
133
  dim_pose = 263
134
  root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
135
  model_dir = pjoin(root_dir, 'model')
136
  model_opt_path = pjoin(root_dir, 'opt.txt')
137
  model_opt = get_opt(model_opt_path, device=opt.device)
138
 
139
+ ###### Loading RVQ ######
140
  vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt')
141
  vq_opt = get_opt(vq_opt_path, device=opt.device)
142
  vq_opt.dim_pose = dim_pose
 
146
  model_opt.num_quantizers = vq_opt.num_quantizers
147
  model_opt.code_dim = vq_opt.code_dim
148
 
149
+ ###### Loading R-Transformer ######
150
  res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt')
151
  res_opt = get_opt(res_opt_path, device=opt.device)
152
  res_model = load_res_model(res_opt, vq_opt, opt)
153
 
154
  assert res_opt.vq_name == model_opt.vq_name
155
 
156
+ ###### Loading M-Transformer ######
157
  t2m_transformer = load_trans_model(model_opt, opt, 'latest.tar')
158
 
159
+ ##### Loading Length Predictor #####
160
  length_estimator = load_len_estimator(model_opt)
161
 
162
+ # Set models to evaluation mode and move to CPU
163
  t2m_transformer.eval()
164
  vq_model.eval()
165
  res_model.eval()
 
173
  opt.nb_joints = 22
174
  mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'mean.npy'))
175
  std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'std.npy'))
176
+
177
  def inv_transform(data):
178
  return data * std + mean
179