[Minor] Use The generator function to generate a list
Browse files
app.py
CHANGED
|
@@ -273,7 +273,6 @@ def generate(
|
|
| 273 |
m_img.astype('float') / 2.0 * red).astype('uint8'))
|
| 274 |
|
| 275 |
|
| 276 |
-
|
| 277 |
mask_video_path = "mask.mp4"
|
| 278 |
fps = 30
|
| 279 |
with imageio.get_writer(mask_video_path, fps=fps) as video:
|
|
@@ -282,7 +281,45 @@ def generate(
|
|
| 282 |
|
| 283 |
return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask]
|
| 284 |
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
def generate_list(
|
| 287 |
input_image: Image.Image,
|
| 288 |
generate_list: str,
|
|
@@ -322,9 +359,11 @@ def generate_list(
|
|
| 322 |
while generate_index < len(generate_list):
|
| 323 |
print(f'generate_index: {str(generate_index)}')
|
| 324 |
instruction = generate_list[generate_index]
|
|
|
|
|
|
|
| 325 |
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
| 326 |
cond = {}
|
| 327 |
-
input_image_torch = 2 * torch.tensor(np.array(input_image_copy
|
| 328 |
input_image_torch = rearrange(input_image_torch, "h w c -> 1 c h w").to(model.device)
|
| 329 |
cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)]
|
| 330 |
cond["c_concat"] = [model.encode_first_stage(input_image_torch).mode().to(model.device)]
|
|
@@ -351,8 +390,10 @@ def generate_list(
|
|
| 351 |
|
| 352 |
x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False)
|
| 353 |
x_1 = torch.where(x_1 > 0, 1, -1) # Thresholding step
|
|
|
|
|
|
|
| 354 |
|
| 355 |
-
if
|
| 356 |
seed += 1
|
| 357 |
retry_number +=1
|
| 358 |
if retry_number > max_retry:
|
|
@@ -384,20 +425,22 @@ def generate_list(
|
|
| 384 |
|
| 385 |
image_video.append((mix_image_np * 255).astype(np.uint8))
|
| 386 |
mix_image = Image.fromarray((mix_image_np * 255).astype(np.uint8)).convert('RGB')
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
-
|
| 394 |
-
fps = 2
|
| 395 |
-
with imageio.get_writer(image_video_path, fps=fps) as video:
|
| 396 |
-
for image in image_video:
|
| 397 |
-
video.append_data(image)
|
| 398 |
|
| 399 |
-
|
| 400 |
-
return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image, mix_result_with_red_mask]
|
| 401 |
|
| 402 |
|
| 403 |
def reset():
|
|
@@ -553,4 +596,5 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
| 553 |
# demo.launch(share=True)
|
| 554 |
|
| 555 |
|
|
|
|
| 556 |
demo.queue().launch()
|
|
|
|
| 273 |
m_img.astype('float') / 2.0 * red).astype('uint8'))
|
| 274 |
|
| 275 |
|
|
|
|
| 276 |
mask_video_path = "mask.mp4"
|
| 277 |
fps = 30
|
| 278 |
with imageio.get_writer(mask_video_path, fps=fps) as video:
|
|
|
|
| 281 |
|
| 282 |
return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask]
|
| 283 |
|
| 284 |
+
|
| 285 |
+
def single_generation(model_wrap_cfg, input_image_copy, instruction, steps, seed, text_cfg_scale, image_cfg_scale, height, width):
|
| 286 |
+
model.cuda()
|
| 287 |
+
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
| 288 |
+
cond = {}
|
| 289 |
+
input_image_torch = 2 * torch.tensor(np.array(input_image_copy.to(model.device))).float() / 255 - 1
|
| 290 |
+
input_image_torch = rearrange(input_image_torch, "h w c -> 1 c h w").to(model.device)
|
| 291 |
+
cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)]
|
| 292 |
+
cond["c_concat"] = [model.encode_first_stage(input_image_torch).mode().to(model.device)]
|
| 293 |
+
|
| 294 |
+
uncond = {}
|
| 295 |
+
uncond["c_crossattn"] = [null_token.to(model.device)]
|
| 296 |
+
uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
|
| 297 |
+
|
| 298 |
+
sigmas = model_wrap.get_sigmas(steps).to(model.device)
|
| 299 |
+
|
| 300 |
+
extra_args = {
|
| 301 |
+
"cond": cond,
|
| 302 |
+
"uncond": uncond,
|
| 303 |
+
"text_cfg_scale": text_cfg_scale,
|
| 304 |
+
"image_cfg_scale": image_cfg_scale,
|
| 305 |
+
}
|
| 306 |
+
torch.manual_seed(seed)
|
| 307 |
+
z_0 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0]
|
| 308 |
+
z_1 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0]
|
| 309 |
+
|
| 310 |
+
z_0, z_1, _, _ = sample_euler_ancestral(model_wrap_cfg, z_0, z_1, sigmas, height, width, extra_args=extra_args)
|
| 311 |
+
|
| 312 |
+
x_0 = model.decode_first_stage(z_0)
|
| 313 |
+
|
| 314 |
+
x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False)
|
| 315 |
+
x_1 = torch.where(x_1 > 0, 1, -1) # Thresholding step
|
| 316 |
+
|
| 317 |
+
x_1_mean = torch.sum(x_1).item()/x_1.numel()
|
| 318 |
+
|
| 319 |
+
return x_0, x_1, x_1_mean
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@spaces.GPU(duration=150)
|
| 323 |
def generate_list(
|
| 324 |
input_image: Image.Image,
|
| 325 |
generate_list: str,
|
|
|
|
| 359 |
while generate_index < len(generate_list):
|
| 360 |
print(f'generate_index: {str(generate_index)}')
|
| 361 |
instruction = generate_list[generate_index]
|
| 362 |
+
|
| 363 |
+
# x_0, x_1, x_1_mean = single_generation(model_wrap_cfg, input_image_copy, instruction, steps, seed, text_cfg_scale, image_cfg_scale, height, width)
|
| 364 |
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
| 365 |
cond = {}
|
| 366 |
+
input_image_torch = 2 * torch.tensor(np.array(input_image_copy)).float() / 255 - 1
|
| 367 |
input_image_torch = rearrange(input_image_torch, "h w c -> 1 c h w").to(model.device)
|
| 368 |
cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)]
|
| 369 |
cond["c_concat"] = [model.encode_first_stage(input_image_torch).mode().to(model.device)]
|
|
|
|
| 390 |
|
| 391 |
x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False)
|
| 392 |
x_1 = torch.where(x_1 > 0, 1, -1) # Thresholding step
|
| 393 |
+
|
| 394 |
+
x_1_mean = torch.sum(x_1).item()/x_1.numel()
|
| 395 |
|
| 396 |
+
if x_1_mean < -0.99:
|
| 397 |
seed += 1
|
| 398 |
retry_number +=1
|
| 399 |
if retry_number > max_retry:
|
|
|
|
| 425 |
|
| 426 |
image_video.append((mix_image_np * 255).astype(np.uint8))
|
| 427 |
mix_image = Image.fromarray((mix_image_np * 255).astype(np.uint8)).convert('RGB')
|
| 428 |
+
|
| 429 |
+
mix_result_with_red_mask = None
|
| 430 |
+
mask_video_path = None
|
| 431 |
+
image_video_path = None
|
| 432 |
+
edited_mask_copy = None
|
| 433 |
+
|
| 434 |
+
if generate_index == len(generate_list):
|
| 435 |
+
image_video_path = "image.mp4"
|
| 436 |
+
fps = 2
|
| 437 |
+
with imageio.get_writer(image_video_path, fps=fps) as video:
|
| 438 |
+
for image in image_video:
|
| 439 |
+
video.append_data(image)
|
| 440 |
|
| 441 |
+
yield [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image, mix_result_with_red_mask]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
+
input_image_copy = mix_image
|
|
|
|
| 444 |
|
| 445 |
|
| 446 |
def reset():
|
|
|
|
| 596 |
# demo.launch(share=True)
|
| 597 |
|
| 598 |
|
| 599 |
+
# demo.queue().launch(enable_queue=True)
|
| 600 |
demo.queue().launch()
|