JiantaoLin
commited on
Commit
·
4ffb78d
1
Parent(s):
32921ee
new
Browse files
pipeline/kiss3d_wrapper.py
CHANGED
@@ -113,11 +113,11 @@ def init_wrapper_from_config(config_path):
|
|
113 |
multiview_pipeline.scheduler.config, timestep_spacing='trailing'
|
114 |
)
|
115 |
|
116 |
-
# unet_ckpt_path =
|
117 |
-
unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="
|
118 |
if unet_ckpt_path is not None:
|
119 |
state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
|
120 |
-
state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
|
121 |
multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
|
122 |
|
123 |
multiview_pipeline.to(multiview_device)
|
|
|
113 |
multiview_pipeline.scheduler.config, timestep_spacing='trailing'
|
114 |
)
|
115 |
|
116 |
+
# unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen_19w.ckpt", repo_type="model", token=access_token)
|
117 |
+
unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen.ckpt", repo_type="model", token=access_token)
|
118 |
if unet_ckpt_path is not None:
|
119 |
state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
|
120 |
+
# state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
|
121 |
multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
|
122 |
|
123 |
multiview_pipeline.to(multiview_device)
|