yeq6x commited on
Commit
bb35a51
·
1 Parent(s): 2867bc8
Files changed (1) hide show
  1. app.py +190 -54
app.py CHANGED
@@ -1,14 +1,13 @@
1
  import spaces
2
- from diffusers import ControlNetModel
3
- from diffusers import StableDiffusionXLControlNetPipeline
4
- from diffusers import EulerAncestralDiscreteScheduler
5
- from PIL import Image
6
- import torch
7
  import numpy as np
8
- import cv2
9
  import gradio as gr
10
- from torchvision import transforms
11
- from controlnet_aux import OpenposeDetector
 
 
 
 
12
 
13
  ratios_map = {
14
  0.5:{"width":704,"height":1408},
@@ -31,9 +30,6 @@ ratios_map = {
31
  }
32
  ratios = np.array(list(ratios_map.keys()))
33
 
34
-
35
- openpose = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
36
-
37
  controlnet = ControlNetModel.from_pretrained(
38
  "yeq6x/Image2PositionColor_v3",
39
  torch_dtype=torch.float16
@@ -54,8 +50,6 @@ pipe.scheduler = EulerAncestralDiscreteScheduler(
54
  num_train_timesteps=1000,
55
  steps_offset=1
56
  )
57
- # pipe.enable_freeu(b1=1.1, b2=1.1, s1=0.5, s2=0.7)
58
- # pipe.enable_xformers_memory_efficient_attention()
59
  pipe.force_zeros_for_empty_prompt = False
60
 
61
  def get_size(init_image):
@@ -72,17 +66,6 @@ def resize_image(image):
72
  w,h = get_size(image)
73
  resized_image = image.resize((w, h))
74
  return resized_image
75
-
76
- def resize_image_old(image):
77
- image = image.convert('RGB')
78
- current_size = image.size
79
- if current_size[0] > current_size[1]:
80
- center_cropped_image = transforms.functional.center_crop(image, (current_size[1], current_size[1]))
81
- else:
82
- center_cropped_image = transforms.functional.center_crop(image, (current_size[0], current_size[0]))
83
- resized_image = transforms.functional.resize(center_cropped_image, (1024, 1024))
84
- return resized_image
85
-
86
 
87
  @spaces.GPU
88
  def generate_(prompt, negative_prompt, pose_image, input_image, num_steps, controlnet_conditioning_scale, seed):
@@ -99,40 +82,193 @@ def process(input_image, prompt, negative_prompt, num_steps, controlnet_conditio
99
  # resize input_image to 1024x1024
100
  input_image = resize_image(input_image)
101
 
102
- pose_image = openpose(input_image, include_body=True, include_hand=True, include_face=True)
103
-
104
  images = generate_(prompt, negative_prompt, pose_image, input_image, num_steps, controlnet_conditioning_scale, seed)
105
 
106
  return [pose_image,images[0]]
107
-
108
- block = gr.Blocks().queue()
109
-
110
- with block:
111
- gr.Markdown("## BRIA 2.3 ControlNet Pose")
112
- gr.HTML('''
113
- <p style="margin-bottom: 10px; font-size: 94%">
114
- This is a demo for ControlNet Pose that using
115
- <a href="https://huggingface.co/briaai/BRIA-2.3" target="_blank">BRIA 2.3 text-to-image model</a> as backbone.
116
- Trained on licensed data, BRIA 2.3 provide full legal liability coverage for copyright and privacy infringement.
117
- </p>
118
- ''')
119
- with gr.Row():
120
- with gr.Column():
121
- input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
122
- prompt = gr.Textbox(label="Prompt")
123
- negative_prompt = gr.Textbox(label="Negative prompt", value="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers")
124
- num_steps = gr.Slider(label="Number of steps", minimum=25, maximum=100, value=50, step=1)
125
- controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=2.0, value=1.0, step=0.05)
126
- seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True,)
127
- run_button = gr.Button(value="Run")
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- with gr.Column():
130
- with gr.Row():
131
- pose_image_output = gr.Image(label="Pose Image", type="pil", interactive=False)
132
- generated_image_output = gr.Image(label="Generated Image", type="pil", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- ips = [input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed]
135
- run_button.click(fn=process, inputs=ips, outputs=[pose_image_output, generated_image_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
 
 
 
137
 
138
- block.launch(debug = True)
 
1
  import spaces
 
 
 
 
 
2
  import numpy as np
3
+ from PIL import Image
4
  import gradio as gr
5
+ import open3d as o3d
6
+ import trimesh
7
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, EulerAncestralDiscreteScheduler
8
+ import torch
9
+ from collections import Counter
10
+ import random
11
 
12
  ratios_map = {
13
  0.5:{"width":704,"height":1408},
 
30
  }
31
  ratios = np.array(list(ratios_map.keys()))
32
 
 
 
 
33
  controlnet = ControlNetModel.from_pretrained(
34
  "yeq6x/Image2PositionColor_v3",
35
  torch_dtype=torch.float16
 
50
  num_train_timesteps=1000,
51
  steps_offset=1
52
  )
 
 
53
  pipe.force_zeros_for_empty_prompt = False
54
 
55
  def get_size(init_image):
 
66
  w,h = get_size(image)
67
  resized_image = image.resize((w, h))
68
  return resized_image
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  @spaces.GPU
71
  def generate_(prompt, negative_prompt, pose_image, input_image, num_steps, controlnet_conditioning_scale, seed):
 
82
  # resize input_image to 1024x1024
83
  input_image = resize_image(input_image)
84
 
 
 
85
  images = generate_(prompt, negative_prompt, pose_image, input_image, num_steps, controlnet_conditioning_scale, seed)
86
 
87
  return [pose_image,images[0]]
88
+
89
+ @spaces.GPU
90
+ def predict_image(cond_image, prompt, negative_prompt, controlnet_conditioning_scale):
91
+ print("predict position map")
92
+ global pipe
93
+ generator = torch.Generator()
94
+ generator.manual_seed(random.randint(0, 2147483647))
95
+ image = pipe(
96
+ prompt,
97
+ negative_prompt=negative_prompt,
98
+ image = cond_image,
99
+ width=1024,
100
+ height=1024,
101
+ guidance_scale=8,
102
+ num_inference_steps=20,
103
+ generator=generator,
104
+ guess_mode = True,
105
+ controlnet_conditioning_scale = controlnet_conditioning_scale
106
+ ).images[0]
107
+
108
+ return image
109
+
110
+ # block = gr.Blocks().queue()
111
+
112
+ # with block:
113
+ # with gr.Row():
114
+ # with gr.Column():
115
+ # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
116
+ # prompt = gr.Textbox(label="Prompt")
117
+ # negative_prompt = gr.Textbox(label="Negative prompt", value="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers")
118
+ # num_steps = gr.Slider(label="Number of steps", minimum=25, maximum=100, value=50, step=1)
119
+ # controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=2.0, value=1.0, step=0.05)
120
+ # seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True,)
121
+ # run_button = gr.Button(value="Run")
122
 
123
+ # with gr.Column():
124
+ # with gr.Row():
125
+ # pose_image_output = gr.Image(label="Pose Image", type="pil", interactive=False)
126
+ # generated_image_output = gr.Image(label="Generated Image", type="pil", interactive=False)
127
+
128
+ # ips = [input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed]
129
+ # run_button.click(fn=process, inputs=ips, outputs=[pose_image_output, generated_image_output])
130
+
131
+
132
+ # block.launch(debug = True)
133
+
134
+ def convert_pil_to_opencv(pil_image):
135
+ return np.array(pil_image)
136
+
137
+ def inv_func(y,
138
+ c = -712.380100,
139
+ a = 137.375240,
140
+ b = 192.435866):
141
+ return (np.exp((y - c) / a) - np.exp(-c/a)) / 964.8468371292845
142
+
143
+ def create_point_cloud(img1, img2):
144
+ if img1.shape != img2.shape:
145
+ raise ValueError("Both images must have the same dimensions.")
146
+
147
+ h, w, _ = img1.shape
148
+ points = []
149
+ colors = []
150
+ for y in range(h):
151
+ for x in range(w):
152
+ # ピクセル位置 (x, y) のRGBをXYZとして取得
153
+ r, g, b = img1[y, x]
154
+ r = inv_func(r) * 0.9
155
+ g = inv_func(g) / 1.7 * 0.6
156
+ b = inv_func(b)
157
+ r *= 150
158
+ g *= 150
159
+ b *= 150
160
+ points.append([g, b, r]) # X, Y, Z
161
+ # 対応するピクセル位置の画像2の色を取得
162
+ colors.append(img2[y, x] / 255.0) # 色は0〜1にスケール
163
 
164
+ return np.array(points), np.array(colors)
165
+
166
+ def point_cloud_to_glb(points, colors):
167
+ # Open3Dでポイントクラウドを作成
168
+ pc = o3d.geometry.PointCloud()
169
+ pc.points = o3d.utility.Vector3dVector(points)
170
+ pc.colors = o3d.utility.Vector3dVector(colors)
171
+
172
+ # 一時的にPLY形式で保存
173
+ temp_ply_file = "temp_output.ply"
174
+ o3d.io.write_point_cloud(temp_ply_file, pc)
175
+
176
+ # PLYをGLBに変換
177
+ mesh = trimesh.load(temp_ply_file)
178
+ glb_file = "output.glb"
179
+ mesh.export(glb_file)
180
+
181
+ return glb_file
182
+
183
+ def visualize_3d(image1, image2):
184
+ print("Processing...")
185
+ # PIL画像をOpenCV形式に変換
186
+ img1 = convert_pil_to_opencv(image1)
187
+ img2 = convert_pil_to_opencv(image2)
188
+
189
+ # ポイントクラウド生成
190
+ points, colors = create_point_cloud(img1, img2)
191
+
192
+ # GLB形式に変換
193
+ glb_file = point_cloud_to_glb(points, colors)
194
+
195
+ return glb_file
196
+
197
+ def scale_image(original_image):
198
+ aspect_ratio = original_image.width / original_image.height
199
+
200
+ if original_image.width > original_image.height:
201
+ new_width = 1024
202
+ new_height = round(new_width / aspect_ratio)
203
+ else:
204
+ new_height = 1024
205
+ new_width = round(new_height * aspect_ratio)
206
+
207
+ resized_original = original_image.resize((new_width, new_height), Image.LANCZOS)
208
+
209
+ return resized_original
210
+
211
+ def get_edge_mode_color(img, edge_width=10):
212
+ # 外周の10ピクセル領域を取得
213
+ left = img.crop((0, 0, edge_width, img.height)) # 左端
214
+ right = img.crop((img.width - edge_width, 0, img.width, img.height)) # 右端
215
+ top = img.crop((0, 0, img.width, edge_width)) # 上端
216
+ bottom = img.crop((0, img.height - edge_width, img.width, img.height)) # 下端
217
+
218
+ # 各領域のピクセルデータを取得して結合
219
+ colors = list(left.getdata()) + list(right.getdata()) + list(top.getdata()) + list(bottom.getdata())
220
+
221
+ # 最頻値(mode)を計算
222
+ mode_color = Counter(colors).most_common(1)[0][0] # 最も頻繁に出現する色を取得
223
+
224
+ return mode_color
225
+
226
+ def paste_image(resized_img):
227
+ # 外周10pxの最頻値を背景色に設定
228
+ mode_color = get_edge_mode_color(resized_img, edge_width=10)
229
+ mode_background = Image.new("RGBA", (1024, 1024), mode_color)
230
+ mode_background = mode_background.convert('RGB')
231
+
232
+ x = (1024 - resized_img.width) // 2
233
+ y = (1024 - resized_img.height) // 2
234
+ mode_background.paste(resized_img, (x, y))
235
+
236
+ return mode_background
237
+
238
+ def outpaint_image(image):
239
+ if type(image) == type(None):
240
+ return None
241
+ resized_img = scale_image(image)
242
+ image = paste_image(resized_img)
243
+
244
+ return image
245
+
246
+ # Gradioアプリケーション
247
+ with gr.Blocks() as demo:
248
+ gr.Markdown("## Position Map Visualizer")
249
+
250
+ with gr.Row():
251
+ with gr.Column():
252
+ with gr.Row():
253
+ img1 = gr.Image(type="pil", label="color Image", height=300)
254
+ img2 = gr.Image(type="pil", label="map Image", height=300)
255
+ prompt = gr.Textbox("position map, 1girl, white background", label="Prompt")
256
+ negative_prompt = gr.Textbox("lowres, bad anatomy, bad hands, bad feet, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry", label="Negative Prompt")
257
+ controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=2.0, value=0.6, step=0.05)
258
+ predict_map_btn = gr.Button("Predict Position Map")
259
+ visualize_3d_btn = gr.Button("Generate 3D Point Cloud")
260
+ with gr.Column():
261
+ reconstruction_output = gr.Model3D(label="3D Viewer", height=600)
262
+ gr.Examples(
263
+ examples=[
264
+ ["resources/source/000006.png", "resources/target/000006.png"],
265
+ ["resources/source/006420.png", "resources/target/006420.png"],
266
+ ],
267
+ inputs=[img1, img2]
268
+ )
269
 
270
+ img1.input(outpaint_image, inputs=img1, outputs=img1)
271
+ predict_map_btn.click(predict_image, inputs=[img1, prompt, negative_prompt, controlnet_conditioning_scale], outputs=img2)
272
+ visualize_3d_btn.click(visualize_3d, inputs=[img2, img1], outputs=reconstruction_output)
273
 
274
+ demo.launch()