yamildiego commited on
Commit
847ce27
·
1 Parent(s): 9b12fa2
Files changed (1) hide show
  1. handler.py +13 -11
handler.py CHANGED
@@ -6,6 +6,7 @@ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
  #from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionSafetyChecker
7
  # import Safety Checker
8
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
 
9
 
10
  import torch
11
 
@@ -75,15 +76,11 @@ class EndpointHandler():
75
  controlnet=self.controlnet,
76
  torch_dtype=dtype,
77
  safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda")
 
 
 
78
  # Define Generator with seed
79
  self.generator = torch.Generator(device=device.type).manual_seed(3)
80
-
81
- targets = [self.pipe.vae, self.pipe.unet]
82
- for target in targets:
83
- for module in target.modules():
84
- if isinstance(module, torch.nn.Conv2d):
85
- module.padding_mode = "circular"
86
-
87
 
88
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
89
  """
@@ -107,7 +104,12 @@ class EndpointHandler():
107
  self.pipe.controlnet = self.controlnet
108
 
109
 
110
-
 
 
 
 
 
111
  # hyperparamters
112
  num_inference_steps = data.pop("num_inference_steps", 30)
113
  guidance_scale = data.pop("guidance_scale", 7.4)
@@ -127,17 +129,17 @@ class EndpointHandler():
127
  #control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
128
 
129
  # run inference pipeline
130
- out = self.pipe(
131
  prompt=prompt,
132
  negative_prompt=negative_prompt,
133
  #image=control_image,
134
- image=image,
135
  num_inference_steps=num_inference_steps,
136
  guidance_scale=guidance_scale,
137
  num_images_per_prompt=1,
138
  height=height,
139
  width=width,
140
- controlnet_conditioning_scale=controlnet_conditioning_scale,
141
  generator=self.generator
142
  )
143
 
 
6
  #from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionSafetyChecker
7
  # import Safety Checker
8
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
9
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
10
 
11
  import torch
12
 
 
76
  controlnet=self.controlnet,
77
  torch_dtype=dtype,
78
  safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda")
79
+
80
+ self.pipe_without_controlnet = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=dtype).to(device.type)
81
+
82
  # Define Generator with seed
83
  self.generator = torch.Generator(device=device.type).manual_seed(3)
 
 
 
 
 
 
 
84
 
85
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
86
  """
 
104
  self.pipe.controlnet = self.controlnet
105
 
106
 
107
+ targets = [self.pipe_without_controlnet.vae, self.pipe_without_controlnet.unet]
108
+ for target in targets:
109
+ for module in target.modules():
110
+ if isinstance(module, torch.nn.Conv2d):
111
+ module.padding_mode = "circular"
112
+
113
  # hyperparamters
114
  num_inference_steps = data.pop("num_inference_steps", 30)
115
  guidance_scale = data.pop("guidance_scale", 7.4)
 
129
  #control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
130
 
131
  # run inference pipeline
132
+ out = self.pipe_without_controlnet(
133
  prompt=prompt,
134
  negative_prompt=negative_prompt,
135
  #image=control_image,
136
+ #image=image,
137
  num_inference_steps=num_inference_steps,
138
  guidance_scale=guidance_scale,
139
  num_images_per_prompt=1,
140
  height=height,
141
  width=width,
142
+ #controlnet_conditioning_scale=controlnet_conditioning_scale,
143
  generator=self.generator
144
  )
145