JiantaoLin commited on
Commit
93ec96b
·
1 Parent(s): cfc3663
Files changed (2) hide show
  1. app.py +1 -0
  2. pipeline/kiss3d_wrapper.py +8 -9
app.py CHANGED
@@ -146,6 +146,7 @@ def text_to_detailed(prompt, seed=None):
146
  # print(f"Before text_to_detailed: {torch.cuda.memory_allocated() / 1024**3} GB")
147
  return k3d_wrapper.get_detailed_prompt(prompt, seed)
148
 
 
149
  def text_to_image(prompt, seed=None, strength=1.0,lora_scale=1.0, num_inference_steps=30, redux_hparam=None, init_image=None, **kwargs):
150
  # print(f"Before text_to_image: {torch.cuda.memory_allocated() / 1024**3} GB")
151
  k3d_wrapper.renew_uuid()
 
146
  # print(f"Before text_to_detailed: {torch.cuda.memory_allocated() / 1024**3} GB")
147
  return k3d_wrapper.get_detailed_prompt(prompt, seed)
148
 
149
+ @spaces.GPU
150
  def text_to_image(prompt, seed=None, strength=1.0,lora_scale=1.0, num_inference_steps=30, redux_hparam=None, init_image=None, **kwargs):
151
  # print(f"Before text_to_image: {torch.cuda.memory_allocated() / 1024**3} GB")
152
  k3d_wrapper.renew_uuid()
pipeline/kiss3d_wrapper.py CHANGED
@@ -70,22 +70,22 @@ def init_wrapper_from_config(config_path):
70
  flux_lora_pth = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
71
  flux_redux_pth = config_['flux'].get('redux', None)
72
 
73
- # if flux_base_model_pth.endswith('safetensors'):
74
- # flux_pipe = FluxImg2ImgPipeline.from_single_file(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
75
- # else:
76
- # flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
77
 
78
  # load flux model and controlnet
79
  if flux_controlnet_pth is not None and False:
80
  flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)
81
  flux_pipe = convert_flux_pipeline(flux_pipe, FluxControlNetImg2ImgPipeline, controlnet=[flux_controlnet])
82
 
83
- # flux_pipe.scheduler = FlowMatchHeunDiscreteScheduler.from_config(flux_pipe.scheduler.config)
84
 
85
  # load lora weights
86
- # flux_pipe.load_lora_weights(flux_lora_pth)
87
- # flux_pipe.to(device=flux_device)
88
- flux_pipe = None
89
 
90
  # load redux model
91
  flux_redux_pipe = None
@@ -465,7 +465,6 @@ class kiss3d_wrapper(object):
465
 
466
  return preprocessed
467
 
468
- @spaces.GPU
469
  def generate_3d_bundle_image_text(self,
470
  prompt,
471
  image=None,
 
70
  flux_lora_pth = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
71
  flux_redux_pth = config_['flux'].get('redux', None)
72
 
73
+ if flux_base_model_pth.endswith('safetensors'):
74
+ flux_pipe = FluxImg2ImgPipeline.from_single_file(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
75
+ else:
76
+ flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
77
 
78
  # load flux model and controlnet
79
  if flux_controlnet_pth is not None and False:
80
  flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)
81
  flux_pipe = convert_flux_pipeline(flux_pipe, FluxControlNetImg2ImgPipeline, controlnet=[flux_controlnet])
82
 
83
+ flux_pipe.scheduler = FlowMatchHeunDiscreteScheduler.from_config(flux_pipe.scheduler.config)
84
 
85
  # load lora weights
86
+ flux_pipe.load_lora_weights(flux_lora_pth)
87
+ flux_pipe.to(device=flux_device)
88
+ # flux_pipe = None
89
 
90
  # load redux model
91
  flux_redux_pipe = None
 
465
 
466
  return preprocessed
467
 
 
468
  def generate_3d_bundle_image_text(self,
469
  prompt,
470
  image=None,