yamildiego commited on
Commit
b792488
·
2 Parent(s): 69d41c4 4b5fca1

Merge branch 'main' of https://huggingface.co/Charles-Elena/ControlNet-endpoint-CRL-test

Browse files
Files changed (1) hide show
  1. handler.py +20 -53
handler.py CHANGED
@@ -4,7 +4,6 @@ from PIL import Image
4
  from io import BytesIO
5
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from diffusers import StableDiffusionPipeline
7
- from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
8
 
9
  import torch
10
 
@@ -21,10 +20,6 @@ class EndpointHandler():
21
  self.stable_diffusion_id = "Lykon/dreamshaper-8"
22
  self.pipe = StableDiffusionPipeline.from_pretrained(self.stable_diffusion_id,torch_dtype=dtype,safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to(device.type)
23
 
24
- self.prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
25
- self.decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
26
-
27
-
28
  self.generator = torch.Generator(device=device.type).manual_seed(3)
29
 
30
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
@@ -32,53 +27,25 @@ class EndpointHandler():
32
  # :param data: A dictionary contains `inputs` and optional `image` field.
33
  # :return: A dictionary with `image` field contains image in base64.
34
  # """
35
- prompt = data.pop("inputs", None)
36
- num_inference_steps = data.pop("num_inference_steps", 30)
37
- guidance_scale = data.pop("guidance_scale", 7.4)
38
- negative_prompt = data.pop("negative_prompt", None)
39
- height = data.pop("height", None)
40
- width = data.pop("width", None)
41
-
42
- # # run inference pipeline
43
- # out = self.pipe(
44
- # prompt=prompt,
45
- # negative_prompt=negative_prompt,
46
- # num_inference_steps=num_inference_steps,
47
- # guidance_scale=guidance_scale,
48
- # num_images_per_prompt=1,
49
- # height=height,
50
- # width=width,
51
- # generator=self.generator
52
- # )
53
-
54
- self.prior_pipeline.to(device)
55
- self.decoder_pipeline.to(device)
56
-
57
- prior_output = prior_pipeline(
58
- prompt=prompt,
59
- height=height,
60
- width=width,
61
- num_inference_steps=num_inference_steps,
62
- # timesteps=DEFAULT_STAGE_C_TIMESTEPS,
63
- negative_prompt=negative_prompt,
64
- guidance_scale=guidance_scale,
65
- num_images_per_prompt=1,
66
- generator=self.generator,
67
- # callback=callback_prior,
68
- # callback_steps=callback_steps
69
  )
70
-
71
-
72
- decoder_output = self.decoder_pipeline(
73
- image_embeddings=prior_output.image_embeddings,
74
- prompt=prompt,
75
- num_inference_steps=num_inference_steps,
76
- # timesteps=decoder_timesteps,
77
- guidance_scale=guidance_scale,
78
- negative_prompt=negative_prompt,
79
- generator=self.generator,
80
- output_type="pil",
81
- ).images
82
-
83
- return decoder_output[0]
84
 
 
 
 
 
4
  from io import BytesIO
5
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from diffusers import StableDiffusionPipeline
 
7
 
8
  import torch
9
 
 
20
  self.stable_diffusion_id = "Lykon/dreamshaper-8"
21
  self.pipe = StableDiffusionPipeline.from_pretrained(self.stable_diffusion_id,torch_dtype=dtype,safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to(device.type)
22
 
 
 
 
 
23
  self.generator = torch.Generator(device=device.type).manual_seed(3)
24
 
25
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
 
27
  # :param data: A dictionary contains `inputs` and optional `image` field.
28
  # :return: A dictionary with `image` field contains image in base64.
29
  # """
30
+ prompt = data.pop("inputs", None)
31
+ num_inference_steps = data.pop("num_inference_steps", 30)
32
+ guidance_scale = data.pop("guidance_scale", 7.4)
33
+ negative_prompt = data.pop("negative_prompt", None)
34
+ height = data.pop("height", None)
35
+ width = data.pop("width", None)
36
+
37
+ # run inference pipeline
38
+ out = self.pipe(
39
+ prompt=prompt,
40
+ negative_prompt=negative_prompt,
41
+ num_inference_steps=num_inference_steps,
42
+ guidance_scale=guidance_scale,
43
+ num_images_per_prompt=1,
44
+ height=height,
45
+ width=width,
46
+ generator=self.generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+
50
+ # return first generate PIL image
51
+ return out.images[0]