JiantaoLin
commited on
Commit
·
fd1d806
1
Parent(s):
da7a94c
new
Browse files- pipeline/kiss3d_wrapper.py +13 -13
pipeline/kiss3d_wrapper.py
CHANGED
@@ -141,20 +141,20 @@ def init_wrapper_from_config(config_path):
|
|
141 |
caption_model = None
|
142 |
|
143 |
# load reconstruction model
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
#
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
# logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
|
156 |
-
recon_model = None
|
157 |
-
recon_model_config = None
|
158 |
# load llm
|
159 |
llm_configs = config_.get('llm', None)
|
160 |
if llm_configs is not None and False:
|
|
|
141 |
caption_model = None
|
142 |
|
143 |
# load reconstruction model
|
144 |
+
logger.info('==> Loading reconstruction model ...')
|
145 |
+
recon_device = config_['reconstruction'].get('device', 'cpu')
|
146 |
+
recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
|
147 |
+
recon_model = instantiate_from_config(recon_model_config.model_config)
|
148 |
+
# load recon model checkpoint
|
149 |
+
model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
|
150 |
+
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
151 |
+
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
152 |
+
recon_model.load_state_dict(state_dict, strict=True)
|
153 |
+
recon_model.to(recon_device)
|
154 |
+
recon_model.eval()
|
155 |
# logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
|
156 |
+
# recon_model = None
|
157 |
+
# recon_model_config = None
|
158 |
# load llm
|
159 |
llm_configs = config_.get('llm', None)
|
160 |
if llm_configs is not None and False:
|