Spaces:
Configuration error
Configuration error
Commit
·
952c41a
1
Parent(s):
fdab143
huggingface dataset
Browse files
README.md
CHANGED
@@ -1,12 +1,371 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VideoGrain: Modulating Space-Time Attention for Multi-Grained Video Editing (ICLR 2025)
|
2 |
+
## [<a href="https://knightyxp.github.io/VideoGrain_project_page/" target="_blank">Project Page</a>]
|
3 |
+
|
4 |
+
[](https://arxiv.org/abs/2502.17258)
|
5 |
+
[](https://huggingface.co/papers/2502.17258)
|
6 |
+
[](https://knightyxp.github.io/VideoGrain_project_page/)
|
7 |
+
[](https://drive.google.com/file/d/1dzdvLnXWeMFR3CE2Ew0Bs06vyFSvnGXA/view?usp=drive_link)
|
8 |
+

|
9 |
+
[](https://www.youtube.com/watch?v=XEM4Pex7F9E)
|
10 |
+
[](https://huggingface.co/datasets/XiangpengYang/VideoGrain-dataset)
|
11 |
+
|
12 |
+
|
13 |
+
## Introduction
|
14 |
+
VideoGrain is a zero-shot method for class-level, instance-level, and part-level video editing.
|
15 |
+
- **Multi-grained Video Editing**
|
16 |
+
- class-level: Editing objects within the same class (previous SOTA limited to this level)
|
17 |
+
- instance-level: Editing each individual instance to distinct object
|
18 |
+
- part-level: Adding new objects or modifying existing attributes at the part-level
|
19 |
+
- **Training-Free**
|
20 |
+
- Does not require any training/fine-tuning
|
21 |
+
- **One-Prompt Multi-region Control & Deep investigations about cross/self attn**
|
22 |
+
- modulating cross-attn for multi-regions control (visualizations available)
|
23 |
+
- modulating self-attn for feature decoupling (clustering are available)
|
24 |
+
|
25 |
+
<table class="center" border="1" cellspacing="0" cellpadding="5">
|
26 |
+
<tr>
|
27 |
+
<td colspan="2" style="text-align:center;"><img src="assets/teaser/class_level.gif" style="width:250px; height:auto;"></td>
|
28 |
+
<td colspan="2" style="text-align:center;"><img src="assets/teaser/instance_part.gif" style="width:250px; height:auto;"></td>
|
29 |
+
<td colspan="2" style="text-align:center;"><img src="assets/teaser/2monkeys.gif" style="width:250px; height:auto;"></td>
|
30 |
+
</tr>
|
31 |
+
<tr>
|
32 |
+
<!-- <td colspan="1" style="text-align:right; width:125px;"> </td> -->
|
33 |
+
<td colspan="2" style="text-align:right; width:250px;"> class level</td>
|
34 |
+
<td colspan="1" style="text-align:center; width:125px;">instance level</td>
|
35 |
+
<td colspan="1" style="text-align:center; width:125px;">part level</td>
|
36 |
+
<td colspan="2" style="text-align:center; width:250px;">animal instances</td>
|
37 |
+
</tr>
|
38 |
+
|
39 |
+
<tr>
|
40 |
+
<td colspan="2" style="text-align:center;"><img src="assets/teaser/2cats.gif" style="width:250px; height:auto;"></td>
|
41 |
+
<td colspan="2" style="text-align:center;"><img src="assets/teaser/soap-box.gif" style="width:250px; height:auto;"></td>
|
42 |
+
<td colspan="2" style="text-align:center;"><img src="assets/teaser/man-text-message.gif" style="width:250px; height:auto;"></td>
|
43 |
+
</tr>
|
44 |
+
<tr>
|
45 |
+
<td colspan="2" style="text-align:center; width:250px;">animal instances</td>
|
46 |
+
<td colspan="2" style="text-align:center; width:250px;">human instances</td>
|
47 |
+
<td colspan="2" style="text-align:center; width:250px;">part-level modification</td>
|
48 |
+
</tr>
|
49 |
+
</table>
|
50 |
+
|
51 |
+
## 📀 Demo Video
|
52 |
+
<!-- [](https://www.youtube.com/watch?v=XEM4Pex7F9E "Demo Video of VideoGrain") -->
|
53 |
+
https://github.com/user-attachments/assets/9bec92fc-21bd-4459-86fa-62404d8762bf
|
54 |
+
|
55 |
+
|
56 |
+
## 📣 News
|
57 |
+
* **[2025/2/25]** Our VideoGrain is posted and recommended by Gradio on [LinkedIn](https://www.linkedin.com/posts/gradio_just-dropped-videograin-a-new-zero-shot-activity-7300094635094261760-hoiE) and [Twitter](https://x.com/Gradio/status/1894328911154028566), and recommended by [AK](https://x.com/_akhaliq/status/1894254599223017622).
|
58 |
+
* **[2025/2/25]** Our VideoGrain is submited by AK to [HuggingFace-daily papers](https://huggingface.co/papers?date=2025-02-25), and rank [#1](https://huggingface.co/papers/2502.17258) paper of that day.
|
59 |
+
* **[2025/2/24]** We release our paper on [arxiv](https://arxiv.org/abs/2502.17258), we also release [code](https://github.com/knightyxp/VideoGrain) and [full-data](https://drive.google.com/file/d/1dzdvLnXWeMFR3CE2Ew0Bs06vyFSvnGXA/view?usp=drive_link) on google drive.
|
60 |
+
* **[2025/1/23]** Our paper is accepted to [ICLR2025](https://openreview.net/forum?id=SSslAtcPB6)! Welcome to **watch** 👀 this repository for the latest updates.
|
61 |
+
|
62 |
+
|
63 |
+
## 🍻 Setup Environment
|
64 |
+
Our method is tested using cuda12.1, fp16 of accelerator and xformers on a single L40.
|
65 |
+
|
66 |
+
```bash
|
67 |
+
# Step 1: Create and activate Conda environment
|
68 |
+
conda create -n videograin python==3.10
|
69 |
+
conda activate videograin
|
70 |
+
|
71 |
+
# Step 2: Install PyTorch, CUDA and Xformers
|
72 |
+
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia
|
73 |
+
pip install --pre -U xformers==0.0.27
|
74 |
+
# Step 3: Install additional dependencies with pip
|
75 |
+
pip install -r requirements.txt
|
76 |
+
```
|
77 |
+
|
78 |
+
`xformers` is recommended to save memory and running time.
|
79 |
+
|
80 |
+
</details>
|
81 |
+
|
82 |
+
You may download all the base model checkpoints using the following bash command
|
83 |
+
```bash
|
84 |
+
## download sd 1.5, controlnet depth/pose v10/v11
|
85 |
+
bash download_all.sh
|
86 |
+
```
|
87 |
+
|
88 |
+
<details><summary>Click for ControlNet annotator weights (if you can not access to huggingface)</summary>
|
89 |
+
|
90 |
+
You can download all the annotator checkpoints (such as DW-Pose, depth_zoe, depth_midas, and OpenPose, cost around 4G) from [baidu](https://pan.baidu.com/s/1sgBFLFkdTCDTn4oqHjGb9A?pwd=pdm5) or [google](https://drive.google.com/file/d/1qOsmWshnFMMr8x1HteaTViTSQLh_4rle/view?usp=drive_link)
|
91 |
+
Then extract them into ./annotator/ckpts
|
92 |
+
|
93 |
+
</details>
|
94 |
+
|
95 |
+
## ⚡️ Prepare all the data
|
96 |
+
|
97 |
+
### Full VideoGrain Data
|
98 |
+
We have provided `all the video data and layout masks in VideoGrain` at following link. Please download unzip the data and put them in the `./data' root directory.
|
99 |
+
```
|
100 |
+
gdown https://drive.google.com/file/d/1dzdvLnXWeMFR3CE2Ew0Bs06vyFSvnGXA/view?usp=drive_link
|
101 |
+
tar -zxvf videograin_data.tar.gz
|
102 |
+
```
|
103 |
+
### Customize Your Own Data
|
104 |
+
**prepare video to frames**
|
105 |
+
If the input video is mp4 file, using the following command to process it to frames:
|
106 |
+
```bash
|
107 |
+
python image_util/sample_video2frames.py --video_path 'your video path' --output_dir './data/video_name/video_name'
|
108 |
+
```
|
109 |
+
**prepare layout masks**
|
110 |
+
We segment videos using our ReLER lab's [SAM-Track](https://github.com/z-x-yang/Segment-and-Track-Anything). I suggest using the `app.py` in SAM-Track for `graio` mode to manually select which region in the video your want to edit. Here, we also provided an script ` image_util/process_webui_mask.py` to process masks from SAM-Track path to VideoGrain path.
|
111 |
+
|
112 |
+
|
113 |
+
## 🔥🔥🔥 VideoGrain Editing
|
114 |
+
|
115 |
+
### 🎨 Inference
|
116 |
+
Your can reproduce the instance + part level results in our teaser by running:
|
117 |
+
|
118 |
+
```bash
|
119 |
+
bash test.sh
|
120 |
+
#or
|
121 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/part_level/adding_new_object/run_two_man/spider_polar_sunglass.yaml
|
122 |
+
```
|
123 |
+
|
124 |
+
For other instance/part/class results in VideoGrain project page or teaser, we provide all the data (video frames and layout masks) and corresponding configs to reproduce, check results in [🚀Multi-Grained Video Editing](#multi-grained-video-editing-results).
|
125 |
+
|
126 |
+
<details><summary>The result is saved at `./result` . (Click for directory structure) </summary>
|
127 |
+
|
128 |
+
```
|
129 |
+
result
|
130 |
+
├── run_two_man
|
131 |
+
│ ├── control # control conditon
|
132 |
+
│ ├── infer_samples
|
133 |
+
│ ├── input # the input video frames
|
134 |
+
│ ├── masked_video.mp4 # check whether edit regions are accuratedly covered
|
135 |
+
│ ├── sample
|
136 |
+
│ ├── step_0 # result image folder
|
137 |
+
│ ├── step_0.mp4 # result video
|
138 |
+
│ ├── source_video.mp4 # the input video
|
139 |
+
│ ├── visualization_denoise # cross attention weight
|
140 |
+
│ ├── sd_study # cluster inversion feature
|
141 |
+
```
|
142 |
+
</details>
|
143 |
+
|
144 |
+
|
145 |
+
## Editing guidance for YOUR Video
|
146 |
+
### 🔛prepare your config
|
147 |
+
|
148 |
+
VideoGrain is a training-free framework. To run VideoGrain on your video, modify `./config/demo_config.yaml` based on your needs:
|
149 |
+
|
150 |
+
1. Replace your pretrained model path and controlnet path in your config. you can change the control_type to `dwpose` or `depth_zoe` or `depth`(midas).
|
151 |
+
2. Prepare your video frames and layout masks (edit regions) using SAM-Track or SAM2 in dataset config.
|
152 |
+
3. Change the `prompt`, and extract each `local prompt` in the editing prompts. the local prompt order should be same as layout masks order.
|
153 |
+
4. Your can change flatten resolution with 1->64, 2->16, 4->8. (commonly, flatten at 64 worked best)
|
154 |
+
5. To ensure temporal consistency, you can set `use_pnp: True` and `inject_step:5/10`. (Note: pnp>10 steps will be bad for multi-regions editing)
|
155 |
+
6. If you want to visualize the cross attn weight, set `vis_cross_attn: True`
|
156 |
+
7. If you want to cluster DDIM Inversion spatial temporal video feature, set `cluster_inversion_feature: True`
|
157 |
+
|
158 |
+
### 😍Editing your video
|
159 |
+
|
160 |
+
```bash
|
161 |
+
bash test.sh
|
162 |
+
#or
|
163 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config /path/to/the/config
|
164 |
+
```
|
165 |
+
|
166 |
+
## 🚀Multi-Grained Video Editing Results
|
167 |
+
|
168 |
+
### 🌈 Multi-Grained Definition
|
169 |
+
You can get multi-grained definition result, using the following command:
|
170 |
+
```bash
|
171 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config /config/class_level/running_two_man/man2spider.yaml #class-level
|
172 |
+
# /config/instance_level/running_two_man/4cls_spider_polar.yaml #instance-level
|
173 |
+
#config/part_level/adding_new_object/run_two_man/spider_polar_sunglass.yaml #part-level
|
174 |
+
```
|
175 |
+
<table class="center">
|
176 |
+
<tr>
|
177 |
+
<td width=25% style="text-align:center;">source video</td>
|
178 |
+
<td width=25% style="text-align:center;">class level</td>
|
179 |
+
<td width=25% style="text-align:center;">instance level</td>
|
180 |
+
<td width=25% style="text-align:center;">part level</td>
|
181 |
+
</tr>
|
182 |
+
<tr>
|
183 |
+
<td><img src="./assets/teaser/run_two_man.gif"></td>
|
184 |
+
<td><img src="./assets/teaser/class_level_0.gif"></td>
|
185 |
+
<td><img src="./assets/teaser/instance_level.gif"></td>
|
186 |
+
<td><img src="./assets/teaser/part_level.gif"></td>
|
187 |
+
</tr>
|
188 |
+
</table>
|
189 |
+
|
190 |
+
## 💃 Instance-level Video Editing
|
191 |
+
You can get instance-level video editing results, using the following command:
|
192 |
+
```bash
|
193 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/instance_level/running_two_man/running_3cls_iron_spider.yaml
|
194 |
+
```
|
195 |
+
|
196 |
+
<table class="center">
|
197 |
+
<tr>
|
198 |
+
<td width=50% style="text-align:center;">running_two_man/3cls_iron_spider.yaml</td>
|
199 |
+
<td width=50% style="text-align:center;">2_monkeys/2cls_teddy_bear_koala.yaml</td>
|
200 |
+
</tr>
|
201 |
+
<tr>
|
202 |
+
<td><img src="assets/instance-level/left_iron_right_spider.gif"></td>
|
203 |
+
<td><img src="assets/instance-level/teddy_koala.gif"></td>
|
204 |
+
</tr>
|
205 |
+
<tr>
|
206 |
+
<td width=50% style="text-align:center;">badminton/2cls_wonder_woman_spiderman.yaml</td>
|
207 |
+
<td width=50% style="text-align:center;">soap-box/soap-box.yaml</td>
|
208 |
+
</tr>
|
209 |
+
<tr>
|
210 |
+
<td><img src="assets/instance-level/badminton.gif"></td>
|
211 |
+
<td><img src="assets/teaser/soap-box.gif"></td>
|
212 |
+
</tr>
|
213 |
+
<tr>
|
214 |
+
<td width=50% style="text-align:center;">2_cats/4cls_panda_vs_poddle.yaml</td>
|
215 |
+
<td width=50% style="text-align:center;">2_cars/left_firetruck_right_bus.yaml</td>
|
216 |
+
</tr>
|
217 |
+
<tr>
|
218 |
+
<td><img src="assets/instance-level/panda_vs_poddle.gif"></td>
|
219 |
+
<td><img src="assets/instance-level/2cars.gif"></td>
|
220 |
+
</tr>
|
221 |
+
</table>
|
222 |
+
|
223 |
+
## 🕺 Part-level Video Editing
|
224 |
+
You can get part-level video editing results, using the following command:
|
225 |
+
```bash
|
226 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/part_level/modification/man_text_message/blue_shirt.yaml
|
227 |
+
```
|
228 |
+
|
229 |
+
<table class="center">
|
230 |
+
<tr>
|
231 |
+
<td><img src="assets/part-level/man_text_message.gif"></td>
|
232 |
+
<td><img src="assets/part-level/blue-shirt.gif"></td>
|
233 |
+
<td><img src="assets/part-level/black-suit.gif"></td>
|
234 |
+
<td><img src="assets/part-level/cat_flower.gif"></td>
|
235 |
+
<td><img src="assets/part-level/ginger_head.gif"></td>
|
236 |
+
<td><img src="assets/part-level/ginger_body.gif"></td>
|
237 |
+
</tr>
|
238 |
+
<tr>
|
239 |
+
<td width=15% style="text-align:center;">source video</td>
|
240 |
+
<td width=15% style="text-align:center;">blue shirt</td>
|
241 |
+
<td width=15% style="text-align:center;">black suit</td>
|
242 |
+
<td width=15% style="text-align:center;">source video</td>
|
243 |
+
<td width=15% style="text-align:center;">ginger head </td>
|
244 |
+
<td width=15% style="text-align:center;">ginger body</td>
|
245 |
+
</tr>
|
246 |
+
<tr>
|
247 |
+
<td><img src="assets/part-level/man_text_message.gif"></td>
|
248 |
+
<td><img src="assets/part-level/superman.gif"></td>
|
249 |
+
<td><img src="assets/part-level/superman+cap.gif"></td>
|
250 |
+
<td><img src="assets/part-level/spin-ball.gif"></td>
|
251 |
+
<td><img src="assets/part-level/superman_spin.gif"></td>
|
252 |
+
<td><img src="assets/part-level/super_sunglass_spin.gif"></td>
|
253 |
+
</tr>
|
254 |
+
<tr>
|
255 |
+
<td width=15% style="text-align:center;">source video</td>
|
256 |
+
<td width=15% style="text-align:center;">superman</td>
|
257 |
+
<td width=15% style="text-align:center;">superman + cap</td>
|
258 |
+
<td width=15% style="text-align:center;">source video</td>
|
259 |
+
<td width=15% style="text-align:center;">superman </td>
|
260 |
+
<td width=15% style="text-align:center;">superman + sunglasses</td>
|
261 |
+
</tr>
|
262 |
+
</table>
|
263 |
+
|
264 |
+
## 🥳 Class-level Video Editing
|
265 |
+
You can get class-level video editing results, using the following command:
|
266 |
+
```bash
|
267 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/class_level/wolf/wolf.yaml
|
268 |
+
```
|
269 |
+
|
270 |
+
<table class="center">
|
271 |
+
<tr>
|
272 |
+
<td><img src="assets/class-level/wolf.gif"></td>
|
273 |
+
<td><img src="assets/class-level/pig.gif"></td>
|
274 |
+
<td><img src="assets/class-level/husky.gif"></td>
|
275 |
+
<td><img src="assets/class-level/bear.gif"></td>
|
276 |
+
<td><img src="assets/class-level/tiger.gif"></td>
|
277 |
+
</tr>
|
278 |
+
<tr>
|
279 |
+
<td width=15% style="text-align:center;">input</td>
|
280 |
+
<td width=15% style="text-align:center;">pig</td>
|
281 |
+
<td width=15% style="text-align:center;">husky</td>
|
282 |
+
<td width=15% style="text-align:center;">bear</td>
|
283 |
+
<td width=15% style="text-align:center;">tiger</td>
|
284 |
+
</tr>
|
285 |
+
<tr>
|
286 |
+
<td><img src="assets/class-level/tennis.gif"></td>
|
287 |
+
<td><img src="assets/class-level/tennis_1cls.gif"></td>
|
288 |
+
<td><img src="assets/class-level/tennis_3cls.gif"></td>
|
289 |
+
<td><img src="assets/class-level/car-1.gif"></td>
|
290 |
+
<td><img src="assets/class-level/posche.gif"></td>
|
291 |
+
</tr>
|
292 |
+
<tr>
|
293 |
+
<td width=15% style="text-align:center;">input</td>
|
294 |
+
<td width=15% style="text-align:center;">iron man</td>
|
295 |
+
<td width=15% style="text-align:center;">Batman + snow court + iced wall</td>
|
296 |
+
<td width=15% style="text-align:center;">input </td>
|
297 |
+
<td width=15% style="text-align:center;">posche</td>
|
298 |
+
</tr>
|
299 |
+
</table>
|
300 |
+
|
301 |
+
|
302 |
+
## Soely Edit on specific subjects, keep background unchanged
|
303 |
+
You can get soely video editing results, using the following command:
|
304 |
+
```bash
|
305 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/instance_level/soely_edit/only_left.yaml
|
306 |
+
#--config config/instance_level/soely_edit/only_right.yaml
|
307 |
+
#--config config/instance_level/soely_edit/joint_edit.yaml
|
308 |
+
```
|
309 |
+
|
310 |
+
<table class="center">
|
311 |
+
<tr>
|
312 |
+
<td><img src="assets/soely_edit/input.gif"></td>
|
313 |
+
<td><img src="assets/soely_edit/left.gif"></td>
|
314 |
+
<td><img src="assets/soely_edit/right.gif"></td>
|
315 |
+
<td><img src="assets/soely_edit/joint.gif"></td>
|
316 |
+
</tr>
|
317 |
+
<tr>
|
318 |
+
<td width=25% style="text-align:center;">source video</td>
|
319 |
+
<td width=25% style="text-align:center;">left→Iron Man</td>
|
320 |
+
<td width=25% style="text-align:center;">right→Spiderman</td>
|
321 |
+
<td width=25% style="text-align:center;">joint edit</td>
|
322 |
+
</tr>
|
323 |
+
</table>
|
324 |
+
|
325 |
+
## 🔍 Visualize Cross Attention Weight
|
326 |
+
You can get visulize attention weight editing results, using the following command:
|
327 |
+
```bash
|
328 |
+
#setting vis_cross_attn: True in your config
|
329 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/instance_level/running_two_man/3cls_spider_polar_vis_weight.yaml
|
330 |
+
```
|
331 |
+
|
332 |
+
<table class="center">
|
333 |
+
<tr>
|
334 |
+
<td><img src="assets/soely_edit/input.gif"></td>
|
335 |
+
<td><img src="assets/vis/edit.gif"></td>
|
336 |
+
<td><img src="assets/vis/spiderman_weight.gif"></td>
|
337 |
+
<td><img src="assets/vis/bear_weight.gif"></td>
|
338 |
+
<td><img src="/assets/vis/cherry_weight.gif"></td>
|
339 |
+
</tr>
|
340 |
+
<tr>
|
341 |
+
<td width=20% style="text-align:center;">source video</td>
|
342 |
+
<td width=20% style="text-align:center;">left→spiderman, right→polar bear, trees→cherry blossoms</td>
|
343 |
+
<td width=20% style="text-align:center;">spiderman weight</td>
|
344 |
+
<td width=20% style="text-align:center;">bear weight</td>
|
345 |
+
<td width=20% style="text-align:center;">cherry weight</td>
|
346 |
+
</tr>
|
347 |
+
</table>
|
348 |
+
|
349 |
+
## ✏️ Citation
|
350 |
+
If you think this project is helpful, please feel free to leave a star⭐️⭐️⭐️ and cite our paper:
|
351 |
+
```bibtex
|
352 |
+
@article{yang2025videograin,
|
353 |
+
title={VideoGrain: Modulating Space-Time Attention for Multi-grained Video Editing},
|
354 |
+
author={Yang, Xiangpeng and Zhu, Linchao and Fan, Hehe and Yang, Yi},
|
355 |
+
journal={arXiv preprint arXiv:2502.17258},
|
356 |
+
year={2025}
|
357 |
+
}
|
358 |
+
```
|
359 |
+
|
360 |
+
## 📞 Contact Authors
|
361 |
+
Xiangpeng Yang [@knightyxp](https://github.com/knightyxp), email: [email protected]/[email protected]
|
362 |
+
|
363 |
+
## ✨ Acknowledgements
|
364 |
+
|
365 |
+
- This code builds on [diffusers](https://github.com/huggingface/diffusers), and [FateZero](https://github.com/ChenyangQiQi/FateZero). Thanks for open-sourcing!
|
366 |
+
- We would like to thank [AK(@_akhaliq)](https://x.com/_akhaliq/status/1894254599223017622) and Gradio team for recommendation!
|
367 |
+
|
368 |
+
|
369 |
+
## ⭐️ Star History
|
370 |
+
|
371 |
+
[](https://star-history.com/#knightyxp/VideoGrain&Date)
|
config/part_level/adding_new_object/run_two_man/spider_polar_sunglass.yaml
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
pretrained_model_path: "ckpt/stable-diffusion-v1-5"
|
2 |
logdir: ./result/part_level/run_two_man/left2spider_right2polar-sunglasses
|
3 |
|
4 |
dataset_config:
|
|
|
1 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-5"
|
2 |
logdir: ./result/part_level/run_two_man/left2spider_right2polar-sunglasses
|
3 |
|
4 |
dataset_config:
|
image.png
ADDED
![]() |
video_diffusion/data/__pycache__/dataset.cpython-310.pyc
CHANGED
Binary files a/video_diffusion/data/__pycache__/dataset.cpython-310.pyc and b/video_diffusion/data/__pycache__/dataset.cpython-310.pyc differ
|
|
video_diffusion/data/dataset.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
import numpy as np
|
3 |
from PIL import Image
|
4 |
from einops import rearrange
|
@@ -10,24 +11,26 @@ from torch.utils.data import Dataset
|
|
10 |
from .transform import short_size_scale, random_crop, center_crop, offset_crop
|
11 |
from ..common.image_util import IMAGE_EXTENSION
|
12 |
import cv2
|
13 |
-
import imageio
|
14 |
-
import shutil
|
15 |
|
16 |
class ImageSequenceDataset(Dataset):
|
17 |
def __init__(
|
18 |
self,
|
19 |
-
path: str,
|
20 |
-
|
|
|
21 |
prompt_ids: torch.Tensor,
|
22 |
prompt: str,
|
23 |
-
start_sample_frame: int
|
24 |
n_sample_frame: int = 8,
|
25 |
sampling_rate: int = 1,
|
26 |
-
stride: int = -1,
|
27 |
image_mode: str = "RGB",
|
28 |
image_size: int = 512,
|
29 |
crop: str = "center",
|
30 |
-
|
|
|
|
|
|
|
31 |
offset: dict = {
|
32 |
"left": 0,
|
33 |
"right": 0,
|
@@ -35,42 +38,33 @@ class ImageSequenceDataset(Dataset):
|
|
35 |
"bottom": 0
|
36 |
},
|
37 |
**args
|
|
|
38 |
):
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
self.images = self.get_image_list(self.path)
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
self.layout_mask_dirs = []
|
49 |
-
for idx, file in enumerate(layout_files):
|
50 |
-
if file.endswith('.mp4'):
|
51 |
-
folder = self.mp4_to_png(file, target_dir=f'./layout_masks/{idx+1}')
|
52 |
-
else:
|
53 |
-
folder = file
|
54 |
-
self.layout_mask_dirs.append(folder)
|
55 |
-
# 保持上传顺序作为 layout_mask_order(此处仅用索引表示顺序)
|
56 |
-
self.layout_mask_order = list(range(len(self.layout_mask_dirs)))
|
57 |
-
# 用第一个 layout mask 目录获取 mask 图像索引(用于判断帧数)
|
58 |
-
self.masks_index = self.get_image_list(self.layout_mask_dirs[0])
|
59 |
|
|
|
60 |
self.n_images = len(self.images)
|
61 |
self.offset = offset
|
62 |
self.start_sample_frame = start_sample_frame
|
63 |
if n_sample_frame < 0:
|
64 |
-
n_sample_frame = len(self.images)
|
65 |
self.n_sample_frame = n_sample_frame
|
|
|
66 |
self.sampling_rate = sampling_rate
|
67 |
|
68 |
self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1
|
69 |
if self.n_images < self.sequence_length:
|
70 |
-
raise ValueError(f"self.n_images
|
71 |
|
72 |
-
#
|
73 |
-
self.stride = stride if stride > 0 else (self.n_images
|
74 |
self.video_len = (self.n_images - self.sequence_length) // self.stride + 1
|
75 |
|
76 |
self.image_mode = image_mode
|
@@ -80,53 +74,67 @@ class ImageSequenceDataset(Dataset):
|
|
80 |
"random": random_crop,
|
81 |
}
|
82 |
if crop not in crop_methods:
|
83 |
-
raise ValueError
|
84 |
self.crop = crop_methods[crop]
|
85 |
|
86 |
self.prompt = prompt
|
87 |
self.prompt_ids = prompt_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
|
90 |
def __len__(self):
|
91 |
max_len = (self.n_images - self.sequence_length) // self.stride + 1
|
|
|
92 |
if hasattr(self, 'num_class_images'):
|
93 |
max_len = max(max_len, self.num_class_images)
|
|
|
94 |
return max_len
|
95 |
|
96 |
def __getitem__(self, index):
|
97 |
return_batch = {}
|
98 |
-
frame_indices = self.get_frame_indices(index
|
99 |
frames = [self.load_frame(i) for i in frame_indices]
|
100 |
frames = self.transform(frames)
|
101 |
|
102 |
layout_ = []
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
masks = np.stack(mask) # shape: (n_sample_frame, c, h, w)
|
109 |
layout_.append(masks)
|
110 |
-
layout_ = np.stack(layout_)
|
111 |
-
|
112 |
merged_masks = []
|
113 |
for i in range(int(self.n_sample_frame)):
|
114 |
-
merged_mask_frame = np.sum(layout_[:,
|
115 |
-
merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8)
|
116 |
merged_masks.append(merged_mask_frame)
|
117 |
masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w")
|
118 |
masks = torch.from_numpy(masks).half()
|
119 |
|
120 |
-
layouts = rearrange(layout_,
|
121 |
layouts = torch.from_numpy(layouts).half()
|
122 |
|
123 |
-
return_batch.update(
|
|
|
124 |
"images": frames,
|
125 |
-
"masks":
|
126 |
-
"layouts":
|
127 |
"prompt_ids": self.prompt_ids,
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
return return_batch
|
131 |
|
132 |
def transform(self, frames):
|
@@ -141,18 +149,24 @@ class ImageSequenceDataset(Dataset):
|
|
141 |
frames = rearrange(np.stack(frames), "f h w c -> c f h w")
|
142 |
return torch.from_numpy(frames).div(255) * 2 - 1
|
143 |
|
144 |
-
def _read_mask(self,
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
147 |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
148 |
mask = (mask > 0).astype(np.uint8)
|
149 |
-
#
|
150 |
height, width = mask.shape
|
151 |
dest_size = (width // 8, height // 8)
|
152 |
-
|
|
|
153 |
mask = mask[np.newaxis, ...]
|
|
|
154 |
return mask
|
155 |
|
|
|
156 |
def load_frame(self, index):
|
157 |
image_path = os.path.join(self.path, self.images[index])
|
158 |
return Image.open(image_path).convert(self.image_mode)
|
@@ -170,31 +184,12 @@ class ImageSequenceDataset(Dataset):
|
|
170 |
|
171 |
def get_class_indices(self, index):
|
172 |
frame_start = index
|
173 |
-
return (frame_start + i
|
174 |
|
175 |
@staticmethod
|
176 |
def get_image_list(path):
|
177 |
images = []
|
178 |
-
# 如果传入的是 mp4 文件,则先转换成 PNG 图像目录
|
179 |
-
if path.endswith('.mp4'):
|
180 |
-
path = ImageSequenceDataset.mp4_to_png(path, target_dir='./input-video')
|
181 |
for file in sorted(os.listdir(path)):
|
182 |
if file.endswith(IMAGE_EXTENSION):
|
183 |
images.append(file)
|
184 |
-
return images
|
185 |
-
|
186 |
-
@staticmethod
|
187 |
-
def mp4_to_png(video_source: str, target_dir: str):
|
188 |
-
"""
|
189 |
-
Convert an mp4 video to a sequence of PNG images, storing them in target_dir.
|
190 |
-
target_dir 为固定路径,例如:'./input-video' 或 './layout_masks/1'
|
191 |
-
"""
|
192 |
-
if os.path.exists(target_dir):
|
193 |
-
shutil.rmtree(target_dir)
|
194 |
-
os.makedirs(target_dir, exist_ok=True)
|
195 |
-
|
196 |
-
reader = imageio.get_reader(video_source)
|
197 |
-
for i, im in enumerate(reader):
|
198 |
-
path = os.path.join(target_dir, f"{i:05d}.png")
|
199 |
-
cv2.imwrite(path, im[:, :, ::-1])
|
200 |
-
return target_dir
|
|
|
1 |
import os
|
2 |
+
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
from einops import rearrange
|
|
|
11 |
from .transform import short_size_scale, random_crop, center_crop, offset_crop
|
12 |
from ..common.image_util import IMAGE_EXTENSION
|
13 |
import cv2
|
|
|
|
|
14 |
|
15 |
class ImageSequenceDataset(Dataset):
|
16 |
def __init__(
|
17 |
self,
|
18 |
+
path: str,
|
19 |
+
layout_mask_dir: str,
|
20 |
+
layout_mask_order: list,
|
21 |
prompt_ids: torch.Tensor,
|
22 |
prompt: str,
|
23 |
+
start_sample_frame: int=0,
|
24 |
n_sample_frame: int = 8,
|
25 |
sampling_rate: int = 1,
|
26 |
+
stride: int = -1, # only used during tuning to sample a long video
|
27 |
image_mode: str = "RGB",
|
28 |
image_size: int = 512,
|
29 |
crop: str = "center",
|
30 |
+
|
31 |
+
class_data_root: str = None,
|
32 |
+
class_prompt_ids: torch.Tensor = None,
|
33 |
+
|
34 |
offset: dict = {
|
35 |
"left": 0,
|
36 |
"right": 0,
|
|
|
38 |
"bottom": 0
|
39 |
},
|
40 |
**args
|
41 |
+
|
42 |
):
|
43 |
+
self.path = path
|
44 |
+
self.images = self.get_image_list(path)
|
45 |
+
#
|
46 |
+
self.layout_mask_dir = layout_mask_dir
|
47 |
+
self.layout_mask_order = list(layout_mask_order)
|
|
|
48 |
|
49 |
+
layout_mask_dir0 = os.path.join(self.layout_mask_dir,self.layout_mask_order[0])
|
50 |
+
self.masks_index = self.get_image_list(layout_mask_dir0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
#
|
53 |
self.n_images = len(self.images)
|
54 |
self.offset = offset
|
55 |
self.start_sample_frame = start_sample_frame
|
56 |
if n_sample_frame < 0:
|
57 |
+
n_sample_frame = len(self.images)
|
58 |
self.n_sample_frame = n_sample_frame
|
59 |
+
# local sampling rate from the video
|
60 |
self.sampling_rate = sampling_rate
|
61 |
|
62 |
self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1
|
63 |
if self.n_images < self.sequence_length:
|
64 |
+
raise ValueError(f"self.n_images {self.n_images } < self.sequence_length {self.sequence_length}: Required number of frames {self.sequence_length} larger than total frames in the dataset {self.n_images }")
|
65 |
|
66 |
+
# During tuning if video is too long, we sample the long video every self.stride globally
|
67 |
+
self.stride = stride if stride > 0 else (self.n_images+1)
|
68 |
self.video_len = (self.n_images - self.sequence_length) // self.stride + 1
|
69 |
|
70 |
self.image_mode = image_mode
|
|
|
74 |
"random": random_crop,
|
75 |
}
|
76 |
if crop not in crop_methods:
|
77 |
+
raise ValueError
|
78 |
self.crop = crop_methods[crop]
|
79 |
|
80 |
self.prompt = prompt
|
81 |
self.prompt_ids = prompt_ids
|
82 |
+
# Negative prompt for regularization to avoid overfitting during one-shot tuning
|
83 |
+
if class_data_root is not None:
|
84 |
+
self.class_data_root = Path(class_data_root)
|
85 |
+
self.class_images_path = sorted(list(self.class_data_root.iterdir()))
|
86 |
+
self.num_class_images = len(self.class_images_path)
|
87 |
+
self.class_prompt_ids = class_prompt_ids
|
88 |
|
89 |
|
90 |
def __len__(self):
|
91 |
max_len = (self.n_images - self.sequence_length) // self.stride + 1
|
92 |
+
|
93 |
if hasattr(self, 'num_class_images'):
|
94 |
max_len = max(max_len, self.num_class_images)
|
95 |
+
|
96 |
return max_len
|
97 |
|
98 |
def __getitem__(self, index):
|
99 |
return_batch = {}
|
100 |
+
frame_indices = self.get_frame_indices(index%self.video_len)
|
101 |
frames = [self.load_frame(i) for i in frame_indices]
|
102 |
frames = self.transform(frames)
|
103 |
|
104 |
layout_ = []
|
105 |
+
for layout_name in self.layout_mask_order:
|
106 |
+
frame_indices = self.get_frame_indices(index%self.video_len)
|
107 |
+
layout_mask_dir = os.path.join(self.layout_mask_dir,layout_name)
|
108 |
+
mask = [self._read_mask(layout_mask_dir,i) for i in frame_indices]
|
109 |
+
masks = np.stack(mask)
|
|
|
110 |
layout_.append(masks)
|
111 |
+
layout_ = np.stack(layout_)
|
|
|
112 |
merged_masks = []
|
113 |
for i in range(int(self.n_sample_frame)):
|
114 |
+
merged_mask_frame = np.sum(layout_[:,i,:,:,:], axis=0)
|
115 |
+
merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8)
|
116 |
merged_masks.append(merged_mask_frame)
|
117 |
masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w")
|
118 |
masks = torch.from_numpy(masks).half()
|
119 |
|
120 |
+
layouts = rearrange(layout_,"s f c h w -> f s c h w" )
|
121 |
layouts = torch.from_numpy(layouts).half()
|
122 |
|
123 |
+
return_batch.update(
|
124 |
+
{
|
125 |
"images": frames,
|
126 |
+
"masks":masks,
|
127 |
+
"layouts":layouts,
|
128 |
"prompt_ids": self.prompt_ids,
|
129 |
+
}
|
130 |
+
)
|
131 |
+
|
132 |
+
if hasattr(self, 'class_data_root'):
|
133 |
+
class_index = index % (self.num_class_images - self.n_sample_frame)
|
134 |
+
class_indices = self.get_class_indices(class_index)
|
135 |
+
frames = [self.load_class_frame(i) for i in class_indices]
|
136 |
+
return_batch["class_images"] = self.tensorize_frames(frames)
|
137 |
+
return_batch["class_prompt_ids"] = self.class_prompt_ids
|
138 |
return return_batch
|
139 |
|
140 |
def transform(self, frames):
|
|
|
149 |
frames = rearrange(np.stack(frames), "f h w c -> c f h w")
|
150 |
return torch.from_numpy(frames).div(255) * 2 - 1
|
151 |
|
152 |
+
def _read_mask(self, mask_path,index: int):
|
153 |
+
### read mask by pil
|
154 |
+
|
155 |
+
mask_path = os.path.join(mask_path,f"{index:05d}.png")
|
156 |
+
|
157 |
+
### read mask by cv2
|
158 |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
159 |
mask = (mask > 0).astype(np.uint8)
|
160 |
+
# Determine dynamic destination size
|
161 |
height, width = mask.shape
|
162 |
dest_size = (width // 8, height // 8)
|
163 |
+
# Resize using nearest neighbor interpolation
|
164 |
+
mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST) #cv2.INTER_CUBIC
|
165 |
mask = mask[np.newaxis, ...]
|
166 |
+
|
167 |
return mask
|
168 |
|
169 |
+
|
170 |
def load_frame(self, index):
|
171 |
image_path = os.path.join(self.path, self.images[index])
|
172 |
return Image.open(image_path).convert(self.image_mode)
|
|
|
184 |
|
185 |
def get_class_indices(self, index):
|
186 |
frame_start = index
|
187 |
+
return (frame_start + i for i in range(self.n_sample_frame))
|
188 |
|
189 |
@staticmethod
|
190 |
def get_image_list(path):
|
191 |
images = []
|
|
|
|
|
|
|
192 |
for file in sorted(os.listdir(path)):
|
193 |
if file.endswith(IMAGE_EXTENSION):
|
194 |
images.append(file)
|
195 |
+
return images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|