Andre Embury commited on
Commit
8297189
·
unverified ·
1 Parent(s): 84d59ce

First test with ControlNet Union

Browse files

Take inspiration:
https://github.com/xinsir6/ControlNetPlus/blob/main/controlnet_union_test_canny.py

Files changed (1) hide show
  1. app.py +63 -17
app.py CHANGED
@@ -7,6 +7,7 @@ import numpy as np
7
  from diffusers import (
8
  # StableDiffusionControlNetImg2ImgPipeline,
9
  ControlNetModel,
 
10
  StableDiffusionXLControlNetPipeline,
11
  )
12
  import torch
@@ -14,9 +15,14 @@ import torch
14
  import requests
15
  from fastapi import FastAPI, HTTPException
16
  from PIL import Image
 
17
 
18
- # from controlnet_aux import CannyDetector
19
- from controlnet_aux import ScribbleDetector
 
 
 
 
20
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  # model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
@@ -48,23 +54,36 @@ else:
48
  # variant="fp16",
49
  # use_safetensors=True,
50
  # ).to(device)
 
 
51
 
52
- # Load SDXL-compatible ControlNet (scribble version)
53
- controlnet = ControlNetModel.from_pretrained(
54
- "diffusers/controlnet-scribble-sdxl-1.0", torch_dtype=torch.float16
55
  )
56
 
57
- # Load SDXL base pipeline with the ControlNet
58
- pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
 
 
 
 
 
 
 
 
 
 
59
  "stabilityai/stable-diffusion-xl-base-1.0",
60
- controlnet=controlnet,
 
61
  torch_dtype=torch.float16,
62
- variant="fp16",
63
- use_safetensors=True,
64
- ).to(device)
 
 
65
 
66
- # canny = CannyDetector()
67
- scribble_detector = ScribbleDetector()
68
 
69
  MAX_SEED = np.iinfo(np.int32).max
70
  MAX_IMAGE_SIZE = 1024
@@ -111,10 +130,17 @@ def infer(
111
  # img = Image.open(io.BytesIO(resp.content)).convert("RGB")
112
  img = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
113
  # img = img.resize((req.width, req.height))
114
- img = img.resize((width, height))
115
 
116
  # control_net_image = canny(img).resize((width, height))
117
- control_net_image = scribble_detector(img).resize((width, height))
 
 
 
 
 
 
 
118
 
119
  prompt = (
120
  "redraw the logo from scratch, clean sharp vector-style, "
@@ -124,8 +150,8 @@ def infer(
124
  output = pipe(
125
  prompt=prompt,
126
  negative_prompt=NEGATIVE,
127
- image=img,
128
- control_image=control_net_image,
129
  # strength=req.strength,
130
  guidance_scale=guidance_scale,
131
  num_inference_steps=num_inference_steps,
@@ -153,6 +179,26 @@ NEGATIVE = "blurry, distorted, messy, gradients, background noise"
153
  WIDTH = 512
154
  HEIGHT = 512
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  with gr.Blocks(css=css) as demo:
157
  with gr.Column(elem_id="col-container"):
158
  gr.Markdown(" # Text-to-Image Gradio Template")
 
7
  from diffusers import (
8
  # StableDiffusionControlNetImg2ImgPipeline,
9
  ControlNetModel,
10
+ ControlNetUnionModel,
11
  StableDiffusionXLControlNetPipeline,
12
  )
13
  import torch
 
15
  import requests
16
  from fastapi import FastAPI, HTTPException
17
  from PIL import Image
18
+ from controlnet_aux import CannyDetector
19
 
20
+ from diffusers import AutoencoderKL
21
+ from diffusers import (
22
+ EulerAncestralDiscreteScheduler,
23
+ StableDiffusionXLControlNetUnionPipeline,
24
+ )
25
+ import cv2
26
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  # model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
54
  # variant="fp16",
55
  # use_safetensors=True,
56
  # ).to(device)
57
+ # # pipe = pipe.to(device)
58
+ # canny = CannyDetector()
59
 
60
+
61
+ eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
62
+ "stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler"
63
  )
64
 
65
+ # when test with other base model, you need to change the vae also.
66
+ vae = AutoencoderKL.from_pretrained(
67
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
68
+ )
69
+
70
+ controlnet_model = ControlNetUnionModel.from_pretrained(
71
+ "xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True
72
+ )
73
+
74
+ # controlnet_union_model = ControlNetUnionModel([controlnet_model])
75
+
76
+ pipe = StableDiffusionXLControlNetUnionPipeline.from_pretrained(
77
  "stabilityai/stable-diffusion-xl-base-1.0",
78
+ controlnet=controlnet_model,
79
+ vae=vae,
80
  torch_dtype=torch.float16,
81
+ scheduler=eulera_scheduler,
82
+ control_mode=[0],
83
+ )
84
+
85
+ pipe = pipe.to(device)
86
 
 
 
87
 
88
  MAX_SEED = np.iinfo(np.int32).max
89
  MAX_IMAGE_SIZE = 1024
 
130
  # img = Image.open(io.BytesIO(resp.content)).convert("RGB")
131
  img = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
132
  # img = img.resize((req.width, req.height))
133
+ # img = img.resize((width, height))
134
 
135
  # control_net_image = canny(img).resize((width, height))
136
+
137
+ img_np = np.array(img)
138
+
139
+ controlnet_img = cv2.resize(img_np, (width, height))
140
+
141
+ controlnet_img = cv2.Canny(controlnet_img, 100, 200)
142
+ controlnet_img = HWC3(controlnet_img)
143
+ controlnet_img = Image.fromarray(controlnet_img)
144
 
145
  prompt = (
146
  "redraw the logo from scratch, clean sharp vector-style, "
 
150
  output = pipe(
151
  prompt=prompt,
152
  negative_prompt=NEGATIVE,
153
+ # image=img,
154
+ control_image=controlnet_img,
155
  # strength=req.strength,
156
  guidance_scale=guidance_scale,
157
  num_inference_steps=num_inference_steps,
 
179
  WIDTH = 512
180
  HEIGHT = 512
181
 
182
+
183
+ def HWC3(x):
184
+ assert x.dtype == np.uint8
185
+ if x.ndim == 2:
186
+ x = x[:, :, None]
187
+ assert x.ndim == 3
188
+ H, W, C = x.shape
189
+ assert C == 1 or C == 3 or C == 4
190
+ if C == 3:
191
+ return x
192
+ if C == 1:
193
+ return np.concatenate([x, x, x], axis=2)
194
+ if C == 4:
195
+ color = x[:, :, 0:3].astype(np.float32)
196
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
197
+ y = color * alpha + 255.0 * (1.0 - alpha)
198
+ y = y.clip(0, 255).astype(np.uint8)
199
+ return y
200
+
201
+
202
  with gr.Blocks(css=css) as demo:
203
  with gr.Column(elem_id="col-container"):
204
  gr.Markdown(" # Text-to-Image Gradio Template")