Upload 27 files
Browse files- .gitattributes +15 -0
- README.md +257 -10
- assets/Compare.png +3 -0
- assets/Pipeline.png +3 -0
- assets/Qualitative-1.png +3 -0
- assets/Qualitative-2-1.png +3 -0
- assets/Qualitative-2-2.png +3 -0
- assets/Qualitative-3-1.png +3 -0
- assets/Qualitative-3-2.png +3 -0
- assets/Qualitative-4-1.png +3 -0
- assets/Qualitative-4-2.png +3 -0
- assets/Qualitative-5-1.png +3 -0
- assets/Qualitative-5-2.png +3 -0
- assets/Quantitative.png +3 -0
- assets/Strategy.png +3 -0
- datasets/README.md +14 -0
- datasets/demo/001.mp4 +0 -0
- datasets/demo/002.mp4 +0 -0
- datasets/demo/003.mp4 +0 -0
- datasets/demo/004.mp4 +0 -0
- datasets/demo/005.mp4 +3 -0
- datasets/demo/006.mp4 +3 -0
- datasets/demo/007.mp4 +0 -0
- eval_metrics.py +256 -0
- inference.sh +75 -0
- inference_script.py +754 -0
- pretrained_models/README.md +1 -0
- requirements.txt +20 -0
.gitattributes
CHANGED
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/Compare.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/Pipeline.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/Qualitative-1.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/Qualitative-2-1.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/Qualitative-2-2.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/Qualitative-3-1.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/Qualitative-3-2.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/Qualitative-4-1.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/Qualitative-4-2.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
assets/Qualitative-5-1.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
assets/Qualitative-5-2.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
assets/Quantitative.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
assets/Strategy.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
datasets/demo/005.mp4 filter=lfs diff=lfs merge=lfs -text
|
50 |
+
datasets/demo/006.mp4 filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,13 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: mit
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DOVE: Efficient One-Step Diffusion Model for Real-World Video Super-Resolution
|
2 |
+
|
3 |
+
[Zheng Chen](https://zhengchen1999.github.io/), [Zichen Zou](https://github.com/zzctmd), [Kewei Zhang](), [Xiongfei Su](https://ieeexplore.ieee.org/author/37086348852), [Xin Yuan](https://en.westlake.edu.cn/faculty/xin-yuan.html), [Yong Guo](https://www.guoyongcs.com/), and [Yulun Zhang](http://yulunzhang.com/), "DOVE: Efficient One-Step Diffusion Model for Real-World Video Super-Resolution", 2025
|
4 |
+
|
5 |
+
<div>
|
6 |
+
<a href="https://github.com/zhengchen1999/DOVE/releases" target='_blank' style="text-decoration: none;"><img src="https://img.shields.io/github/downloads/zhengchen1999/DOVE/total?color=green&style=flat"></a>
|
7 |
+
<a href="https://github.com/zhengchen1999/DOVE" target='_blank' style="text-decoration: none;"><img src="https://visitor-badge.laobi.icu/badge?page_id=zhengchen1999/DOVE"></a>
|
8 |
+
<a href="https://github.com/zhengchen1999/DOVE/stargazers" target='_blank' style="text-decoration: none;"><img src="https://img.shields.io/github/stars/zhengchen1999/DOVE?style=social"></a>
|
9 |
+
</div>
|
10 |
+
|
11 |
+
|
12 |
+
[[arXiv](https://arxiv.org/abs/2505.16239)] [[supplementary material](https://github.com/zhengchen1999/DOVE/releases/download/v1/Supplementary_Material.pdf)] [[dataset](https://drive.google.com/drive/folders/1e7CyNzfJBa2saWvPr2HI2q_FJhLIc-Ww?usp=drive_link)] [[pretrained models](https://drive.google.com/drive/folders/1wj9jY0fn6prSWJ7BjJOXfxC0bs8skKbQ?usp=sharing)]
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
#### 🔥🔥🔥 News
|
17 |
+
|
18 |
+
- **2025-6-09:** Test datasets, inference scripts, and pretrained models are available. ⭐️⭐️⭐️
|
19 |
+
- **2025-5-22:** This repo is released.
|
20 |
+
|
21 |
---
|
22 |
+
|
23 |
+
> **Abstract:** Diffusion models have demonstrated promising performance in real-world video super-resolution (VSR). However, the dozens of sampling steps they require, make inference extremely slow. Sampling acceleration techniques, particularly single-step, provide a potential solution. Nonetheless, achieving one step in VSR remains challenging, due to the high training overhead on video data and stringent fidelity demands. To tackle the above issues, we propose DOVE, an efficient one-step diffusion model for real-world VSR. DOVE is obtained by fine-tuning a pretrained video diffusion model (*i.e.*, CogVideoX). To effectively train DOVE, we introduce the latent–pixel training strategy. The strategy employs a two-stage scheme to gradually adapt the model to the video super-resolution task.
|
24 |
+
> Meanwhile, we design a video processing pipeline to construct a high-quality dataset tailored for VSR, termed HQ-VSR. Fine-tuning on this dataset further enhances the restoration capability of DOVE. Extensive experiments show that DOVE exhibits comparable or superior performance to multi-step diffusion-based VSR methods. It also offers outstanding inference efficiency, achieving up to a **28×** speed-up over existing methods such as MGLD-VSR.
|
25 |
+
|
26 |
+

|
27 |
+
|
|
|
|
|
|
|
28 |
---
|
29 |
|
30 |
+
|
31 |
+
|
32 |
+
<table border="0" style="width: 100%; text-align: center; margin-top: 20px;">
|
33 |
+
<tr>
|
34 |
+
<td>
|
35 |
+
<video src="https://github.com/user-attachments/assets/4ad0ca78-6cca-48c0-95a5-5d5554093f7d" controls autoplay loop></video>
|
36 |
+
</td>
|
37 |
+
<td>
|
38 |
+
<video src="https://github.com/user-attachments/assets/e5b5d247-28af-43fd-b32c-1f1b5896d9e7" controls autoplay loop></video>
|
39 |
+
</td>
|
40 |
+
</tr>
|
41 |
+
</table>
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
---
|
47 |
+
|
48 |
+
### Training Strategy
|
49 |
+
|
50 |
+

|
51 |
+
|
52 |
+
---
|
53 |
+
|
54 |
+
### Video Processing Pipeline
|
55 |
+
|
56 |
+

|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
## 🔖 TODO
|
62 |
+
|
63 |
+
- [x] Release testing code.
|
64 |
+
- [x] Release pre-trained models.
|
65 |
+
- [ ] Release training code.
|
66 |
+
- [ ] Release video processing pipeline.
|
67 |
+
- [ ] Release HQ-VSR dataset.
|
68 |
+
- [ ] Provide WebUI.
|
69 |
+
- [ ] Provide HuggingFace demo.
|
70 |
+
|
71 |
+
## ⚙️ Dependencies
|
72 |
+
|
73 |
+
- Python 3.11
|
74 |
+
- PyTorch\>=2.5.0
|
75 |
+
- Diffusers
|
76 |
+
|
77 |
+
```bash
|
78 |
+
# Clone the github repo and go to the default directory 'DOVE'.
|
79 |
+
git clone https://github.com/zhengchen1999/DOVE.git
|
80 |
+
conda create -n DOVE python=3.11
|
81 |
+
conda activate DOVE
|
82 |
+
pip install -r requirements.txt
|
83 |
+
pip install diffusers["torch"] transformers
|
84 |
+
pip install pyiqa
|
85 |
+
```
|
86 |
+
|
87 |
+
## 🔗 Contents
|
88 |
+
|
89 |
+
1. [Datasets](#datasets)
|
90 |
+
1. [Models](#models)
|
91 |
+
1. Training
|
92 |
+
1. [Testing](#testing)
|
93 |
+
1. [Results](#results)
|
94 |
+
1. [Acknowledgements](#acknowledgements)
|
95 |
+
|
96 |
+
## <a name="datasets"></a>📁 Datasets
|
97 |
+
|
98 |
+
### 🗳️ Test Datasets
|
99 |
+
|
100 |
+
We provide several real-world and synthetic test datasets for evaluation. All datasets follow a consistent directory structure:
|
101 |
+
|
102 |
+
| Dataset | Type | # Num | Download |
|
103 |
+
| :------ | :--------: | :---: | :----------------------------------------------------------: |
|
104 |
+
| UDM10 | Synthetic | 10 | [Google Drive](https://drive.google.com/file/d/1AmGVSCwMm_OFPd3DKgNyTwj0GG2H-tG4/view?usp=drive_link) |
|
105 |
+
| SPMCS | Synthetic | 30 | [Google Drive](https://drive.google.com/file/d/1b2uktCFPKS-R1fTecWcLFcOnmUFIBNWT/view?usp=drive_link) |
|
106 |
+
| YouHQ40 | Synthetic | 40 | [Google Drive](https://drive.google.com/file/d/1zO23UCStxL3htPJQcDUUnUeMvDrysLTh/view?usp=sharing) |
|
107 |
+
| RealVSR | Real-world | 50 | [Google Drive](https://drive.google.com/file/d/1wr4tTiCvQlqdYPeU1dmnjb5KFY4VjGCO/view?usp=drive_link) |
|
108 |
+
| MVSR4x | Real-world | 15 | [Google Drive](https://drive.google.com/file/d/16sesBD_9Xx_5Grtx18nosBw1w94KlpQt/view?usp=drive_link) |
|
109 |
+
| VideoLQ | Real-world | 50 | [Google Drive](https://drive.google.com/file/d/1lh0vkU_llxE0un1OigJ0DWPQwt1i68Vn/view?usp=drive_link) |
|
110 |
+
|
111 |
+
All datasets are hosted on [here](https://drive.google.com/drive/folders/1yNKG6rtTNtZQY8qL74GoQwA0jgjBUEby?usp=sharing). Make sure the path is correct (`datasets/test/`) before running inference.
|
112 |
+
|
113 |
+
The directory structure is as follows:
|
114 |
+
|
115 |
+
```shell
|
116 |
+
datasets/
|
117 |
+
└── test/
|
118 |
+
└── [DatasetName]/
|
119 |
+
├── GT/ # Ground Truth: folder of high-quality frames (one per clip)
|
120 |
+
├── GT-Video/ # Ground Truth (video version): lossless MKV format
|
121 |
+
├── LQ/ # Low-quality Input: folder of degraded frames (one per clip)
|
122 |
+
└── LQ-Video/ # Low-Quality Input (video version): lossless MKV format
|
123 |
+
```
|
124 |
+
|
125 |
+
## <a name="models"></a>📦 Models
|
126 |
+
|
127 |
+
We provide pretrained weights for DOVE and DOVE-2B.
|
128 |
+
|
129 |
+
| Model Name | Description | HuggingFace | Google Drive | Visual Results |
|
130 |
+
| :--------- | :-------------------------------------: | :---------: | :----------------------------------------------------------: | ------------------------------------------------------------ |
|
131 |
+
| DOVE | Base version, built on CogVideoX1.5-5B; | TODO | [Download](https://drive.google.com/file/d/1Nl3XoJndMtpu6KPFcskUTkI0qWBiSXF2/view?usp=drive_link) | [Download](https://drive.google.com/drive/folders/1J92X1amVijH9dNWGQcz-6Cx44B7EipWr?usp=drive_link) |
|
132 |
+
| DOVE-2B | Smaller version, based on CogVideoX-2B | TODO | TODO | TODO |
|
133 |
+
|
134 |
+
> Place downloaded model files into the `pretrained_models/` folder, e.g., `pretrained_models/DOVE`.
|
135 |
+
|
136 |
+
## <a name="testing"></a>🔨 Testing
|
137 |
+
|
138 |
+
We provide inference commands below. Before running, make sure to download the corresponding pretrained models and test datasets.
|
139 |
+
|
140 |
+
For more options and usage, please refer to [inference_script.py](inference_script.py).
|
141 |
+
|
142 |
+
The full testing commands are provided in the shell script: [inference.sh](inference.sh).
|
143 |
+
|
144 |
+
```shell
|
145 |
+
# 🔹 Demo inference
|
146 |
+
python inference_script.py \
|
147 |
+
--input_dir datasets/demo \
|
148 |
+
--model_path pretrained_models/DOVE \
|
149 |
+
--output_path results/DOVE/demo \
|
150 |
+
--is_vae_st \
|
151 |
+
--save_format yuv420p
|
152 |
+
|
153 |
+
# 🔹 Reproduce paper results
|
154 |
+
python inference_script.py \
|
155 |
+
--input_dir datasets/test/UDM10/LQ-Video \
|
156 |
+
--model_path pretrained_models/DOVE \
|
157 |
+
--output_path results/DOVE/UDM10 \
|
158 |
+
--is_vae_st \
|
159 |
+
|
160 |
+
# 🔹 Evaluate quantitative metrics
|
161 |
+
python eval_metrics.py \
|
162 |
+
--gt datasets/test/UDM10/GT \
|
163 |
+
--pred results/DOVE/UDM10 \
|
164 |
+
--metrics psnr,ssim,lpips,dists,clipiqa
|
165 |
+
```
|
166 |
+
|
167 |
+
> 💡 If you encounter out-of-memory (OOM) issues, you can enable chunk-based testing by setting the following parameters: tile_size_hw, overlap_hw, chunk_len, and overlap_t.
|
168 |
+
>
|
169 |
+
> 💡 Default save format is `yuv444p`. If playback fails, try `save_format=yuv420p` (may slightly affect metrics).
|
170 |
+
>
|
171 |
+
> **TODO:** Add metric computation scripts for FasterVQA, DOVER, and $E^*_{warp}$.
|
172 |
+
|
173 |
+
## <a name="results"></a>🔎 Results
|
174 |
+
|
175 |
+
We achieve state-of-the-art performance on real-world video super-resolution. Visual results are available at [Google Drive](https://drive.google.com/drive/folders/1J92X1amVijH9dNWGQcz-6Cx44B7EipWr?usp=drive_link).
|
176 |
+
|
177 |
+
<details open>
|
178 |
+
<summary>Quantitative Results (click to expand)</summary>
|
179 |
+
|
180 |
+
- Results in Tab. 2 of the main paper
|
181 |
+
|
182 |
+
<p align="center">
|
183 |
+
<img width="900" src="assets/Quantitative.png">
|
184 |
+
</p>
|
185 |
+
|
186 |
+
</details>
|
187 |
+
|
188 |
+
<details open>
|
189 |
+
<summary>Qualitative Results (click to expand)</summary>
|
190 |
+
|
191 |
+
- Results in Fig. 4 of the main paper
|
192 |
+
|
193 |
+
<p align="center">
|
194 |
+
<img width="900" src="assets/Qualitative-1.png">
|
195 |
+
</p>
|
196 |
+
<details>
|
197 |
+
<summary>More Qualitative Results</summary>
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
- More results in Fig. 3 of the supplementary material
|
203 |
+
|
204 |
+
<p align="center">
|
205 |
+
<img width="900" src="assets/Qualitative-2-1.png">
|
206 |
+
</p>
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
- More results in Fig. 4 of the supplementary material
|
211 |
+
|
212 |
+
<p align="center">
|
213 |
+
<img width="900" src="assets/Qualitative-2-2.png">
|
214 |
+
</p>
|
215 |
+
|
216 |
+
|
217 |
+
- More results in Fig. 5 of the supplementary material
|
218 |
+
|
219 |
+
<p align="center">
|
220 |
+
<img width="900" src="assets/Qualitative-3-1.png">
|
221 |
+
<img width="900" src="assets/Qualitative-3-2.png">
|
222 |
+
</p>
|
223 |
+
|
224 |
+
|
225 |
+
- More results in Fig. 6 of the supplementary material
|
226 |
+
|
227 |
+
<p align="center">
|
228 |
+
<img width="900" src="assets/Qualitative-4-1.png">
|
229 |
+
<img width="900" src="assets/Qualitative-4-2.png">
|
230 |
+
</p>
|
231 |
+
|
232 |
+
|
233 |
+
- More results in Fig. 7 of the supplementary material
|
234 |
+
|
235 |
+
<p align="center">
|
236 |
+
<img width="900" src="assets/Qualitative-5-1.png">
|
237 |
+
<img width="900" src="assets/Qualitative-5-2.png">
|
238 |
+
</p>
|
239 |
+
|
240 |
+
</details>
|
241 |
+
|
242 |
+
</details>
|
243 |
+
|
244 |
+
## <a name="citation"></a>📎 Citation
|
245 |
+
|
246 |
+
If you find the code helpful in your research or work, please cite the following paper(s).
|
247 |
+
|
248 |
+
```
|
249 |
+
@article{chen2025dove,
|
250 |
+
title={DOVE: Efficient One-Step Diffusion Model for Real-World Video Super-Resolution},
|
251 |
+
author={Chen, Zheng and Zou, Zichen and Zhang, Kewei and Su, Xiongfei and Yuan, Xin and Guo, Yong and Zhang, Yulun},
|
252 |
+
journal={arXiv preprint arXiv:2505.16239},
|
253 |
+
year={2025}
|
254 |
+
}
|
255 |
+
```
|
256 |
+
|
257 |
+
## <a name="acknowledgements"></a>💡 Acknowledgements
|
258 |
+
|
259 |
+
This project is based on [CogVideo](https://github.com/THUDM/CogVideo) and [Open-Sora](https://github.com/hpcaitech/Open-Sora).
|
260 |
+
|
assets/Compare.png
ADDED
![]() |
Git LFS Details
|
assets/Pipeline.png
ADDED
![]() |
Git LFS Details
|
assets/Qualitative-1.png
ADDED
![]() |
Git LFS Details
|
assets/Qualitative-2-1.png
ADDED
![]() |
Git LFS Details
|
assets/Qualitative-2-2.png
ADDED
![]() |
Git LFS Details
|
assets/Qualitative-3-1.png
ADDED
![]() |
Git LFS Details
|
assets/Qualitative-3-2.png
ADDED
![]() |
Git LFS Details
|
assets/Qualitative-4-1.png
ADDED
![]() |
Git LFS Details
|
assets/Qualitative-4-2.png
ADDED
![]() |
Git LFS Details
|
assets/Qualitative-5-1.png
ADDED
![]() |
Git LFS Details
|
assets/Qualitative-5-2.png
ADDED
![]() |
Git LFS Details
|
assets/Quantitative.png
ADDED
![]() |
Git LFS Details
|
assets/Strategy.png
ADDED
![]() |
Git LFS Details
|
datasets/README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The directory structure is as follows:
|
2 |
+
|
3 |
+
```shell
|
4 |
+
datasets/
|
5 |
+
└── demo/
|
6 |
+
└── test/
|
7 |
+
└── [DatasetName]/
|
8 |
+
├── GT/ # Ground Truth: folder of high-quality frames (one per clip)
|
9 |
+
├── GT-Video/ # Ground Truth (video version): lossless MKV format
|
10 |
+
├── LQ/ # Low-quality Input: folder of degraded frames (one per clip)
|
11 |
+
└── LQ-Video/ # Low-Quality Input (video version): lossless MKV format
|
12 |
+
```
|
13 |
+
|
14 |
+
All datasets are available [here](https://drive.google.com/drive/folders/1yNKG6rtTNtZQY8qL74GoQwA0jgjBUEby?usp=sharing).
|
datasets/demo/001.mp4
ADDED
Binary file (62.5 kB). View file
|
|
datasets/demo/002.mp4
ADDED
Binary file (97.3 kB). View file
|
|
datasets/demo/003.mp4
ADDED
Binary file (60.4 kB). View file
|
|
datasets/demo/004.mp4
ADDED
Binary file (79.6 kB). View file
|
|
datasets/demo/005.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de2fe395e78a9d556a3763d7a7bdf87102e0c1191e1d146a54d487f78a57d708
|
3 |
+
size 268870
|
datasets/demo/006.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:194615754e590bc57e84f41a11b7bca4d564455ed93b72be700a6300119d34ac
|
3 |
+
size 206748
|
datasets/demo/007.mp4
ADDED
Binary file (48.4 kB). View file
|
|
eval_metrics.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import pyiqa
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from tqdm import tqdm
|
9 |
+
from torchvision import transforms
|
10 |
+
|
11 |
+
# 0 ~ 1
|
12 |
+
to_tensor = transforms.ToTensor()
|
13 |
+
video_exts = ['.mp4', '.avi', '.mov', '.mkv']
|
14 |
+
fr_metrics = ['psnr', 'ssim', 'lpips', 'dists']
|
15 |
+
|
16 |
+
|
17 |
+
def is_video_file(filename):
|
18 |
+
return any(filename.lower().endswith(ext) for ext in video_exts)
|
19 |
+
|
20 |
+
def rgb_to_y(img):
|
21 |
+
# Assumes img is [1, 3, H, W] in [0,1], returns [1, 1, H, W]
|
22 |
+
r, g, b = img[:, 0:1], img[:, 1:2], img[:, 2:3]
|
23 |
+
y = 0.257 * r + 0.504 * g + 0.098 * b + 0.0625
|
24 |
+
return y
|
25 |
+
|
26 |
+
def crop_border(img, crop):
|
27 |
+
return img[:, :, crop:-crop, crop:-crop]
|
28 |
+
|
29 |
+
def read_video_frames(video_path):
|
30 |
+
cap = cv2.VideoCapture(video_path)
|
31 |
+
frames = []
|
32 |
+
while True:
|
33 |
+
ret, frame = cap.read()
|
34 |
+
if not ret:
|
35 |
+
break
|
36 |
+
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
37 |
+
frames.append(to_tensor(Image.fromarray(rgb)))
|
38 |
+
cap.release()
|
39 |
+
return torch.stack(frames)
|
40 |
+
|
41 |
+
|
42 |
+
def read_image_folder(folder_path):
|
43 |
+
image_files = sorted([
|
44 |
+
os.path.join(folder_path, f) for f in os.listdir(folder_path)
|
45 |
+
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
|
46 |
+
])
|
47 |
+
frames = [to_tensor(Image.open(p).convert("RGB")) for p in image_files]
|
48 |
+
return torch.stack(frames)
|
49 |
+
|
50 |
+
|
51 |
+
def load_sequence(path):
|
52 |
+
if os.path.isdir(path):
|
53 |
+
return read_image_folder(path)
|
54 |
+
elif os.path.isfile(path):
|
55 |
+
if is_video_file(path):
|
56 |
+
return read_video_frames(path)
|
57 |
+
elif path.lower().endswith(('.png', '.jpg', '.jpeg')):
|
58 |
+
# Treat image as a single-frame video
|
59 |
+
img = to_tensor(Image.open(path).convert("RGB"))
|
60 |
+
return img.unsqueeze(0) # [1, C, H, W]
|
61 |
+
raise ValueError(f"Unsupported input: {path}")
|
62 |
+
|
63 |
+
|
64 |
+
def crop_img_center(img, target_h, target_w):
|
65 |
+
_, h, w = img.shape
|
66 |
+
top = max((h - target_h) // 2, 0)
|
67 |
+
left = max((w - target_w) // 2, 0)
|
68 |
+
return img[:, top:top+target_h, left:left+target_w]
|
69 |
+
|
70 |
+
def crop_img_top_left(img, target_h, target_w):
|
71 |
+
# Crop image from top-left corner to (target_h, target_w)
|
72 |
+
return img[:, :target_h, :target_w]
|
73 |
+
|
74 |
+
def match_resolution(gt_frames, pred_frames, is_center=False, name=None):
|
75 |
+
t = min(gt_frames.shape[0], pred_frames.shape[0])
|
76 |
+
gt_frames = gt_frames[:t]
|
77 |
+
pred_frames = pred_frames[:t]
|
78 |
+
_, _, h_g, w_g = gt_frames.shape
|
79 |
+
_, _, h_p, w_p = pred_frames.shape
|
80 |
+
|
81 |
+
target_h = min(h_g, h_p)
|
82 |
+
target_w = min(w_g, w_p)
|
83 |
+
|
84 |
+
if (h_g != h_p or w_g != w_p) and name:
|
85 |
+
if is_center:
|
86 |
+
print(f"[{name}] Resolution mismatch detected: GT is ({h_g}, {w_g}), Pred is ({h_p}, {w_p}). Both GT and Pred were center cropped to ({target_h}, {target_w}).")
|
87 |
+
else:
|
88 |
+
print(f"[{name}] Resolution mismatch detected: GT is ({h_g}, {w_g}), Pred is ({h_p}, {w_p}). Both GT and Pred were top-left cropped to ({target_h}, {target_w}).")
|
89 |
+
|
90 |
+
if is_center:
|
91 |
+
gt_frames = torch.stack([crop_img_center(f, target_h, target_w) for f in gt_frames])
|
92 |
+
pred_frames = torch.stack([crop_img_center(f, target_h, target_w) for f in pred_frames])
|
93 |
+
else:
|
94 |
+
gt_frames = torch.stack([crop_img_top_left(f, target_h, target_w) for f in gt_frames])
|
95 |
+
pred_frames = torch.stack([crop_img_top_left(f, target_h, target_w) for f in pred_frames])
|
96 |
+
|
97 |
+
return gt_frames, pred_frames
|
98 |
+
|
99 |
+
|
100 |
+
def init_models(metrics, device):
|
101 |
+
models = {}
|
102 |
+
for name in metrics:
|
103 |
+
try:
|
104 |
+
models[name] = pyiqa.create_metric(name).to(device).eval()
|
105 |
+
except Exception as e:
|
106 |
+
print(f"Failed to initialize metric '{name}': {e}")
|
107 |
+
return models
|
108 |
+
|
109 |
+
def compute_metrics(pred_frames, gt_frames, models, device, batch_mode, crop, test_y_channel):
|
110 |
+
if batch_mode:
|
111 |
+
pred_batch = pred_frames.to(device) # [F, C, H, W]
|
112 |
+
gt_batch = gt_frames.to(device) # [F, C, H, W]
|
113 |
+
|
114 |
+
results = {}
|
115 |
+
for name, model in models.items():
|
116 |
+
if name in fr_metrics:
|
117 |
+
pred_eval = pred_batch
|
118 |
+
gt_eval = gt_batch
|
119 |
+
if crop > 0:
|
120 |
+
pred_eval = crop_border(pred_eval, crop)
|
121 |
+
gt_eval = crop_border(gt_eval, crop)
|
122 |
+
if test_y_channel:
|
123 |
+
pred_eval = rgb_to_y(pred_eval)
|
124 |
+
gt_eval = rgb_to_y(gt_eval)
|
125 |
+
values = model(pred_eval, gt_eval) # [F]
|
126 |
+
else:
|
127 |
+
values = model(pred_batch) # no-reference
|
128 |
+
results[name] = round(values.mean().item(), 4)
|
129 |
+
return results
|
130 |
+
|
131 |
+
else:
|
132 |
+
results = {name: [] for name in models}
|
133 |
+
for pred, gt in zip(pred_frames, gt_frames):
|
134 |
+
pred = pred.unsqueeze(0).to(device)
|
135 |
+
gt = gt.unsqueeze(0).to(device)
|
136 |
+
|
137 |
+
for name, model in models.items():
|
138 |
+
if name in fr_metrics:
|
139 |
+
pred_eval = pred
|
140 |
+
gt_eval = gt
|
141 |
+
if crop > 0:
|
142 |
+
pred_eval = crop_border(pred_eval, crop)
|
143 |
+
gt_eval = crop_border(gt_eval, crop)
|
144 |
+
if test_y_channel:
|
145 |
+
pred_eval = rgb_to_y(pred_eval)
|
146 |
+
gt_eval = rgb_to_y(gt_eval)
|
147 |
+
value = model(pred_eval, gt_eval).item()
|
148 |
+
else:
|
149 |
+
value = model(pred).item()
|
150 |
+
results[name].append(value)
|
151 |
+
|
152 |
+
return {k: round(np.mean(v), 4) for k, v in results.items()}
|
153 |
+
|
154 |
+
|
155 |
+
def process(gt_root, pred_root, out_path, metrics, batch_mode, crop, test_y_channel, is_center):
|
156 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
157 |
+
print(f"Using device: {device}")
|
158 |
+
models = init_models(metrics, device)
|
159 |
+
|
160 |
+
has_gt = bool(gt_root and os.path.exists(gt_root))
|
161 |
+
|
162 |
+
if has_gt:
|
163 |
+
gt_files = {os.path.splitext(f)[0]: os.path.join(gt_root, f) for f in os.listdir(gt_root)}
|
164 |
+
pred_files = {os.path.splitext(f)[0]: os.path.join(pred_root, f) for f in os.listdir(pred_root)}
|
165 |
+
|
166 |
+
pred_names = sorted(pred_files.keys())
|
167 |
+
results = {}
|
168 |
+
aggregate = {metric: [] for metric in metrics}
|
169 |
+
|
170 |
+
for name in tqdm(pred_names, desc="Evaluating"):
|
171 |
+
# # valida
|
172 |
+
# name_hr = name.replace('_CAT_A_x4', '').replace('img_', 'img')
|
173 |
+
name_hr = name
|
174 |
+
if has_gt and name_hr not in gt_files:
|
175 |
+
print(f"Skipping {name_hr}: no matching GT file.")
|
176 |
+
continue
|
177 |
+
|
178 |
+
pred_path = pred_files[name]
|
179 |
+
gt_path = gt_files[name_hr] if has_gt else None
|
180 |
+
|
181 |
+
try:
|
182 |
+
pred_frames = load_sequence(pred_path)
|
183 |
+
|
184 |
+
if has_gt:
|
185 |
+
gt_frames = load_sequence(gt_path)
|
186 |
+
gt_frames, pred_frames = match_resolution(gt_frames, pred_frames, is_center=is_center, name=name)
|
187 |
+
scores = compute_metrics(pred_frames, gt_frames, models, device, batch_mode, crop, test_y_channel)
|
188 |
+
else:
|
189 |
+
nr_models = {k: v for k, v in models.items() if k not in fr_metrics}
|
190 |
+
if not nr_models:
|
191 |
+
print(f"Skipping {name}: GT is not provided and no NR-IQA metrics found.")
|
192 |
+
continue
|
193 |
+
dummy_gt = pred_frames
|
194 |
+
scores = compute_metrics(pred_frames, dummy_gt, nr_models, device, batch_mode, crop, test_y_channel)
|
195 |
+
|
196 |
+
results[name] = scores
|
197 |
+
for k in scores:
|
198 |
+
aggregate[k].append(scores[k])
|
199 |
+
except Exception as e:
|
200 |
+
print(f"Error processing {name}: {e}")
|
201 |
+
|
202 |
+
print("\nPer-sample Results:")
|
203 |
+
for name in sorted(results):
|
204 |
+
print(f"{name}: " + ", ".join(f"{k}={v:.4f}" for k, v in results[name].items()))
|
205 |
+
|
206 |
+
print("\nOverall Average Results:")
|
207 |
+
count = len(results)
|
208 |
+
if count > 0:
|
209 |
+
overall_avg = {k: round(np.mean(v), 4) for k, v in aggregate.items()}
|
210 |
+
for k, v in overall_avg.items():
|
211 |
+
print(f"{k.upper()}: {v:.4f}")
|
212 |
+
else:
|
213 |
+
overall_avg = {}
|
214 |
+
print("No valid samples were processed.")
|
215 |
+
|
216 |
+
print(f"\nProcessed {count} samples.")
|
217 |
+
|
218 |
+
output = {
|
219 |
+
"per_sample": results,
|
220 |
+
"average": overall_avg,
|
221 |
+
"count": count
|
222 |
+
}
|
223 |
+
|
224 |
+
os.makedirs(out_path, exist_ok=True)
|
225 |
+
out_name = 'metrics_'
|
226 |
+
for metric in metrics:
|
227 |
+
out_name += f"{metric}_"
|
228 |
+
out_name = out_name.rstrip('_') + '.json'
|
229 |
+
out_path = os.path.join(out_path, out_name)
|
230 |
+
|
231 |
+
with open(out_path, 'w') as f:
|
232 |
+
json.dump(output, f, indent=2)
|
233 |
+
|
234 |
+
print(f"Results saved to: {out_path}")
|
235 |
+
|
236 |
+
if __name__ == "__main__":
|
237 |
+
import argparse
|
238 |
+
parser = argparse.ArgumentParser()
|
239 |
+
parser.add_argument('--gt', type=str, default='', help='Path to GT folder (optional for NR-IQA)')
|
240 |
+
parser.add_argument('--pred', type=str, required=True, help='Path to predicted results folder')
|
241 |
+
parser.add_argument('--out', type=str, default='', help='Path to save JSON output (as directory)')
|
242 |
+
parser.add_argument('--metrics', type=str, default='psnr,ssim,clipiqa',
|
243 |
+
help='Comma-separated list of metrics: psnr,ssim,clipiqa,lpips,...')
|
244 |
+
parser.add_argument('--batch_mode', action='store_true', help='Use batch mode for metrics computation')
|
245 |
+
parser.add_argument('--crop', type=int, default=0, help='Crop border size for PSNR/SSIM')
|
246 |
+
parser.add_argument('--test_y_channel', action='store_true', help='Use Y channel for PSNR/SSIM')
|
247 |
+
parser.add_argument('--is_center', action='store_true', help='Use center crop for PSNR/SSIM')
|
248 |
+
|
249 |
+
args = parser.parse_args()
|
250 |
+
|
251 |
+
if args.out == '':
|
252 |
+
out = args.pred
|
253 |
+
else:
|
254 |
+
out = args.out
|
255 |
+
metric_list = [m.strip().lower() for m in args.metrics.split(',')]
|
256 |
+
process(args.gt, args.pred, out, metric_list, args.batch_mode, args.crop, args.test_y_channel, args.is_center)
|
inference.sh
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# UDM10
|
4 |
+
python inference_script.py \
|
5 |
+
--input_dir datasets/test/UDM10/LQ-Video \
|
6 |
+
--model_path pretrained_models/DOVE \
|
7 |
+
--output_path results/DOVE/UDM10 \
|
8 |
+
--is_vae_st \
|
9 |
+
|
10 |
+
python eval_metrics.py \
|
11 |
+
--gt datasets/test/UDM10/GT \
|
12 |
+
--pred results/DOVE/UDM10 \
|
13 |
+
--metrics psnr,ssim,lpips,dists,clipiqa
|
14 |
+
|
15 |
+
# SPMCS
|
16 |
+
python inference_script.py \
|
17 |
+
--input_dir datasets/test/SPMCS/LQ-Video \
|
18 |
+
--model_path pretrained_models/DOVE \
|
19 |
+
--output_path results/DOVE/SPMCS \
|
20 |
+
--is_vae_st \
|
21 |
+
|
22 |
+
python eval_metrics.py \
|
23 |
+
--gt datasets/test/UDM10/GT \
|
24 |
+
--pred results/DOVE/SPMCS \
|
25 |
+
--metrics psnr,ssim,lpips,dists,clipiqa
|
26 |
+
|
27 |
+
# YouHQ40
|
28 |
+
python inference_script.py \
|
29 |
+
--input_dir datasets/test/YouHQ40/LQ-Video \
|
30 |
+
--model_path pretrained_models/DOVE \
|
31 |
+
--output_path results/DOVE/YouHQ40 \
|
32 |
+
--is_vae_st \
|
33 |
+
|
34 |
+
python eval_metrics.py \
|
35 |
+
--gt datasets/test/UDM10/GT \
|
36 |
+
--pred results/DOVE/YouHQ40 \
|
37 |
+
--metrics psnr,ssim,lpips,dists,clipiqa
|
38 |
+
|
39 |
+
# RealVSR
|
40 |
+
python inference_script.py \
|
41 |
+
--input_dir datasets/test/RealVSR/LQ-Video \
|
42 |
+
--model_path pretrained_models/DOVE \
|
43 |
+
--output_path results/DOVE/RealVSR \
|
44 |
+
--is_vae_st \
|
45 |
+
--upscale 1 \
|
46 |
+
|
47 |
+
python eval_metrics.py \
|
48 |
+
--gt datasets/test/UDM10/GT \
|
49 |
+
--pred results/DOVE/RealVSR \
|
50 |
+
--metrics psnr,ssim,lpips,dists,clipiqa
|
51 |
+
|
52 |
+
# MVSR4x
|
53 |
+
python inference_script.py \
|
54 |
+
--input_dir datasets/test/MVSR4x/LQ-Video \
|
55 |
+
--model_path pretrained_models/DOVE \
|
56 |
+
--output_path results/DOVE/MVSR4x \
|
57 |
+
--is_vae_st \
|
58 |
+
--upscale 1 \
|
59 |
+
|
60 |
+
python eval_metrics.py \
|
61 |
+
--gt datasets/test/UDM10/GT \
|
62 |
+
--pred results/DOVE/MVSR4x \
|
63 |
+
--metrics psnr,ssim,lpips,dists,clipiqa
|
64 |
+
|
65 |
+
# VideoLQ
|
66 |
+
python inference_script.py \
|
67 |
+
--input_dir datasets/test/VideoLQ/LQ-Video \
|
68 |
+
--model_path pretrained_models/DOVE \
|
69 |
+
--output_path results/DOVE/VideoLQ \
|
70 |
+
--is_vae_st \
|
71 |
+
|
72 |
+
python eval_metrics.py \
|
73 |
+
--gt datasets/test/UDM10/GT \
|
74 |
+
--pred results/DOVE/VideoLQ \
|
75 |
+
--metrics clipiqa
|
inference_script.py
ADDED
@@ -0,0 +1,754 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import argparse
|
3 |
+
import logging
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torchvision import transforms
|
7 |
+
from torchvision.io import write_video
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from diffusers import (
|
11 |
+
CogVideoXDPMScheduler,
|
12 |
+
CogVideoXPipeline,
|
13 |
+
)
|
14 |
+
|
15 |
+
from transformers import set_seed
|
16 |
+
from typing import Dict, Tuple
|
17 |
+
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
18 |
+
|
19 |
+
import json
|
20 |
+
import os
|
21 |
+
import cv2
|
22 |
+
from PIL import Image
|
23 |
+
|
24 |
+
from pathlib import Path
|
25 |
+
import pyiqa
|
26 |
+
import imageio.v3 as iio
|
27 |
+
import glob
|
28 |
+
|
29 |
+
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
30 |
+
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
31 |
+
import decord # isort:skip
|
32 |
+
|
33 |
+
decord.bridge.set_bridge("torch")
|
34 |
+
|
35 |
+
logging.basicConfig(level=logging.INFO)
|
36 |
+
|
37 |
+
# 0 ~ 1
|
38 |
+
to_tensor = transforms.ToTensor()
|
39 |
+
video_exts = ['.mp4', '.avi', '.mov', '.mkv']
|
40 |
+
fr_metrics = ['psnr', 'ssim', 'lpips', 'dists']
|
41 |
+
|
42 |
+
|
43 |
+
def no_grad(func):
|
44 |
+
def wrapper(*args, **kwargs):
|
45 |
+
with torch.no_grad():
|
46 |
+
return func(*args, **kwargs)
|
47 |
+
return wrapper
|
48 |
+
|
49 |
+
|
50 |
+
def is_video_file(filename):
|
51 |
+
return any(filename.lower().endswith(ext) for ext in video_exts)
|
52 |
+
|
53 |
+
|
54 |
+
def read_video_frames(video_path):
|
55 |
+
cap = cv2.VideoCapture(video_path)
|
56 |
+
frames = []
|
57 |
+
while True:
|
58 |
+
ret, frame = cap.read()
|
59 |
+
if not ret:
|
60 |
+
break
|
61 |
+
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
62 |
+
frames.append(to_tensor(Image.fromarray(rgb)))
|
63 |
+
cap.release()
|
64 |
+
return torch.stack(frames)
|
65 |
+
|
66 |
+
|
67 |
+
def read_image_folder(folder_path):
|
68 |
+
image_files = sorted([
|
69 |
+
os.path.join(folder_path, f) for f in os.listdir(folder_path)
|
70 |
+
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
|
71 |
+
])
|
72 |
+
frames = [to_tensor(Image.open(p).convert("RGB")) for p in image_files]
|
73 |
+
return torch.stack(frames)
|
74 |
+
|
75 |
+
|
76 |
+
def load_sequence(path):
|
77 |
+
# return a tensor of shape [F, C, H, W] // 0, 1
|
78 |
+
if os.path.isdir(path):
|
79 |
+
return read_image_folder(path)
|
80 |
+
elif os.path.isfile(path):
|
81 |
+
if is_video_file(path):
|
82 |
+
return read_video_frames(path)
|
83 |
+
elif path.lower().endswith(('.png', '.jpg', '.jpeg')):
|
84 |
+
# Treat image as a single-frame video
|
85 |
+
img = to_tensor(Image.open(path).convert("RGB"))
|
86 |
+
return img.unsqueeze(0) # [1, C, H, W]
|
87 |
+
raise ValueError(f"Unsupported input: {path}")
|
88 |
+
|
89 |
+
@no_grad
|
90 |
+
def compute_metrics(pred_frames, gt_frames, metrics_model, metric_accumulator, file_name):
|
91 |
+
|
92 |
+
print(f"\n\n[{file_name}] Metrics:", end=" ")
|
93 |
+
for name, model in metrics_model.items():
|
94 |
+
scores = []
|
95 |
+
for i in range(pred_frames.shape[0]):
|
96 |
+
pred = pred_frames[i].unsqueeze(0)
|
97 |
+
if gt_frames != None:
|
98 |
+
gt = gt_frames[i].unsqueeze(0)
|
99 |
+
if name in fr_metrics:
|
100 |
+
score = model(pred, gt).item()
|
101 |
+
else:
|
102 |
+
score = model(pred).item()
|
103 |
+
scores.append(score)
|
104 |
+
val = sum(scores) / len(scores)
|
105 |
+
metric_accumulator[name].append(val)
|
106 |
+
print(f"{name.upper()}={val:.4f}", end=" ")
|
107 |
+
print()
|
108 |
+
|
109 |
+
|
110 |
+
def save_frames_as_png(video, output_dir, fps=8):
|
111 |
+
"""
|
112 |
+
Save video frames as PNG sequence.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
video (torch.Tensor): shape [B, C, F, H, W], float in [0, 1]
|
116 |
+
output_dir (str): directory to save PNG files
|
117 |
+
fps (int): kept for API compatibility
|
118 |
+
"""
|
119 |
+
video = video[0] # Remove batch dimension
|
120 |
+
video = video.permute(1, 2, 3, 0) # [F, H, W, C]
|
121 |
+
|
122 |
+
os.makedirs(output_dir, exist_ok=True)
|
123 |
+
frames = (video * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
|
124 |
+
|
125 |
+
for i, frame in enumerate(frames):
|
126 |
+
filename = os.path.join(output_dir, f"{i:03d}.png")
|
127 |
+
Image.fromarray(frame).save(filename)
|
128 |
+
|
129 |
+
|
130 |
+
def save_video_with_imageio_lossless(video, output_path, fps=8):
|
131 |
+
"""
|
132 |
+
Save a video tensor to .mkv using imageio.v3.imwrite with ffmpeg backend.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
video (torch.Tensor): shape [B, C, F, H, W], float in [0, 1]
|
136 |
+
output_path (str): where to save the .mkv file
|
137 |
+
fps (int): frames per second
|
138 |
+
"""
|
139 |
+
video = video[0]
|
140 |
+
video = video.permute(1, 2, 3, 0)
|
141 |
+
|
142 |
+
frames = (video * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
|
143 |
+
|
144 |
+
iio.imwrite(
|
145 |
+
output_path,
|
146 |
+
frames,
|
147 |
+
fps=fps,
|
148 |
+
codec='libx264rgb',
|
149 |
+
pixelformat='rgb24',
|
150 |
+
macro_block_size=None,
|
151 |
+
ffmpeg_params=['-crf', '0'],
|
152 |
+
)
|
153 |
+
|
154 |
+
|
155 |
+
def save_video_with_imageio(video, output_path, fps=8, format='yuv444p'):
|
156 |
+
"""
|
157 |
+
Save a video tensor to .mp4 using imageio.v3.imwrite with ffmpeg backend.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
video (torch.Tensor): shape [B, C, F, H, W], float in [0, 1]
|
161 |
+
output_path (str): where to save the .mp4 file
|
162 |
+
fps (int): frames per second
|
163 |
+
"""
|
164 |
+
video = video[0]
|
165 |
+
video = video.permute(1, 2, 3, 0)
|
166 |
+
|
167 |
+
frames = (video * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
|
168 |
+
|
169 |
+
if format == 'yuv444p':
|
170 |
+
iio.imwrite(
|
171 |
+
output_path,
|
172 |
+
frames,
|
173 |
+
fps=fps,
|
174 |
+
codec='libx264',
|
175 |
+
pixelformat='yuv444p',
|
176 |
+
macro_block_size=None,
|
177 |
+
ffmpeg_params=['-crf', '0'],
|
178 |
+
)
|
179 |
+
else:
|
180 |
+
iio.imwrite(
|
181 |
+
output_path,
|
182 |
+
frames,
|
183 |
+
fps=fps,
|
184 |
+
codec='libx264',
|
185 |
+
pixelformat='yuv420p',
|
186 |
+
macro_block_size=None,
|
187 |
+
ffmpeg_params=['-crf', '10'],
|
188 |
+
)
|
189 |
+
|
190 |
+
|
191 |
+
def preprocess_video_match(
|
192 |
+
video_path: Path | str,
|
193 |
+
is_match: bool = False,
|
194 |
+
) -> torch.Tensor:
|
195 |
+
"""
|
196 |
+
Loads a single video.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
video_path: Path to the video file.
|
200 |
+
Returns:
|
201 |
+
A torch.Tensor with shape [F, C, H, W] where:
|
202 |
+
F = number of frames
|
203 |
+
C = number of channels (3 for RGB)
|
204 |
+
H = height
|
205 |
+
W = width
|
206 |
+
"""
|
207 |
+
if isinstance(video_path, str):
|
208 |
+
video_path = Path(video_path)
|
209 |
+
video_reader = decord.VideoReader(uri=video_path.as_posix())
|
210 |
+
video_num_frames = len(video_reader)
|
211 |
+
frames = video_reader.get_batch(list(range(video_num_frames)))
|
212 |
+
F, H, W, C = frames.shape
|
213 |
+
original_shape = (F, H, W, C)
|
214 |
+
|
215 |
+
pad_f = 0
|
216 |
+
pad_h = 0
|
217 |
+
pad_w = 0
|
218 |
+
|
219 |
+
if is_match:
|
220 |
+
remainder = (F - 1) % 8
|
221 |
+
if remainder != 0:
|
222 |
+
last_frame = frames[-1:]
|
223 |
+
pad_f = 8 - remainder
|
224 |
+
repeated_frames = last_frame.repeat(pad_f, 1, 1, 1)
|
225 |
+
frames = torch.cat([frames, repeated_frames], dim=0)
|
226 |
+
|
227 |
+
pad_h = (16 - H % 16) % 16
|
228 |
+
pad_w = (16 - W % 16) % 16
|
229 |
+
if pad_h > 0 or pad_w > 0:
|
230 |
+
# pad = (w_left, w_right, h_top, h_bottom)
|
231 |
+
frames = torch.nn.functional.pad(frames, pad=(0, 0, 0, pad_w, 0, pad_h)) # pad right and bottom
|
232 |
+
|
233 |
+
# to F, C, H, W
|
234 |
+
return frames.float().permute(0, 3, 1, 2).contiguous(), pad_f, pad_h, pad_w, original_shape
|
235 |
+
|
236 |
+
|
237 |
+
def remove_padding_and_extra_frames(video, pad_F, pad_H, pad_W):
|
238 |
+
if pad_F > 0:
|
239 |
+
video = video[:, :, :-pad_F, :, :]
|
240 |
+
if pad_H > 0:
|
241 |
+
video = video[:, :, :, :-pad_H, :]
|
242 |
+
if pad_W > 0:
|
243 |
+
video = video[:, :, :, :, :-pad_W]
|
244 |
+
|
245 |
+
return video
|
246 |
+
|
247 |
+
|
248 |
+
def make_temporal_chunks(F, chunk_len, overlap_t=8):
|
249 |
+
"""
|
250 |
+
Args:
|
251 |
+
F: total number of frames
|
252 |
+
chunk_len: int, chunk length in time (excluding overlap)
|
253 |
+
overlap: int, number of overlapping frames between chunks
|
254 |
+
Returns:
|
255 |
+
time_chunks: List of (start_t, end_t) tuples
|
256 |
+
"""
|
257 |
+
if chunk_len == 0:
|
258 |
+
return [(0, F)]
|
259 |
+
|
260 |
+
effective_stride = chunk_len - overlap_t
|
261 |
+
if effective_stride <= 0:
|
262 |
+
raise ValueError("chunk_len must be greater than overlap")
|
263 |
+
|
264 |
+
chunk_starts = list(range(0, F - overlap_t, effective_stride))
|
265 |
+
if chunk_starts[-1] + chunk_len < F:
|
266 |
+
chunk_starts.append(F - chunk_len)
|
267 |
+
|
268 |
+
time_chunks = []
|
269 |
+
for i, t_start in enumerate(chunk_starts):
|
270 |
+
t_end = min(t_start + chunk_len, F)
|
271 |
+
time_chunks.append((t_start, t_end))
|
272 |
+
|
273 |
+
if len(time_chunks) >= 2 and time_chunks[-1][1] - time_chunks[-1][0] < chunk_len:
|
274 |
+
last = time_chunks.pop()
|
275 |
+
prev_start, _ = time_chunks[-1]
|
276 |
+
time_chunks[-1] = (prev_start, last[1])
|
277 |
+
|
278 |
+
return time_chunks
|
279 |
+
|
280 |
+
|
281 |
+
def make_spatial_tiles(H, W, tile_size_hw, overlap_hw=(32, 32)):
|
282 |
+
"""
|
283 |
+
Args:
|
284 |
+
H, W: height and width of the frame
|
285 |
+
tile_size_hw: Tuple (tile_height, tile_width)
|
286 |
+
overlap_hw: Tuple (overlap_height, overlap_width)
|
287 |
+
Returns:
|
288 |
+
spatial_tiles: List of (start_h, end_h, start_w, end_w) tuples
|
289 |
+
"""
|
290 |
+
tile_height, tile_width = tile_size_hw
|
291 |
+
overlap_h, overlap_w = overlap_hw
|
292 |
+
|
293 |
+
if tile_height == 0 or tile_width == 0:
|
294 |
+
return [(0, H, 0, W)]
|
295 |
+
|
296 |
+
tile_stride_h = tile_height - overlap_h
|
297 |
+
tile_stride_w = tile_width - overlap_w
|
298 |
+
|
299 |
+
if tile_stride_h <= 0 or tile_stride_w <= 0:
|
300 |
+
raise ValueError("Tile size must be greater than overlap")
|
301 |
+
|
302 |
+
h_tiles = list(range(0, H - overlap_h, tile_stride_h))
|
303 |
+
if not h_tiles or h_tiles[-1] + tile_height < H:
|
304 |
+
h_tiles.append(H - tile_height)
|
305 |
+
|
306 |
+
# Merge last row if needed
|
307 |
+
if len(h_tiles) >= 2 and h_tiles[-1] + tile_height > H:
|
308 |
+
h_tiles.pop()
|
309 |
+
|
310 |
+
w_tiles = list(range(0, W - overlap_w, tile_stride_w))
|
311 |
+
if not w_tiles or w_tiles[-1] + tile_width < W:
|
312 |
+
w_tiles.append(W - tile_width)
|
313 |
+
|
314 |
+
# Merge last column if needed
|
315 |
+
if len(w_tiles) >= 2 and w_tiles[-1] + tile_width > W:
|
316 |
+
w_tiles.pop()
|
317 |
+
|
318 |
+
spatial_tiles = []
|
319 |
+
for h_start in h_tiles:
|
320 |
+
h_end = min(h_start + tile_height, H)
|
321 |
+
if h_end + tile_stride_h > H:
|
322 |
+
h_end = H
|
323 |
+
for w_start in w_tiles:
|
324 |
+
w_end = min(w_start + tile_width, W)
|
325 |
+
if w_end + tile_stride_w > W:
|
326 |
+
w_end = W
|
327 |
+
spatial_tiles.append((h_start, h_end, w_start, w_end))
|
328 |
+
return spatial_tiles
|
329 |
+
|
330 |
+
|
331 |
+
def get_valid_tile_region(t_start, t_end, h_start, h_end, w_start, w_end,
|
332 |
+
video_shape, overlap_t, overlap_h, overlap_w):
|
333 |
+
_, _, F, H, W = video_shape
|
334 |
+
|
335 |
+
t_len = t_end - t_start
|
336 |
+
h_len = h_end - h_start
|
337 |
+
w_len = w_end - w_start
|
338 |
+
|
339 |
+
valid_t_start = 0 if t_start == 0 else overlap_t // 2
|
340 |
+
valid_t_end = t_len if t_end == F else t_len - overlap_t // 2
|
341 |
+
valid_h_start = 0 if h_start == 0 else overlap_h // 2
|
342 |
+
valid_h_end = h_len if h_end == H else h_len - overlap_h // 2
|
343 |
+
valid_w_start = 0 if w_start == 0 else overlap_w // 2
|
344 |
+
valid_w_end = w_len if w_end == W else w_len - overlap_w // 2
|
345 |
+
|
346 |
+
out_t_start = t_start + valid_t_start
|
347 |
+
out_t_end = t_start + valid_t_end
|
348 |
+
out_h_start = h_start + valid_h_start
|
349 |
+
out_h_end = h_start + valid_h_end
|
350 |
+
out_w_start = w_start + valid_w_start
|
351 |
+
out_w_end = w_start + valid_w_end
|
352 |
+
|
353 |
+
return {
|
354 |
+
"valid_t_start": valid_t_start, "valid_t_end": valid_t_end,
|
355 |
+
"valid_h_start": valid_h_start, "valid_h_end": valid_h_end,
|
356 |
+
"valid_w_start": valid_w_start, "valid_w_end": valid_w_end,
|
357 |
+
"out_t_start": out_t_start, "out_t_end": out_t_end,
|
358 |
+
"out_h_start": out_h_start, "out_h_end": out_h_end,
|
359 |
+
"out_w_start": out_w_start, "out_w_end": out_w_end,
|
360 |
+
}
|
361 |
+
|
362 |
+
|
363 |
+
def prepare_rotary_positional_embeddings(
|
364 |
+
height: int,
|
365 |
+
width: int,
|
366 |
+
num_frames: int,
|
367 |
+
transformer_config: Dict,
|
368 |
+
vae_scale_factor_spatial: int,
|
369 |
+
device: torch.device,
|
370 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
371 |
+
|
372 |
+
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
373 |
+
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
374 |
+
|
375 |
+
if transformer_config.patch_size_t is None:
|
376 |
+
base_num_frames = num_frames
|
377 |
+
else:
|
378 |
+
base_num_frames = (
|
379 |
+
num_frames + transformer_config.patch_size_t - 1
|
380 |
+
) // transformer_config.patch_size_t
|
381 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
382 |
+
embed_dim=transformer_config.attention_head_dim,
|
383 |
+
crops_coords=None,
|
384 |
+
grid_size=(grid_height, grid_width),
|
385 |
+
temporal_size=base_num_frames,
|
386 |
+
grid_type="slice",
|
387 |
+
max_size=(grid_height, grid_width),
|
388 |
+
device=device,
|
389 |
+
)
|
390 |
+
|
391 |
+
return freqs_cos, freqs_sin
|
392 |
+
|
393 |
+
@no_grad
|
394 |
+
def process_video(
|
395 |
+
pipe: CogVideoXPipeline,
|
396 |
+
video: torch.Tensor,
|
397 |
+
prompt: str = '',
|
398 |
+
noise_step: int = 0,
|
399 |
+
sr_noise_step: int = 399,
|
400 |
+
):
|
401 |
+
# SR the video frames based on the prompt.
|
402 |
+
# `num_frames` is the Number of frames to generate.
|
403 |
+
|
404 |
+
# Decode video
|
405 |
+
video = video.to(pipe.vae.device, dtype=pipe.vae.dtype)
|
406 |
+
latent_dist = pipe.vae.encode(video).latent_dist
|
407 |
+
latent = latent_dist.sample() * pipe.vae.config.scaling_factor
|
408 |
+
|
409 |
+
patch_size_t = pipe.transformer.config.patch_size_t
|
410 |
+
if patch_size_t is not None:
|
411 |
+
ncopy = latent.shape[2] % patch_size_t
|
412 |
+
# Copy the first frame ncopy times to match patch_size_t
|
413 |
+
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
|
414 |
+
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
|
415 |
+
|
416 |
+
assert latent.shape[2] % patch_size_t == 0
|
417 |
+
|
418 |
+
batch_size, num_channels, num_frames, height, width = latent.shape
|
419 |
+
|
420 |
+
# Get prompt embeddings
|
421 |
+
prompt_token_ids = pipe.tokenizer(
|
422 |
+
prompt,
|
423 |
+
padding="max_length",
|
424 |
+
max_length=pipe.transformer.config.max_text_seq_length,
|
425 |
+
truncation=True,
|
426 |
+
add_special_tokens=True,
|
427 |
+
return_tensors="pt",
|
428 |
+
)
|
429 |
+
prompt_token_ids = prompt_token_ids.input_ids
|
430 |
+
prompt_embedding = pipe.text_encoder(
|
431 |
+
prompt_token_ids.to(latent.device)
|
432 |
+
)[0]
|
433 |
+
_, seq_len, _ = prompt_embedding.shape
|
434 |
+
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype)
|
435 |
+
|
436 |
+
latent = latent.permute(0, 2, 1, 3, 4)
|
437 |
+
|
438 |
+
# Add noise to latent (Select)
|
439 |
+
if noise_step != 0:
|
440 |
+
noise = torch.randn_like(latent)
|
441 |
+
add_timesteps = torch.full(
|
442 |
+
(batch_size,),
|
443 |
+
fill_value=noise_step,
|
444 |
+
dtype=torch.long,
|
445 |
+
device=latent.device,
|
446 |
+
)
|
447 |
+
latent = pipe.scheduler.add_noise(latent, noise, add_timesteps)
|
448 |
+
|
449 |
+
timesteps = torch.full(
|
450 |
+
(batch_size,),
|
451 |
+
fill_value=sr_noise_step,
|
452 |
+
dtype=torch.long,
|
453 |
+
device=latent.device,
|
454 |
+
)
|
455 |
+
|
456 |
+
# Prepare rotary embeds
|
457 |
+
vae_scale_factor_spatial = 2 ** (len(pipe.vae.config.block_out_channels) - 1)
|
458 |
+
transformer_config = pipe.transformer.config
|
459 |
+
rotary_emb = (
|
460 |
+
prepare_rotary_positional_embeddings(
|
461 |
+
height=height * vae_scale_factor_spatial,
|
462 |
+
width=width * vae_scale_factor_spatial,
|
463 |
+
num_frames=num_frames,
|
464 |
+
transformer_config=transformer_config,
|
465 |
+
vae_scale_factor_spatial=vae_scale_factor_spatial,
|
466 |
+
device=latent.device,
|
467 |
+
)
|
468 |
+
if pipe.transformer.config.use_rotary_positional_embeddings
|
469 |
+
else None
|
470 |
+
)
|
471 |
+
|
472 |
+
# Predict noise
|
473 |
+
predicted_noise = pipe.transformer(
|
474 |
+
hidden_states=latent,
|
475 |
+
encoder_hidden_states=prompt_embedding,
|
476 |
+
timestep=timesteps,
|
477 |
+
image_rotary_emb=rotary_emb,
|
478 |
+
return_dict=False,
|
479 |
+
)[0]
|
480 |
+
|
481 |
+
latent_generate = pipe.scheduler.get_velocity(
|
482 |
+
predicted_noise, latent, timesteps
|
483 |
+
)
|
484 |
+
|
485 |
+
# generate video
|
486 |
+
if patch_size_t is not None and ncopy > 0:
|
487 |
+
latent_generate = latent_generate[:, ncopy:, :, :, :]
|
488 |
+
|
489 |
+
# [B, C, F, H, W]
|
490 |
+
video_generate = pipe.decode_latents(latent_generate)
|
491 |
+
video_generate = (video_generate * 0.5 + 0.5).clamp(0.0, 1.0)
|
492 |
+
|
493 |
+
return video_generate
|
494 |
+
|
495 |
+
|
496 |
+
if __name__ == "__main__":
|
497 |
+
parser = argparse.ArgumentParser(description="VSR using DOVE")
|
498 |
+
|
499 |
+
parser.add_argument("--input_dir", type=str)
|
500 |
+
|
501 |
+
parser.add_argument("--input_json", type=str, default=None)
|
502 |
+
|
503 |
+
parser.add_argument("--gt_dir", type=str, default=None)
|
504 |
+
|
505 |
+
parser.add_argument("--eval_metrics", type=str, default='') # 'psnr,ssim,lpips,dists,clipiqa,musiq,maniqa,niqe'
|
506 |
+
|
507 |
+
parser.add_argument("--model_path", type=str)
|
508 |
+
|
509 |
+
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
|
510 |
+
|
511 |
+
parser.add_argument("--output_path", type=str, default="./results", help="The path save generated video")
|
512 |
+
|
513 |
+
parser.add_argument("--fps", type=int, default=16, help="The frames per second for the generated video")
|
514 |
+
|
515 |
+
parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation")
|
516 |
+
|
517 |
+
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
|
518 |
+
|
519 |
+
parser.add_argument("--upscale_mode", type=str, default="bilinear")
|
520 |
+
|
521 |
+
parser.add_argument("--upscale", type=int, default=4)
|
522 |
+
|
523 |
+
parser.add_argument("--noise_step", type=int, default=0)
|
524 |
+
|
525 |
+
parser.add_argument("--sr_noise_step", type=int, default=399)
|
526 |
+
|
527 |
+
parser.add_argument("--is_cpu_offload", action="store_true", help="Enable CPU offload for the model")
|
528 |
+
|
529 |
+
parser.add_argument("--is_vae_st", action="store_true", help="Enable VAE slicing and tiling")
|
530 |
+
|
531 |
+
parser.add_argument("--png_save", action="store_true", help="Save output as PNG sequence")
|
532 |
+
|
533 |
+
parser.add_argument("--save_format", type=str, default="yuv444p", help="Save output as PNG sequence")
|
534 |
+
|
535 |
+
# Crop and Tiling Parameters
|
536 |
+
parser.add_argument("--tile_size_hw", type=int, nargs=2, default=(0, 0), help="Tile size for spatial tiling (height, width)")
|
537 |
+
|
538 |
+
parser.add_argument("--overlap_hw", type=int, nargs=2, default=(32, 32))
|
539 |
+
|
540 |
+
parser.add_argument("--chunk_len", type=int, default=0, help="Chunk length for temporal chunking")
|
541 |
+
|
542 |
+
parser.add_argument("--overlap_t", type=int, default=8)
|
543 |
+
|
544 |
+
args = parser.parse_args()
|
545 |
+
|
546 |
+
if args.dtype == "float16":
|
547 |
+
dtype = torch.float16
|
548 |
+
elif args.dtype == "bfloat16":
|
549 |
+
dtype = torch.bfloat16
|
550 |
+
elif args.dtype == "float32":
|
551 |
+
dtype = torch.float32
|
552 |
+
else:
|
553 |
+
raise ValueError("Invalid dtype. Choose from 'float16', 'bfloat16', or 'float32'.")
|
554 |
+
|
555 |
+
if args.chunk_len > 0:
|
556 |
+
print(f"Chunking video into {args.chunk_len} frames with {args.overlap_t} overlap")
|
557 |
+
overlap_t = args.overlap_t
|
558 |
+
else:
|
559 |
+
overlap_t = 0
|
560 |
+
if args.tile_size_hw != (0, 0):
|
561 |
+
print(f"Tiling video into {args.tile_size_hw} frames with {args.overlap_hw} overlap")
|
562 |
+
overlap_hw = args.overlap_hw
|
563 |
+
else:
|
564 |
+
overlap_hw = (0, 0)
|
565 |
+
|
566 |
+
# Set seed
|
567 |
+
set_seed(args.seed)
|
568 |
+
|
569 |
+
if args.input_json is not None:
|
570 |
+
with open(args.input_json, 'r') as f:
|
571 |
+
video_prompt_dict = json.load(f)
|
572 |
+
else:
|
573 |
+
video_prompt_dict = {}
|
574 |
+
|
575 |
+
# Get all video files from input directory
|
576 |
+
video_files = []
|
577 |
+
for ext in video_exts:
|
578 |
+
video_files.extend(glob.glob(os.path.join(args.input_dir, f'*{ext}')))
|
579 |
+
video_files = sorted(video_files) # Sort files for consistent ordering
|
580 |
+
|
581 |
+
if not video_files:
|
582 |
+
raise ValueError(f"No video files found in {args.input_dir}")
|
583 |
+
|
584 |
+
os.makedirs(args.output_path, exist_ok=True)
|
585 |
+
|
586 |
+
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
|
587 |
+
# add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload()
|
588 |
+
# function to use Multi GPUs.
|
589 |
+
|
590 |
+
pipe = CogVideoXPipeline.from_pretrained(args.model_path, torch_dtype=dtype)
|
591 |
+
|
592 |
+
# If you're using with lora, add this code
|
593 |
+
if args.lora_path:
|
594 |
+
print(f"Loading LoRA weights from {args.lora_path}")
|
595 |
+
pipe.load_lora_weights(
|
596 |
+
args.lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1"
|
597 |
+
)
|
598 |
+
pipe.fuse_lora(components=["transformer"], lora_scale=1.0) # lora_scale = lora_alpha / rank
|
599 |
+
|
600 |
+
# 2. Set Scheduler.
|
601 |
+
# Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`.
|
602 |
+
# We recommend using `CogVideoXDDIMScheduler` for CogVideoX-2B.
|
603 |
+
# using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V.
|
604 |
+
|
605 |
+
# pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
606 |
+
pipe.scheduler = CogVideoXDPMScheduler.from_config(
|
607 |
+
pipe.scheduler.config, timestep_spacing="trailing"
|
608 |
+
)
|
609 |
+
|
610 |
+
# 3. Enable CPU offload for the model.
|
611 |
+
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
|
612 |
+
# and enable to("cuda")
|
613 |
+
|
614 |
+
if args.is_cpu_offload:
|
615 |
+
# pipe.enable_model_cpu_offload()
|
616 |
+
pipe.enable_sequential_cpu_offload()
|
617 |
+
else:
|
618 |
+
pipe.to("cuda")
|
619 |
+
|
620 |
+
if args.is_vae_st:
|
621 |
+
pipe.vae.enable_slicing()
|
622 |
+
pipe.vae.enable_tiling()
|
623 |
+
|
624 |
+
# pipe.transformer.eval()
|
625 |
+
# torch.set_grad_enabled(False)
|
626 |
+
|
627 |
+
# 4. Set the metircs
|
628 |
+
if args.eval_metrics != '':
|
629 |
+
metrics_list = [m.strip().lower() for m in args.eval_metrics.split(',')]
|
630 |
+
metrics_models = {}
|
631 |
+
for name in metrics_list:
|
632 |
+
try:
|
633 |
+
metrics_models[name] = pyiqa.create_metric(name).to(pipe.device).eval()
|
634 |
+
except Exception as e:
|
635 |
+
print(f"Failed to initialize metric '{name}': {e}")
|
636 |
+
metric_accumulator = {name: [] for name in metrics_list}
|
637 |
+
else:
|
638 |
+
metrics_models = None
|
639 |
+
metric_accumulator = None
|
640 |
+
|
641 |
+
for video_path in tqdm(video_files, desc="Processing videos"):
|
642 |
+
video_name = os.path.basename(video_path)
|
643 |
+
prompt = video_prompt_dict.get(video_name, "")
|
644 |
+
if os.path.exists(video_path):
|
645 |
+
# Read video
|
646 |
+
# [F, C, H, W]
|
647 |
+
video, pad_f, pad_h, pad_w, original_shape = preprocess_video_match(video_path, is_match=True)
|
648 |
+
H_, W_ = video.shape[2], video.shape[3]
|
649 |
+
video = torch.nn.functional.interpolate(video, size=(H_*args.upscale, W_*args.upscale), mode=args.upscale_mode, align_corners=False)
|
650 |
+
__frame_transform = transforms.Compose(
|
651 |
+
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)] # -1, 1
|
652 |
+
)
|
653 |
+
video = torch.stack([__frame_transform(f) for f in video], dim=0)
|
654 |
+
video = video.unsqueeze(0)
|
655 |
+
# [B, C, F, H, W]
|
656 |
+
video = video.permute(0, 2, 1, 3, 4).contiguous()
|
657 |
+
|
658 |
+
_B, _C, _F, _H, _W = video.shape
|
659 |
+
time_chunks = make_temporal_chunks(_F, args.chunk_len, overlap_t)
|
660 |
+
spatial_tiles = make_spatial_tiles(_H, _W, args.tile_size_hw, overlap_hw)
|
661 |
+
|
662 |
+
output_video = torch.zeros_like(video)
|
663 |
+
write_count = torch.zeros_like(video, dtype=torch.int)
|
664 |
+
|
665 |
+
print(f"Process video: {video_name} | Prompt: {prompt} | Frame: {_F} (ori: {original_shape[0]}; pad: {pad_f}) | Target Resolution: {_H}, {_W} (ori: {original_shape[1]*args.upscale}, {original_shape[2]*args.upscale}; pad: {pad_h}, {pad_w}) | Chunk Num: {len(time_chunks)*len(spatial_tiles)}")
|
666 |
+
|
667 |
+
for t_start, t_end in time_chunks:
|
668 |
+
for h_start, h_end, w_start, w_end in spatial_tiles:
|
669 |
+
video_chunk = video[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
670 |
+
# print(f"video_chunk: {video_chunk.shape} | t: {t_start}:{t_end} | h: {h_start}:{h_end} | w: {w_start}:{w_end}")
|
671 |
+
|
672 |
+
# [B, C, F, H, W]
|
673 |
+
_video_generate = process_video(
|
674 |
+
pipe=pipe,
|
675 |
+
video=video_chunk,
|
676 |
+
prompt=prompt,
|
677 |
+
noise_step=args.noise_step,
|
678 |
+
sr_noise_step=args.sr_noise_step,
|
679 |
+
)
|
680 |
+
|
681 |
+
region = get_valid_tile_region(
|
682 |
+
t_start, t_end, h_start, h_end, w_start, w_end,
|
683 |
+
video_shape=video.shape,
|
684 |
+
overlap_t=overlap_t,
|
685 |
+
overlap_h=overlap_hw[0],
|
686 |
+
overlap_w=overlap_hw[1],
|
687 |
+
)
|
688 |
+
output_video[:, :, region["out_t_start"]:region["out_t_end"],
|
689 |
+
region["out_h_start"]:region["out_h_end"],
|
690 |
+
region["out_w_start"]:region["out_w_end"]] = \
|
691 |
+
_video_generate[:, :, region["valid_t_start"]:region["valid_t_end"],
|
692 |
+
region["valid_h_start"]:region["valid_h_end"],
|
693 |
+
region["valid_w_start"]:region["valid_w_end"]]
|
694 |
+
write_count[:, :, region["out_t_start"]:region["out_t_end"],
|
695 |
+
region["out_h_start"]:region["out_h_end"],
|
696 |
+
region["out_w_start"]:region["out_w_end"]] += 1
|
697 |
+
|
698 |
+
video_generate = output_video
|
699 |
+
|
700 |
+
if (write_count == 0).any():
|
701 |
+
print("Error: Lack of write in region !!!")
|
702 |
+
exit()
|
703 |
+
if (write_count > 1).any():
|
704 |
+
print("Error: Write count > 1 in region !!!")
|
705 |
+
exit()
|
706 |
+
|
707 |
+
video_generate = remove_padding_and_extra_frames(video_generate, pad_f, pad_h*4, pad_w*4)
|
708 |
+
file_name = os.path.basename(video_path)
|
709 |
+
output_path = os.path.join(args.output_path, file_name)
|
710 |
+
|
711 |
+
if metrics_models is not None:
|
712 |
+
# [1, C, F, H, W] -> [F, C, H, W]
|
713 |
+
pred_frames = video_generate[0]
|
714 |
+
pred_frames = pred_frames.permute(1, 0, 2, 3).contiguous()
|
715 |
+
if args.gt_dir is not None:
|
716 |
+
gt_frames = load_sequence(os.path.join(args.gt_dir, file_name))
|
717 |
+
else:
|
718 |
+
gt_frames = None
|
719 |
+
compute_metrics(pred_frames, gt_frames, metrics_models, metric_accumulator, file_name)
|
720 |
+
|
721 |
+
if args.png_save:
|
722 |
+
# Save as PNG sequence
|
723 |
+
output_dir = output_path.rsplit('.', 1)[0] # Remove extension
|
724 |
+
save_frames_as_png(video_generate, output_dir, fps=args.fps)
|
725 |
+
else:
|
726 |
+
output_path = output_path.replace('.mkv', '.mp4')
|
727 |
+
save_video_with_imageio(video_generate, output_path, fps=args.fps, format=args.save_format)
|
728 |
+
else:
|
729 |
+
print(f"Warning: {video_name} not found in {args.input_dir}")
|
730 |
+
|
731 |
+
if metrics_models is not None:
|
732 |
+
print("\n=== Overall Average Metrics ===")
|
733 |
+
count = len(next(iter(metric_accumulator.values())))
|
734 |
+
overall_avg = {metric: 0 for metric in metrics_list}
|
735 |
+
out_name = 'metrics_'
|
736 |
+
for metric in metrics_list:
|
737 |
+
out_name += f"{metric}_"
|
738 |
+
scores = metric_accumulator[metric]
|
739 |
+
if scores:
|
740 |
+
avg = sum(scores) / len(scores)
|
741 |
+
overall_avg[metric] = avg
|
742 |
+
print(f"{metric.upper()}: {avg:.4f}")
|
743 |
+
|
744 |
+
out_name = out_name.rstrip('_') + '.json'
|
745 |
+
out_path = os.path.join(args.output_path, out_name)
|
746 |
+
output = {
|
747 |
+
"per_sample": metric_accumulator,
|
748 |
+
"average": overall_avg,
|
749 |
+
"count": count
|
750 |
+
}
|
751 |
+
with open(out_path, 'w') as f:
|
752 |
+
json.dump(output, f, indent=2)
|
753 |
+
|
754 |
+
print("All videos processed.")
|
pretrained_models/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Place pretrained models here.
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate>=1.1.1
|
2 |
+
transformers>=4.46.2
|
3 |
+
numpy==1.26.0
|
4 |
+
torch>=2.5.0
|
5 |
+
torchvision>=0.20.0
|
6 |
+
sentencepiece>=0.2.0
|
7 |
+
SwissArmyTransformer>=0.4.12
|
8 |
+
gradio>=5.5.0
|
9 |
+
imageio>=2.35.1
|
10 |
+
imageio-ffmpeg>=0.5.1
|
11 |
+
openai>=1.54.0
|
12 |
+
moviepy>=2.0.0
|
13 |
+
scikit-video>=1.1.11
|
14 |
+
pydantic>=2.10.3
|
15 |
+
wandb
|
16 |
+
peft
|
17 |
+
opencv-python
|
18 |
+
decord
|
19 |
+
av
|
20 |
+
torchdiffeq
|