Spaces:
Runtime error
Runtime error
controlnet now kicks out models to save memory
Browse files- lib/model_zoo/controlnet.py +27 -65
lib/model_zoo/controlnet.py
CHANGED
|
@@ -14,54 +14,6 @@ from .openaimodel import \
|
|
| 14 |
ResBlock, AttentionBlock, SpatialTransformer, \
|
| 15 |
Downsample, timestep_embedding
|
| 16 |
|
| 17 |
-
####################
|
| 18 |
-
# preprocess depth #
|
| 19 |
-
####################
|
| 20 |
-
|
| 21 |
-
# depth_model = None
|
| 22 |
-
|
| 23 |
-
# def unload_midas_model():
|
| 24 |
-
# global depth_model
|
| 25 |
-
# if depth_model is not None:
|
| 26 |
-
# depth_model = depth_model.cpu()
|
| 27 |
-
|
| 28 |
-
# def apply_midas(input_image, a=np.pi*2.0, bg_th=0.1, device='cpu'):
|
| 29 |
-
# import cv2
|
| 30 |
-
# from einops import rearrange
|
| 31 |
-
# from .controlnet_annotators.midas import MiDaSInference
|
| 32 |
-
# global depth_model
|
| 33 |
-
# if depth_model is None:
|
| 34 |
-
# depth_model = MiDaSInference(model_type="dpt_hybrid")
|
| 35 |
-
# depth_model = depth_model.to(device)
|
| 36 |
-
|
| 37 |
-
# assert input_image.ndim == 3
|
| 38 |
-
# image_depth = input_image
|
| 39 |
-
# with torch.no_grad():
|
| 40 |
-
# image_depth = torch.from_numpy(image_depth).float()
|
| 41 |
-
# image_depth = image_depth.to(device)
|
| 42 |
-
# image_depth = image_depth / 127.5 - 1.0
|
| 43 |
-
# image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
| 44 |
-
# depth = depth_model(image_depth)[0]
|
| 45 |
-
|
| 46 |
-
# depth_pt = depth.clone()
|
| 47 |
-
# depth_pt -= torch.min(depth_pt)
|
| 48 |
-
# depth_pt /= torch.max(depth_pt)
|
| 49 |
-
# depth_pt = depth_pt.cpu().numpy()
|
| 50 |
-
# depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
|
| 51 |
-
|
| 52 |
-
# depth_np = depth.cpu().numpy()
|
| 53 |
-
# x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
|
| 54 |
-
# y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
|
| 55 |
-
# z = np.ones_like(x) * a
|
| 56 |
-
# x[depth_pt < bg_th] = 0
|
| 57 |
-
# y[depth_pt < bg_th] = 0
|
| 58 |
-
# normal = np.stack([x, y, z], axis=2)
|
| 59 |
-
# normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
|
| 60 |
-
# normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
| 61 |
-
|
| 62 |
-
# return depth_image, normal_image
|
| 63 |
-
|
| 64 |
-
|
| 65 |
@register('controlnet')
|
| 66 |
class ControlNet(nn.Module):
|
| 67 |
def __init__(
|
|
@@ -360,37 +312,41 @@ class ControlNet(nn.Module):
|
|
| 360 |
return y_torch
|
| 361 |
|
| 362 |
elif type == 'depth':
|
| 363 |
-
from .controlnet_annotator.midas import apply_midas
|
| 364 |
y_list, _ = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, device=device) for xi in x_list])
|
| 365 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
| 366 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
| 367 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
|
| 368 |
return y_torch
|
| 369 |
|
| 370 |
elif type in ['hed', 'softedge_v11p']:
|
| 371 |
-
from .controlnet_annotator.hed import apply_hed
|
| 372 |
y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
|
| 373 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
| 374 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
| 375 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
|
|
|
|
| 376 |
return y_torch
|
| 377 |
|
| 378 |
elif type in ['mlsd', 'mlsd_v11p']:
|
| 379 |
thr_v = kwargs.pop('thr_v', 0.1)
|
| 380 |
thr_d = kwargs.pop('thr_d', 0.1)
|
| 381 |
-
from .controlnet_annotator.mlsd import apply_mlsd
|
| 382 |
y_list = [apply_mlsd(np.array(xi), thr_v=thr_v, thr_d=thr_d, device=device) for xi in x_list]
|
| 383 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
| 384 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
| 385 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
|
| 386 |
return y_torch
|
| 387 |
|
| 388 |
elif type == 'normal':
|
| 389 |
bg_th = kwargs.pop('bg_th', 0.4)
|
| 390 |
-
from .controlnet_annotator.midas import apply_midas
|
| 391 |
_, y_list = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, bg_th=bg_th, device=device) for xi in x_list])
|
| 392 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
| 393 |
-
|
| 394 |
return y_torch
|
| 395 |
|
| 396 |
elif type in ['openpose', 'openpose_v11p']:
|
|
@@ -403,6 +359,7 @@ class ControlNet(nn.Module):
|
|
| 403 |
y_list = [apply_openpose(np.array(xi)) for xi in x_list]
|
| 404 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
| 405 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
|
| 406 |
return y_torch
|
| 407 |
|
| 408 |
elif type in ['openpose_withface', 'openpose_withface_v11p']:
|
|
@@ -415,6 +372,7 @@ class ControlNet(nn.Module):
|
|
| 415 |
y_list = [apply_openpose(np.array(xi)) for xi in x_list]
|
| 416 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
| 417 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
|
| 418 |
return y_torch
|
| 419 |
|
| 420 |
elif type in ['openpose_withfacehand', 'openpose_withfacehand_v11p']:
|
|
@@ -427,6 +385,7 @@ class ControlNet(nn.Module):
|
|
| 427 |
y_list = [apply_openpose(np.array(xi)) for xi in x_list]
|
| 428 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
| 429 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
|
| 430 |
return y_torch
|
| 431 |
|
| 432 |
elif type == 'scribble':
|
|
@@ -454,21 +413,23 @@ class ControlNet(nn.Module):
|
|
| 454 |
return result
|
| 455 |
|
| 456 |
if method == 'hed':
|
| 457 |
-
from .controlnet_annotator.hed import apply_hed
|
| 458 |
y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
|
| 459 |
y_list = [make_scribble(yi) for yi in y_list]
|
| 460 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
| 461 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
| 462 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
|
| 463 |
return y_torch
|
| 464 |
|
| 465 |
elif method == 'pidinet':
|
| 466 |
-
from .controlnet_annotator.pidinet import apply_pidinet
|
| 467 |
y_list = [apply_pidinet(np.array(xi), device=device) for xi in x_list]
|
| 468 |
y_list = [make_scribble(yi) for yi in y_list]
|
| 469 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
| 470 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
| 471 |
y_torch = y_torch.to(device).to(torch.float32)
|
|
|
|
| 472 |
return y_torch
|
| 473 |
|
| 474 |
elif method == 'xdog':
|
|
@@ -491,13 +452,14 @@ class ControlNet(nn.Module):
|
|
| 491 |
raise ValueError
|
| 492 |
|
| 493 |
elif type == 'seg':
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
|
|
|
|
|
| 14 |
ResBlock, AttentionBlock, SpatialTransformer, \
|
| 15 |
Downsample, timestep_embedding
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
@register('controlnet')
|
| 18 |
class ControlNet(nn.Module):
|
| 19 |
def __init__(
|
|
|
|
| 312 |
return y_torch
|
| 313 |
|
| 314 |
elif type == 'depth':
|
| 315 |
+
from .controlnet_annotator.midas import apply_midas, unload_midas_model
|
| 316 |
y_list, _ = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, device=device) for xi in x_list])
|
| 317 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
| 318 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
| 319 |
y_torch = y_torch.to(device).to(torch.float32)
|
| 320 |
+
unload_midas_model()
|
| 321 |
return y_torch
|
| 322 |
|
| 323 |
elif type in ['hed', 'softedge_v11p']:
|
| 324 |
+
from .controlnet_annotator.hed import apply_hed, unload_hed_model
|
| 325 |
y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
|
| 326 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
| 327 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
| 328 |
y_torch = y_torch.to(device).to(torch.float32)
|
| 329 |
+
from .controlnet_annotator.midas import model as model_midas
|
| 330 |
+
unload_hed_model()
|
| 331 |
return y_torch
|
| 332 |
|
| 333 |
elif type in ['mlsd', 'mlsd_v11p']:
|
| 334 |
thr_v = kwargs.pop('thr_v', 0.1)
|
| 335 |
thr_d = kwargs.pop('thr_d', 0.1)
|
| 336 |
+
from .controlnet_annotator.mlsd import apply_mlsd, unload_mlsd_model
|
| 337 |
y_list = [apply_mlsd(np.array(xi), thr_v=thr_v, thr_d=thr_d, device=device) for xi in x_list]
|
| 338 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
| 339 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
| 340 |
y_torch = y_torch.to(device).to(torch.float32)
|
| 341 |
+
unload_mlsd_model()
|
| 342 |
return y_torch
|
| 343 |
|
| 344 |
elif type == 'normal':
|
| 345 |
bg_th = kwargs.pop('bg_th', 0.4)
|
| 346 |
+
from .controlnet_annotator.midas import apply_midas, unload_midas_model
|
| 347 |
_, y_list = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, bg_th=bg_th, device=device) for xi in x_list])
|
| 348 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
| 349 |
+
unload_midas_model()
|
| 350 |
return y_torch
|
| 351 |
|
| 352 |
elif type in ['openpose', 'openpose_v11p']:
|
|
|
|
| 359 |
y_list = [apply_openpose(np.array(xi)) for xi in x_list]
|
| 360 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
| 361 |
y_torch = y_torch.to(device).to(torch.float32)
|
| 362 |
+
OpenposeModel.unload()
|
| 363 |
return y_torch
|
| 364 |
|
| 365 |
elif type in ['openpose_withface', 'openpose_withface_v11p']:
|
|
|
|
| 372 |
y_list = [apply_openpose(np.array(xi)) for xi in x_list]
|
| 373 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
| 374 |
y_torch = y_torch.to(device).to(torch.float32)
|
| 375 |
+
OpenposeModel.unload()
|
| 376 |
return y_torch
|
| 377 |
|
| 378 |
elif type in ['openpose_withfacehand', 'openpose_withfacehand_v11p']:
|
|
|
|
| 385 |
y_list = [apply_openpose(np.array(xi)) for xi in x_list]
|
| 386 |
y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
|
| 387 |
y_torch = y_torch.to(device).to(torch.float32)
|
| 388 |
+
OpenposeModel.unload()
|
| 389 |
return y_torch
|
| 390 |
|
| 391 |
elif type == 'scribble':
|
|
|
|
| 413 |
return result
|
| 414 |
|
| 415 |
if method == 'hed':
|
| 416 |
+
from .controlnet_annotator.hed import apply_hed, unload_hed_model
|
| 417 |
y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
|
| 418 |
y_list = [make_scribble(yi) for yi in y_list]
|
| 419 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
| 420 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
| 421 |
y_torch = y_torch.to(device).to(torch.float32)
|
| 422 |
+
unload_hed_model()
|
| 423 |
return y_torch
|
| 424 |
|
| 425 |
elif method == 'pidinet':
|
| 426 |
+
from .controlnet_annotator.pidinet import apply_pidinet, unload_pid_model
|
| 427 |
y_list = [apply_pidinet(np.array(xi), device=device) for xi in x_list]
|
| 428 |
y_list = [make_scribble(yi) for yi in y_list]
|
| 429 |
y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
| 430 |
y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
|
| 431 |
y_torch = y_torch.to(device).to(torch.float32)
|
| 432 |
+
unload_pid_model()
|
| 433 |
return y_torch
|
| 434 |
|
| 435 |
elif method == 'xdog':
|
|
|
|
| 452 |
raise ValueError
|
| 453 |
|
| 454 |
elif type == 'seg':
|
| 455 |
+
assert False, "This part is broken"
|
| 456 |
+
# method = kwargs.pop('method', 'ufade20k')
|
| 457 |
+
# if method == 'ufade20k':
|
| 458 |
+
# from .controlnet_annotator.uniformer import apply_uniformer
|
| 459 |
+
# y_list = [apply_uniformer(np.array(xi), palette='ade20k', device=device) for xi in x_list]
|
| 460 |
+
# y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
|
| 461 |
+
# y_torch = y_torch.to(device).to(torch.float32)
|
| 462 |
+
# return y_torch
|
| 463 |
+
|
| 464 |
+
# else:
|
| 465 |
+
# raise ValueError
|