yamildiego commited on
Commit
9b12fa2
·
1 Parent(s): 1a01558
Files changed (1) hide show
  1. handler.py +9 -7
handler.py CHANGED
@@ -72,11 +72,18 @@ class EndpointHandler():
72
  self.stable_diffusion_id = "Lykon/dreamshaper-8"
73
 
74
  self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
75
- #controlnet=self.controlnet,
76
  torch_dtype=dtype,
77
  safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda")
78
  # Define Generator with seed
79
  self.generator = torch.Generator(device=device.type).manual_seed(3)
 
 
 
 
 
 
 
80
 
81
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
82
  """
@@ -100,12 +107,7 @@ class EndpointHandler():
100
  self.pipe.controlnet = self.controlnet
101
 
102
 
103
- targets = [self.pipe.vae, self.pipe.unet]
104
- for target in targets:
105
- for module in target.modules():
106
- if isinstance(module, torch.nn.Conv2d):
107
- module.padding_mode = "circular"
108
-
109
  # hyperparamters
110
  num_inference_steps = data.pop("num_inference_steps", 30)
111
  guidance_scale = data.pop("guidance_scale", 7.4)
 
72
  self.stable_diffusion_id = "Lykon/dreamshaper-8"
73
 
74
  self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
75
+ controlnet=self.controlnet,
76
  torch_dtype=dtype,
77
  safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda")
78
  # Define Generator with seed
79
  self.generator = torch.Generator(device=device.type).manual_seed(3)
80
+
81
+ targets = [self.pipe.vae, self.pipe.unet]
82
+ for target in targets:
83
+ for module in target.modules():
84
+ if isinstance(module, torch.nn.Conv2d):
85
+ module.padding_mode = "circular"
86
+
87
 
88
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
89
  """
 
107
  self.pipe.controlnet = self.controlnet
108
 
109
 
110
+
 
 
 
 
 
111
  # hyperparamters
112
  num_inference_steps = data.pop("num_inference_steps", 30)
113
  guidance_scale = data.pop("guidance_scale", 7.4)