JiantaoLin
commited on
Commit
Β·
844009d
1
Parent(s):
af53f48
- pipeline/kiss3d_wrapper.py +15 -15
pipeline/kiss3d_wrapper.py
CHANGED
|
@@ -75,8 +75,8 @@ def init_wrapper_from_config(config_path):
|
|
| 75 |
else:
|
| 76 |
flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
|
| 77 |
# flux_pipe.enable_vae_slicing()
|
| 78 |
-
|
| 79 |
-
flux_pipe.enable_sequential_cpu_offload()
|
| 80 |
# load flux model and controlnet
|
| 81 |
if flux_controlnet_pth is not None and False:
|
| 82 |
flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)
|
|
@@ -139,20 +139,20 @@ def init_wrapper_from_config(config_path):
|
|
| 139 |
caption_model = None
|
| 140 |
|
| 141 |
# load reconstruction model
|
| 142 |
-
logger.info('==> Loading reconstruction model ...')
|
| 143 |
-
recon_device = config_['reconstruction'].get('device', 'cpu')
|
| 144 |
-
recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
|
| 145 |
-
recon_model = instantiate_from_config(recon_model_config.model_config)
|
| 146 |
-
# load recon model checkpoint
|
| 147 |
-
model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
|
| 148 |
-
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
| 149 |
-
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
| 150 |
-
recon_model.load_state_dict(state_dict, strict=True)
|
| 151 |
-
recon_model.to(recon_device)
|
| 152 |
-
recon_model.eval()
|
| 153 |
# logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
|
| 154 |
-
|
| 155 |
-
|
| 156 |
# load llm
|
| 157 |
llm_configs = config_.get('llm', None)
|
| 158 |
if llm_configs is not None and False:
|
|
|
|
| 75 |
else:
|
| 76 |
flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
|
| 77 |
# flux_pipe.enable_vae_slicing()
|
| 78 |
+
flux_pipe.enable_xformers_memory_efficient_attention()
|
| 79 |
+
# flux_pipe.enable_sequential_cpu_offload()
|
| 80 |
# load flux model and controlnet
|
| 81 |
if flux_controlnet_pth is not None and False:
|
| 82 |
flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)
|
|
|
|
| 139 |
caption_model = None
|
| 140 |
|
| 141 |
# load reconstruction model
|
| 142 |
+
# logger.info('==> Loading reconstruction model ...')
|
| 143 |
+
# recon_device = config_['reconstruction'].get('device', 'cpu')
|
| 144 |
+
# recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
|
| 145 |
+
# recon_model = instantiate_from_config(recon_model_config.model_config)
|
| 146 |
+
# # load recon model checkpoint
|
| 147 |
+
# model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
|
| 148 |
+
# state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
| 149 |
+
# state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
| 150 |
+
# recon_model.load_state_dict(state_dict, strict=True)
|
| 151 |
+
# recon_model.to(recon_device)
|
| 152 |
+
# recon_model.eval()
|
| 153 |
# logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
|
| 154 |
+
recon_model = None
|
| 155 |
+
recon_model_config = None
|
| 156 |
# load llm
|
| 157 |
llm_configs = config_.get('llm', None)
|
| 158 |
if llm_configs is not None and False:
|