Spaces:
Runtime error
Runtime error
convert to CPU
Browse files- annotator/midas/__init__.py +2 -2
- app.py +109 -110
- ckpt/cldm_v15.yaml +2 -0
- requirements.txt +1 -0
- stablevideo/atlas_data.py +1 -1
- stablevideo/atlas_utils.py +1 -1
annotator/midas/__init__.py
CHANGED
|
@@ -8,13 +8,13 @@ from .api import MiDaSInference
|
|
| 8 |
|
| 9 |
class MidasDetector:
|
| 10 |
def __init__(self):
|
| 11 |
-
self.model = MiDaSInference(model_type="dpt_hybrid")
|
| 12 |
|
| 13 |
def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
|
| 14 |
assert input_image.ndim == 3
|
| 15 |
image_depth = input_image
|
| 16 |
with torch.no_grad():
|
| 17 |
-
image_depth = torch.from_numpy(image_depth).float()
|
| 18 |
image_depth = image_depth / 127.5 - 1.0
|
| 19 |
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
| 20 |
depth = self.model(image_depth)[0]
|
|
|
|
| 8 |
|
| 9 |
class MidasDetector:
|
| 10 |
def __init__(self):
|
| 11 |
+
self.model = MiDaSInference(model_type="dpt_hybrid")
|
| 12 |
|
| 13 |
def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
|
| 14 |
assert input_image.ndim == 3
|
| 15 |
image_depth = input_image
|
| 16 |
with torch.no_grad():
|
| 17 |
+
image_depth = torch.from_numpy(image_depth).float()
|
| 18 |
image_depth = image_depth / 127.5 - 1.0
|
| 19 |
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
| 20 |
depth = self.model(image_depth)[0]
|
app.py
CHANGED
|
@@ -48,7 +48,7 @@ class StableVideo:
|
|
| 48 |
):
|
| 49 |
self.apply_canny = CannyDetector()
|
| 50 |
canny_model = create_model(base_cfg).cpu()
|
| 51 |
-
canny_model.load_state_dict(load_state_dict(canny_model_cfg, location='
|
| 52 |
self.canny_ddim_sampler = DDIMSampler(canny_model)
|
| 53 |
self.canny_model = canny_model
|
| 54 |
|
|
@@ -59,7 +59,7 @@ class StableVideo:
|
|
| 59 |
):
|
| 60 |
self.apply_midas = MidasDetector()
|
| 61 |
depth_model = create_model(base_cfg).cpu()
|
| 62 |
-
depth_model.load_state_dict(load_state_dict(depth_model_cfg, location='
|
| 63 |
self.depth_ddim_sampler = DDIMSampler(depth_model)
|
| 64 |
self.depth_model = depth_model
|
| 65 |
|
|
@@ -101,7 +101,7 @@ class StableVideo:
|
|
| 101 |
|
| 102 |
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 103 |
|
| 104 |
-
control = torch.from_numpy(detected_map.copy()).float()
|
| 105 |
control = torch.stack([control for _ in range(1)], dim=0)
|
| 106 |
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
| 107 |
|
|
@@ -128,7 +128,7 @@ class StableVideo:
|
|
| 128 |
|
| 129 |
@torch.no_grad()
|
| 130 |
def edit_background(self, *args, **kwargs):
|
| 131 |
-
self.depth_model = self.depth_model
|
| 132 |
|
| 133 |
input_image = self.b_atlas_origin
|
| 134 |
self.depth_edit(input_image, *args, **kwargs)
|
|
@@ -155,7 +155,7 @@ class StableVideo:
|
|
| 155 |
if_net=False,
|
| 156 |
num_samples=1):
|
| 157 |
|
| 158 |
-
self.canny_model = self.canny_model
|
| 159 |
|
| 160 |
keyframes = [int(x) for x in keyframes.split(",")]
|
| 161 |
if self.data is None:
|
|
@@ -186,7 +186,7 @@ class StableVideo:
|
|
| 186 |
# get canny control
|
| 187 |
detected_map = self.apply_canny(img, low_threshold, high_threshold)
|
| 188 |
detected_map = HWC3(detected_map)
|
| 189 |
-
control = torch.from_numpy(detected_map.copy()).float()
|
| 190 |
control = einops.rearrange(control.unsqueeze(0), 'b h w c -> b c h w').clone()
|
| 191 |
|
| 192 |
cond = {"c_concat": [control], "c_crossattn": c_crossattn}
|
|
@@ -195,7 +195,7 @@ class StableVideo:
|
|
| 195 |
|
| 196 |
# if not the key frame, calculate the mapping from last atlas
|
| 197 |
if i == 0:
|
| 198 |
-
latent = torch.randn((1, 4, H // 8, W // 8))
|
| 199 |
samples, _ = self.canny_ddim_sampler.sample(ddim_steps, num_samples,
|
| 200 |
shape, cond, verbose=False, eta=eta,
|
| 201 |
unconditional_guidance_scale=scale,
|
|
@@ -209,7 +209,7 @@ class StableVideo:
|
|
| 209 |
mapped_img = mapped_img.resize((W, H))
|
| 210 |
mapped_img = np.array(mapped_img).astype(np.float32) / 255.0
|
| 211 |
mapped_img = mapped_img[None].transpose(0, 3, 1, 2)
|
| 212 |
-
mapped_img = torch.from_numpy(mapped_img)
|
| 213 |
mapped_img = 2. * mapped_img - 1.
|
| 214 |
latent = self.canny_model.get_first_stage_encoding(self.canny_model.encode_first_stage(mapped_img))
|
| 215 |
|
|
@@ -232,7 +232,7 @@ class StableVideo:
|
|
| 232 |
result = alpha * result
|
| 233 |
|
| 234 |
# buffer for training
|
| 235 |
-
result_copy = result.clone()
|
| 236 |
result_copy.requires_grad = True
|
| 237 |
result_list.append(result_copy)
|
| 238 |
|
|
@@ -259,7 +259,7 @@ class StableVideo:
|
|
| 259 |
# aggregate net #
|
| 260 |
#####################################
|
| 261 |
lr, n_epoch = 1e-3, 500
|
| 262 |
-
agg_net = AGGNet()
|
| 263 |
loss_fn = nn.L1Loss()
|
| 264 |
optimizer = optim.SGD(agg_net.parameters(), lr=lr, momentum=0.9)
|
| 265 |
for _ in range(n_epoch):
|
|
@@ -291,12 +291,12 @@ class StableVideo:
|
|
| 291 |
def render(self, f_atlas, b_atlas):
|
| 292 |
# foreground
|
| 293 |
if f_atlas == None:
|
| 294 |
-
f_atlas = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0)
|
| 295 |
else:
|
| 296 |
f_atlas, mask = f_atlas["image"], f_atlas["mask"]
|
| 297 |
-
f_atlas_origin = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0)
|
| 298 |
-
f_atlas = transforms.ToTensor()(f_atlas).unsqueeze(0)
|
| 299 |
-
mask = transforms.ToTensor()(mask).unsqueeze(0)
|
| 300 |
if f_atlas.shape != mask.shape:
|
| 301 |
print("Warning: truncating mask to atlas shape {}".format(f_atlas.shape))
|
| 302 |
mask = mask[:f_atlas.shape[0], :f_atlas.shape[1], :f_atlas.shape[2], :f_atlas.shape[3]]
|
|
@@ -326,7 +326,7 @@ class StableVideo:
|
|
| 326 |
if b_atlas == None:
|
| 327 |
b_atlas = self.b_atlas_origin
|
| 328 |
|
| 329 |
-
b_atlas = transforms.ToTensor()(b_atlas).unsqueeze(0)
|
| 330 |
background_edit = F.grid_sample(
|
| 331 |
b_atlas, self.data.scaled_background_uvs, mode="bilinear", align_corners=self.data.config["align_corners"]
|
| 332 |
).clamp(min=0.0, max=1.0)
|
|
@@ -349,99 +349,98 @@ class StableVideo:
|
|
| 349 |
return save_name
|
| 350 |
|
| 351 |
if __name__ == '__main__':
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
with
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
with gr.
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
with gr.
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
output_background_atlas = gr.Image(label="Output Background Atlas", type="pil", interactive=False)
|
| 416 |
-
|
| 417 |
-
# edit param
|
| 418 |
-
f_adv_edit_param = [adv_keyframes,
|
| 419 |
-
adv_atlas_resolution,
|
| 420 |
-
f_prompt,
|
| 421 |
-
adv_a_prompt,
|
| 422 |
-
adv_n_prompt,
|
| 423 |
-
adv_image_resolution,
|
| 424 |
-
adv_low_threshold,
|
| 425 |
-
adv_high_threshold,
|
| 426 |
-
adv_ddim_steps,
|
| 427 |
-
adv_s,
|
| 428 |
-
adv_scale,
|
| 429 |
-
adv_seed,
|
| 430 |
-
adv_eta,
|
| 431 |
-
adv_if_net]
|
| 432 |
-
b_edit_param = [b_prompt,
|
| 433 |
-
b_a_prompt,
|
| 434 |
-
b_n_prompt,
|
| 435 |
-
b_image_resolution,
|
| 436 |
-
b_detect_resolution,
|
| 437 |
-
b_ddim_steps,
|
| 438 |
-
b_scale,
|
| 439 |
-
b_seed,
|
| 440 |
-
b_eta]
|
| 441 |
-
# action
|
| 442 |
-
load_video_button.click(fn=stablevideo.load_video, inputs=video_name, outputs=[original_video, foreground_atlas, background_atlas])
|
| 443 |
-
f_advance_run_button.click(fn=stablevideo.advanced_edit_foreground, inputs=f_adv_edit_param, outputs=[output_foreground_atlas])
|
| 444 |
-
b_run_button.click(fn=stablevideo.edit_background, inputs=b_edit_param, outputs=[output_background_atlas])
|
| 445 |
-
run_button.click(fn=stablevideo.render, inputs=[output_foreground_atlas, output_background_atlas], outputs=[output_video])
|
| 446 |
|
| 447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
):
|
| 49 |
self.apply_canny = CannyDetector()
|
| 50 |
canny_model = create_model(base_cfg).cpu()
|
| 51 |
+
canny_model.load_state_dict(load_state_dict(canny_model_cfg, location='cpu'), strict=False)
|
| 52 |
self.canny_ddim_sampler = DDIMSampler(canny_model)
|
| 53 |
self.canny_model = canny_model
|
| 54 |
|
|
|
|
| 59 |
):
|
| 60 |
self.apply_midas = MidasDetector()
|
| 61 |
depth_model = create_model(base_cfg).cpu()
|
| 62 |
+
depth_model.load_state_dict(load_state_dict(depth_model_cfg, location='cpu'), strict=False)
|
| 63 |
self.depth_ddim_sampler = DDIMSampler(depth_model)
|
| 64 |
self.depth_model = depth_model
|
| 65 |
|
|
|
|
| 101 |
|
| 102 |
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 103 |
|
| 104 |
+
control = torch.from_numpy(detected_map.copy()).float() / 255.0
|
| 105 |
control = torch.stack([control for _ in range(1)], dim=0)
|
| 106 |
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
| 107 |
|
|
|
|
| 128 |
|
| 129 |
@torch.no_grad()
|
| 130 |
def edit_background(self, *args, **kwargs):
|
| 131 |
+
self.depth_model = self.depth_model
|
| 132 |
|
| 133 |
input_image = self.b_atlas_origin
|
| 134 |
self.depth_edit(input_image, *args, **kwargs)
|
|
|
|
| 155 |
if_net=False,
|
| 156 |
num_samples=1):
|
| 157 |
|
| 158 |
+
self.canny_model = self.canny_model
|
| 159 |
|
| 160 |
keyframes = [int(x) for x in keyframes.split(",")]
|
| 161 |
if self.data is None:
|
|
|
|
| 186 |
# get canny control
|
| 187 |
detected_map = self.apply_canny(img, low_threshold, high_threshold)
|
| 188 |
detected_map = HWC3(detected_map)
|
| 189 |
+
control = torch.from_numpy(detected_map.copy()).float() / 255.0
|
| 190 |
control = einops.rearrange(control.unsqueeze(0), 'b h w c -> b c h w').clone()
|
| 191 |
|
| 192 |
cond = {"c_concat": [control], "c_crossattn": c_crossattn}
|
|
|
|
| 195 |
|
| 196 |
# if not the key frame, calculate the mapping from last atlas
|
| 197 |
if i == 0:
|
| 198 |
+
latent = torch.randn((1, 4, H // 8, W // 8))
|
| 199 |
samples, _ = self.canny_ddim_sampler.sample(ddim_steps, num_samples,
|
| 200 |
shape, cond, verbose=False, eta=eta,
|
| 201 |
unconditional_guidance_scale=scale,
|
|
|
|
| 209 |
mapped_img = mapped_img.resize((W, H))
|
| 210 |
mapped_img = np.array(mapped_img).astype(np.float32) / 255.0
|
| 211 |
mapped_img = mapped_img[None].transpose(0, 3, 1, 2)
|
| 212 |
+
mapped_img = torch.from_numpy(mapped_img)
|
| 213 |
mapped_img = 2. * mapped_img - 1.
|
| 214 |
latent = self.canny_model.get_first_stage_encoding(self.canny_model.encode_first_stage(mapped_img))
|
| 215 |
|
|
|
|
| 232 |
result = alpha * result
|
| 233 |
|
| 234 |
# buffer for training
|
| 235 |
+
result_copy = result.clone()
|
| 236 |
result_copy.requires_grad = True
|
| 237 |
result_list.append(result_copy)
|
| 238 |
|
|
|
|
| 259 |
# aggregate net #
|
| 260 |
#####################################
|
| 261 |
lr, n_epoch = 1e-3, 500
|
| 262 |
+
agg_net = AGGNet()
|
| 263 |
loss_fn = nn.L1Loss()
|
| 264 |
optimizer = optim.SGD(agg_net.parameters(), lr=lr, momentum=0.9)
|
| 265 |
for _ in range(n_epoch):
|
|
|
|
| 291 |
def render(self, f_atlas, b_atlas):
|
| 292 |
# foreground
|
| 293 |
if f_atlas == None:
|
| 294 |
+
f_atlas = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0)
|
| 295 |
else:
|
| 296 |
f_atlas, mask = f_atlas["image"], f_atlas["mask"]
|
| 297 |
+
f_atlas_origin = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0)
|
| 298 |
+
f_atlas = transforms.ToTensor()(f_atlas).unsqueeze(0)
|
| 299 |
+
mask = transforms.ToTensor()(mask).unsqueeze(0)
|
| 300 |
if f_atlas.shape != mask.shape:
|
| 301 |
print("Warning: truncating mask to atlas shape {}".format(f_atlas.shape))
|
| 302 |
mask = mask[:f_atlas.shape[0], :f_atlas.shape[1], :f_atlas.shape[2], :f_atlas.shape[3]]
|
|
|
|
| 326 |
if b_atlas == None:
|
| 327 |
b_atlas = self.b_atlas_origin
|
| 328 |
|
| 329 |
+
b_atlas = transforms.ToTensor()(b_atlas).unsqueeze(0)
|
| 330 |
background_edit = F.grid_sample(
|
| 331 |
b_atlas, self.data.scaled_background_uvs, mode="bilinear", align_corners=self.data.config["align_corners"]
|
| 332 |
).clamp(min=0.0, max=1.0)
|
|
|
|
| 349 |
return save_name
|
| 350 |
|
| 351 |
if __name__ == '__main__':
|
| 352 |
+
stablevideo = StableVideo(base_cfg="ckpt/cldm_v15.yaml",
|
| 353 |
+
canny_model_cfg="ckpt/control_sd15_canny.pth",
|
| 354 |
+
depth_model_cfg="ckpt/control_sd15_depth.pth",
|
| 355 |
+
save_memory=True)
|
| 356 |
+
stablevideo.load_canny_model()
|
| 357 |
+
stablevideo.load_depth_model()
|
| 358 |
+
|
| 359 |
+
block = gr.Blocks().queue()
|
| 360 |
+
with block:
|
| 361 |
+
with gr.Row():
|
| 362 |
+
gr.Markdown("## StableVideo")
|
| 363 |
+
with gr.Row():
|
| 364 |
+
with gr.Column():
|
| 365 |
+
original_video = gr.Video(label="Original Video", interactive=False)
|
| 366 |
+
with gr.Row():
|
| 367 |
+
foreground_atlas = gr.Image(label="Foreground Atlas", type="pil")
|
| 368 |
+
background_atlas = gr.Image(label="Background Atlas", type="pil")
|
| 369 |
+
gr.Markdown("### Step 1. select one example video and click **Load Video** buttom and wait for 10 sec.")
|
| 370 |
+
avail_video = [f.name for f in os.scandir("data") if f.is_dir()]
|
| 371 |
+
video_name = gr.Radio(choices=avail_video,
|
| 372 |
+
label="Select Example Videos",
|
| 373 |
+
value="car-turn")
|
| 374 |
+
load_video_button = gr.Button("Load Video")
|
| 375 |
+
gr.Markdown("### Step 2. write text prompt and advanced options for background and foreground.")
|
| 376 |
+
with gr.Row():
|
| 377 |
+
f_prompt = gr.Textbox(label="Foreground Prompt", value="a picture of an orange suv")
|
| 378 |
+
b_prompt = gr.Textbox(label="Background Prompt", value="winter scene, snowy scene, beautiful snow")
|
| 379 |
+
with gr.Row():
|
| 380 |
+
with gr.Accordion("Advanced Foreground Options", open=False):
|
| 381 |
+
adv_keyframes = gr.Textbox(label="keyframe", value="20, 40, 60")
|
| 382 |
+
adv_atlas_resolution = gr.Slider(label="Atlas Resolution", minimum=1000, maximum=3000, value=2000, step=100)
|
| 383 |
+
adv_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
|
| 384 |
+
adv_low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
|
| 385 |
+
adv_high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
|
| 386 |
+
adv_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 387 |
+
adv_s = gr.Slider(label="Noise Scale", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
|
| 388 |
+
adv_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=15.0, value=9.0, step=0.1)
|
| 389 |
+
adv_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 390 |
+
adv_eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 391 |
+
adv_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, no background')
|
| 392 |
+
adv_n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 393 |
+
adv_if_net = gr.gradio.Checkbox(label="if use agg net", value=False)
|
| 394 |
+
|
| 395 |
+
with gr.Accordion("Background Options", open=False):
|
| 396 |
+
b_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
|
| 397 |
+
b_detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=512, step=1)
|
| 398 |
+
b_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 399 |
+
b_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
| 400 |
+
b_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 401 |
+
b_eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 402 |
+
b_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 403 |
+
b_n_prompt = gr.Textbox(label="Negative Prompt",
|
| 404 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 405 |
+
gr.Markdown("### Step 3. edit each one and render.")
|
| 406 |
+
with gr.Row():
|
| 407 |
+
f_advance_run_button = gr.Button("Advanced Edit Foreground (slower, better)")
|
| 408 |
+
b_run_button = gr.Button("Edit Background")
|
| 409 |
+
run_button = gr.Button("Render")
|
| 410 |
+
with gr.Column():
|
| 411 |
+
output_video = gr.Video(label="Output Video", interactive=False)
|
| 412 |
+
# output_foreground_atlas = gr.Image(label="Output Foreground Atlas", type="pil", interactive=False)
|
| 413 |
+
output_foreground_atlas = gr.ImageMask(label="Editable Output Foreground Atlas", type="pil", tool="sketch", interactive=True)
|
| 414 |
+
output_background_atlas = gr.Image(label="Output Background Atlas", type="pil", interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
|
| 416 |
+
# edit param
|
| 417 |
+
f_adv_edit_param = [adv_keyframes,
|
| 418 |
+
adv_atlas_resolution,
|
| 419 |
+
f_prompt,
|
| 420 |
+
adv_a_prompt,
|
| 421 |
+
adv_n_prompt,
|
| 422 |
+
adv_image_resolution,
|
| 423 |
+
adv_low_threshold,
|
| 424 |
+
adv_high_threshold,
|
| 425 |
+
adv_ddim_steps,
|
| 426 |
+
adv_s,
|
| 427 |
+
adv_scale,
|
| 428 |
+
adv_seed,
|
| 429 |
+
adv_eta,
|
| 430 |
+
adv_if_net]
|
| 431 |
+
b_edit_param = [b_prompt,
|
| 432 |
+
b_a_prompt,
|
| 433 |
+
b_n_prompt,
|
| 434 |
+
b_image_resolution,
|
| 435 |
+
b_detect_resolution,
|
| 436 |
+
b_ddim_steps,
|
| 437 |
+
b_scale,
|
| 438 |
+
b_seed,
|
| 439 |
+
b_eta]
|
| 440 |
+
# action
|
| 441 |
+
load_video_button.click(fn=stablevideo.load_video, inputs=video_name, outputs=[original_video, foreground_atlas, background_atlas])
|
| 442 |
+
f_advance_run_button.click(fn=stablevideo.advanced_edit_foreground, inputs=f_adv_edit_param, outputs=[output_foreground_atlas])
|
| 443 |
+
b_run_button.click(fn=stablevideo.edit_background, inputs=b_edit_param, outputs=[output_background_atlas])
|
| 444 |
+
run_button.click(fn=stablevideo.render, inputs=[output_foreground_atlas, output_background_atlas], outputs=[output_video])
|
| 445 |
+
|
| 446 |
+
block.launch()
|
ckpt/cldm_v15.yaml
CHANGED
|
@@ -77,3 +77,5 @@ model:
|
|
| 77 |
|
| 78 |
cond_stage_config:
|
| 79 |
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
cond_stage_config:
|
| 79 |
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
| 80 |
+
params:
|
| 81 |
+
device: "cpu"
|
requirements.txt
CHANGED
|
@@ -120,3 +120,4 @@ wcwidth==0.2.6
|
|
| 120 |
websockets==11.0.3
|
| 121 |
Werkzeug==2.3.7
|
| 122 |
yarl==1.9.2
|
|
|
|
|
|
| 120 |
websockets==11.0.3
|
| 121 |
Werkzeug==2.3.7
|
| 122 |
yarl==1.9.2
|
| 123 |
+
xformers
|
stablevideo/atlas_data.py
CHANGED
|
@@ -30,7 +30,7 @@ class AtlasData():
|
|
| 30 |
maximum_number_of_frames = json_dict["maximum_number_of_frames"]
|
| 31 |
|
| 32 |
config = {
|
| 33 |
-
"device": "
|
| 34 |
"checkpoint_path": f"data/{video_name}/checkpoint.ckpt",
|
| 35 |
"resx": json_dict["resx"],
|
| 36 |
"resy": json_dict["resy"],
|
|
|
|
| 30 |
maximum_number_of_frames = json_dict["maximum_number_of_frames"]
|
| 31 |
|
| 32 |
config = {
|
| 33 |
+
"device": "cpu",
|
| 34 |
"checkpoint_path": f"data/{video_name}/checkpoint.ckpt",
|
| 35 |
"resx": json_dict["resx"],
|
| 36 |
"resy": json_dict["resy"],
|
stablevideo/atlas_utils.py
CHANGED
|
@@ -72,7 +72,7 @@ def load_neural_atlases_models(config):
|
|
| 72 |
skip_layers=[],
|
| 73 |
).to(config["device"])
|
| 74 |
|
| 75 |
-
checkpoint = torch.load(config["checkpoint_path"])
|
| 76 |
foreground_mapping.load_state_dict(checkpoint["model_F_mapping1_state_dict"])
|
| 77 |
background_mapping.load_state_dict(checkpoint["model_F_mapping2_state_dict"])
|
| 78 |
foreground_atlas_model.load_state_dict(checkpoint["F_atlas_state_dict"])
|
|
|
|
| 72 |
skip_layers=[],
|
| 73 |
).to(config["device"])
|
| 74 |
|
| 75 |
+
checkpoint = torch.load(config["checkpoint_path"], map_location=torch.device('cpu'))
|
| 76 |
foreground_mapping.load_state_dict(checkpoint["model_F_mapping1_state_dict"])
|
| 77 |
background_mapping.load_state_dict(checkpoint["model_F_mapping2_state_dict"])
|
| 78 |
foreground_atlas_model.load_state_dict(checkpoint["F_atlas_state_dict"])
|