charleselena commited on
Commit
4b5fca1
·
verified ·
1 Parent(s): 3a07267

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +31 -49
handler.py CHANGED
@@ -1,51 +1,33 @@
1
- from typing import Dict, List, Any
2
- import base64
3
- from PIL import Image
4
- from io import BytesIO
5
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
- from diffusers import StableDiffusionPipeline
7
-
8
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
-
11
- # # set device
12
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
- if device.type != 'cuda':
14
- raise ValueError("need to run on GPU")
15
- # set mixed precision dtype
16
- dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
17
-
18
- class EndpointHandler():
19
- def __init__(self, path=""):
20
- self.stable_diffusion_id = "Lykon/dreamshaper-8"
21
- self.pipe = StableDiffusionPipeline.from_pretrained(self.stable_diffusion_id,torch_dtype=dtype,safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to(device.type)
22
-
23
- self.generator = torch.Generator(device=device.type).manual_seed(3)
24
-
25
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
26
- # """
27
- # :param data: A dictionary contains `inputs` and optional `image` field.
28
- # :return: A dictionary with `image` field contains image in base64.
29
- # """
30
- prompt = data.pop("inputs", None)
31
- num_inference_steps = data.pop("num_inference_steps", 30)
32
- guidance_scale = data.pop("guidance_scale", 7.4)
33
- negative_prompt = data.pop("negative_prompt", None)
34
- height = data.pop("height", None)
35
- width = data.pop("width", None)
36
-
37
- # run inference pipeline
38
- out = self.pipe(
39
- prompt=prompt,
40
- negative_prompt=negative_prompt,
41
- num_inference_steps=num_inference_steps,
42
- guidance_scale=guidance_scale,
43
- num_images_per_prompt=1,
44
- height=height,
45
- width=width,
46
- generator=self.generator
47
- )
48
-
49
-
50
- # return first generate PIL image
51
- return out.images[0]
 
 
 
 
 
 
 
 
1
  import torch
2
+ from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
3
+
4
+ device = "cuda"
5
+ num_images_per_prompt = 1
6
+
7
+ prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device)
8
+ decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to(device)
9
+
10
+ prompt = "Anthropomorphic cat dressed as a pilot"
11
+ negative_prompt = ""
12
+
13
+ prior_output = prior(
14
+ prompt=prompt,
15
+ height=1024,
16
+ width=1024,
17
+ negative_prompt=negative_prompt,
18
+ guidance_scale=4.0,
19
+ num_images_per_prompt=num_images_per_prompt,
20
+ num_inference_steps=20
21
+ )
22
+
23
+ decoder_output = decoder(
24
+ image_embeddings=prior_output.image_embeddings.half(),
25
+ prompt=prompt,
26
+ negative_prompt=negative_prompt,
27
+ guidance_scale=0.0,
28
+ output_type="pil",
29
+ num_inference_steps=10
30
+ ).images
31
+
32
+ return images[0]
33