JiantaoLin commited on
Commit
4ffb78d
·
1 Parent(s): 32921ee
Files changed (1) hide show
  1. pipeline/kiss3d_wrapper.py +3 -3
pipeline/kiss3d_wrapper.py CHANGED
@@ -113,11 +113,11 @@ def init_wrapper_from_config(config_path):
113
  multiview_pipeline.scheduler.config, timestep_spacing='trailing'
114
  )
115
 
116
- # unet_ckpt_path = config_['multiview'].get('unet', None)
117
- unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen_19w.ckpt", repo_type="model", token=access_token)
118
  if unet_ckpt_path is not None:
119
  state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
120
- state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
121
  multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
122
 
123
  multiview_pipeline.to(multiview_device)
 
113
  multiview_pipeline.scheduler.config, timestep_spacing='trailing'
114
  )
115
 
116
+ # unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen_19w.ckpt", repo_type="model", token=access_token)
117
+ unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen.ckpt", repo_type="model", token=access_token)
118
  if unet_ckpt_path is not None:
119
  state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
120
+ # state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
121
  multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
122
 
123
  multiview_pipeline.to(multiview_device)