RamAnanth1 commited on
Commit
3ec1733
·
1 Parent(s): 608c551

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +28 -22
model.py CHANGED
@@ -25,7 +25,7 @@ from annotator.util import HWC3, resize_image
25
 
26
  CONTROLNET_MODEL_IDS = {
27
 
28
- 'depth': 'lllyasviel/sd-controlnet-depth',
29
 
30
  }
31
 
@@ -38,7 +38,7 @@ def download_all_controlnet_weights() -> None:
38
  class Model:
39
  def __init__(self,
40
  base_model_id: str = 'runwayml/stable-diffusion-v1-5',
41
- task_name: str = 'depth'):
42
  self.device = torch.device(
43
  'cuda:0' if torch.cuda.is_available() else 'cpu')
44
  self.base_model_id = ''
@@ -123,29 +123,32 @@ class Model:
123
  generator=generator,
124
  image=control_image).images
125
 
126
- @staticmethod
127
- def preprocess_depth(
128
  input_image: np.ndarray,
129
  image_resolution: int,
130
  detect_resolution: int,
131
- is_depth_image: bool,
 
132
  ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
133
  input_image = HWC3(input_image)
134
- if not is_depth_image:
135
- control_image, _ = apply_midas(
136
- resize_image(input_image, detect_resolution))
137
- control_image = HWC3(control_image)
138
- image = resize_image(input_image, image_resolution)
139
- H, W = image.shape[:2]
140
- control_image = cv2.resize(control_image, (W, H),
141
- interpolation=cv2.INTER_LINEAR)
142
- else:
143
- control_image = resize_image(input_image, image_resolution)
 
 
144
  return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
145
- control_image)
146
 
147
  @torch.inference_mode()
148
- def process_depth(
149
  self,
150
  input_image: np.ndarray,
151
  prompt: str,
@@ -157,20 +160,23 @@ class Model:
157
  num_steps: int,
158
  guidance_scale: float,
159
  seed: int,
160
- is_depth_image: bool,
 
161
  ) -> list[PIL.Image.Image]:
162
- control_image, vis_control_image = self.preprocess_depth(
163
  input_image=input_image,
164
  image_resolution=image_resolution,
165
  detect_resolution=detect_resolution,
166
- is_depth_image=is_depth_image,
 
167
  )
168
- self.load_controlnet_weight('depth')
169
  results = self.run_pipe(
170
  prompt=self.get_prompt(prompt, additional_prompt),
171
  negative_prompt=negative_prompt,
172
  control_image=control_image,
173
- num_images=num_images,
 
174
  num_steps=num_steps,
175
  guidance_scale=guidance_scale,
176
  seed=seed,
 
25
 
26
  CONTROLNET_MODEL_IDS = {
27
 
28
+ 'hough': 'lllyasviel/sd-controlnet-hough',
29
 
30
  }
31
 
 
38
  class Model:
39
  def __init__(self,
40
  base_model_id: str = 'runwayml/stable-diffusion-v1-5',
41
+ task_name: str = 'hough'):
42
  self.device = torch.device(
43
  'cuda:0' if torch.cuda.is_available() else 'cpu')
44
  self.base_model_id = ''
 
123
  generator=generator,
124
  image=control_image).images
125
 
126
+ @staticmethod
127
+ def preprocess_hough(
128
  input_image: np.ndarray,
129
  image_resolution: int,
130
  detect_resolution: int,
131
+ value_threshold: float,
132
+ distance_threshold: float,
133
  ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
134
  input_image = HWC3(input_image)
135
+ control_image = apply_mlsd(
136
+ resize_image(input_image, detect_resolution), value_threshold,
137
+ distance_threshold)
138
+ control_image = HWC3(control_image)
139
+ image = resize_image(input_image, image_resolution)
140
+ H, W = image.shape[:2]
141
+ control_image = cv2.resize(control_image, (W, H),
142
+ interpolation=cv2.INTER_NEAREST)
143
+
144
+ vis_control_image = 255 - cv2.dilate(
145
+ control_image, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
146
+
147
  return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
148
+ vis_control_image)
149
 
150
  @torch.inference_mode()
151
+ def process_hough(
152
  self,
153
  input_image: np.ndarray,
154
  prompt: str,
 
160
  num_steps: int,
161
  guidance_scale: float,
162
  seed: int,
163
+ value_threshold: float,
164
+ distance_threshold: float,
165
  ) -> list[PIL.Image.Image]:
166
+ control_image, vis_control_image = self.preprocess_hough(
167
  input_image=input_image,
168
  image_resolution=image_resolution,
169
  detect_resolution=detect_resolution,
170
+ value_threshold=value_threshold,
171
+ distance_threshold=distance_threshold,
172
  )
173
+ self.load_controlnet_weight('hough')
174
  results = self.run_pipe(
175
  prompt=self.get_prompt(prompt, additional_prompt),
176
  negative_prompt=negative_prompt,
177
  control_image=control_image,
178
+
179
+ num_images=num_images,
180
  num_steps=num_steps,
181
  guidance_scale=guidance_scale,
182
  seed=seed,