diff --git a/__pycache__/test.cpython-310.pyc b/__pycache__/test.cpython-310.pyc index e271dc00942710da1c76f23520e0f8e23309a6ea..7691b319e155f934ceedffd1e3c183a6b6084503 100644 Binary files a/__pycache__/test.cpython-310.pyc and b/__pycache__/test.cpython-310.pyc differ diff --git a/app.py b/app.py index 80867b33815a0f2faec6dda95567a19941d1b6d1..416fba0a23fe3e113836a789ca50615389de035d 100644 --- a/app.py +++ b/app.py @@ -11,7 +11,7 @@ from webui.merge_config_gradio import merge_config_then_run import huggingface_hub import shutil import os - +import torch HF_TOKEN = os.getenv('HF_TOKEN') pipe = merge_config_then_run() @@ -39,21 +39,10 @@ If you have any questions, please feel free to reach me out at knightyxp@gmai """ +def update_layout_visibility(selected_num): + num = int(selected_num) + return [gr.update(visible=(i < num)) for i in range(len(layout_files))] -def update_layout_visibility(num): - """ - Given the user's selection (string) in ["2","3","4","5"], - return visibility updates for each of the 5 layout video inputs. - """ - n = int(num) - # Show layout_file1 if n >= 1, layout_file2 if n >= 2, etc. - return [ - gr.update(visible=(n >= 1)), - gr.update(visible=(n >= 2)), - gr.update(visible=(n >= 3)), - gr.update(visible=(n >= 4)), - gr.update(visible=(n >= 5)) - ] with gr.Blocks(css='style.css') as demo: # gr.Markdown(TITLE) @@ -138,55 +127,28 @@ with gr.Blocks(css='style.css') as demo: info="Please select the number of editing areas" ) - # Put all layout-video components in one Row to display them horizontally. + # 使用循环生成所有的布局视频组件,并存到列表 layout_files 中 + layout_files = [] with gr.Row(): - layout_file1 = gr.Video( - label="Layout Video 1", - type="numpy", - format="mp4", - visible=True - ) - layout_file2 = gr.Video( - label="Layout Video 2", - type="numpy", - format="mp4", - visible=True - ) - layout_file3 = gr.Video( - label="Layout Video 3", - type="numpy", - format="mp4", - visible=False - ) - layout_file4 = gr.Video( - label="Layout Video 4", - type="numpy", - format="mp4", - visible=False - ) - layout_file5 = gr.Video( - label="Layout Video 5", - type="numpy", - format="mp4", - visible=False - ) + for i in range(5): + video = gr.Video( + label=f"Layout Video {i+1}", + type="numpy", + format="mp4", + visible=(i < 2) # 默认显示前两个 + ) + layout_files.append(video) - # Toggle visibility of the layout videos based on user selection + # 当 num_layouts 改变时,通过回调函数更新 layout_files 列表中各视频组件的 visible 属性 num_layouts.change( fn=update_layout_visibility, inputs=num_layouts, - outputs=[ - layout_file1, - layout_file2, - layout_file3, - layout_file4, - layout_file5 - ] + outputs=layout_files ) prompt = gr.Textbox(label='Prompt', info='Change the prompt, and extract each local prompt in the editing prompts.\ - the local prompt order should be same as layout masks order.)', + (the local prompt order should be same as layout masks order.)', ) model_id = gr.Dropdown( @@ -198,11 +160,25 @@ with gr.Blocks(css='style.css') as demo: value='stable-diffusion-v1-5/stable-diffusion-v1-5') - run_button = gr.Button('Generate') - with gr.Column(): result = gr.Video(label='Result') # result.style(height=512, width=512) + with gr.Accordion('Temporal Crop offset and Sampling Stride', open=False): + n_sample_frame = gr.Slider(label='Number of Frames', + minimum=0, + maximum=32, + step=1, + value=16) + sampling_rate = gr.Slider(label='sampling_rate', + minimum=0, + maximum=20, + step=1, + value=1) + start_sample_frame = gr.Number(label='Start frame in the video', + value=0, + precision=0) + + with gr.Row(): control_list = ['dwpose', 'depth_zoe', 'depth_midas'] control_type = gr.Dropdown( @@ -252,7 +228,9 @@ with gr.Blocks(css='style.css') as demo: value=["1"], info="Select one or more flatten resolution factors. Mapping: 1 -> 64, 2 -> 32 (64/2), 4 -> 16 (64/4), 8 -> 8 (64/8)." ) - + + + run_button = gr.Button('Generate') with gr.Row(): from example import style_example @@ -278,25 +256,22 @@ with gr.Blocks(css='style.css') as demo: # # cache_examples=os.getenv('SYSTEM') == 'spaces' # ) gr.Markdown(ARTICLE) - inputs = [ - model_id, - user_input_video, - num_layouts, - layout_file1, - layout_file2, - layout_file3, - layout_file4, - layout_file5, - prompt, - model_id, - control_type, - dwpose_options, - controlnet_conditioning_scale, - use_pnp, - pnp_inject_steps, - flatten_res, + inputs = [user_input_video, num_layouts, + *layout_files, + prompt, + model_id, + n_sample_frame, + start_sample_frame, + sampling_rate, + control_type, + dwpose_options, + controlnet_conditioning_scale, + use_pnp, + pnp_inject_steps, + flatten_res, ] prompt.submit(fn=pipe.run, inputs=inputs, outputs=result) run_button.click(fn=pipe.run, inputs=inputs, outputs=result) - + if device == 'cuda': + torch.cuda.empty_cache() demo.queue().launch() \ No newline at end of file diff --git a/config/demo_config.yaml b/config/demo_config.yaml index d648343564ad50b31e6170731fd1e8cb01e47581..5861dc8b6932a46488cca4490f5318120f1173ee 100644 --- a/config/demo_config.yaml +++ b/config/demo_config.yaml @@ -2,13 +2,13 @@ pretrained_model_path: "/home/xianyang/Data/code/FateZero/ckpt/stable-diffusion- logdir: ./result/run_two_man/instance_level/3cls_spider_polar_vis_cross_attn dataset_config: - path: "data/run_two_man/run_two_man_fr2" - prompt: 'Man in red hoddie and man in gray shirt are jogging in forest' + path: "" + prompt: "" n_sample_frame: 16 start_sample_frame: 0 - sampling_rate: 2 - layout_mask_dir: "./data/run_two_man/layout_masks_fr2" - layout_mask_order: ['left_man_plus','right_man_plus','trees','trunk'] + sampling_rate: 1 + layout_mask_dir: "" + layout_mask_order: [] negative_promot: "ugly, blurry, low res, unrealistic, unaesthetic" control_config: @@ -34,7 +34,7 @@ editing_config: sample_seeds: [0] num_inference_steps: 50 blending_percentage: 0 - vis_cross_attn: True + vis_cross_attn: False #cluster_inversion_feature: True diff --git a/input-video/00000.png b/input-video/00000.png new file mode 100644 index 0000000000000000000000000000000000000000..1132e2cfe069c2aea2d99988ed37a6aa56c90bd4 Binary files /dev/null and b/input-video/00000.png differ diff --git a/input-video/00001.png b/input-video/00001.png new file mode 100644 index 0000000000000000000000000000000000000000..babfadc86b0d93b46281b52d940f3bcb448f39cf Binary files /dev/null and b/input-video/00001.png differ diff --git a/input-video/00002.png b/input-video/00002.png new file mode 100644 index 0000000000000000000000000000000000000000..d8303513a19273fc80fa3cd246a1d448ef5da964 Binary files /dev/null and b/input-video/00002.png differ diff --git a/input-video/00003.png b/input-video/00003.png new file mode 100644 index 0000000000000000000000000000000000000000..0d00ac6b2b8d6adb849735f2ecf5def3b4d8378f Binary files /dev/null and b/input-video/00003.png differ diff --git a/input-video/00004.png b/input-video/00004.png new file mode 100644 index 0000000000000000000000000000000000000000..bba5fa1780276ad033b8d948feaffce3934da74b Binary files /dev/null and b/input-video/00004.png differ diff --git a/input-video/00005.png b/input-video/00005.png new file mode 100644 index 0000000000000000000000000000000000000000..feb2d2d52d15bf6f9e7236e5b0f9f4065907e0e8 Binary files /dev/null and b/input-video/00005.png differ diff --git a/input-video/00006.png b/input-video/00006.png new file mode 100644 index 0000000000000000000000000000000000000000..146bcc084d172e119279cd14d60373c258d433dd Binary files /dev/null and b/input-video/00006.png differ diff --git a/input-video/00007.png b/input-video/00007.png new file mode 100644 index 0000000000000000000000000000000000000000..94bf0cdaa0e248e48e3e66d161d85cf7f965c42e Binary files /dev/null and b/input-video/00007.png differ diff --git a/input-video/00008.png b/input-video/00008.png new file mode 100644 index 0000000000000000000000000000000000000000..60f25ff1b64b58690355336f476b688d5a07a65e Binary files /dev/null and b/input-video/00008.png differ diff --git a/input-video/00009.png b/input-video/00009.png new file mode 100644 index 0000000000000000000000000000000000000000..afd4e46edfdcfd58d3560f125ada080079ea9313 Binary files /dev/null and b/input-video/00009.png differ diff --git a/input-video/00010.png b/input-video/00010.png new file mode 100644 index 0000000000000000000000000000000000000000..773c3c9763807b5a3c0056397341139d5ff623b6 Binary files /dev/null and b/input-video/00010.png differ diff --git a/input-video/00011.png b/input-video/00011.png new file mode 100644 index 0000000000000000000000000000000000000000..2f0cefd09c3173e2710a7ff3e1e621c7f6fa4189 Binary files /dev/null and b/input-video/00011.png differ diff --git a/input-video/00012.png b/input-video/00012.png new file mode 100644 index 0000000000000000000000000000000000000000..27f3ffb9e75321e394e0b0c2548327c858e2f506 Binary files /dev/null and b/input-video/00012.png differ diff --git a/input-video/00013.png b/input-video/00013.png new file mode 100644 index 0000000000000000000000000000000000000000..a8ff6814f6f3c8a2b0f1f6e93ec37b1cdc10f7db Binary files /dev/null and b/input-video/00013.png differ diff --git a/input-video/00014.png b/input-video/00014.png new file mode 100644 index 0000000000000000000000000000000000000000..7bb96032ac4511bd26eef2050d5677438851fbd2 Binary files /dev/null and b/input-video/00014.png differ diff --git a/input-video/00015.png b/input-video/00015.png new file mode 100644 index 0000000000000000000000000000000000000000..0c11698cf43b5654a4b9dc561de624c8c98950ec Binary files /dev/null and b/input-video/00015.png differ diff --git a/layout_masks/1/00000.png b/layout_masks/1/00000.png new file mode 100644 index 0000000000000000000000000000000000000000..48bd10ff71d62e978995a6516e6319fd6bb8b913 Binary files /dev/null and b/layout_masks/1/00000.png differ diff --git a/layout_masks/1/00001.png b/layout_masks/1/00001.png new file mode 100644 index 0000000000000000000000000000000000000000..8f4a5c7db6cc4b2329979ace4de65798a5c671b8 Binary files /dev/null and b/layout_masks/1/00001.png differ diff --git a/layout_masks/1/00002.png b/layout_masks/1/00002.png new file mode 100644 index 0000000000000000000000000000000000000000..5a5a77b31fd2355f2c4885fa5c6ed04651c85dff Binary files /dev/null and b/layout_masks/1/00002.png differ diff --git a/layout_masks/1/00003.png b/layout_masks/1/00003.png new file mode 100644 index 0000000000000000000000000000000000000000..0cbada401f7287353713c0cac3195cad6f537f83 Binary files /dev/null and b/layout_masks/1/00003.png differ diff --git a/layout_masks/1/00004.png b/layout_masks/1/00004.png new file mode 100644 index 0000000000000000000000000000000000000000..b945bf5f8a4bc001ca0608d317b30ac3cbd296c3 Binary files /dev/null and b/layout_masks/1/00004.png differ diff --git a/layout_masks/1/00005.png b/layout_masks/1/00005.png new file mode 100644 index 0000000000000000000000000000000000000000..419bd4ba8ef78afa983beaf1b0ba3e7c3530254f Binary files /dev/null and b/layout_masks/1/00005.png differ diff --git a/layout_masks/1/00006.png b/layout_masks/1/00006.png new file mode 100644 index 0000000000000000000000000000000000000000..40fea7a951928ee260318c931ee3de2726e2f9a8 Binary files /dev/null and b/layout_masks/1/00006.png differ diff --git a/layout_masks/1/00007.png b/layout_masks/1/00007.png new file mode 100644 index 0000000000000000000000000000000000000000..0a2c7c456cf16e95085f1925326b5070b892cfee Binary files /dev/null and b/layout_masks/1/00007.png differ diff --git a/layout_masks/1/00008.png b/layout_masks/1/00008.png new file mode 100644 index 0000000000000000000000000000000000000000..ed141eb911ae3ab5a9cc1fc9773b08941bc296b2 Binary files /dev/null and b/layout_masks/1/00008.png differ diff --git a/layout_masks/1/00009.png b/layout_masks/1/00009.png new file mode 100644 index 0000000000000000000000000000000000000000..12dd06d8aa011cbb1f038ec70620bc9f2ca21b09 Binary files /dev/null and b/layout_masks/1/00009.png differ diff --git a/layout_masks/1/00010.png b/layout_masks/1/00010.png new file mode 100644 index 0000000000000000000000000000000000000000..cdedf53d66b7b04d48867d13fb24484763895cfc Binary files /dev/null and b/layout_masks/1/00010.png differ diff --git a/layout_masks/1/00011.png b/layout_masks/1/00011.png new file mode 100644 index 0000000000000000000000000000000000000000..2a5d9bef6be697773e3070101e25333cc0726107 Binary files /dev/null and b/layout_masks/1/00011.png differ diff --git a/layout_masks/1/00012.png b/layout_masks/1/00012.png new file mode 100644 index 0000000000000000000000000000000000000000..9f5afaaa39bee9894ce1f1f6b167fcf8480b31f1 Binary files /dev/null and b/layout_masks/1/00012.png differ diff --git a/layout_masks/1/00013.png b/layout_masks/1/00013.png new file mode 100644 index 0000000000000000000000000000000000000000..37b0e822f17997ba2b488d03969a1ac98e745120 Binary files /dev/null and b/layout_masks/1/00013.png differ diff --git a/layout_masks/1/00014.png b/layout_masks/1/00014.png new file mode 100644 index 0000000000000000000000000000000000000000..14a722aa6e7d4b95650673c9d7d0ab497594ccf0 Binary files /dev/null and b/layout_masks/1/00014.png differ diff --git a/layout_masks/1/00015.png b/layout_masks/1/00015.png new file mode 100644 index 0000000000000000000000000000000000000000..83162819d35b77e6431b9abbec2828a4070f14b7 Binary files /dev/null and b/layout_masks/1/00015.png differ diff --git a/layout_masks/2/00000.png b/layout_masks/2/00000.png new file mode 100644 index 0000000000000000000000000000000000000000..62496a7c46ba863bfb00c9a5e4f8597c4b027282 Binary files /dev/null and b/layout_masks/2/00000.png differ diff --git a/layout_masks/2/00001.png b/layout_masks/2/00001.png new file mode 100644 index 0000000000000000000000000000000000000000..620b8af00ee88e11b3cbaaefee8030a78b32c1ab Binary files /dev/null and b/layout_masks/2/00001.png differ diff --git a/layout_masks/2/00002.png b/layout_masks/2/00002.png new file mode 100644 index 0000000000000000000000000000000000000000..c391271d8e77e05bcbcee53d7d7ee4b014526e55 Binary files /dev/null and b/layout_masks/2/00002.png differ diff --git a/layout_masks/2/00003.png b/layout_masks/2/00003.png new file mode 100644 index 0000000000000000000000000000000000000000..9c13dd527b40a2fa6d3e0b24bb2c0c86eeb4313d Binary files /dev/null and b/layout_masks/2/00003.png differ diff --git a/layout_masks/2/00004.png b/layout_masks/2/00004.png new file mode 100644 index 0000000000000000000000000000000000000000..66c977827ff63f0bbaecf04d1b18d7b6f57edc7f Binary files /dev/null and b/layout_masks/2/00004.png differ diff --git a/layout_masks/2/00005.png b/layout_masks/2/00005.png new file mode 100644 index 0000000000000000000000000000000000000000..a4f0f956cddceb0e8083743958437fee455428ff Binary files /dev/null and b/layout_masks/2/00005.png differ diff --git a/layout_masks/2/00006.png b/layout_masks/2/00006.png new file mode 100644 index 0000000000000000000000000000000000000000..7c6562227f196e6eae9228832b05b02400475547 Binary files /dev/null and b/layout_masks/2/00006.png differ diff --git a/layout_masks/2/00007.png b/layout_masks/2/00007.png new file mode 100644 index 0000000000000000000000000000000000000000..2d948900b827f895091e8b093a4b2f2bca8ee2d7 Binary files /dev/null and b/layout_masks/2/00007.png differ diff --git a/layout_masks/2/00008.png b/layout_masks/2/00008.png new file mode 100644 index 0000000000000000000000000000000000000000..56800dc4fd2f1e0e8429e6e5a5708ef99310dae9 Binary files /dev/null and b/layout_masks/2/00008.png differ diff --git a/layout_masks/2/00009.png b/layout_masks/2/00009.png new file mode 100644 index 0000000000000000000000000000000000000000..663b03a46dc345e66bdf89ae2860d29a1867c7be Binary files /dev/null and b/layout_masks/2/00009.png differ diff --git a/layout_masks/2/00010.png b/layout_masks/2/00010.png new file mode 100644 index 0000000000000000000000000000000000000000..d0f5df54efe3a248f15668a54f0390c307d145f8 Binary files /dev/null and b/layout_masks/2/00010.png differ diff --git a/layout_masks/2/00011.png b/layout_masks/2/00011.png new file mode 100644 index 0000000000000000000000000000000000000000..15208f0b1418040bdd9f1104348e14a77e470a7a Binary files /dev/null and b/layout_masks/2/00011.png differ diff --git a/layout_masks/2/00012.png b/layout_masks/2/00012.png new file mode 100644 index 0000000000000000000000000000000000000000..92cada72982a71558cd830ea087adc0e76b9c380 Binary files /dev/null and b/layout_masks/2/00012.png differ diff --git a/layout_masks/2/00013.png b/layout_masks/2/00013.png new file mode 100644 index 0000000000000000000000000000000000000000..6c2a69f339a2018959dc6092878ccaabadd174c2 Binary files /dev/null and b/layout_masks/2/00013.png differ diff --git a/layout_masks/2/00014.png b/layout_masks/2/00014.png new file mode 100644 index 0000000000000000000000000000000000000000..3cc7b6455f95789f3faf1e5c471c8f370b395e37 Binary files /dev/null and b/layout_masks/2/00014.png differ diff --git a/layout_masks/2/00015.png b/layout_masks/2/00015.png new file mode 100644 index 0000000000000000000000000000000000000000..8e6e19820ab671927729cba27c5e17b5ed4d3751 Binary files /dev/null and b/layout_masks/2/00015.png differ diff --git a/layout_masks/3/00000.png b/layout_masks/3/00000.png new file mode 100644 index 0000000000000000000000000000000000000000..d540461e6ba7fe0613082a2d0a0882e6b3dbfd75 Binary files /dev/null and b/layout_masks/3/00000.png differ diff --git a/layout_masks/3/00001.png b/layout_masks/3/00001.png new file mode 100644 index 0000000000000000000000000000000000000000..2673b40a02c4dde6d74da87761c71afcd1d583c0 Binary files /dev/null and b/layout_masks/3/00001.png differ diff --git a/layout_masks/3/00002.png b/layout_masks/3/00002.png new file mode 100644 index 0000000000000000000000000000000000000000..9fb7e23b005bc8dc222fe1956771b8ce366112fd Binary files /dev/null and b/layout_masks/3/00002.png differ diff --git a/layout_masks/3/00003.png b/layout_masks/3/00003.png new file mode 100644 index 0000000000000000000000000000000000000000..aeeac0051f5dc440eb0b8a9fc8a62b3d8bb9c2f1 Binary files /dev/null and b/layout_masks/3/00003.png differ diff --git a/layout_masks/3/00004.png b/layout_masks/3/00004.png new file mode 100644 index 0000000000000000000000000000000000000000..2f0b36af8fb24e96b3e0b47049e6bcda16261754 Binary files /dev/null and b/layout_masks/3/00004.png differ diff --git a/layout_masks/3/00005.png b/layout_masks/3/00005.png new file mode 100644 index 0000000000000000000000000000000000000000..32be1124b2524a20554c4ebc02647ed0be36c827 Binary files /dev/null and b/layout_masks/3/00005.png differ diff --git a/layout_masks/3/00006.png b/layout_masks/3/00006.png new file mode 100644 index 0000000000000000000000000000000000000000..15fbb228a6b0d5477562e488d49a62efe8af24a5 Binary files /dev/null and b/layout_masks/3/00006.png differ diff --git a/layout_masks/3/00007.png b/layout_masks/3/00007.png new file mode 100644 index 0000000000000000000000000000000000000000..88e808110b3449d37472bcd876f90125cf81c2b9 Binary files /dev/null and b/layout_masks/3/00007.png differ diff --git a/layout_masks/3/00008.png b/layout_masks/3/00008.png new file mode 100644 index 0000000000000000000000000000000000000000..90b6ff1503ab4079715136130508fecb6915c58a Binary files /dev/null and b/layout_masks/3/00008.png differ diff --git a/layout_masks/3/00009.png b/layout_masks/3/00009.png new file mode 100644 index 0000000000000000000000000000000000000000..f21dfd8301777c1bd4ad671b2ec885f258496141 Binary files /dev/null and b/layout_masks/3/00009.png differ diff --git a/layout_masks/3/00010.png b/layout_masks/3/00010.png new file mode 100644 index 0000000000000000000000000000000000000000..0ecdc3489901bc0d4a9a9a206b425b69f1d374ca Binary files /dev/null and b/layout_masks/3/00010.png differ diff --git a/layout_masks/3/00011.png b/layout_masks/3/00011.png new file mode 100644 index 0000000000000000000000000000000000000000..8d25050eba1fc0f845207b16749c01465020b369 Binary files /dev/null and b/layout_masks/3/00011.png differ diff --git a/layout_masks/3/00012.png b/layout_masks/3/00012.png new file mode 100644 index 0000000000000000000000000000000000000000..d5054fbca176b94622e226b28d7097aa48a3a0f3 Binary files /dev/null and b/layout_masks/3/00012.png differ diff --git a/layout_masks/3/00013.png b/layout_masks/3/00013.png new file mode 100644 index 0000000000000000000000000000000000000000..7bf390dd040276b2777c9f4cd097ce250145cf9c Binary files /dev/null and b/layout_masks/3/00013.png differ diff --git a/layout_masks/3/00014.png b/layout_masks/3/00014.png new file mode 100644 index 0000000000000000000000000000000000000000..adff194fff603502c585a92fe97cda36261e59e0 Binary files /dev/null and b/layout_masks/3/00014.png differ diff --git a/layout_masks/3/00015.png b/layout_masks/3/00015.png new file mode 100644 index 0000000000000000000000000000000000000000..61aee16b521703d5a90c0b0042ffd7fde836e941 Binary files /dev/null and b/layout_masks/3/00015.png differ diff --git a/layout_masks/4/00000.png b/layout_masks/4/00000.png new file mode 100644 index 0000000000000000000000000000000000000000..0c938fdb97357aec88a1117a96eccf694a306d54 Binary files /dev/null and b/layout_masks/4/00000.png differ diff --git a/layout_masks/4/00001.png b/layout_masks/4/00001.png new file mode 100644 index 0000000000000000000000000000000000000000..9a18c4847b536f5cac050b50fe3a407d08903085 Binary files /dev/null and b/layout_masks/4/00001.png differ diff --git a/layout_masks/4/00002.png b/layout_masks/4/00002.png new file mode 100644 index 0000000000000000000000000000000000000000..fc10b57487b41b0322982a5f7edd6e400077b994 Binary files /dev/null and b/layout_masks/4/00002.png differ diff --git a/layout_masks/4/00003.png b/layout_masks/4/00003.png new file mode 100644 index 0000000000000000000000000000000000000000..0e73940c6ac3e14d6091483facfefe4c95471f0f Binary files /dev/null and b/layout_masks/4/00003.png differ diff --git a/layout_masks/4/00004.png b/layout_masks/4/00004.png new file mode 100644 index 0000000000000000000000000000000000000000..7791fbbee7f630d2dabf94f7c35143df168a133a Binary files /dev/null and b/layout_masks/4/00004.png differ diff --git a/layout_masks/4/00005.png b/layout_masks/4/00005.png new file mode 100644 index 0000000000000000000000000000000000000000..bb8eb4b9f2b1e3720fbb855272f51eb6d58fe088 Binary files /dev/null and b/layout_masks/4/00005.png differ diff --git a/layout_masks/4/00006.png b/layout_masks/4/00006.png new file mode 100644 index 0000000000000000000000000000000000000000..aa5e0fa381abc9974df280fa46f12d0fbdf4b05c Binary files /dev/null and b/layout_masks/4/00006.png differ diff --git a/layout_masks/4/00007.png b/layout_masks/4/00007.png new file mode 100644 index 0000000000000000000000000000000000000000..70c3c23fac4b9a463614175c6417087e78ca3dfd Binary files /dev/null and b/layout_masks/4/00007.png differ diff --git a/layout_masks/4/00008.png b/layout_masks/4/00008.png new file mode 100644 index 0000000000000000000000000000000000000000..4ec2ad232250ec909c3d3ca6ad8b75b34c95de53 Binary files /dev/null and b/layout_masks/4/00008.png differ diff --git a/layout_masks/4/00009.png b/layout_masks/4/00009.png new file mode 100644 index 0000000000000000000000000000000000000000..728d8a6afd50cdcaf29f02849a391cd818083fdb Binary files /dev/null and b/layout_masks/4/00009.png differ diff --git a/layout_masks/4/00010.png b/layout_masks/4/00010.png new file mode 100644 index 0000000000000000000000000000000000000000..c3e535983a79824835e18b403c17e92af1ec0526 Binary files /dev/null and b/layout_masks/4/00010.png differ diff --git a/layout_masks/4/00011.png b/layout_masks/4/00011.png new file mode 100644 index 0000000000000000000000000000000000000000..13f46271568f07222e3b6d179c828a92cf79a3a2 Binary files /dev/null and b/layout_masks/4/00011.png differ diff --git a/layout_masks/4/00012.png b/layout_masks/4/00012.png new file mode 100644 index 0000000000000000000000000000000000000000..aa7917bedf6cdde8cc058158d6d5034756fd9d55 Binary files /dev/null and b/layout_masks/4/00012.png differ diff --git a/layout_masks/4/00013.png b/layout_masks/4/00013.png new file mode 100644 index 0000000000000000000000000000000000000000..c17de554e346994e279d7cfe7f06e4f161092916 Binary files /dev/null and b/layout_masks/4/00013.png differ diff --git a/layout_masks/4/00014.png b/layout_masks/4/00014.png new file mode 100644 index 0000000000000000000000000000000000000000..e961c1b98bc7c9aaef5b0c96ae1c3250fed47301 Binary files /dev/null and b/layout_masks/4/00014.png differ diff --git a/layout_masks/4/00015.png b/layout_masks/4/00015.png new file mode 100644 index 0000000000000000000000000000000000000000..9238c69e236870cc82515b741a0317cdcebf5323 Binary files /dev/null and b/layout_masks/4/00015.png differ diff --git a/layout_masks/5/00000.png b/layout_masks/5/00000.png new file mode 100644 index 0000000000000000000000000000000000000000..dae3f061be6d9bccefa8190e12c97a21005be06d Binary files /dev/null and b/layout_masks/5/00000.png differ diff --git a/layout_masks/5/00001.png b/layout_masks/5/00001.png new file mode 100644 index 0000000000000000000000000000000000000000..c92427abeebdcb30f9406b5424956c08effa04f6 Binary files /dev/null and b/layout_masks/5/00001.png differ diff --git a/layout_masks/5/00002.png b/layout_masks/5/00002.png new file mode 100644 index 0000000000000000000000000000000000000000..4e36988cb728a24b248c627a2461c2e92443b3af Binary files /dev/null and b/layout_masks/5/00002.png differ diff --git a/layout_masks/5/00003.png b/layout_masks/5/00003.png new file mode 100644 index 0000000000000000000000000000000000000000..6fb40eb7079b0dba82db5c5702df6bffa231ddc1 Binary files /dev/null and b/layout_masks/5/00003.png differ diff --git a/layout_masks/5/00004.png b/layout_masks/5/00004.png new file mode 100644 index 0000000000000000000000000000000000000000..e22c4b41011c6fa2a877c9b94c0af2bef0048233 Binary files /dev/null and b/layout_masks/5/00004.png differ diff --git a/layout_masks/5/00005.png b/layout_masks/5/00005.png new file mode 100644 index 0000000000000000000000000000000000000000..1f5f13fad902e05dad089032b1019d2adf293214 Binary files /dev/null and b/layout_masks/5/00005.png differ diff --git a/layout_masks/5/00006.png b/layout_masks/5/00006.png new file mode 100644 index 0000000000000000000000000000000000000000..f55c316dde5b4a37d15f5c180cd07cde329a9698 Binary files /dev/null and b/layout_masks/5/00006.png differ diff --git a/layout_masks/5/00007.png b/layout_masks/5/00007.png new file mode 100644 index 0000000000000000000000000000000000000000..699b405a372926ea706056df21730901fc011264 Binary files /dev/null and b/layout_masks/5/00007.png differ diff --git a/layout_masks/5/00008.png b/layout_masks/5/00008.png new file mode 100644 index 0000000000000000000000000000000000000000..81d40959777ff8732c5bed002959fc3067866707 Binary files /dev/null and b/layout_masks/5/00008.png differ diff --git a/layout_masks/5/00009.png b/layout_masks/5/00009.png new file mode 100644 index 0000000000000000000000000000000000000000..a6215b70a7fdf66f5b901268528bbf27b0c3546c Binary files /dev/null and b/layout_masks/5/00009.png differ diff --git a/layout_masks/5/00010.png b/layout_masks/5/00010.png new file mode 100644 index 0000000000000000000000000000000000000000..8d71e2af297c30f4ec6d03422f58c883235fe936 Binary files /dev/null and b/layout_masks/5/00010.png differ diff --git a/layout_masks/5/00011.png b/layout_masks/5/00011.png new file mode 100644 index 0000000000000000000000000000000000000000..26f56e5e34ed5d7ff0120df0d111905d22ff302a Binary files /dev/null and b/layout_masks/5/00011.png differ diff --git a/layout_masks/5/00012.png b/layout_masks/5/00012.png new file mode 100644 index 0000000000000000000000000000000000000000..51ee657df326c58a9f1b5993a4bbd7acb78d0772 Binary files /dev/null and b/layout_masks/5/00012.png differ diff --git a/layout_masks/5/00013.png b/layout_masks/5/00013.png new file mode 100644 index 0000000000000000000000000000000000000000..ae4aab594ca9d4ba75c08ef5291314a547d0edf1 Binary files /dev/null and b/layout_masks/5/00013.png differ diff --git a/layout_masks/5/00014.png b/layout_masks/5/00014.png new file mode 100644 index 0000000000000000000000000000000000000000..793ccdc5daee4472243a82f8af611d5e8b63cae5 Binary files /dev/null and b/layout_masks/5/00014.png differ diff --git a/layout_masks/5/00015.png b/layout_masks/5/00015.png new file mode 100644 index 0000000000000000000000000000000000000000..d523d9331021f250cf2f4dffc4efc950821a6748 Binary files /dev/null and b/layout_masks/5/00015.png differ diff --git a/test.py b/test.py index 1c20870ba5357aa41e5c621364ff1aa876a33091..812a9677c80336481eb4b868c75eea29e0b0cea1 100644 --- a/test.py +++ b/test.py @@ -108,7 +108,6 @@ def test( subfolder="vae", ) - # unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet") unet = UNetPseudo3DConditionModel.from_2d_model( os.path.join(pretrained_model_path, "unet"), model_config=model_config ) @@ -236,6 +235,8 @@ def test( images = rearrange(images.to(dtype=torch.float32), "b c f h w -> (b f) h w c") control_type = control_config['control_type'] + print('control_type',control_type) + apply_control = get_control(control_type) control = [] @@ -352,6 +353,7 @@ def test( layouts = batch["layouts"].to(dtype=weight_dtype) #layouts = f s c h w + if accelerator.is_main_process: if validation_sample_logger is not None: @@ -379,8 +381,10 @@ def test( attn_inversion_dict = attn_inversion_dict, ) - accelerator.end_training() - + accelerator.end_training() + save_path = os.path.join(logdir,'sample/step_0.gif') + print('save_path',save_path) + return save_path @click.command() @click.option("--config", type=str, default="config/shape/exp_config/single_object/tennis_3.yaml") def run(config): diff --git a/video_diffusion/data/__pycache__/dataset.cpython-310.pyc b/video_diffusion/data/__pycache__/dataset.cpython-310.pyc index 0deec1bd056c375bdd752a8057fbee09d8b13004..d5260fbe71864236dbc3e72b22285a4b64332cfe 100644 Binary files a/video_diffusion/data/__pycache__/dataset.cpython-310.pyc and b/video_diffusion/data/__pycache__/dataset.cpython-310.pyc differ diff --git a/video_diffusion/data/dataset.py b/video_diffusion/data/dataset.py index 965961c4846d91578c108985092c9a30da50fab2..76dd4ba2cafdbe52cd521fc83e6c0d7c54827b94 100644 --- a/video_diffusion/data/dataset.py +++ b/video_diffusion/data/dataset.py @@ -1,5 +1,4 @@ import os - import numpy as np from PIL import Image from einops import rearrange @@ -11,26 +10,24 @@ from torch.utils.data import Dataset from .transform import short_size_scale, random_crop, center_crop, offset_crop from ..common.image_util import IMAGE_EXTENSION import cv2 +import imageio +import shutil class ImageSequenceDataset(Dataset): def __init__( self, - path: str, - layout_mask_dir: str, - layout_mask_order: list, + path: str, # 输入视频,如果是 mp4 则转换到固定目录 './input-video' + layout_files: list, # 上传的 layout mask 文件列表(mp4 或目录),转换后存放到固定目录 './layout_masks/1', './layout_masks/2', ... prompt_ids: torch.Tensor, prompt: str, - start_sample_frame: int=0, + start_sample_frame: int = 0, n_sample_frame: int = 8, sampling_rate: int = 1, - stride: int = -1, # only used during tuning to sample a long video + stride: int = -1, # tuning 时用于对长视频进行采样 image_mode: str = "RGB", image_size: int = 512, crop: str = "center", - - class_data_root: str = None, - class_prompt_ids: torch.Tensor = None, - + offset: dict = { "left": 0, "right": 0, @@ -38,33 +35,42 @@ class ImageSequenceDataset(Dataset): "bottom": 0 }, **args - ): - self.path = path - self.images = self.get_image_list(path) - # - self.layout_mask_dir = layout_mask_dir - self.layout_mask_order = list(layout_mask_order) + # 若输入视频是 mp4,则转换到固定目录 './input-video' + if path.endswith('.mp4'): + self.path = self.mp4_to_png(path, target_dir='./input-video') + else: + self.path = path + self.images = self.get_image_list(self.path) - layout_mask_dir0 = os.path.join(self.layout_mask_dir,self.layout_mask_order[0]) - self.masks_index = self.get_image_list(layout_mask_dir0) + # 对每个上传的 layout 文件进行处理 + # 若是 mp4,则转换到固定目录 './layout_masks/{i+1}' + self.layout_mask_dirs = [] + for idx, file in enumerate(layout_files): + if file.endswith('.mp4'): + folder = self.mp4_to_png(file, target_dir=f'./layout_masks/{idx+1}') + else: + folder = file + self.layout_mask_dirs.append(folder) + # 保持上传顺序作为 layout_mask_order(此处仅用索引表示顺序) + self.layout_mask_order = list(range(len(self.layout_mask_dirs))) + # 用第一个 layout mask 目录获取 mask 图像索引(用于判断帧数) + self.masks_index = self.get_image_list(self.layout_mask_dirs[0]) - # self.n_images = len(self.images) self.offset = offset self.start_sample_frame = start_sample_frame if n_sample_frame < 0: - n_sample_frame = len(self.images) + n_sample_frame = len(self.images) self.n_sample_frame = n_sample_frame - # local sampling rate from the video self.sampling_rate = sampling_rate self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1 if self.n_images < self.sequence_length: - raise ValueError(f"self.n_images {self.n_images } < self.sequence_length {self.sequence_length}: Required number of frames {self.sequence_length} larger than total frames in the dataset {self.n_images }") + raise ValueError(f"self.n_images {self.n_images} < self.sequence_length {self.sequence_length}: Required number of frames {self.sequence_length} larger than total frames in the dataset {self.n_images}") - # During tuning if video is too long, we sample the long video every self.stride globally - self.stride = stride if stride > 0 else (self.n_images+1) + # 若视频太长,则全局采样 + self.stride = stride if stride > 0 else (self.n_images + 1) self.video_len = (self.n_images - self.sequence_length) // self.stride + 1 self.image_mode = image_mode @@ -74,67 +80,53 @@ class ImageSequenceDataset(Dataset): "random": random_crop, } if crop not in crop_methods: - raise ValueError + raise ValueError("Unsupported crop method") self.crop = crop_methods[crop] self.prompt = prompt self.prompt_ids = prompt_ids - # Negative prompt for regularization to avoid overfitting during one-shot tuning - if class_data_root is not None: - self.class_data_root = Path(class_data_root) - self.class_images_path = sorted(list(self.class_data_root.iterdir())) - self.num_class_images = len(self.class_images_path) - self.class_prompt_ids = class_prompt_ids def __len__(self): max_len = (self.n_images - self.sequence_length) // self.stride + 1 - if hasattr(self, 'num_class_images'): max_len = max(max_len, self.num_class_images) - return max_len def __getitem__(self, index): return_batch = {} - frame_indices = self.get_frame_indices(index%self.video_len) + frame_indices = self.get_frame_indices(index % self.video_len) frames = [self.load_frame(i) for i in frame_indices] frames = self.transform(frames) layout_ = [] - for layout_name in self.layout_mask_order: - frame_indices = self.get_frame_indices(index%self.video_len) - layout_mask_dir = os.path.join(self.layout_mask_dir,layout_name) - mask = [self._read_mask(layout_mask_dir,i) for i in frame_indices] - masks = np.stack(mask) + # 遍历每个 layout mask 目录(顺序与用户上传顺序一致) + for layout_dir in self.layout_mask_dirs: + # 对于每个 layout 目录,根据帧索引读取对应的 mask 图像(PNG 文件) + frame_indices_local = self.get_frame_indices(index % self.video_len) + mask = [self._read_mask(layout_dir, i) for i in frame_indices_local] + masks = np.stack(mask) # shape: (n_sample_frame, c, h, w) layout_.append(masks) - layout_ = np.stack(layout_) + layout_ = np.stack(layout_) # shape: (num_layouts, n_sample_frame, c, h, w) + merged_masks = [] for i in range(int(self.n_sample_frame)): - merged_mask_frame = np.sum(layout_[:,i,:,:,:], axis=0) - merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8) + merged_mask_frame = np.sum(layout_[:, i, :, :, :], axis=0) + merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8) merged_masks.append(merged_mask_frame) masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w") masks = torch.from_numpy(masks).half() - layouts = rearrange(layout_,"s f c h w -> f s c h w" ) + layouts = rearrange(layout_, "s f c h w -> f s c h w") layouts = torch.from_numpy(layouts).half() - return_batch.update( - { + return_batch.update({ "images": frames, - "masks":masks, - "layouts":layouts, + "masks": masks, + "layouts": layouts, "prompt_ids": self.prompt_ids, - } - ) - - if hasattr(self, 'class_data_root'): - class_index = index % (self.num_class_images - self.n_sample_frame) - class_indices = self.get_class_indices(class_index) - frames = [self.load_class_frame(i) for i in class_indices] - return_batch["class_images"] = self.tensorize_frames(frames) - return_batch["class_prompt_ids"] = self.class_prompt_ids + }) + return return_batch def transform(self, frames): @@ -149,24 +141,18 @@ class ImageSequenceDataset(Dataset): frames = rearrange(np.stack(frames), "f h w c -> c f h w") return torch.from_numpy(frames).div(255) * 2 - 1 - def _read_mask(self, mask_path,index: int): - ### read mask by pil - - mask_path = os.path.join(mask_path,f"{index:05d}.png") - - ### read mask by cv2 + def _read_mask(self, mask_dir, index: int): + # 构造 mask 文件名(png 格式) + mask_path = os.path.join(mask_dir, f"{index:05d}.png") mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) mask = (mask > 0).astype(np.uint8) - # Determine dynamic destination size + # 根据原图大小动态缩放(这里缩小8倍) height, width = mask.shape dest_size = (width // 8, height // 8) - # Resize using nearest neighbor interpolation - mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST) #cv2.INTER_CUBIC + mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST) mask = mask[np.newaxis, ...] - return mask - def load_frame(self, index): image_path = os.path.join(self.path, self.images[index]) return Image.open(image_path).convert(self.image_mode) @@ -184,12 +170,31 @@ class ImageSequenceDataset(Dataset): def get_class_indices(self, index): frame_start = index - return (frame_start + i for i in range(self.n_sample_frame)) + return (frame_start + i for i in range(self.n_sample_frame)) @staticmethod def get_image_list(path): images = [] + # 如果传入的是 mp4 文件,则先转换成 PNG 图像目录 + if path.endswith('.mp4'): + path = ImageSequenceDataset.mp4_to_png(path, target_dir='./input-video') for file in sorted(os.listdir(path)): if file.endswith(IMAGE_EXTENSION): images.append(file) return images + + @staticmethod + def mp4_to_png(video_source: str, target_dir: str): + """ + Convert an mp4 video to a sequence of PNG images, storing them in target_dir. + target_dir 为固定路径,例如:'./input-video' 或 './layout_masks/1' + """ + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + os.makedirs(target_dir, exist_ok=True) + + reader = imageio.get_reader(video_source) + for i, im in enumerate(reader): + path = os.path.join(target_dir, f"{i:05d}.png") + cv2.imwrite(path, im[:, :, ::-1]) + return target_dir diff --git a/webui/__pycache__/merge_config_gradio.cpython-310.pyc b/webui/__pycache__/merge_config_gradio.cpython-310.pyc index 4a89b04da28d5cf498d8da056888cb91e4cd7541..d8a67134df6f17874b5e3b5c497deb2c97cea964 100644 Binary files a/webui/__pycache__/merge_config_gradio.cpython-310.pyc and b/webui/__pycache__/merge_config_gradio.cpython-310.pyc differ diff --git a/webui/merge_config_gradio.py b/webui/merge_config_gradio.py index 9181e00714f0ee458b44faa107433cc052333f20..d2f8720a3ffe7d8b08f453f40a56b4500123a479 100644 --- a/webui/merge_config_gradio.py +++ b/webui/merge_config_gradio.py @@ -5,60 +5,31 @@ import gradio as gr class merge_config_then_run(): def __init__(self) -> None: - # Load the tokenizer - pretrained_model_path = '/home/xianyang/Data/code/FateZero/ckpt/stable-diffusion-v1-5' - self.tokenizer = None - self.text_encoder = None - self.vae = None - self.unet = None - - cache_ckpt = True - if cache_ckpt: - self.tokenizer = AutoTokenizer.from_pretrained( - pretrained_model_path, - # 'FateZero/ckpt/stable-diffusion-v1-4', - subfolder="tokenizer", - use_fast=False, - ) - - # Load models and create wrapper for stable diffusion - self.text_encoder = CLIPTextModel.from_pretrained( - pretrained_model_path, - subfolder="text_encoder", - ) - - self.vae = AutoencoderKL.from_pretrained( - pretrained_model_path, - subfolder="vae", - ) - - self.unet = UNetPseudo3DConditionModel.from_2d_model( - os.path.join(pretrained_model_path, "unet"), model_config=model_config - ) + # Load the tokenizer + self.pretrained_model_path = '/home/xianyang/Data/code/FateZero/ckpt/stable-diffusion-v1-5' + # load controlnet + def run( self, - # def merge_config_then_run( + user_input_video, + num_layouts, + layout_file1, + layout_file2, + layout_file3, + layout_file4, + layout_file5, + prompt, model_id, - data_path, - source_prompt, - target_prompt, - cross_replace_steps, - self_replace_steps, - enhance_words, - enhance_words_value, - num_steps, - guidance_scale, - user_input_video=None, - - # Temporal and spatial crop of the video - start_sample_frame=0, - n_sample_frame=8, - stride=1, - left_crop=0, - right_crop=0, - top_crop=0, - bottom_crop=0, + n_sample_frame, + start_sample_frame, + sampling_rate, + control_type, + dwpose_options, + controlnet_conditioning_scale, + use_pnp, + pnp_inject_steps, + flatten_res, ): # , ] = inputs default_edit_config='config/demo_config.yaml' @@ -66,25 +37,30 @@ class merge_config_then_run(): dataset_time_string = get_time_string() config_now = copy.deepcopy(Omegadict_default_edit_config) - print(f"config_now['pretrained_model_path'] = model_id {model_id}") - # config_now['pretrained_model_path'] = model_id - config_now['dataset_config']['prompt'] = source_prompt - config_now['dataset_config']['path'] = data_path - # ImageSequenceDataset_dict = { } - offset_dict = { - "left": left_crop, - "right": right_crop, - "top": top_crop, - "bottom": bottom_crop, - } - ImageSequenceDataset_dict = { - "start_sample_frame" : start_sample_frame, - "n_sample_frame" : n_sample_frame, - "sampling_rate" : stride, - "offset": offset_dict, - } - config_now['dataset_config'].update(ImageSequenceDataset_dict) - if user_input_video and data_path is None: + + config_now['pretrained_model_path'] = self.pretrained_model_path + print(f"config_now['pretrained_model_path'] = model_id {self.pretrained_model_path}") + + + + #==========update datset_config===============# + + # 将所有 layout 文件放入列表中 + all_layout_files = [layout_file1, layout_file2, layout_file3, layout_file4, layout_file5] + # 根据 num_layouts 转换为整数,并只使用前 N 个 + n_layouts = int(num_layouts) + layout_files = all_layout_files[:n_layouts] + + + config_now['dataset_config']['prompt'] = '' + config_now['dataset_config']['path'] = user_input_video + config_now['dataset_config']['n_sample_frame'] = n_sample_frame + config_now['dataset_config']['start_sample_frame'] = start_sample_frame + config_now['dataset_config']['sampling_rate'] = sampling_rate + config_now['dataset_config']['layout_files'] = layout_files + + + if user_input_video is None: raise gr.Error('You need to upload a video or choose a provided video') if user_input_video is not None: if isinstance(user_input_video, str): @@ -92,21 +68,52 @@ class merge_config_then_run(): elif hasattr(user_input_video, 'name') and user_input_video.name is not None: config_now['dataset_config']['path'] = user_input_video.name + # 检查每个 layout file 是否存在 + layout_files_checked = [] + for idx, lf in enumerate(layout_files): + if lf is None: + raise gr.Error(f'Layout file {idx+1} is missing') + if isinstance(lf, str): + lf_path = lf + elif hasattr(lf, 'name') and lf.name is not None: + lf_path = lf.name + else: + raise gr.Error(f'Layout file {idx+1} is invalid') + if not os.path.exists(lf_path): + raise gr.Error(f'Layout file "{lf_path}" does not exist') + layout_files_checked.append(lf_path) + config_now['dataset_config']['layout_files'] = layout_files_checked + #==========update datset_config===============# - # editing config - config_now['editing_config']['prompts'] = [target_prompt] - config_now['editing_config']['guidance_scale'] = guidance_scale - config_now['editing_config']['num_inference_steps'] = num_steps - + #==========update control_config===============# + config_now['control_config']['control_type'] = control_type + config_now['control_config']['controlnet_conditioning_scale'] = float(controlnet_conditioning_scale) + config_now['control_config']['hand'] = 'hand' in dwpose_options + config_now['control_config']['face'] = 'face' in dwpose_options + + + if control_type == "depth_midas": + pretrained_controlnet_path = "/home/xianyang/Data/code/controlvideo/sd-controlnet-depth" + elif control_type == "depth_zoe": + pretrained_controlnet_path = "/home/xianyang/Data/code/FateZero/ckpt/control_v11f1p_sd15_depth" + elif control_type == "dwpose": + pretrained_controlnet_path = "/home/xianyang/Data/code/FateZero/ckpt/control_v11p_sd15_openpose" + #==========update control_config===============# + + + #==========update editing_config===============# + config_now['editing_config']['use_pnp'] = [use_pnp] + config_now['editing_config']['inject_step'] = int(pnp_inject_steps) + config_now['editing_config']['flatten_res'] = [int(x) for x in flatten_res] + config_now['editing_config']['editing_prompts'] = [[x.strip() for x in prompt.split(',')]] + print('editing prompt', prompt) + #==========update editing_config===============# logdir = default_edit_config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')+f'_{dataset_time_string}' config_now['logdir'] = logdir print(f'Saving at {logdir}') - save_path = test(tokenizer = self.tokenizer, - text_encoder = self.text_encoder, - vae = self.vae, - unet = self.unet, - config=default_edit_config, **config_now) + save_path = test(config = config_now, + **config_now) mp4_path = save_path.replace('_0.gif', '_0_0_0.mp4') return mp4_path