yamildiego commited on
Commit
0afaab3
·
1 Parent(s): e0d4444

test without decoder StableCascadePipeline

Browse files
Files changed (1) hide show
  1. handler.py +2 -15
handler.py CHANGED
@@ -4,7 +4,7 @@ from PIL import Image
4
  from io import BytesIO
5
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from diffusers import StableDiffusionPipeline
7
- from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
8
 
9
  import torch
10
 
@@ -19,10 +19,9 @@ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.
19
  class EndpointHandler():
20
  def __init__(self, path=""):
21
  self.stable_diffusion_id = "Lykon/dreamshaper-8"
22
- self.pipe = StableDiffusionPipeline.from_pretrained(self.stable_diffusion_id,torch_dtype=dtype,safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to(device.type)
23
 
24
  self.prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
25
- self.decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
26
 
27
 
28
  self.generator = torch.Generator(device=device.type).manual_seed(3)
@@ -39,18 +38,6 @@ class EndpointHandler():
39
  height = data.pop("height", None)
40
  width = data.pop("width", None)
41
 
42
- # # run inference pipeline
43
- # out = self.pipe(
44
- # prompt=prompt,
45
- # negative_prompt=negative_prompt,
46
- # num_inference_steps=num_inference_steps,
47
- # guidance_scale=guidance_scale,
48
- # num_images_per_prompt=1,
49
- # height=height,
50
- # width=width,
51
- # generator=self.generator
52
- # )
53
-
54
  self.prior_pipeline.to(device)
55
  self.decoder_pipeline.to(device)
56
 
 
4
  from io import BytesIO
5
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from diffusers import StableDiffusionPipeline
7
+ from diffusers import StableCascadePipeline, StableCascadePriorPipeline
8
 
9
  import torch
10
 
 
19
  class EndpointHandler():
20
  def __init__(self, path=""):
21
  self.stable_diffusion_id = "Lykon/dreamshaper-8"
 
22
 
23
  self.prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
24
+ self.decoder_pipeline = StableCascadePipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
25
 
26
 
27
  self.generator = torch.Generator(device=device.type).manual_seed(3)
 
38
  height = data.pop("height", None)
39
  width = data.pop("width", None)
40
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  self.prior_pipeline.to(device)
42
  self.decoder_pipeline.to(device)
43