XiangpengYang commited on
Commit
952c41a
·
1 Parent(s): fdab143

huggingface dataset

Browse files
README.md CHANGED
@@ -1,12 +1,371 @@
1
- ---
2
- title: VideoGrain
3
- emoji: 🔥
4
- colorFrom: gray
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.21.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VideoGrain: Modulating Space-Time Attention for Multi-Grained Video Editing (ICLR 2025)
2
+ ## [<a href="https://knightyxp.github.io/VideoGrain_project_page/" target="_blank">Project Page</a>]
3
+
4
+ [![arXiv](https://img.shields.io/badge/arXiv-2502.17258-B31B1B.svg)](https://arxiv.org/abs/2502.17258)
5
+ [![HuggingFace Daily Papers Top1](https://img.shields.io/static/v1?label=HuggingFace%20Daily%20Papers&message=Top1&color=blue)](https://huggingface.co/papers/2502.17258)
6
+ [![Project page](https://img.shields.io/badge/Project-Page-brightgreen)](https://knightyxp.github.io/VideoGrain_project_page/)
7
+ [![Full Data](https://img.shields.io/badge/Full-Data-brightgreen)](https://drive.google.com/file/d/1dzdvLnXWeMFR3CE2Ew0Bs06vyFSvnGXA/view?usp=drive_link)
8
+ ![visitors](https://visitor-badge.laobi.icu/badge?page_id=knightyxp.VideoGrain&left_color=green&right_color=red)
9
+ [![Youtube Video - VideoGrain](https://img.shields.io/badge/Demo_Video-VideoGrain-red)](https://www.youtube.com/watch?v=XEM4Pex7F9E)
10
+ [![Hugging Face Dataset](https://img.shields.io/badge/HuggingFace-Dataset-blue?style=for-the-badge&logo=huggingface&logoColor=white)](https://huggingface.co/datasets/XiangpengYang/VideoGrain-dataset)
11
+
12
+
13
+ ## Introduction
14
+ VideoGrain is a zero-shot method for class-level, instance-level, and part-level video editing.
15
+ - **Multi-grained Video Editing**
16
+ - class-level: Editing objects within the same class (previous SOTA limited to this level)
17
+ - instance-level: Editing each individual instance to distinct object
18
+ - part-level: Adding new objects or modifying existing attributes at the part-level
19
+ - **Training-Free**
20
+ - Does not require any training/fine-tuning
21
+ - **One-Prompt Multi-region Control & Deep investigations about cross/self attn**
22
+ - modulating cross-attn for multi-regions control (visualizations available)
23
+ - modulating self-attn for feature decoupling (clustering are available)
24
+
25
+ <table class="center" border="1" cellspacing="0" cellpadding="5">
26
+ <tr>
27
+ <td colspan="2" style="text-align:center;"><img src="assets/teaser/class_level.gif" style="width:250px; height:auto;"></td>
28
+ <td colspan="2" style="text-align:center;"><img src="assets/teaser/instance_part.gif" style="width:250px; height:auto;"></td>
29
+ <td colspan="2" style="text-align:center;"><img src="assets/teaser/2monkeys.gif" style="width:250px; height:auto;"></td>
30
+ </tr>
31
+ <tr>
32
+ <!-- <td colspan="1" style="text-align:right; width:125px;"> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</td> -->
33
+ <td colspan="2" style="text-align:right; width:250px;"> class level</td>
34
+ <td colspan="1" style="text-align:center; width:125px;">instance level</td>
35
+ <td colspan="1" style="text-align:center; width:125px;">part level</td>
36
+ <td colspan="2" style="text-align:center; width:250px;">animal instances</td>
37
+ </tr>
38
+
39
+ <tr>
40
+ <td colspan="2" style="text-align:center;"><img src="assets/teaser/2cats.gif" style="width:250px; height:auto;"></td>
41
+ <td colspan="2" style="text-align:center;"><img src="assets/teaser/soap-box.gif" style="width:250px; height:auto;"></td>
42
+ <td colspan="2" style="text-align:center;"><img src="assets/teaser/man-text-message.gif" style="width:250px; height:auto;"></td>
43
+ </tr>
44
+ <tr>
45
+ <td colspan="2" style="text-align:center; width:250px;">animal instances</td>
46
+ <td colspan="2" style="text-align:center; width:250px;">human instances</td>
47
+ <td colspan="2" style="text-align:center; width:250px;">part-level modification</td>
48
+ </tr>
49
+ </table>
50
+
51
+ ## 📀 Demo Video
52
+ <!-- [![Demo Video of VideoGrain](https://res.cloudinary.com/dii3btvh8/image/upload/v1740987943/cover_video_y6cjfe.png)](https://www.youtube.com/watch?v=XEM4Pex7F9E "Demo Video of VideoGrain") -->
53
+ https://github.com/user-attachments/assets/9bec92fc-21bd-4459-86fa-62404d8762bf
54
+
55
+
56
+ ## 📣 News
57
+ * **[2025/2/25]** Our VideoGrain is posted and recommended by Gradio on [LinkedIn](https://www.linkedin.com/posts/gradio_just-dropped-videograin-a-new-zero-shot-activity-7300094635094261760-hoiE) and [Twitter](https://x.com/Gradio/status/1894328911154028566), and recommended by [AK](https://x.com/_akhaliq/status/1894254599223017622).
58
+ * **[2025/2/25]** Our VideoGrain is submited by AK to [HuggingFace-daily papers](https://huggingface.co/papers?date=2025-02-25), and rank [#1](https://huggingface.co/papers/2502.17258) paper of that day.
59
+ * **[2025/2/24]** We release our paper on [arxiv](https://arxiv.org/abs/2502.17258), we also release [code](https://github.com/knightyxp/VideoGrain) and [full-data](https://drive.google.com/file/d/1dzdvLnXWeMFR3CE2Ew0Bs06vyFSvnGXA/view?usp=drive_link) on google drive.
60
+ * **[2025/1/23]** Our paper is accepted to [ICLR2025](https://openreview.net/forum?id=SSslAtcPB6)! Welcome to **watch** 👀 this repository for the latest updates.
61
+
62
+
63
+ ## 🍻 Setup Environment
64
+ Our method is tested using cuda12.1, fp16 of accelerator and xformers on a single L40.
65
+
66
+ ```bash
67
+ # Step 1: Create and activate Conda environment
68
+ conda create -n videograin python==3.10
69
+ conda activate videograin
70
+
71
+ # Step 2: Install PyTorch, CUDA and Xformers
72
+ conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia
73
+ pip install --pre -U xformers==0.0.27
74
+ # Step 3: Install additional dependencies with pip
75
+ pip install -r requirements.txt
76
+ ```
77
+
78
+ `xformers` is recommended to save memory and running time.
79
+
80
+ </details>
81
+
82
+ You may download all the base model checkpoints using the following bash command
83
+ ```bash
84
+ ## download sd 1.5, controlnet depth/pose v10/v11
85
+ bash download_all.sh
86
+ ```
87
+
88
+ <details><summary>Click for ControlNet annotator weights (if you can not access to huggingface)</summary>
89
+
90
+ You can download all the annotator checkpoints (such as DW-Pose, depth_zoe, depth_midas, and OpenPose, cost around 4G) from [baidu](https://pan.baidu.com/s/1sgBFLFkdTCDTn4oqHjGb9A?pwd=pdm5) or [google](https://drive.google.com/file/d/1qOsmWshnFMMr8x1HteaTViTSQLh_4rle/view?usp=drive_link)
91
+ Then extract them into ./annotator/ckpts
92
+
93
+ </details>
94
+
95
+ ## ⚡️ Prepare all the data
96
+
97
+ ### Full VideoGrain Data
98
+ We have provided `all the video data and layout masks in VideoGrain` at following link. Please download unzip the data and put them in the `./data' root directory.
99
+ ```
100
+ gdown https://drive.google.com/file/d/1dzdvLnXWeMFR3CE2Ew0Bs06vyFSvnGXA/view?usp=drive_link
101
+ tar -zxvf videograin_data.tar.gz
102
+ ```
103
+ ### Customize Your Own Data
104
+ **prepare video to frames**
105
+ If the input video is mp4 file, using the following command to process it to frames:
106
+ ```bash
107
+ python image_util/sample_video2frames.py --video_path 'your video path' --output_dir './data/video_name/video_name'
108
+ ```
109
+ **prepare layout masks**
110
+ We segment videos using our ReLER lab's [SAM-Track](https://github.com/z-x-yang/Segment-and-Track-Anything). I suggest using the `app.py` in SAM-Track for `graio` mode to manually select which region in the video your want to edit. Here, we also provided an script ` image_util/process_webui_mask.py` to process masks from SAM-Track path to VideoGrain path.
111
+
112
+
113
+ ## 🔥🔥🔥 VideoGrain Editing
114
+
115
+ ### 🎨 Inference
116
+ Your can reproduce the instance + part level results in our teaser by running:
117
+
118
+ ```bash
119
+ bash test.sh
120
+ #or
121
+ CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/part_level/adding_new_object/run_two_man/spider_polar_sunglass.yaml
122
+ ```
123
+
124
+ For other instance/part/class results in VideoGrain project page or teaser, we provide all the data (video frames and layout masks) and corresponding configs to reproduce, check results in [🚀Multi-Grained Video Editing](#multi-grained-video-editing-results).
125
+
126
+ <details><summary>The result is saved at `./result` . (Click for directory structure) </summary>
127
+
128
+ ```
129
+ result
130
+ ├── run_two_man
131
+ │ ├── control # control conditon
132
+ │ ├── infer_samples
133
+ │ ├── input # the input video frames
134
+ │ ├── masked_video.mp4 # check whether edit regions are accuratedly covered
135
+ │ ├── sample
136
+ │ ├── step_0 # result image folder
137
+ │ ├── step_0.mp4 # result video
138
+ │ ├── source_video.mp4 # the input video
139
+ │ ├── visualization_denoise # cross attention weight
140
+ │ ├── sd_study # cluster inversion feature
141
+ ```
142
+ </details>
143
+
144
+
145
+ ## Editing guidance for YOUR Video
146
+ ### 🔛prepare your config
147
+
148
+ VideoGrain is a training-free framework. To run VideoGrain on your video, modify `./config/demo_config.yaml` based on your needs:
149
+
150
+ 1. Replace your pretrained model path and controlnet path in your config. you can change the control_type to `dwpose` or `depth_zoe` or `depth`(midas).
151
+ 2. Prepare your video frames and layout masks (edit regions) using SAM-Track or SAM2 in dataset config.
152
+ 3. Change the `prompt`, and extract each `local prompt` in the editing prompts. the local prompt order should be same as layout masks order.
153
+ 4. Your can change flatten resolution with 1->64, 2->16, 4->8. (commonly, flatten at 64 worked best)
154
+ 5. To ensure temporal consistency, you can set `use_pnp: True` and `inject_step:5/10`. (Note: pnp>10 steps will be bad for multi-regions editing)
155
+ 6. If you want to visualize the cross attn weight, set `vis_cross_attn: True`
156
+ 7. If you want to cluster DDIM Inversion spatial temporal video feature, set `cluster_inversion_feature: True`
157
+
158
+ ### 😍Editing your video
159
+
160
+ ```bash
161
+ bash test.sh
162
+ #or
163
+ CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config /path/to/the/config
164
+ ```
165
+
166
+ ## 🚀Multi-Grained Video Editing Results
167
+
168
+ ### 🌈 Multi-Grained Definition
169
+ You can get multi-grained definition result, using the following command:
170
+ ```bash
171
+ CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config /config/class_level/running_two_man/man2spider.yaml #class-level
172
+ # /config/instance_level/running_two_man/4cls_spider_polar.yaml #instance-level
173
+ #config/part_level/adding_new_object/run_two_man/spider_polar_sunglass.yaml #part-level
174
+ ```
175
+ <table class="center">
176
+ <tr>
177
+ <td width=25% style="text-align:center;">source video</td>
178
+ <td width=25% style="text-align:center;">class level</td>
179
+ <td width=25% style="text-align:center;">instance level</td>
180
+ <td width=25% style="text-align:center;">part level</td>
181
+ </tr>
182
+ <tr>
183
+ <td><img src="./assets/teaser/run_two_man.gif"></td>
184
+ <td><img src="./assets/teaser/class_level_0.gif"></td>
185
+ <td><img src="./assets/teaser/instance_level.gif"></td>
186
+ <td><img src="./assets/teaser/part_level.gif"></td>
187
+ </tr>
188
+ </table>
189
+
190
+ ## 💃 Instance-level Video Editing
191
+ You can get instance-level video editing results, using the following command:
192
+ ```bash
193
+ CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/instance_level/running_two_man/running_3cls_iron_spider.yaml
194
+ ```
195
+
196
+ <table class="center">
197
+ <tr>
198
+ <td width=50% style="text-align:center;">running_two_man/3cls_iron_spider.yaml</td>
199
+ <td width=50% style="text-align:center;">2_monkeys/2cls_teddy_bear_koala.yaml</td>
200
+ </tr>
201
+ <tr>
202
+ <td><img src="assets/instance-level/left_iron_right_spider.gif"></td>
203
+ <td><img src="assets/instance-level/teddy_koala.gif"></td>
204
+ </tr>
205
+ <tr>
206
+ <td width=50% style="text-align:center;">badminton/2cls_wonder_woman_spiderman.yaml</td>
207
+ <td width=50% style="text-align:center;">soap-box/soap-box.yaml</td>
208
+ </tr>
209
+ <tr>
210
+ <td><img src="assets/instance-level/badminton.gif"></td>
211
+ <td><img src="assets/teaser/soap-box.gif"></td>
212
+ </tr>
213
+ <tr>
214
+ <td width=50% style="text-align:center;">2_cats/4cls_panda_vs_poddle.yaml</td>
215
+ <td width=50% style="text-align:center;">2_cars/left_firetruck_right_bus.yaml</td>
216
+ </tr>
217
+ <tr>
218
+ <td><img src="assets/instance-level/panda_vs_poddle.gif"></td>
219
+ <td><img src="assets/instance-level/2cars.gif"></td>
220
+ </tr>
221
+ </table>
222
+
223
+ ## 🕺 Part-level Video Editing
224
+ You can get part-level video editing results, using the following command:
225
+ ```bash
226
+ CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/part_level/modification/man_text_message/blue_shirt.yaml
227
+ ```
228
+
229
+ <table class="center">
230
+ <tr>
231
+ <td><img src="assets/part-level/man_text_message.gif"></td>
232
+ <td><img src="assets/part-level/blue-shirt.gif"></td>
233
+ <td><img src="assets/part-level/black-suit.gif"></td>
234
+ <td><img src="assets/part-level/cat_flower.gif"></td>
235
+ <td><img src="assets/part-level/ginger_head.gif"></td>
236
+ <td><img src="assets/part-level/ginger_body.gif"></td>
237
+ </tr>
238
+ <tr>
239
+ <td width=15% style="text-align:center;">source video</td>
240
+ <td width=15% style="text-align:center;">blue shirt</td>
241
+ <td width=15% style="text-align:center;">black suit</td>
242
+ <td width=15% style="text-align:center;">source video</td>
243
+ <td width=15% style="text-align:center;">ginger head </td>
244
+ <td width=15% style="text-align:center;">ginger body</td>
245
+ </tr>
246
+ <tr>
247
+ <td><img src="assets/part-level/man_text_message.gif"></td>
248
+ <td><img src="assets/part-level/superman.gif"></td>
249
+ <td><img src="assets/part-level/superman+cap.gif"></td>
250
+ <td><img src="assets/part-level/spin-ball.gif"></td>
251
+ <td><img src="assets/part-level/superman_spin.gif"></td>
252
+ <td><img src="assets/part-level/super_sunglass_spin.gif"></td>
253
+ </tr>
254
+ <tr>
255
+ <td width=15% style="text-align:center;">source video</td>
256
+ <td width=15% style="text-align:center;">superman</td>
257
+ <td width=15% style="text-align:center;">superman + cap</td>
258
+ <td width=15% style="text-align:center;">source video</td>
259
+ <td width=15% style="text-align:center;">superman </td>
260
+ <td width=15% style="text-align:center;">superman + sunglasses</td>
261
+ </tr>
262
+ </table>
263
+
264
+ ## 🥳 Class-level Video Editing
265
+ You can get class-level video editing results, using the following command:
266
+ ```bash
267
+ CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/class_level/wolf/wolf.yaml
268
+ ```
269
+
270
+ <table class="center">
271
+ <tr>
272
+ <td><img src="assets/class-level/wolf.gif"></td>
273
+ <td><img src="assets/class-level/pig.gif"></td>
274
+ <td><img src="assets/class-level/husky.gif"></td>
275
+ <td><img src="assets/class-level/bear.gif"></td>
276
+ <td><img src="assets/class-level/tiger.gif"></td>
277
+ </tr>
278
+ <tr>
279
+ <td width=15% style="text-align:center;">input</td>
280
+ <td width=15% style="text-align:center;">pig</td>
281
+ <td width=15% style="text-align:center;">husky</td>
282
+ <td width=15% style="text-align:center;">bear</td>
283
+ <td width=15% style="text-align:center;">tiger</td>
284
+ </tr>
285
+ <tr>
286
+ <td><img src="assets/class-level/tennis.gif"></td>
287
+ <td><img src="assets/class-level/tennis_1cls.gif"></td>
288
+ <td><img src="assets/class-level/tennis_3cls.gif"></td>
289
+ <td><img src="assets/class-level/car-1.gif"></td>
290
+ <td><img src="assets/class-level/posche.gif"></td>
291
+ </tr>
292
+ <tr>
293
+ <td width=15% style="text-align:center;">input</td>
294
+ <td width=15% style="text-align:center;">iron man</td>
295
+ <td width=15% style="text-align:center;">Batman + snow court + iced wall</td>
296
+ <td width=15% style="text-align:center;">input </td>
297
+ <td width=15% style="text-align:center;">posche</td>
298
+ </tr>
299
+ </table>
300
+
301
+
302
+ ## Soely Edit on specific subjects, keep background unchanged
303
+ You can get soely video editing results, using the following command:
304
+ ```bash
305
+ CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/instance_level/soely_edit/only_left.yaml
306
+ #--config config/instance_level/soely_edit/only_right.yaml
307
+ #--config config/instance_level/soely_edit/joint_edit.yaml
308
+ ```
309
+
310
+ <table class="center">
311
+ <tr>
312
+ <td><img src="assets/soely_edit/input.gif"></td>
313
+ <td><img src="assets/soely_edit/left.gif"></td>
314
+ <td><img src="assets/soely_edit/right.gif"></td>
315
+ <td><img src="assets/soely_edit/joint.gif"></td>
316
+ </tr>
317
+ <tr>
318
+ <td width=25% style="text-align:center;">source video</td>
319
+ <td width=25% style="text-align:center;">left→Iron Man</td>
320
+ <td width=25% style="text-align:center;">right→Spiderman</td>
321
+ <td width=25% style="text-align:center;">joint edit</td>
322
+ </tr>
323
+ </table>
324
+
325
+ ## 🔍 Visualize Cross Attention Weight
326
+ You can get visulize attention weight editing results, using the following command:
327
+ ```bash
328
+ #setting vis_cross_attn: True in your config
329
+ CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/instance_level/running_two_man/3cls_spider_polar_vis_weight.yaml
330
+ ```
331
+
332
+ <table class="center">
333
+ <tr>
334
+ <td><img src="assets/soely_edit/input.gif"></td>
335
+ <td><img src="assets/vis/edit.gif"></td>
336
+ <td><img src="assets/vis/spiderman_weight.gif"></td>
337
+ <td><img src="assets/vis/bear_weight.gif"></td>
338
+ <td><img src="/assets/vis/cherry_weight.gif"></td>
339
+ </tr>
340
+ <tr>
341
+ <td width=20% style="text-align:center;">source video</td>
342
+ <td width=20% style="text-align:center;">left→spiderman, right→polar bear, trees→cherry blossoms</td>
343
+ <td width=20% style="text-align:center;">spiderman weight</td>
344
+ <td width=20% style="text-align:center;">bear weight</td>
345
+ <td width=20% style="text-align:center;">cherry weight</td>
346
+ </tr>
347
+ </table>
348
+
349
+ ## ✏️ Citation
350
+ If you think this project is helpful, please feel free to leave a star⭐️⭐️⭐️ and cite our paper:
351
+ ```bibtex
352
+ @article{yang2025videograin,
353
+ title={VideoGrain: Modulating Space-Time Attention for Multi-grained Video Editing},
354
+ author={Yang, Xiangpeng and Zhu, Linchao and Fan, Hehe and Yang, Yi},
355
+ journal={arXiv preprint arXiv:2502.17258},
356
+ year={2025}
357
+ }
358
+ ```
359
+
360
+ ## 📞 Contact Authors
361
+ Xiangpeng Yang [@knightyxp](https://github.com/knightyxp), email: [email protected]/[email protected]
362
+
363
+ ## ✨ Acknowledgements
364
+
365
+ - This code builds on [diffusers](https://github.com/huggingface/diffusers), and [FateZero](https://github.com/ChenyangQiQi/FateZero). Thanks for open-sourcing!
366
+ - We would like to thank [AK(@_akhaliq)](https://x.com/_akhaliq/status/1894254599223017622) and Gradio team for recommendation!
367
+
368
+
369
+ ## ⭐️ Star History
370
+
371
+ [![Star History Chart](https://api.star-history.com/svg?repos=knightyxp/VideoGrain&type=Date)](https://star-history.com/#knightyxp/VideoGrain&Date)
config/part_level/adding_new_object/run_two_man/spider_polar_sunglass.yaml CHANGED
@@ -1,4 +1,4 @@
1
- pretrained_model_path: "ckpt/stable-diffusion-v1-5"
2
  logdir: ./result/part_level/run_two_man/left2spider_right2polar-sunglasses
3
 
4
  dataset_config:
 
1
+ pretrained_model_path: "./ckpt/stable-diffusion-v1-5"
2
  logdir: ./result/part_level/run_two_man/left2spider_right2polar-sunglasses
3
 
4
  dataset_config:
image.png ADDED
video_diffusion/data/__pycache__/dataset.cpython-310.pyc CHANGED
Binary files a/video_diffusion/data/__pycache__/dataset.cpython-310.pyc and b/video_diffusion/data/__pycache__/dataset.cpython-310.pyc differ
 
video_diffusion/data/dataset.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import numpy as np
3
  from PIL import Image
4
  from einops import rearrange
@@ -10,24 +11,26 @@ from torch.utils.data import Dataset
10
  from .transform import short_size_scale, random_crop, center_crop, offset_crop
11
  from ..common.image_util import IMAGE_EXTENSION
12
  import cv2
13
- import imageio
14
- import shutil
15
 
16
  class ImageSequenceDataset(Dataset):
17
  def __init__(
18
  self,
19
- path: str, # 输入视频,如果是 mp4 则转换到固定目录 './input-video'
20
- layout_files: list, # 上传的 layout mask 文件列表(mp4 或目录),转换后存放到固定目录 './layout_masks/1', './layout_masks/2', ...
 
21
  prompt_ids: torch.Tensor,
22
  prompt: str,
23
- start_sample_frame: int = 0,
24
  n_sample_frame: int = 8,
25
  sampling_rate: int = 1,
26
- stride: int = -1, # tuning 时用于对长视频进行采样
27
  image_mode: str = "RGB",
28
  image_size: int = 512,
29
  crop: str = "center",
30
-
 
 
 
31
  offset: dict = {
32
  "left": 0,
33
  "right": 0,
@@ -35,42 +38,33 @@ class ImageSequenceDataset(Dataset):
35
  "bottom": 0
36
  },
37
  **args
 
38
  ):
39
- # 若输入视频是 mp4,则转换到固定目录 './input-video'
40
- if path.endswith('.mp4'):
41
- self.path = self.mp4_to_png(path, target_dir='./input-video')
42
- else:
43
- self.path = path
44
- self.images = self.get_image_list(self.path)
45
 
46
- # 对每个上传的 layout 文件进行处理
47
- # 若是 mp4,则转换到固定目录 './layout_masks/{i+1}'
48
- self.layout_mask_dirs = []
49
- for idx, file in enumerate(layout_files):
50
- if file.endswith('.mp4'):
51
- folder = self.mp4_to_png(file, target_dir=f'./layout_masks/{idx+1}')
52
- else:
53
- folder = file
54
- self.layout_mask_dirs.append(folder)
55
- # 保持上传顺序作为 layout_mask_order(此处仅用索引表示顺序)
56
- self.layout_mask_order = list(range(len(self.layout_mask_dirs)))
57
- # 用第一个 layout mask 目录获取 mask 图像索引(用于判断帧数)
58
- self.masks_index = self.get_image_list(self.layout_mask_dirs[0])
59
 
 
60
  self.n_images = len(self.images)
61
  self.offset = offset
62
  self.start_sample_frame = start_sample_frame
63
  if n_sample_frame < 0:
64
- n_sample_frame = len(self.images)
65
  self.n_sample_frame = n_sample_frame
 
66
  self.sampling_rate = sampling_rate
67
 
68
  self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1
69
  if self.n_images < self.sequence_length:
70
- raise ValueError(f"self.n_images {self.n_images} < self.sequence_length {self.sequence_length}: Required number of frames {self.sequence_length} larger than total frames in the dataset {self.n_images}")
71
 
72
- # 若视频太长,则全局采样
73
- self.stride = stride if stride > 0 else (self.n_images + 1)
74
  self.video_len = (self.n_images - self.sequence_length) // self.stride + 1
75
 
76
  self.image_mode = image_mode
@@ -80,53 +74,67 @@ class ImageSequenceDataset(Dataset):
80
  "random": random_crop,
81
  }
82
  if crop not in crop_methods:
83
- raise ValueError("Unsupported crop method")
84
  self.crop = crop_methods[crop]
85
 
86
  self.prompt = prompt
87
  self.prompt_ids = prompt_ids
 
 
 
 
 
 
88
 
89
 
90
  def __len__(self):
91
  max_len = (self.n_images - self.sequence_length) // self.stride + 1
 
92
  if hasattr(self, 'num_class_images'):
93
  max_len = max(max_len, self.num_class_images)
 
94
  return max_len
95
 
96
  def __getitem__(self, index):
97
  return_batch = {}
98
- frame_indices = self.get_frame_indices(index % self.video_len)
99
  frames = [self.load_frame(i) for i in frame_indices]
100
  frames = self.transform(frames)
101
 
102
  layout_ = []
103
- # 遍历每个 layout mask 目录(顺序与用户上传顺序一致)
104
- for layout_dir in self.layout_mask_dirs:
105
- # 对于每个 layout 目录,根据帧索引读取对应的 mask 图像(PNG 文件)
106
- frame_indices_local = self.get_frame_indices(index % self.video_len)
107
- mask = [self._read_mask(layout_dir, i) for i in frame_indices_local]
108
- masks = np.stack(mask) # shape: (n_sample_frame, c, h, w)
109
  layout_.append(masks)
110
- layout_ = np.stack(layout_) # shape: (num_layouts, n_sample_frame, c, h, w)
111
-
112
  merged_masks = []
113
  for i in range(int(self.n_sample_frame)):
114
- merged_mask_frame = np.sum(layout_[:, i, :, :, :], axis=0)
115
- merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8)
116
  merged_masks.append(merged_mask_frame)
117
  masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w")
118
  masks = torch.from_numpy(masks).half()
119
 
120
- layouts = rearrange(layout_, "s f c h w -> f s c h w")
121
  layouts = torch.from_numpy(layouts).half()
122
 
123
- return_batch.update({
 
124
  "images": frames,
125
- "masks": masks,
126
- "layouts": layouts,
127
  "prompt_ids": self.prompt_ids,
128
- })
129
-
 
 
 
 
 
 
 
130
  return return_batch
131
 
132
  def transform(self, frames):
@@ -141,18 +149,24 @@ class ImageSequenceDataset(Dataset):
141
  frames = rearrange(np.stack(frames), "f h w c -> c f h w")
142
  return torch.from_numpy(frames).div(255) * 2 - 1
143
 
144
- def _read_mask(self, mask_dir, index: int):
145
- # 构造 mask 文件名(png 格式)
146
- mask_path = os.path.join(mask_dir, f"{index:05d}.png")
 
 
 
147
  mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
148
  mask = (mask > 0).astype(np.uint8)
149
- # 根据原图大小动态缩放(这里缩小8倍)
150
  height, width = mask.shape
151
  dest_size = (width // 8, height // 8)
152
- mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST)
 
153
  mask = mask[np.newaxis, ...]
 
154
  return mask
155
 
 
156
  def load_frame(self, index):
157
  image_path = os.path.join(self.path, self.images[index])
158
  return Image.open(image_path).convert(self.image_mode)
@@ -170,31 +184,12 @@ class ImageSequenceDataset(Dataset):
170
 
171
  def get_class_indices(self, index):
172
  frame_start = index
173
- return (frame_start + i for i in range(self.n_sample_frame))
174
 
175
  @staticmethod
176
  def get_image_list(path):
177
  images = []
178
- # 如果传入的是 mp4 文件,则先转换成 PNG 图像目录
179
- if path.endswith('.mp4'):
180
- path = ImageSequenceDataset.mp4_to_png(path, target_dir='./input-video')
181
  for file in sorted(os.listdir(path)):
182
  if file.endswith(IMAGE_EXTENSION):
183
  images.append(file)
184
- return images
185
-
186
- @staticmethod
187
- def mp4_to_png(video_source: str, target_dir: str):
188
- """
189
- Convert an mp4 video to a sequence of PNG images, storing them in target_dir.
190
- target_dir 为固定路径,例如:'./input-video' 或 './layout_masks/1'
191
- """
192
- if os.path.exists(target_dir):
193
- shutil.rmtree(target_dir)
194
- os.makedirs(target_dir, exist_ok=True)
195
-
196
- reader = imageio.get_reader(video_source)
197
- for i, im in enumerate(reader):
198
- path = os.path.join(target_dir, f"{i:05d}.png")
199
- cv2.imwrite(path, im[:, :, ::-1])
200
- return target_dir
 
1
  import os
2
+
3
  import numpy as np
4
  from PIL import Image
5
  from einops import rearrange
 
11
  from .transform import short_size_scale, random_crop, center_crop, offset_crop
12
  from ..common.image_util import IMAGE_EXTENSION
13
  import cv2
 
 
14
 
15
  class ImageSequenceDataset(Dataset):
16
  def __init__(
17
  self,
18
+ path: str,
19
+ layout_mask_dir: str,
20
+ layout_mask_order: list,
21
  prompt_ids: torch.Tensor,
22
  prompt: str,
23
+ start_sample_frame: int=0,
24
  n_sample_frame: int = 8,
25
  sampling_rate: int = 1,
26
+ stride: int = -1, # only used during tuning to sample a long video
27
  image_mode: str = "RGB",
28
  image_size: int = 512,
29
  crop: str = "center",
30
+
31
+ class_data_root: str = None,
32
+ class_prompt_ids: torch.Tensor = None,
33
+
34
  offset: dict = {
35
  "left": 0,
36
  "right": 0,
 
38
  "bottom": 0
39
  },
40
  **args
41
+
42
  ):
43
+ self.path = path
44
+ self.images = self.get_image_list(path)
45
+ #
46
+ self.layout_mask_dir = layout_mask_dir
47
+ self.layout_mask_order = list(layout_mask_order)
 
48
 
49
+ layout_mask_dir0 = os.path.join(self.layout_mask_dir,self.layout_mask_order[0])
50
+ self.masks_index = self.get_image_list(layout_mask_dir0)
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ #
53
  self.n_images = len(self.images)
54
  self.offset = offset
55
  self.start_sample_frame = start_sample_frame
56
  if n_sample_frame < 0:
57
+ n_sample_frame = len(self.images)
58
  self.n_sample_frame = n_sample_frame
59
+ # local sampling rate from the video
60
  self.sampling_rate = sampling_rate
61
 
62
  self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1
63
  if self.n_images < self.sequence_length:
64
+ raise ValueError(f"self.n_images {self.n_images } < self.sequence_length {self.sequence_length}: Required number of frames {self.sequence_length} larger than total frames in the dataset {self.n_images }")
65
 
66
+ # During tuning if video is too long, we sample the long video every self.stride globally
67
+ self.stride = stride if stride > 0 else (self.n_images+1)
68
  self.video_len = (self.n_images - self.sequence_length) // self.stride + 1
69
 
70
  self.image_mode = image_mode
 
74
  "random": random_crop,
75
  }
76
  if crop not in crop_methods:
77
+ raise ValueError
78
  self.crop = crop_methods[crop]
79
 
80
  self.prompt = prompt
81
  self.prompt_ids = prompt_ids
82
+ # Negative prompt for regularization to avoid overfitting during one-shot tuning
83
+ if class_data_root is not None:
84
+ self.class_data_root = Path(class_data_root)
85
+ self.class_images_path = sorted(list(self.class_data_root.iterdir()))
86
+ self.num_class_images = len(self.class_images_path)
87
+ self.class_prompt_ids = class_prompt_ids
88
 
89
 
90
  def __len__(self):
91
  max_len = (self.n_images - self.sequence_length) // self.stride + 1
92
+
93
  if hasattr(self, 'num_class_images'):
94
  max_len = max(max_len, self.num_class_images)
95
+
96
  return max_len
97
 
98
  def __getitem__(self, index):
99
  return_batch = {}
100
+ frame_indices = self.get_frame_indices(index%self.video_len)
101
  frames = [self.load_frame(i) for i in frame_indices]
102
  frames = self.transform(frames)
103
 
104
  layout_ = []
105
+ for layout_name in self.layout_mask_order:
106
+ frame_indices = self.get_frame_indices(index%self.video_len)
107
+ layout_mask_dir = os.path.join(self.layout_mask_dir,layout_name)
108
+ mask = [self._read_mask(layout_mask_dir,i) for i in frame_indices]
109
+ masks = np.stack(mask)
 
110
  layout_.append(masks)
111
+ layout_ = np.stack(layout_)
 
112
  merged_masks = []
113
  for i in range(int(self.n_sample_frame)):
114
+ merged_mask_frame = np.sum(layout_[:,i,:,:,:], axis=0)
115
+ merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8)
116
  merged_masks.append(merged_mask_frame)
117
  masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w")
118
  masks = torch.from_numpy(masks).half()
119
 
120
+ layouts = rearrange(layout_,"s f c h w -> f s c h w" )
121
  layouts = torch.from_numpy(layouts).half()
122
 
123
+ return_batch.update(
124
+ {
125
  "images": frames,
126
+ "masks":masks,
127
+ "layouts":layouts,
128
  "prompt_ids": self.prompt_ids,
129
+ }
130
+ )
131
+
132
+ if hasattr(self, 'class_data_root'):
133
+ class_index = index % (self.num_class_images - self.n_sample_frame)
134
+ class_indices = self.get_class_indices(class_index)
135
+ frames = [self.load_class_frame(i) for i in class_indices]
136
+ return_batch["class_images"] = self.tensorize_frames(frames)
137
+ return_batch["class_prompt_ids"] = self.class_prompt_ids
138
  return return_batch
139
 
140
  def transform(self, frames):
 
149
  frames = rearrange(np.stack(frames), "f h w c -> c f h w")
150
  return torch.from_numpy(frames).div(255) * 2 - 1
151
 
152
+ def _read_mask(self, mask_path,index: int):
153
+ ### read mask by pil
154
+
155
+ mask_path = os.path.join(mask_path,f"{index:05d}.png")
156
+
157
+ ### read mask by cv2
158
  mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
159
  mask = (mask > 0).astype(np.uint8)
160
+ # Determine dynamic destination size
161
  height, width = mask.shape
162
  dest_size = (width // 8, height // 8)
163
+ # Resize using nearest neighbor interpolation
164
+ mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST) #cv2.INTER_CUBIC
165
  mask = mask[np.newaxis, ...]
166
+
167
  return mask
168
 
169
+
170
  def load_frame(self, index):
171
  image_path = os.path.join(self.path, self.images[index])
172
  return Image.open(image_path).convert(self.image_mode)
 
184
 
185
  def get_class_indices(self, index):
186
  frame_start = index
187
+ return (frame_start + i for i in range(self.n_sample_frame))
188
 
189
  @staticmethod
190
  def get_image_list(path):
191
  images = []
 
 
 
192
  for file in sorted(os.listdir(path)):
193
  if file.endswith(IMAGE_EXTENSION):
194
  images.append(file)
195
+ return images