JiantaoLin commited on
Commit
5926261
Β·
1 Parent(s): 1576cad
Files changed (1) hide show
  1. pipeline/kiss3d_wrapper.py +2 -1
pipeline/kiss3d_wrapper.py CHANGED
@@ -74,7 +74,8 @@ def init_wrapper_from_config(config_path):
74
  flux_pipe = FluxImg2ImgPipeline.from_single_file(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
75
  else:
76
  flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
77
-
 
78
  # load flux model and controlnet
79
  if flux_controlnet_pth is not None and False:
80
  flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)
 
74
  flux_pipe = FluxImg2ImgPipeline.from_single_file(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
75
  else:
76
  flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
77
+ flux_pipe.enable_vae_slicing()
78
+ flux_pipe.enable_xformers_memory_efficient_attention()
79
  # load flux model and controlnet
80
  if flux_controlnet_pth is not None and False:
81
  flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)