Spaces:
Runtime error
Runtime error
Commit
·
3ec1733
1
Parent(s):
608c551
Update model.py
Browse files
model.py
CHANGED
@@ -25,7 +25,7 @@ from annotator.util import HWC3, resize_image
|
|
25 |
|
26 |
CONTROLNET_MODEL_IDS = {
|
27 |
|
28 |
-
'
|
29 |
|
30 |
}
|
31 |
|
@@ -38,7 +38,7 @@ def download_all_controlnet_weights() -> None:
|
|
38 |
class Model:
|
39 |
def __init__(self,
|
40 |
base_model_id: str = 'runwayml/stable-diffusion-v1-5',
|
41 |
-
task_name: str = '
|
42 |
self.device = torch.device(
|
43 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
44 |
self.base_model_id = ''
|
@@ -123,29 +123,32 @@ class Model:
|
|
123 |
generator=generator,
|
124 |
image=control_image).images
|
125 |
|
126 |
-
|
127 |
-
def
|
128 |
input_image: np.ndarray,
|
129 |
image_resolution: int,
|
130 |
detect_resolution: int,
|
131 |
-
|
|
|
132 |
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
|
133 |
input_image = HWC3(input_image)
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
144 |
return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
|
145 |
-
|
146 |
|
147 |
@torch.inference_mode()
|
148 |
-
def
|
149 |
self,
|
150 |
input_image: np.ndarray,
|
151 |
prompt: str,
|
@@ -157,20 +160,23 @@ class Model:
|
|
157 |
num_steps: int,
|
158 |
guidance_scale: float,
|
159 |
seed: int,
|
160 |
-
|
|
|
161 |
) -> list[PIL.Image.Image]:
|
162 |
-
control_image, vis_control_image = self.
|
163 |
input_image=input_image,
|
164 |
image_resolution=image_resolution,
|
165 |
detect_resolution=detect_resolution,
|
166 |
-
|
|
|
167 |
)
|
168 |
-
self.load_controlnet_weight('
|
169 |
results = self.run_pipe(
|
170 |
prompt=self.get_prompt(prompt, additional_prompt),
|
171 |
negative_prompt=negative_prompt,
|
172 |
control_image=control_image,
|
173 |
-
|
|
|
174 |
num_steps=num_steps,
|
175 |
guidance_scale=guidance_scale,
|
176 |
seed=seed,
|
|
|
25 |
|
26 |
CONTROLNET_MODEL_IDS = {
|
27 |
|
28 |
+
'hough': 'lllyasviel/sd-controlnet-hough',
|
29 |
|
30 |
}
|
31 |
|
|
|
38 |
class Model:
|
39 |
def __init__(self,
|
40 |
base_model_id: str = 'runwayml/stable-diffusion-v1-5',
|
41 |
+
task_name: str = 'hough'):
|
42 |
self.device = torch.device(
|
43 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
44 |
self.base_model_id = ''
|
|
|
123 |
generator=generator,
|
124 |
image=control_image).images
|
125 |
|
126 |
+
@staticmethod
|
127 |
+
def preprocess_hough(
|
128 |
input_image: np.ndarray,
|
129 |
image_resolution: int,
|
130 |
detect_resolution: int,
|
131 |
+
value_threshold: float,
|
132 |
+
distance_threshold: float,
|
133 |
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
|
134 |
input_image = HWC3(input_image)
|
135 |
+
control_image = apply_mlsd(
|
136 |
+
resize_image(input_image, detect_resolution), value_threshold,
|
137 |
+
distance_threshold)
|
138 |
+
control_image = HWC3(control_image)
|
139 |
+
image = resize_image(input_image, image_resolution)
|
140 |
+
H, W = image.shape[:2]
|
141 |
+
control_image = cv2.resize(control_image, (W, H),
|
142 |
+
interpolation=cv2.INTER_NEAREST)
|
143 |
+
|
144 |
+
vis_control_image = 255 - cv2.dilate(
|
145 |
+
control_image, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
|
146 |
+
|
147 |
return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
|
148 |
+
vis_control_image)
|
149 |
|
150 |
@torch.inference_mode()
|
151 |
+
def process_hough(
|
152 |
self,
|
153 |
input_image: np.ndarray,
|
154 |
prompt: str,
|
|
|
160 |
num_steps: int,
|
161 |
guidance_scale: float,
|
162 |
seed: int,
|
163 |
+
value_threshold: float,
|
164 |
+
distance_threshold: float,
|
165 |
) -> list[PIL.Image.Image]:
|
166 |
+
control_image, vis_control_image = self.preprocess_hough(
|
167 |
input_image=input_image,
|
168 |
image_resolution=image_resolution,
|
169 |
detect_resolution=detect_resolution,
|
170 |
+
value_threshold=value_threshold,
|
171 |
+
distance_threshold=distance_threshold,
|
172 |
)
|
173 |
+
self.load_controlnet_weight('hough')
|
174 |
results = self.run_pipe(
|
175 |
prompt=self.get_prompt(prompt, additional_prompt),
|
176 |
negative_prompt=negative_prompt,
|
177 |
control_image=control_image,
|
178 |
+
|
179 |
+
num_images=num_images,
|
180 |
num_steps=num_steps,
|
181 |
guidance_scale=guidance_scale,
|
182 |
seed=seed,
|