JiantaoLin commited on
Commit
138cabe
·
1 Parent(s): 50e89c5
pipeline/kiss3d_wrapper.py CHANGED
@@ -137,19 +137,20 @@ def init_wrapper_from_config(config_path):
137
  caption_model = None
138
 
139
  # load reconstruction model
140
- logger.info('==> Loading reconstruction model ...')
141
- recon_device = config_['reconstruction'].get('device', 'cpu')
142
- recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
143
- recon_model = instantiate_from_config(recon_model_config.model_config)
144
- # load recon model checkpoint
145
- model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
146
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
147
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
148
- recon_model.load_state_dict(state_dict, strict=True)
149
- recon_model.to(recon_device)
150
- recon_model.eval()
151
  # logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
152
-
 
153
  # load llm
154
  llm_configs = config_.get('llm', None)
155
  if llm_configs is not None and False:
@@ -503,8 +504,8 @@ class kiss3d_wrapper(object):
503
  'num_inference_steps': num_inference_steps,
504
  'guidance_scale': 3.5,
505
  'num_images_per_prompt': 1,
506
- 'width': 2048//2,
507
- 'height': 1024//2,
508
  'output_type': 'np',
509
  'generator': generator,
510
  'joint_attention_kwargs': {"scale": lora_scale}
 
137
  caption_model = None
138
 
139
  # load reconstruction model
140
+ # logger.info('==> Loading reconstruction model ...')
141
+ # recon_device = config_['reconstruction'].get('device', 'cpu')
142
+ # recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
143
+ # recon_model = instantiate_from_config(recon_model_config.model_config)
144
+ # # load recon model checkpoint
145
+ # model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
146
+ # state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
147
+ # state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
148
+ # recon_model.load_state_dict(state_dict, strict=True)
149
+ # recon_model.to(recon_device)
150
+ # recon_model.eval()
151
  # logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
152
+ recon_model = None
153
+ recon_model_config = None
154
  # load llm
155
  llm_configs = config_.get('llm', None)
156
  if llm_configs is not None and False:
 
504
  'num_inference_steps': num_inference_steps,
505
  'guidance_scale': 3.5,
506
  'num_images_per_prompt': 1,
507
+ 'width': 2048,
508
+ 'height': 1024,
509
  'output_type': 'np',
510
  'generator': generator,
511
  'joint_attention_kwargs': {"scale": lora_scale}
pipeline/pipeline_config/default.yaml CHANGED
@@ -14,20 +14,20 @@ multiview:
14
  unet: "./checkpoint/zero123++/flexgen_19w.ckpt"
15
  num_inference_steps: 50
16
  seed: 42
17
- device: 'cuda:1'
18
 
19
  reconstruction:
20
  model_config: "./models/lrm/config/PRM_inference.yaml"
21
  base_model: "./checkpoint/lrm/final_ckpt.ckpt"
22
- device: 'cuda:1'
23
 
24
  caption:
25
  base_model: "multimodalart/Florence-2-large-no-flash-attn"
26
- device: 'cuda:1'
27
 
28
  llm:
29
  base_model: "Qwen/Qwen2-7B-Instruct"
30
- device: 'cuda:1'
31
 
32
  use_zero_gpu: false # for huggingface demo only
33
  3d_bundle_templates: './init_3d_Bundle'
 
14
  unet: "./checkpoint/zero123++/flexgen_19w.ckpt"
15
  num_inference_steps: 50
16
  seed: 42
17
+ device: 'cuda:0'
18
 
19
  reconstruction:
20
  model_config: "./models/lrm/config/PRM_inference.yaml"
21
  base_model: "./checkpoint/lrm/final_ckpt.ckpt"
22
+ device: 'cuda:0'
23
 
24
  caption:
25
  base_model: "multimodalart/Florence-2-large-no-flash-attn"
26
+ device: 'cuda:0'
27
 
28
  llm:
29
  base_model: "Qwen/Qwen2-7B-Instruct"
30
+ device: 'cuda:0'
31
 
32
  use_zero_gpu: false # for huggingface demo only
33
  3d_bundle_templates: './init_3d_Bundle'