sky24h commited on
Commit
a9e865c
·
1 Parent(s): 9cbebfb

add ZeroGPU support

Browse files
Files changed (1) hide show
  1. inference_utils.py +8 -8
inference_utils.py CHANGED
@@ -21,7 +21,7 @@ from diffusers import DDIMScheduler, ControlNetModel
21
  from diffusers import UNet2DConditionModel as OriginalUNet2DConditionModel
22
  from detail_encoder.encoder_plus import detail_encoder
23
 
24
-
25
 
26
  def get_draw(pil_img, size):
27
  cv2_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
@@ -63,23 +63,23 @@ def init_pipeline():
63
  id_encoder_path = base_path + "/pytorch_model_1.bin"
64
  pose_encoder_path = base_path + "/pytorch_model_2.bin"
65
 
66
- Unet = OriginalUNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to("cuda")
67
  id_encoder = ControlNetModel.from_unet(Unet)
68
  pose_encoder = ControlNetModel.from_unet(Unet)
69
- makeup_encoder = detail_encoder(Unet, "openai/clip-vit-large-patch14", "cuda", dtype=torch.float32)
70
  id_state_dict = torch.load(id_encoder_path)
71
  pose_state_dict = torch.load(pose_encoder_path)
72
  makeup_state_dict = torch.load(makeup_encoder_path)
73
  id_encoder.load_state_dict(id_state_dict, strict=False)
74
  pose_encoder.load_state_dict(pose_state_dict, strict=False)
75
  makeup_encoder.load_state_dict(makeup_state_dict, strict=False)
76
- id_encoder.to("cuda")
77
- pose_encoder.to("cuda")
78
- makeup_encoder.to("cuda")
79
 
80
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
81
- model_id, safety_checker=None, unet=Unet, controlnet=[id_encoder, pose_encoder], torch_dtype=torch.float32
82
- ).to("cuda")
83
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
84
  return pipe, makeup_encoder
85
 
 
21
  from diffusers import UNet2DConditionModel as OriginalUNet2DConditionModel
22
  from detail_encoder.encoder_plus import detail_encoder
23
 
24
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
25
 
26
  def get_draw(pil_img, size):
27
  cv2_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
 
63
  id_encoder_path = base_path + "/pytorch_model_1.bin"
64
  pose_encoder_path = base_path + "/pytorch_model_2.bin"
65
 
66
+ Unet = OriginalUNet2DConditionModel.from_pretrained(model_id, device=device, subfolder="unet")
67
  id_encoder = ControlNetModel.from_unet(Unet)
68
  pose_encoder = ControlNetModel.from_unet(Unet)
69
+ makeup_encoder = detail_encoder(Unet, "openai/clip-vit-large-patch14", device=device, dtype=torch.float16)
70
  id_state_dict = torch.load(id_encoder_path)
71
  pose_state_dict = torch.load(pose_encoder_path)
72
  makeup_state_dict = torch.load(makeup_encoder_path)
73
  id_encoder.load_state_dict(id_state_dict, strict=False)
74
  pose_encoder.load_state_dict(pose_state_dict, strict=False)
75
  makeup_encoder.load_state_dict(makeup_state_dict, strict=False)
76
+ id_encoder.to(device=device)
77
+ pose_encoder.to(device=device)
78
+ makeup_encoder.to(device=device)
79
 
80
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
81
+ model_id, safety_checker=None, unet=Unet, controlnet=[id_encoder, pose_encoder], device=device, torch_dtype=torch.float16
82
+ ).to(device=device)
83
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
84
  return pipe, makeup_encoder
85