yamildiego commited on
Commit
69d41c4
·
1 Parent(s): 3a07267

try to implemnmt stable cascade

Browse files
Files changed (1) hide show
  1. handler.py +53 -20
handler.py CHANGED
@@ -4,6 +4,7 @@ from PIL import Image
4
  from io import BytesIO
5
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from diffusers import StableDiffusionPipeline
 
7
 
8
  import torch
9
 
@@ -20,6 +21,10 @@ class EndpointHandler():
20
  self.stable_diffusion_id = "Lykon/dreamshaper-8"
21
  self.pipe = StableDiffusionPipeline.from_pretrained(self.stable_diffusion_id,torch_dtype=dtype,safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to(device.type)
22
 
 
 
 
 
23
  self.generator = torch.Generator(device=device.type).manual_seed(3)
24
 
25
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
@@ -27,25 +32,53 @@ class EndpointHandler():
27
  # :param data: A dictionary contains `inputs` and optional `image` field.
28
  # :return: A dictionary with `image` field contains image in base64.
29
  # """
30
- prompt = data.pop("inputs", None)
31
- num_inference_steps = data.pop("num_inference_steps", 30)
32
- guidance_scale = data.pop("guidance_scale", 7.4)
33
- negative_prompt = data.pop("negative_prompt", None)
34
- height = data.pop("height", None)
35
- width = data.pop("width", None)
36
-
37
- # run inference pipeline
38
- out = self.pipe(
39
- prompt=prompt,
40
- negative_prompt=negative_prompt,
41
- num_inference_steps=num_inference_steps,
42
- guidance_scale=guidance_scale,
43
- num_images_per_prompt=1,
44
- height=height,
45
- width=width,
46
- generator=self.generator
47
- )
48
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # return first generate PIL image
51
- return out.images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from io import BytesIO
5
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from diffusers import StableDiffusionPipeline
7
+ from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
8
 
9
  import torch
10
 
 
21
  self.stable_diffusion_id = "Lykon/dreamshaper-8"
22
  self.pipe = StableDiffusionPipeline.from_pretrained(self.stable_diffusion_id,torch_dtype=dtype,safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to(device.type)
23
 
24
+ self.prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
25
+ self.decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
26
+
27
+
28
  self.generator = torch.Generator(device=device.type).manual_seed(3)
29
 
30
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
 
32
  # :param data: A dictionary contains `inputs` and optional `image` field.
33
  # :return: A dictionary with `image` field contains image in base64.
34
  # """
35
+ prompt = data.pop("inputs", None)
36
+ num_inference_steps = data.pop("num_inference_steps", 30)
37
+ guidance_scale = data.pop("guidance_scale", 7.4)
38
+ negative_prompt = data.pop("negative_prompt", None)
39
+ height = data.pop("height", None)
40
+ width = data.pop("width", None)
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # # run inference pipeline
43
+ # out = self.pipe(
44
+ # prompt=prompt,
45
+ # negative_prompt=negative_prompt,
46
+ # num_inference_steps=num_inference_steps,
47
+ # guidance_scale=guidance_scale,
48
+ # num_images_per_prompt=1,
49
+ # height=height,
50
+ # width=width,
51
+ # generator=self.generator
52
+ # )
53
 
54
+ self.prior_pipeline.to(device)
55
+ self.decoder_pipeline.to(device)
56
+
57
+ prior_output = prior_pipeline(
58
+ prompt=prompt,
59
+ height=height,
60
+ width=width,
61
+ num_inference_steps=num_inference_steps,
62
+ # timesteps=DEFAULT_STAGE_C_TIMESTEPS,
63
+ negative_prompt=negative_prompt,
64
+ guidance_scale=guidance_scale,
65
+ num_images_per_prompt=1,
66
+ generator=self.generator,
67
+ # callback=callback_prior,
68
+ # callback_steps=callback_steps
69
+ )
70
+
71
+
72
+ decoder_output = self.decoder_pipeline(
73
+ image_embeddings=prior_output.image_embeddings,
74
+ prompt=prompt,
75
+ num_inference_steps=num_inference_steps,
76
+ # timesteps=decoder_timesteps,
77
+ guidance_scale=guidance_scale,
78
+ negative_prompt=negative_prompt,
79
+ generator=self.generator,
80
+ output_type="pil",
81
+ ).images
82
+
83
+ return decoder_output[0]
84
+