JiantaoLin commited on
Commit
fd1d806
·
1 Parent(s): da7a94c
Files changed (1) hide show
  1. pipeline/kiss3d_wrapper.py +13 -13
pipeline/kiss3d_wrapper.py CHANGED
@@ -141,20 +141,20 @@ def init_wrapper_from_config(config_path):
141
  caption_model = None
142
 
143
  # load reconstruction model
144
- # logger.info('==> Loading reconstruction model ...')
145
- # recon_device = config_['reconstruction'].get('device', 'cpu')
146
- # recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
147
- # recon_model = instantiate_from_config(recon_model_config.model_config)
148
- # # load recon model checkpoint
149
- # model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
150
- # state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
151
- # state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
152
- # recon_model.load_state_dict(state_dict, strict=True)
153
- # recon_model.to(recon_device)
154
- # recon_model.eval()
155
  # logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
156
- recon_model = None
157
- recon_model_config = None
158
  # load llm
159
  llm_configs = config_.get('llm', None)
160
  if llm_configs is not None and False:
 
141
  caption_model = None
142
 
143
  # load reconstruction model
144
+ logger.info('==> Loading reconstruction model ...')
145
+ recon_device = config_['reconstruction'].get('device', 'cpu')
146
+ recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
147
+ recon_model = instantiate_from_config(recon_model_config.model_config)
148
+ # load recon model checkpoint
149
+ model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
150
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
151
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
152
+ recon_model.load_state_dict(state_dict, strict=True)
153
+ recon_model.to(recon_device)
154
+ recon_model.eval()
155
  # logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
156
+ # recon_model = None
157
+ # recon_model_config = None
158
  # load llm
159
  llm_configs = config_.get('llm', None)
160
  if llm_configs is not None and False: