roll-ai commited on
Commit
49f0d37
·
verified ·
1 Parent(s): e942da4

Upload 27 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/Compare.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/Pipeline.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/Qualitative-1.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/Qualitative-2-1.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/Qualitative-2-2.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/Qualitative-3-1.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/Qualitative-3-2.png filter=lfs diff=lfs merge=lfs -text
43
+ assets/Qualitative-4-1.png filter=lfs diff=lfs merge=lfs -text
44
+ assets/Qualitative-4-2.png filter=lfs diff=lfs merge=lfs -text
45
+ assets/Qualitative-5-1.png filter=lfs diff=lfs merge=lfs -text
46
+ assets/Qualitative-5-2.png filter=lfs diff=lfs merge=lfs -text
47
+ assets/Quantitative.png filter=lfs diff=lfs merge=lfs -text
48
+ assets/Strategy.png filter=lfs diff=lfs merge=lfs -text
49
+ datasets/demo/005.mp4 filter=lfs diff=lfs merge=lfs -text
50
+ datasets/demo/006.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Dove
3
- emoji:
4
- colorFrom: purple
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.35.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DOVE: Efficient One-Step Diffusion Model for Real-World Video Super-Resolution
2
+
3
+ [Zheng Chen](https://zhengchen1999.github.io/), [Zichen Zou](https://github.com/zzctmd), [Kewei Zhang](), [Xiongfei Su](https://ieeexplore.ieee.org/author/37086348852), [Xin Yuan](https://en.westlake.edu.cn/faculty/xin-yuan.html), [Yong Guo](https://www.guoyongcs.com/), and [Yulun Zhang](http://yulunzhang.com/), "DOVE: Efficient One-Step Diffusion Model for Real-World Video Super-Resolution", 2025
4
+
5
+ <div>
6
+ <a href="https://github.com/zhengchen1999/DOVE/releases" target='_blank' style="text-decoration: none;"><img src="https://img.shields.io/github/downloads/zhengchen1999/DOVE/total?color=green&style=flat"></a>
7
+ <a href="https://github.com/zhengchen1999/DOVE" target='_blank' style="text-decoration: none;"><img src="https://visitor-badge.laobi.icu/badge?page_id=zhengchen1999/DOVE"></a>
8
+ <a href="https://github.com/zhengchen1999/DOVE/stargazers" target='_blank' style="text-decoration: none;"><img src="https://img.shields.io/github/stars/zhengchen1999/DOVE?style=social"></a>
9
+ </div>
10
+
11
+
12
+ [[arXiv](https://arxiv.org/abs/2505.16239)] [[supplementary material](https://github.com/zhengchen1999/DOVE/releases/download/v1/Supplementary_Material.pdf)] [[dataset](https://drive.google.com/drive/folders/1e7CyNzfJBa2saWvPr2HI2q_FJhLIc-Ww?usp=drive_link)] [[pretrained models](https://drive.google.com/drive/folders/1wj9jY0fn6prSWJ7BjJOXfxC0bs8skKbQ?usp=sharing)]
13
+
14
+
15
+
16
+ #### 🔥🔥🔥 News
17
+
18
+ - **2025-6-09:** Test datasets, inference scripts, and pretrained models are available. ⭐️⭐️⭐️
19
+ - **2025-5-22:** This repo is released.
20
+
21
  ---
22
+
23
+ > **Abstract:** Diffusion models have demonstrated promising performance in real-world video super-resolution (VSR). However, the dozens of sampling steps they require, make inference extremely slow. Sampling acceleration techniques, particularly single-step, provide a potential solution. Nonetheless, achieving one step in VSR remains challenging, due to the high training overhead on video data and stringent fidelity demands. To tackle the above issues, we propose DOVE, an efficient one-step diffusion model for real-world VSR. DOVE is obtained by fine-tuning a pretrained video diffusion model (*i.e.*, CogVideoX). To effectively train DOVE, we introduce the latent–pixel training strategy. The strategy employs a two-stage scheme to gradually adapt the model to the video super-resolution task.
24
+ > Meanwhile, we design a video processing pipeline to construct a high-quality dataset tailored for VSR, termed HQ-VSR. Fine-tuning on this dataset further enhances the restoration capability of DOVE. Extensive experiments show that DOVE exhibits comparable or superior performance to multi-step diffusion-based VSR methods. It also offers outstanding inference efficiency, achieving up to a **28×** speed-up over existing methods such as MGLD-VSR.
25
+
26
+ ![](./assets/Compare.png)
27
+
 
 
 
28
  ---
29
 
30
+
31
+
32
+ <table border="0" style="width: 100%; text-align: center; margin-top: 20px;">
33
+ <tr>
34
+ <td>
35
+ <video src="https://github.com/user-attachments/assets/4ad0ca78-6cca-48c0-95a5-5d5554093f7d" controls autoplay loop></video>
36
+ </td>
37
+ <td>
38
+ <video src="https://github.com/user-attachments/assets/e5b5d247-28af-43fd-b32c-1f1b5896d9e7" controls autoplay loop></video>
39
+ </td>
40
+ </tr>
41
+ </table>
42
+
43
+
44
+
45
+
46
+ ---
47
+
48
+ ### Training Strategy
49
+
50
+ ![](./assets/Strategy.png)
51
+
52
+ ---
53
+
54
+ ### Video Processing Pipeline
55
+
56
+ ![](./assets/Pipeline.png)
57
+
58
+
59
+
60
+
61
+ ## 🔖 TODO
62
+
63
+ - [x] Release testing code.
64
+ - [x] Release pre-trained models.
65
+ - [ ] Release training code.
66
+ - [ ] Release video processing pipeline.
67
+ - [ ] Release HQ-VSR dataset.
68
+ - [ ] Provide WebUI.
69
+ - [ ] Provide HuggingFace demo.
70
+
71
+ ## ⚙️ Dependencies
72
+
73
+ - Python 3.11
74
+ - PyTorch\>=2.5.0
75
+ - Diffusers
76
+
77
+ ```bash
78
+ # Clone the github repo and go to the default directory 'DOVE'.
79
+ git clone https://github.com/zhengchen1999/DOVE.git
80
+ conda create -n DOVE python=3.11
81
+ conda activate DOVE
82
+ pip install -r requirements.txt
83
+ pip install diffusers["torch"] transformers
84
+ pip install pyiqa
85
+ ```
86
+
87
+ ## 🔗 Contents
88
+
89
+ 1. [Datasets](#datasets)
90
+ 1. [Models](#models)
91
+ 1. Training
92
+ 1. [Testing](#testing)
93
+ 1. [Results](#results)
94
+ 1. [Acknowledgements](#acknowledgements)
95
+
96
+ ## <a name="datasets"></a>📁 Datasets
97
+
98
+ ### 🗳️ Test Datasets
99
+
100
+ We provide several real-world and synthetic test datasets for evaluation. All datasets follow a consistent directory structure:
101
+
102
+ | Dataset | Type | # Num | Download |
103
+ | :------ | :--------: | :---: | :----------------------------------------------------------: |
104
+ | UDM10 | Synthetic | 10 | [Google Drive](https://drive.google.com/file/d/1AmGVSCwMm_OFPd3DKgNyTwj0GG2H-tG4/view?usp=drive_link) |
105
+ | SPMCS | Synthetic | 30 | [Google Drive](https://drive.google.com/file/d/1b2uktCFPKS-R1fTecWcLFcOnmUFIBNWT/view?usp=drive_link) |
106
+ | YouHQ40 | Synthetic | 40 | [Google Drive](https://drive.google.com/file/d/1zO23UCStxL3htPJQcDUUnUeMvDrysLTh/view?usp=sharing) |
107
+ | RealVSR | Real-world | 50 | [Google Drive](https://drive.google.com/file/d/1wr4tTiCvQlqdYPeU1dmnjb5KFY4VjGCO/view?usp=drive_link) |
108
+ | MVSR4x | Real-world | 15 | [Google Drive](https://drive.google.com/file/d/16sesBD_9Xx_5Grtx18nosBw1w94KlpQt/view?usp=drive_link) |
109
+ | VideoLQ | Real-world | 50 | [Google Drive](https://drive.google.com/file/d/1lh0vkU_llxE0un1OigJ0DWPQwt1i68Vn/view?usp=drive_link) |
110
+
111
+ All datasets are hosted on [here](https://drive.google.com/drive/folders/1yNKG6rtTNtZQY8qL74GoQwA0jgjBUEby?usp=sharing). Make sure the path is correct (`datasets/test/`) before running inference.
112
+
113
+ The directory structure is as follows:
114
+
115
+ ```shell
116
+ datasets/
117
+ └── test/
118
+ └── [DatasetName]/
119
+ ├── GT/ # Ground Truth: folder of high-quality frames (one per clip)
120
+ ├── GT-Video/ # Ground Truth (video version): lossless MKV format
121
+ ├── LQ/ # Low-quality Input: folder of degraded frames (one per clip)
122
+ └── LQ-Video/ # Low-Quality Input (video version): lossless MKV format
123
+ ```
124
+
125
+ ## <a name="models"></a>📦 Models
126
+
127
+ We provide pretrained weights for DOVE and DOVE-2B.
128
+
129
+ | Model Name | Description | HuggingFace | Google Drive | Visual Results |
130
+ | :--------- | :-------------------------------------: | :---------: | :----------------------------------------------------------: | ------------------------------------------------------------ |
131
+ | DOVE | Base version, built on CogVideoX1.5-5B; | TODO | [Download](https://drive.google.com/file/d/1Nl3XoJndMtpu6KPFcskUTkI0qWBiSXF2/view?usp=drive_link) | [Download](https://drive.google.com/drive/folders/1J92X1amVijH9dNWGQcz-6Cx44B7EipWr?usp=drive_link) |
132
+ | DOVE-2B | Smaller version, based on CogVideoX-2B | TODO | TODO | TODO |
133
+
134
+ > Place downloaded model files into the `pretrained_models/` folder, e.g., `pretrained_models/DOVE`.
135
+
136
+ ## <a name="testing"></a>🔨 Testing
137
+
138
+ We provide inference commands below. Before running, make sure to download the corresponding pretrained models and test datasets.
139
+
140
+ For more options and usage, please refer to [inference_script.py](inference_script.py).
141
+
142
+ The full testing commands are provided in the shell script: [inference.sh](inference.sh).
143
+
144
+ ```shell
145
+ # 🔹 Demo inference
146
+ python inference_script.py \
147
+ --input_dir datasets/demo \
148
+ --model_path pretrained_models/DOVE \
149
+ --output_path results/DOVE/demo \
150
+ --is_vae_st \
151
+ --save_format yuv420p
152
+
153
+ # 🔹 Reproduce paper results
154
+ python inference_script.py \
155
+ --input_dir datasets/test/UDM10/LQ-Video \
156
+ --model_path pretrained_models/DOVE \
157
+ --output_path results/DOVE/UDM10 \
158
+ --is_vae_st \
159
+
160
+ # 🔹 Evaluate quantitative metrics
161
+ python eval_metrics.py \
162
+ --gt datasets/test/UDM10/GT \
163
+ --pred results/DOVE/UDM10 \
164
+ --metrics psnr,ssim,lpips,dists,clipiqa
165
+ ```
166
+
167
+ > 💡 If you encounter out-of-memory (OOM) issues, you can enable chunk-based testing by setting the following parameters: tile_size_hw, overlap_hw, chunk_len, and overlap_t.
168
+ >
169
+ > 💡 Default save format is `yuv444p`. If playback fails, try `save_format=yuv420p` (may slightly affect metrics).
170
+ >
171
+ > **TODO:** Add metric computation scripts for FasterVQA, DOVER, and $E^*_{warp}$.
172
+
173
+ ## <a name="results"></a>🔎 Results
174
+
175
+ We achieve state-of-the-art performance on real-world video super-resolution. Visual results are available at [Google Drive](https://drive.google.com/drive/folders/1J92X1amVijH9dNWGQcz-6Cx44B7EipWr?usp=drive_link).
176
+
177
+ <details open>
178
+ <summary>Quantitative Results (click to expand)</summary>
179
+
180
+ - Results in Tab. 2 of the main paper
181
+
182
+ <p align="center">
183
+ <img width="900" src="assets/Quantitative.png">
184
+ </p>
185
+
186
+ </details>
187
+
188
+ <details open>
189
+ <summary>Qualitative Results (click to expand)</summary>
190
+
191
+ - Results in Fig. 4 of the main paper
192
+
193
+ <p align="center">
194
+ <img width="900" src="assets/Qualitative-1.png">
195
+ </p>
196
+ <details>
197
+ <summary>More Qualitative Results</summary>
198
+
199
+
200
+
201
+
202
+ - More results in Fig. 3 of the supplementary material
203
+
204
+ <p align="center">
205
+ <img width="900" src="assets/Qualitative-2-1.png">
206
+ </p>
207
+
208
+
209
+
210
+ - More results in Fig. 4 of the supplementary material
211
+
212
+ <p align="center">
213
+ <img width="900" src="assets/Qualitative-2-2.png">
214
+ </p>
215
+
216
+
217
+ - More results in Fig. 5 of the supplementary material
218
+
219
+ <p align="center">
220
+ <img width="900" src="assets/Qualitative-3-1.png">
221
+ <img width="900" src="assets/Qualitative-3-2.png">
222
+ </p>
223
+
224
+
225
+ - More results in Fig. 6 of the supplementary material
226
+
227
+ <p align="center">
228
+ <img width="900" src="assets/Qualitative-4-1.png">
229
+ <img width="900" src="assets/Qualitative-4-2.png">
230
+ </p>
231
+
232
+
233
+ - More results in Fig. 7 of the supplementary material
234
+
235
+ <p align="center">
236
+ <img width="900" src="assets/Qualitative-5-1.png">
237
+ <img width="900" src="assets/Qualitative-5-2.png">
238
+ </p>
239
+
240
+ </details>
241
+
242
+ </details>
243
+
244
+ ## <a name="citation"></a>📎 Citation
245
+
246
+ If you find the code helpful in your research or work, please cite the following paper(s).
247
+
248
+ ```
249
+ @article{chen2025dove,
250
+ title={DOVE: Efficient One-Step Diffusion Model for Real-World Video Super-Resolution},
251
+ author={Chen, Zheng and Zou, Zichen and Zhang, Kewei and Su, Xiongfei and Yuan, Xin and Guo, Yong and Zhang, Yulun},
252
+ journal={arXiv preprint arXiv:2505.16239},
253
+ year={2025}
254
+ }
255
+ ```
256
+
257
+ ## <a name="acknowledgements"></a>💡 Acknowledgements
258
+
259
+ This project is based on [CogVideo](https://github.com/THUDM/CogVideo) and [Open-Sora](https://github.com/hpcaitech/Open-Sora).
260
+
assets/Compare.png ADDED

Git LFS Details

  • SHA256: ac8b786acbad04433ad7aec57b9502d7db88ba982a8cfb5a99a299a8abb4839c
  • Pointer size: 132 Bytes
  • Size of remote file: 5.29 MB
assets/Pipeline.png ADDED

Git LFS Details

  • SHA256: bff7a4f0dea33326d5d6a757ee12ab4b647b423a83cf24e1771a60f584ac9bdd
  • Pointer size: 132 Bytes
  • Size of remote file: 5.28 MB
assets/Qualitative-1.png ADDED

Git LFS Details

  • SHA256: 54b239dec2d0e98820ef364e7836e4497b9961ed867cfaa942d4b22b82128a4a
  • Pointer size: 132 Bytes
  • Size of remote file: 9.42 MB
assets/Qualitative-2-1.png ADDED

Git LFS Details

  • SHA256: 81beb758072810952a99e1959cbbff89d980011895917f4fd550b803d406d3f1
  • Pointer size: 132 Bytes
  • Size of remote file: 6.53 MB
assets/Qualitative-2-2.png ADDED

Git LFS Details

  • SHA256: 7f98f52085042d95ebfe1ae61defdae1214cd775492163bbfbad6d5b951422df
  • Pointer size: 132 Bytes
  • Size of remote file: 7.66 MB
assets/Qualitative-3-1.png ADDED

Git LFS Details

  • SHA256: 73d2a09a2ce624a1db5810cc8abd89ee5a62a850a7713c6037f1ae6664330091
  • Pointer size: 132 Bytes
  • Size of remote file: 6.45 MB
assets/Qualitative-3-2.png ADDED

Git LFS Details

  • SHA256: 4714918f7f05801c30d9ed943d922e57b7441fb0325255983ed88b60f90a8370
  • Pointer size: 132 Bytes
  • Size of remote file: 5.1 MB
assets/Qualitative-4-1.png ADDED

Git LFS Details

  • SHA256: 3dc6d0b452f3eaa2b71a4d2778863b6b81d7bd5755538354cfa5992e244aafd3
  • Pointer size: 132 Bytes
  • Size of remote file: 4.6 MB
assets/Qualitative-4-2.png ADDED

Git LFS Details

  • SHA256: fd7055d2ba23242b0a5bfc960df40d8dffd6b8620009647b74306fcc530e4493
  • Pointer size: 132 Bytes
  • Size of remote file: 6.21 MB
assets/Qualitative-5-1.png ADDED

Git LFS Details

  • SHA256: fd17c8fcc1441b9f48eab8879db02aacbe8b3e2bea2cd0e66605d01ab254483e
  • Pointer size: 132 Bytes
  • Size of remote file: 6.13 MB
assets/Qualitative-5-2.png ADDED

Git LFS Details

  • SHA256: ce7e4dc356f9270748512b22df6c643f8e6f597fe9d47766a1e9f1852cddc462
  • Pointer size: 132 Bytes
  • Size of remote file: 5.08 MB
assets/Quantitative.png ADDED

Git LFS Details

  • SHA256: 2f1c45636885b196986ac0646a719488ac73e939696b46830151f65836231c9c
  • Pointer size: 131 Bytes
  • Size of remote file: 830 kB
assets/Strategy.png ADDED

Git LFS Details

  • SHA256: b3c3a3e382cf96dfb2d36f766ff4662018cf6042a1e02555fa359f44aad98f4d
  • Pointer size: 132 Bytes
  • Size of remote file: 4.58 MB
datasets/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The directory structure is as follows:
2
+
3
+ ```shell
4
+ datasets/
5
+ └── demo/
6
+ └── test/
7
+ └── [DatasetName]/
8
+ ├── GT/ # Ground Truth: folder of high-quality frames (one per clip)
9
+ ├── GT-Video/ # Ground Truth (video version): lossless MKV format
10
+ ├── LQ/ # Low-quality Input: folder of degraded frames (one per clip)
11
+ └── LQ-Video/ # Low-Quality Input (video version): lossless MKV format
12
+ ```
13
+
14
+ All datasets are available [here](https://drive.google.com/drive/folders/1yNKG6rtTNtZQY8qL74GoQwA0jgjBUEby?usp=sharing).
datasets/demo/001.mp4 ADDED
Binary file (62.5 kB). View file
 
datasets/demo/002.mp4 ADDED
Binary file (97.3 kB). View file
 
datasets/demo/003.mp4 ADDED
Binary file (60.4 kB). View file
 
datasets/demo/004.mp4 ADDED
Binary file (79.6 kB). View file
 
datasets/demo/005.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de2fe395e78a9d556a3763d7a7bdf87102e0c1191e1d146a54d487f78a57d708
3
+ size 268870
datasets/demo/006.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:194615754e590bc57e84f41a11b7bca4d564455ed93b72be700a6300119d34ac
3
+ size 206748
datasets/demo/007.mp4 ADDED
Binary file (48.4 kB). View file
 
eval_metrics.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import torch
5
+ import pyiqa
6
+ import numpy as np
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+ from torchvision import transforms
10
+
11
+ # 0 ~ 1
12
+ to_tensor = transforms.ToTensor()
13
+ video_exts = ['.mp4', '.avi', '.mov', '.mkv']
14
+ fr_metrics = ['psnr', 'ssim', 'lpips', 'dists']
15
+
16
+
17
+ def is_video_file(filename):
18
+ return any(filename.lower().endswith(ext) for ext in video_exts)
19
+
20
+ def rgb_to_y(img):
21
+ # Assumes img is [1, 3, H, W] in [0,1], returns [1, 1, H, W]
22
+ r, g, b = img[:, 0:1], img[:, 1:2], img[:, 2:3]
23
+ y = 0.257 * r + 0.504 * g + 0.098 * b + 0.0625
24
+ return y
25
+
26
+ def crop_border(img, crop):
27
+ return img[:, :, crop:-crop, crop:-crop]
28
+
29
+ def read_video_frames(video_path):
30
+ cap = cv2.VideoCapture(video_path)
31
+ frames = []
32
+ while True:
33
+ ret, frame = cap.read()
34
+ if not ret:
35
+ break
36
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
37
+ frames.append(to_tensor(Image.fromarray(rgb)))
38
+ cap.release()
39
+ return torch.stack(frames)
40
+
41
+
42
+ def read_image_folder(folder_path):
43
+ image_files = sorted([
44
+ os.path.join(folder_path, f) for f in os.listdir(folder_path)
45
+ if f.lower().endswith(('.png', '.jpg', '.jpeg'))
46
+ ])
47
+ frames = [to_tensor(Image.open(p).convert("RGB")) for p in image_files]
48
+ return torch.stack(frames)
49
+
50
+
51
+ def load_sequence(path):
52
+ if os.path.isdir(path):
53
+ return read_image_folder(path)
54
+ elif os.path.isfile(path):
55
+ if is_video_file(path):
56
+ return read_video_frames(path)
57
+ elif path.lower().endswith(('.png', '.jpg', '.jpeg')):
58
+ # Treat image as a single-frame video
59
+ img = to_tensor(Image.open(path).convert("RGB"))
60
+ return img.unsqueeze(0) # [1, C, H, W]
61
+ raise ValueError(f"Unsupported input: {path}")
62
+
63
+
64
+ def crop_img_center(img, target_h, target_w):
65
+ _, h, w = img.shape
66
+ top = max((h - target_h) // 2, 0)
67
+ left = max((w - target_w) // 2, 0)
68
+ return img[:, top:top+target_h, left:left+target_w]
69
+
70
+ def crop_img_top_left(img, target_h, target_w):
71
+ # Crop image from top-left corner to (target_h, target_w)
72
+ return img[:, :target_h, :target_w]
73
+
74
+ def match_resolution(gt_frames, pred_frames, is_center=False, name=None):
75
+ t = min(gt_frames.shape[0], pred_frames.shape[0])
76
+ gt_frames = gt_frames[:t]
77
+ pred_frames = pred_frames[:t]
78
+ _, _, h_g, w_g = gt_frames.shape
79
+ _, _, h_p, w_p = pred_frames.shape
80
+
81
+ target_h = min(h_g, h_p)
82
+ target_w = min(w_g, w_p)
83
+
84
+ if (h_g != h_p or w_g != w_p) and name:
85
+ if is_center:
86
+ print(f"[{name}] Resolution mismatch detected: GT is ({h_g}, {w_g}), Pred is ({h_p}, {w_p}). Both GT and Pred were center cropped to ({target_h}, {target_w}).")
87
+ else:
88
+ print(f"[{name}] Resolution mismatch detected: GT is ({h_g}, {w_g}), Pred is ({h_p}, {w_p}). Both GT and Pred were top-left cropped to ({target_h}, {target_w}).")
89
+
90
+ if is_center:
91
+ gt_frames = torch.stack([crop_img_center(f, target_h, target_w) for f in gt_frames])
92
+ pred_frames = torch.stack([crop_img_center(f, target_h, target_w) for f in pred_frames])
93
+ else:
94
+ gt_frames = torch.stack([crop_img_top_left(f, target_h, target_w) for f in gt_frames])
95
+ pred_frames = torch.stack([crop_img_top_left(f, target_h, target_w) for f in pred_frames])
96
+
97
+ return gt_frames, pred_frames
98
+
99
+
100
+ def init_models(metrics, device):
101
+ models = {}
102
+ for name in metrics:
103
+ try:
104
+ models[name] = pyiqa.create_metric(name).to(device).eval()
105
+ except Exception as e:
106
+ print(f"Failed to initialize metric '{name}': {e}")
107
+ return models
108
+
109
+ def compute_metrics(pred_frames, gt_frames, models, device, batch_mode, crop, test_y_channel):
110
+ if batch_mode:
111
+ pred_batch = pred_frames.to(device) # [F, C, H, W]
112
+ gt_batch = gt_frames.to(device) # [F, C, H, W]
113
+
114
+ results = {}
115
+ for name, model in models.items():
116
+ if name in fr_metrics:
117
+ pred_eval = pred_batch
118
+ gt_eval = gt_batch
119
+ if crop > 0:
120
+ pred_eval = crop_border(pred_eval, crop)
121
+ gt_eval = crop_border(gt_eval, crop)
122
+ if test_y_channel:
123
+ pred_eval = rgb_to_y(pred_eval)
124
+ gt_eval = rgb_to_y(gt_eval)
125
+ values = model(pred_eval, gt_eval) # [F]
126
+ else:
127
+ values = model(pred_batch) # no-reference
128
+ results[name] = round(values.mean().item(), 4)
129
+ return results
130
+
131
+ else:
132
+ results = {name: [] for name in models}
133
+ for pred, gt in zip(pred_frames, gt_frames):
134
+ pred = pred.unsqueeze(0).to(device)
135
+ gt = gt.unsqueeze(0).to(device)
136
+
137
+ for name, model in models.items():
138
+ if name in fr_metrics:
139
+ pred_eval = pred
140
+ gt_eval = gt
141
+ if crop > 0:
142
+ pred_eval = crop_border(pred_eval, crop)
143
+ gt_eval = crop_border(gt_eval, crop)
144
+ if test_y_channel:
145
+ pred_eval = rgb_to_y(pred_eval)
146
+ gt_eval = rgb_to_y(gt_eval)
147
+ value = model(pred_eval, gt_eval).item()
148
+ else:
149
+ value = model(pred).item()
150
+ results[name].append(value)
151
+
152
+ return {k: round(np.mean(v), 4) for k, v in results.items()}
153
+
154
+
155
+ def process(gt_root, pred_root, out_path, metrics, batch_mode, crop, test_y_channel, is_center):
156
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
157
+ print(f"Using device: {device}")
158
+ models = init_models(metrics, device)
159
+
160
+ has_gt = bool(gt_root and os.path.exists(gt_root))
161
+
162
+ if has_gt:
163
+ gt_files = {os.path.splitext(f)[0]: os.path.join(gt_root, f) for f in os.listdir(gt_root)}
164
+ pred_files = {os.path.splitext(f)[0]: os.path.join(pred_root, f) for f in os.listdir(pred_root)}
165
+
166
+ pred_names = sorted(pred_files.keys())
167
+ results = {}
168
+ aggregate = {metric: [] for metric in metrics}
169
+
170
+ for name in tqdm(pred_names, desc="Evaluating"):
171
+ # # valida
172
+ # name_hr = name.replace('_CAT_A_x4', '').replace('img_', 'img')
173
+ name_hr = name
174
+ if has_gt and name_hr not in gt_files:
175
+ print(f"Skipping {name_hr}: no matching GT file.")
176
+ continue
177
+
178
+ pred_path = pred_files[name]
179
+ gt_path = gt_files[name_hr] if has_gt else None
180
+
181
+ try:
182
+ pred_frames = load_sequence(pred_path)
183
+
184
+ if has_gt:
185
+ gt_frames = load_sequence(gt_path)
186
+ gt_frames, pred_frames = match_resolution(gt_frames, pred_frames, is_center=is_center, name=name)
187
+ scores = compute_metrics(pred_frames, gt_frames, models, device, batch_mode, crop, test_y_channel)
188
+ else:
189
+ nr_models = {k: v for k, v in models.items() if k not in fr_metrics}
190
+ if not nr_models:
191
+ print(f"Skipping {name}: GT is not provided and no NR-IQA metrics found.")
192
+ continue
193
+ dummy_gt = pred_frames
194
+ scores = compute_metrics(pred_frames, dummy_gt, nr_models, device, batch_mode, crop, test_y_channel)
195
+
196
+ results[name] = scores
197
+ for k in scores:
198
+ aggregate[k].append(scores[k])
199
+ except Exception as e:
200
+ print(f"Error processing {name}: {e}")
201
+
202
+ print("\nPer-sample Results:")
203
+ for name in sorted(results):
204
+ print(f"{name}: " + ", ".join(f"{k}={v:.4f}" for k, v in results[name].items()))
205
+
206
+ print("\nOverall Average Results:")
207
+ count = len(results)
208
+ if count > 0:
209
+ overall_avg = {k: round(np.mean(v), 4) for k, v in aggregate.items()}
210
+ for k, v in overall_avg.items():
211
+ print(f"{k.upper()}: {v:.4f}")
212
+ else:
213
+ overall_avg = {}
214
+ print("No valid samples were processed.")
215
+
216
+ print(f"\nProcessed {count} samples.")
217
+
218
+ output = {
219
+ "per_sample": results,
220
+ "average": overall_avg,
221
+ "count": count
222
+ }
223
+
224
+ os.makedirs(out_path, exist_ok=True)
225
+ out_name = 'metrics_'
226
+ for metric in metrics:
227
+ out_name += f"{metric}_"
228
+ out_name = out_name.rstrip('_') + '.json'
229
+ out_path = os.path.join(out_path, out_name)
230
+
231
+ with open(out_path, 'w') as f:
232
+ json.dump(output, f, indent=2)
233
+
234
+ print(f"Results saved to: {out_path}")
235
+
236
+ if __name__ == "__main__":
237
+ import argparse
238
+ parser = argparse.ArgumentParser()
239
+ parser.add_argument('--gt', type=str, default='', help='Path to GT folder (optional for NR-IQA)')
240
+ parser.add_argument('--pred', type=str, required=True, help='Path to predicted results folder')
241
+ parser.add_argument('--out', type=str, default='', help='Path to save JSON output (as directory)')
242
+ parser.add_argument('--metrics', type=str, default='psnr,ssim,clipiqa',
243
+ help='Comma-separated list of metrics: psnr,ssim,clipiqa,lpips,...')
244
+ parser.add_argument('--batch_mode', action='store_true', help='Use batch mode for metrics computation')
245
+ parser.add_argument('--crop', type=int, default=0, help='Crop border size for PSNR/SSIM')
246
+ parser.add_argument('--test_y_channel', action='store_true', help='Use Y channel for PSNR/SSIM')
247
+ parser.add_argument('--is_center', action='store_true', help='Use center crop for PSNR/SSIM')
248
+
249
+ args = parser.parse_args()
250
+
251
+ if args.out == '':
252
+ out = args.pred
253
+ else:
254
+ out = args.out
255
+ metric_list = [m.strip().lower() for m in args.metrics.split(',')]
256
+ process(args.gt, args.pred, out, metric_list, args.batch_mode, args.crop, args.test_y_channel, args.is_center)
inference.sh ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # UDM10
4
+ python inference_script.py \
5
+ --input_dir datasets/test/UDM10/LQ-Video \
6
+ --model_path pretrained_models/DOVE \
7
+ --output_path results/DOVE/UDM10 \
8
+ --is_vae_st \
9
+
10
+ python eval_metrics.py \
11
+ --gt datasets/test/UDM10/GT \
12
+ --pred results/DOVE/UDM10 \
13
+ --metrics psnr,ssim,lpips,dists,clipiqa
14
+
15
+ # SPMCS
16
+ python inference_script.py \
17
+ --input_dir datasets/test/SPMCS/LQ-Video \
18
+ --model_path pretrained_models/DOVE \
19
+ --output_path results/DOVE/SPMCS \
20
+ --is_vae_st \
21
+
22
+ python eval_metrics.py \
23
+ --gt datasets/test/UDM10/GT \
24
+ --pred results/DOVE/SPMCS \
25
+ --metrics psnr,ssim,lpips,dists,clipiqa
26
+
27
+ # YouHQ40
28
+ python inference_script.py \
29
+ --input_dir datasets/test/YouHQ40/LQ-Video \
30
+ --model_path pretrained_models/DOVE \
31
+ --output_path results/DOVE/YouHQ40 \
32
+ --is_vae_st \
33
+
34
+ python eval_metrics.py \
35
+ --gt datasets/test/UDM10/GT \
36
+ --pred results/DOVE/YouHQ40 \
37
+ --metrics psnr,ssim,lpips,dists,clipiqa
38
+
39
+ # RealVSR
40
+ python inference_script.py \
41
+ --input_dir datasets/test/RealVSR/LQ-Video \
42
+ --model_path pretrained_models/DOVE \
43
+ --output_path results/DOVE/RealVSR \
44
+ --is_vae_st \
45
+ --upscale 1 \
46
+
47
+ python eval_metrics.py \
48
+ --gt datasets/test/UDM10/GT \
49
+ --pred results/DOVE/RealVSR \
50
+ --metrics psnr,ssim,lpips,dists,clipiqa
51
+
52
+ # MVSR4x
53
+ python inference_script.py \
54
+ --input_dir datasets/test/MVSR4x/LQ-Video \
55
+ --model_path pretrained_models/DOVE \
56
+ --output_path results/DOVE/MVSR4x \
57
+ --is_vae_st \
58
+ --upscale 1 \
59
+
60
+ python eval_metrics.py \
61
+ --gt datasets/test/UDM10/GT \
62
+ --pred results/DOVE/MVSR4x \
63
+ --metrics psnr,ssim,lpips,dists,clipiqa
64
+
65
+ # VideoLQ
66
+ python inference_script.py \
67
+ --input_dir datasets/test/VideoLQ/LQ-Video \
68
+ --model_path pretrained_models/DOVE \
69
+ --output_path results/DOVE/VideoLQ \
70
+ --is_vae_st \
71
+
72
+ python eval_metrics.py \
73
+ --gt datasets/test/UDM10/GT \
74
+ --pred results/DOVE/VideoLQ \
75
+ --metrics clipiqa
inference_script.py ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import argparse
3
+ import logging
4
+
5
+ import torch
6
+ from torchvision import transforms
7
+ from torchvision.io import write_video
8
+ from tqdm import tqdm
9
+
10
+ from diffusers import (
11
+ CogVideoXDPMScheduler,
12
+ CogVideoXPipeline,
13
+ )
14
+
15
+ from transformers import set_seed
16
+ from typing import Dict, Tuple
17
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
18
+
19
+ import json
20
+ import os
21
+ import cv2
22
+ from PIL import Image
23
+
24
+ from pathlib import Path
25
+ import pyiqa
26
+ import imageio.v3 as iio
27
+ import glob
28
+
29
+ # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
30
+ # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
31
+ import decord # isort:skip
32
+
33
+ decord.bridge.set_bridge("torch")
34
+
35
+ logging.basicConfig(level=logging.INFO)
36
+
37
+ # 0 ~ 1
38
+ to_tensor = transforms.ToTensor()
39
+ video_exts = ['.mp4', '.avi', '.mov', '.mkv']
40
+ fr_metrics = ['psnr', 'ssim', 'lpips', 'dists']
41
+
42
+
43
+ def no_grad(func):
44
+ def wrapper(*args, **kwargs):
45
+ with torch.no_grad():
46
+ return func(*args, **kwargs)
47
+ return wrapper
48
+
49
+
50
+ def is_video_file(filename):
51
+ return any(filename.lower().endswith(ext) for ext in video_exts)
52
+
53
+
54
+ def read_video_frames(video_path):
55
+ cap = cv2.VideoCapture(video_path)
56
+ frames = []
57
+ while True:
58
+ ret, frame = cap.read()
59
+ if not ret:
60
+ break
61
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
62
+ frames.append(to_tensor(Image.fromarray(rgb)))
63
+ cap.release()
64
+ return torch.stack(frames)
65
+
66
+
67
+ def read_image_folder(folder_path):
68
+ image_files = sorted([
69
+ os.path.join(folder_path, f) for f in os.listdir(folder_path)
70
+ if f.lower().endswith(('.png', '.jpg', '.jpeg'))
71
+ ])
72
+ frames = [to_tensor(Image.open(p).convert("RGB")) for p in image_files]
73
+ return torch.stack(frames)
74
+
75
+
76
+ def load_sequence(path):
77
+ # return a tensor of shape [F, C, H, W] // 0, 1
78
+ if os.path.isdir(path):
79
+ return read_image_folder(path)
80
+ elif os.path.isfile(path):
81
+ if is_video_file(path):
82
+ return read_video_frames(path)
83
+ elif path.lower().endswith(('.png', '.jpg', '.jpeg')):
84
+ # Treat image as a single-frame video
85
+ img = to_tensor(Image.open(path).convert("RGB"))
86
+ return img.unsqueeze(0) # [1, C, H, W]
87
+ raise ValueError(f"Unsupported input: {path}")
88
+
89
+ @no_grad
90
+ def compute_metrics(pred_frames, gt_frames, metrics_model, metric_accumulator, file_name):
91
+
92
+ print(f"\n\n[{file_name}] Metrics:", end=" ")
93
+ for name, model in metrics_model.items():
94
+ scores = []
95
+ for i in range(pred_frames.shape[0]):
96
+ pred = pred_frames[i].unsqueeze(0)
97
+ if gt_frames != None:
98
+ gt = gt_frames[i].unsqueeze(0)
99
+ if name in fr_metrics:
100
+ score = model(pred, gt).item()
101
+ else:
102
+ score = model(pred).item()
103
+ scores.append(score)
104
+ val = sum(scores) / len(scores)
105
+ metric_accumulator[name].append(val)
106
+ print(f"{name.upper()}={val:.4f}", end=" ")
107
+ print()
108
+
109
+
110
+ def save_frames_as_png(video, output_dir, fps=8):
111
+ """
112
+ Save video frames as PNG sequence.
113
+
114
+ Args:
115
+ video (torch.Tensor): shape [B, C, F, H, W], float in [0, 1]
116
+ output_dir (str): directory to save PNG files
117
+ fps (int): kept for API compatibility
118
+ """
119
+ video = video[0] # Remove batch dimension
120
+ video = video.permute(1, 2, 3, 0) # [F, H, W, C]
121
+
122
+ os.makedirs(output_dir, exist_ok=True)
123
+ frames = (video * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
124
+
125
+ for i, frame in enumerate(frames):
126
+ filename = os.path.join(output_dir, f"{i:03d}.png")
127
+ Image.fromarray(frame).save(filename)
128
+
129
+
130
+ def save_video_with_imageio_lossless(video, output_path, fps=8):
131
+ """
132
+ Save a video tensor to .mkv using imageio.v3.imwrite with ffmpeg backend.
133
+
134
+ Args:
135
+ video (torch.Tensor): shape [B, C, F, H, W], float in [0, 1]
136
+ output_path (str): where to save the .mkv file
137
+ fps (int): frames per second
138
+ """
139
+ video = video[0]
140
+ video = video.permute(1, 2, 3, 0)
141
+
142
+ frames = (video * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
143
+
144
+ iio.imwrite(
145
+ output_path,
146
+ frames,
147
+ fps=fps,
148
+ codec='libx264rgb',
149
+ pixelformat='rgb24',
150
+ macro_block_size=None,
151
+ ffmpeg_params=['-crf', '0'],
152
+ )
153
+
154
+
155
+ def save_video_with_imageio(video, output_path, fps=8, format='yuv444p'):
156
+ """
157
+ Save a video tensor to .mp4 using imageio.v3.imwrite with ffmpeg backend.
158
+
159
+ Args:
160
+ video (torch.Tensor): shape [B, C, F, H, W], float in [0, 1]
161
+ output_path (str): where to save the .mp4 file
162
+ fps (int): frames per second
163
+ """
164
+ video = video[0]
165
+ video = video.permute(1, 2, 3, 0)
166
+
167
+ frames = (video * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
168
+
169
+ if format == 'yuv444p':
170
+ iio.imwrite(
171
+ output_path,
172
+ frames,
173
+ fps=fps,
174
+ codec='libx264',
175
+ pixelformat='yuv444p',
176
+ macro_block_size=None,
177
+ ffmpeg_params=['-crf', '0'],
178
+ )
179
+ else:
180
+ iio.imwrite(
181
+ output_path,
182
+ frames,
183
+ fps=fps,
184
+ codec='libx264',
185
+ pixelformat='yuv420p',
186
+ macro_block_size=None,
187
+ ffmpeg_params=['-crf', '10'],
188
+ )
189
+
190
+
191
+ def preprocess_video_match(
192
+ video_path: Path | str,
193
+ is_match: bool = False,
194
+ ) -> torch.Tensor:
195
+ """
196
+ Loads a single video.
197
+
198
+ Args:
199
+ video_path: Path to the video file.
200
+ Returns:
201
+ A torch.Tensor with shape [F, C, H, W] where:
202
+ F = number of frames
203
+ C = number of channels (3 for RGB)
204
+ H = height
205
+ W = width
206
+ """
207
+ if isinstance(video_path, str):
208
+ video_path = Path(video_path)
209
+ video_reader = decord.VideoReader(uri=video_path.as_posix())
210
+ video_num_frames = len(video_reader)
211
+ frames = video_reader.get_batch(list(range(video_num_frames)))
212
+ F, H, W, C = frames.shape
213
+ original_shape = (F, H, W, C)
214
+
215
+ pad_f = 0
216
+ pad_h = 0
217
+ pad_w = 0
218
+
219
+ if is_match:
220
+ remainder = (F - 1) % 8
221
+ if remainder != 0:
222
+ last_frame = frames[-1:]
223
+ pad_f = 8 - remainder
224
+ repeated_frames = last_frame.repeat(pad_f, 1, 1, 1)
225
+ frames = torch.cat([frames, repeated_frames], dim=0)
226
+
227
+ pad_h = (16 - H % 16) % 16
228
+ pad_w = (16 - W % 16) % 16
229
+ if pad_h > 0 or pad_w > 0:
230
+ # pad = (w_left, w_right, h_top, h_bottom)
231
+ frames = torch.nn.functional.pad(frames, pad=(0, 0, 0, pad_w, 0, pad_h)) # pad right and bottom
232
+
233
+ # to F, C, H, W
234
+ return frames.float().permute(0, 3, 1, 2).contiguous(), pad_f, pad_h, pad_w, original_shape
235
+
236
+
237
+ def remove_padding_and_extra_frames(video, pad_F, pad_H, pad_W):
238
+ if pad_F > 0:
239
+ video = video[:, :, :-pad_F, :, :]
240
+ if pad_H > 0:
241
+ video = video[:, :, :, :-pad_H, :]
242
+ if pad_W > 0:
243
+ video = video[:, :, :, :, :-pad_W]
244
+
245
+ return video
246
+
247
+
248
+ def make_temporal_chunks(F, chunk_len, overlap_t=8):
249
+ """
250
+ Args:
251
+ F: total number of frames
252
+ chunk_len: int, chunk length in time (excluding overlap)
253
+ overlap: int, number of overlapping frames between chunks
254
+ Returns:
255
+ time_chunks: List of (start_t, end_t) tuples
256
+ """
257
+ if chunk_len == 0:
258
+ return [(0, F)]
259
+
260
+ effective_stride = chunk_len - overlap_t
261
+ if effective_stride <= 0:
262
+ raise ValueError("chunk_len must be greater than overlap")
263
+
264
+ chunk_starts = list(range(0, F - overlap_t, effective_stride))
265
+ if chunk_starts[-1] + chunk_len < F:
266
+ chunk_starts.append(F - chunk_len)
267
+
268
+ time_chunks = []
269
+ for i, t_start in enumerate(chunk_starts):
270
+ t_end = min(t_start + chunk_len, F)
271
+ time_chunks.append((t_start, t_end))
272
+
273
+ if len(time_chunks) >= 2 and time_chunks[-1][1] - time_chunks[-1][0] < chunk_len:
274
+ last = time_chunks.pop()
275
+ prev_start, _ = time_chunks[-1]
276
+ time_chunks[-1] = (prev_start, last[1])
277
+
278
+ return time_chunks
279
+
280
+
281
+ def make_spatial_tiles(H, W, tile_size_hw, overlap_hw=(32, 32)):
282
+ """
283
+ Args:
284
+ H, W: height and width of the frame
285
+ tile_size_hw: Tuple (tile_height, tile_width)
286
+ overlap_hw: Tuple (overlap_height, overlap_width)
287
+ Returns:
288
+ spatial_tiles: List of (start_h, end_h, start_w, end_w) tuples
289
+ """
290
+ tile_height, tile_width = tile_size_hw
291
+ overlap_h, overlap_w = overlap_hw
292
+
293
+ if tile_height == 0 or tile_width == 0:
294
+ return [(0, H, 0, W)]
295
+
296
+ tile_stride_h = tile_height - overlap_h
297
+ tile_stride_w = tile_width - overlap_w
298
+
299
+ if tile_stride_h <= 0 or tile_stride_w <= 0:
300
+ raise ValueError("Tile size must be greater than overlap")
301
+
302
+ h_tiles = list(range(0, H - overlap_h, tile_stride_h))
303
+ if not h_tiles or h_tiles[-1] + tile_height < H:
304
+ h_tiles.append(H - tile_height)
305
+
306
+ # Merge last row if needed
307
+ if len(h_tiles) >= 2 and h_tiles[-1] + tile_height > H:
308
+ h_tiles.pop()
309
+
310
+ w_tiles = list(range(0, W - overlap_w, tile_stride_w))
311
+ if not w_tiles or w_tiles[-1] + tile_width < W:
312
+ w_tiles.append(W - tile_width)
313
+
314
+ # Merge last column if needed
315
+ if len(w_tiles) >= 2 and w_tiles[-1] + tile_width > W:
316
+ w_tiles.pop()
317
+
318
+ spatial_tiles = []
319
+ for h_start in h_tiles:
320
+ h_end = min(h_start + tile_height, H)
321
+ if h_end + tile_stride_h > H:
322
+ h_end = H
323
+ for w_start in w_tiles:
324
+ w_end = min(w_start + tile_width, W)
325
+ if w_end + tile_stride_w > W:
326
+ w_end = W
327
+ spatial_tiles.append((h_start, h_end, w_start, w_end))
328
+ return spatial_tiles
329
+
330
+
331
+ def get_valid_tile_region(t_start, t_end, h_start, h_end, w_start, w_end,
332
+ video_shape, overlap_t, overlap_h, overlap_w):
333
+ _, _, F, H, W = video_shape
334
+
335
+ t_len = t_end - t_start
336
+ h_len = h_end - h_start
337
+ w_len = w_end - w_start
338
+
339
+ valid_t_start = 0 if t_start == 0 else overlap_t // 2
340
+ valid_t_end = t_len if t_end == F else t_len - overlap_t // 2
341
+ valid_h_start = 0 if h_start == 0 else overlap_h // 2
342
+ valid_h_end = h_len if h_end == H else h_len - overlap_h // 2
343
+ valid_w_start = 0 if w_start == 0 else overlap_w // 2
344
+ valid_w_end = w_len if w_end == W else w_len - overlap_w // 2
345
+
346
+ out_t_start = t_start + valid_t_start
347
+ out_t_end = t_start + valid_t_end
348
+ out_h_start = h_start + valid_h_start
349
+ out_h_end = h_start + valid_h_end
350
+ out_w_start = w_start + valid_w_start
351
+ out_w_end = w_start + valid_w_end
352
+
353
+ return {
354
+ "valid_t_start": valid_t_start, "valid_t_end": valid_t_end,
355
+ "valid_h_start": valid_h_start, "valid_h_end": valid_h_end,
356
+ "valid_w_start": valid_w_start, "valid_w_end": valid_w_end,
357
+ "out_t_start": out_t_start, "out_t_end": out_t_end,
358
+ "out_h_start": out_h_start, "out_h_end": out_h_end,
359
+ "out_w_start": out_w_start, "out_w_end": out_w_end,
360
+ }
361
+
362
+
363
+ def prepare_rotary_positional_embeddings(
364
+ height: int,
365
+ width: int,
366
+ num_frames: int,
367
+ transformer_config: Dict,
368
+ vae_scale_factor_spatial: int,
369
+ device: torch.device,
370
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
371
+
372
+ grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
373
+ grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
374
+
375
+ if transformer_config.patch_size_t is None:
376
+ base_num_frames = num_frames
377
+ else:
378
+ base_num_frames = (
379
+ num_frames + transformer_config.patch_size_t - 1
380
+ ) // transformer_config.patch_size_t
381
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
382
+ embed_dim=transformer_config.attention_head_dim,
383
+ crops_coords=None,
384
+ grid_size=(grid_height, grid_width),
385
+ temporal_size=base_num_frames,
386
+ grid_type="slice",
387
+ max_size=(grid_height, grid_width),
388
+ device=device,
389
+ )
390
+
391
+ return freqs_cos, freqs_sin
392
+
393
+ @no_grad
394
+ def process_video(
395
+ pipe: CogVideoXPipeline,
396
+ video: torch.Tensor,
397
+ prompt: str = '',
398
+ noise_step: int = 0,
399
+ sr_noise_step: int = 399,
400
+ ):
401
+ # SR the video frames based on the prompt.
402
+ # `num_frames` is the Number of frames to generate.
403
+
404
+ # Decode video
405
+ video = video.to(pipe.vae.device, dtype=pipe.vae.dtype)
406
+ latent_dist = pipe.vae.encode(video).latent_dist
407
+ latent = latent_dist.sample() * pipe.vae.config.scaling_factor
408
+
409
+ patch_size_t = pipe.transformer.config.patch_size_t
410
+ if patch_size_t is not None:
411
+ ncopy = latent.shape[2] % patch_size_t
412
+ # Copy the first frame ncopy times to match patch_size_t
413
+ first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
414
+ latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
415
+
416
+ assert latent.shape[2] % patch_size_t == 0
417
+
418
+ batch_size, num_channels, num_frames, height, width = latent.shape
419
+
420
+ # Get prompt embeddings
421
+ prompt_token_ids = pipe.tokenizer(
422
+ prompt,
423
+ padding="max_length",
424
+ max_length=pipe.transformer.config.max_text_seq_length,
425
+ truncation=True,
426
+ add_special_tokens=True,
427
+ return_tensors="pt",
428
+ )
429
+ prompt_token_ids = prompt_token_ids.input_ids
430
+ prompt_embedding = pipe.text_encoder(
431
+ prompt_token_ids.to(latent.device)
432
+ )[0]
433
+ _, seq_len, _ = prompt_embedding.shape
434
+ prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype)
435
+
436
+ latent = latent.permute(0, 2, 1, 3, 4)
437
+
438
+ # Add noise to latent (Select)
439
+ if noise_step != 0:
440
+ noise = torch.randn_like(latent)
441
+ add_timesteps = torch.full(
442
+ (batch_size,),
443
+ fill_value=noise_step,
444
+ dtype=torch.long,
445
+ device=latent.device,
446
+ )
447
+ latent = pipe.scheduler.add_noise(latent, noise, add_timesteps)
448
+
449
+ timesteps = torch.full(
450
+ (batch_size,),
451
+ fill_value=sr_noise_step,
452
+ dtype=torch.long,
453
+ device=latent.device,
454
+ )
455
+
456
+ # Prepare rotary embeds
457
+ vae_scale_factor_spatial = 2 ** (len(pipe.vae.config.block_out_channels) - 1)
458
+ transformer_config = pipe.transformer.config
459
+ rotary_emb = (
460
+ prepare_rotary_positional_embeddings(
461
+ height=height * vae_scale_factor_spatial,
462
+ width=width * vae_scale_factor_spatial,
463
+ num_frames=num_frames,
464
+ transformer_config=transformer_config,
465
+ vae_scale_factor_spatial=vae_scale_factor_spatial,
466
+ device=latent.device,
467
+ )
468
+ if pipe.transformer.config.use_rotary_positional_embeddings
469
+ else None
470
+ )
471
+
472
+ # Predict noise
473
+ predicted_noise = pipe.transformer(
474
+ hidden_states=latent,
475
+ encoder_hidden_states=prompt_embedding,
476
+ timestep=timesteps,
477
+ image_rotary_emb=rotary_emb,
478
+ return_dict=False,
479
+ )[0]
480
+
481
+ latent_generate = pipe.scheduler.get_velocity(
482
+ predicted_noise, latent, timesteps
483
+ )
484
+
485
+ # generate video
486
+ if patch_size_t is not None and ncopy > 0:
487
+ latent_generate = latent_generate[:, ncopy:, :, :, :]
488
+
489
+ # [B, C, F, H, W]
490
+ video_generate = pipe.decode_latents(latent_generate)
491
+ video_generate = (video_generate * 0.5 + 0.5).clamp(0.0, 1.0)
492
+
493
+ return video_generate
494
+
495
+
496
+ if __name__ == "__main__":
497
+ parser = argparse.ArgumentParser(description="VSR using DOVE")
498
+
499
+ parser.add_argument("--input_dir", type=str)
500
+
501
+ parser.add_argument("--input_json", type=str, default=None)
502
+
503
+ parser.add_argument("--gt_dir", type=str, default=None)
504
+
505
+ parser.add_argument("--eval_metrics", type=str, default='') # 'psnr,ssim,lpips,dists,clipiqa,musiq,maniqa,niqe'
506
+
507
+ parser.add_argument("--model_path", type=str)
508
+
509
+ parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
510
+
511
+ parser.add_argument("--output_path", type=str, default="./results", help="The path save generated video")
512
+
513
+ parser.add_argument("--fps", type=int, default=16, help="The frames per second for the generated video")
514
+
515
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation")
516
+
517
+ parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
518
+
519
+ parser.add_argument("--upscale_mode", type=str, default="bilinear")
520
+
521
+ parser.add_argument("--upscale", type=int, default=4)
522
+
523
+ parser.add_argument("--noise_step", type=int, default=0)
524
+
525
+ parser.add_argument("--sr_noise_step", type=int, default=399)
526
+
527
+ parser.add_argument("--is_cpu_offload", action="store_true", help="Enable CPU offload for the model")
528
+
529
+ parser.add_argument("--is_vae_st", action="store_true", help="Enable VAE slicing and tiling")
530
+
531
+ parser.add_argument("--png_save", action="store_true", help="Save output as PNG sequence")
532
+
533
+ parser.add_argument("--save_format", type=str, default="yuv444p", help="Save output as PNG sequence")
534
+
535
+ # Crop and Tiling Parameters
536
+ parser.add_argument("--tile_size_hw", type=int, nargs=2, default=(0, 0), help="Tile size for spatial tiling (height, width)")
537
+
538
+ parser.add_argument("--overlap_hw", type=int, nargs=2, default=(32, 32))
539
+
540
+ parser.add_argument("--chunk_len", type=int, default=0, help="Chunk length for temporal chunking")
541
+
542
+ parser.add_argument("--overlap_t", type=int, default=8)
543
+
544
+ args = parser.parse_args()
545
+
546
+ if args.dtype == "float16":
547
+ dtype = torch.float16
548
+ elif args.dtype == "bfloat16":
549
+ dtype = torch.bfloat16
550
+ elif args.dtype == "float32":
551
+ dtype = torch.float32
552
+ else:
553
+ raise ValueError("Invalid dtype. Choose from 'float16', 'bfloat16', or 'float32'.")
554
+
555
+ if args.chunk_len > 0:
556
+ print(f"Chunking video into {args.chunk_len} frames with {args.overlap_t} overlap")
557
+ overlap_t = args.overlap_t
558
+ else:
559
+ overlap_t = 0
560
+ if args.tile_size_hw != (0, 0):
561
+ print(f"Tiling video into {args.tile_size_hw} frames with {args.overlap_hw} overlap")
562
+ overlap_hw = args.overlap_hw
563
+ else:
564
+ overlap_hw = (0, 0)
565
+
566
+ # Set seed
567
+ set_seed(args.seed)
568
+
569
+ if args.input_json is not None:
570
+ with open(args.input_json, 'r') as f:
571
+ video_prompt_dict = json.load(f)
572
+ else:
573
+ video_prompt_dict = {}
574
+
575
+ # Get all video files from input directory
576
+ video_files = []
577
+ for ext in video_exts:
578
+ video_files.extend(glob.glob(os.path.join(args.input_dir, f'*{ext}')))
579
+ video_files = sorted(video_files) # Sort files for consistent ordering
580
+
581
+ if not video_files:
582
+ raise ValueError(f"No video files found in {args.input_dir}")
583
+
584
+ os.makedirs(args.output_path, exist_ok=True)
585
+
586
+ # 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
587
+ # add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload()
588
+ # function to use Multi GPUs.
589
+
590
+ pipe = CogVideoXPipeline.from_pretrained(args.model_path, torch_dtype=dtype)
591
+
592
+ # If you're using with lora, add this code
593
+ if args.lora_path:
594
+ print(f"Loading LoRA weights from {args.lora_path}")
595
+ pipe.load_lora_weights(
596
+ args.lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1"
597
+ )
598
+ pipe.fuse_lora(components=["transformer"], lora_scale=1.0) # lora_scale = lora_alpha / rank
599
+
600
+ # 2. Set Scheduler.
601
+ # Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`.
602
+ # We recommend using `CogVideoXDDIMScheduler` for CogVideoX-2B.
603
+ # using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V.
604
+
605
+ # pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
606
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(
607
+ pipe.scheduler.config, timestep_spacing="trailing"
608
+ )
609
+
610
+ # 3. Enable CPU offload for the model.
611
+ # turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
612
+ # and enable to("cuda")
613
+
614
+ if args.is_cpu_offload:
615
+ # pipe.enable_model_cpu_offload()
616
+ pipe.enable_sequential_cpu_offload()
617
+ else:
618
+ pipe.to("cuda")
619
+
620
+ if args.is_vae_st:
621
+ pipe.vae.enable_slicing()
622
+ pipe.vae.enable_tiling()
623
+
624
+ # pipe.transformer.eval()
625
+ # torch.set_grad_enabled(False)
626
+
627
+ # 4. Set the metircs
628
+ if args.eval_metrics != '':
629
+ metrics_list = [m.strip().lower() for m in args.eval_metrics.split(',')]
630
+ metrics_models = {}
631
+ for name in metrics_list:
632
+ try:
633
+ metrics_models[name] = pyiqa.create_metric(name).to(pipe.device).eval()
634
+ except Exception as e:
635
+ print(f"Failed to initialize metric '{name}': {e}")
636
+ metric_accumulator = {name: [] for name in metrics_list}
637
+ else:
638
+ metrics_models = None
639
+ metric_accumulator = None
640
+
641
+ for video_path in tqdm(video_files, desc="Processing videos"):
642
+ video_name = os.path.basename(video_path)
643
+ prompt = video_prompt_dict.get(video_name, "")
644
+ if os.path.exists(video_path):
645
+ # Read video
646
+ # [F, C, H, W]
647
+ video, pad_f, pad_h, pad_w, original_shape = preprocess_video_match(video_path, is_match=True)
648
+ H_, W_ = video.shape[2], video.shape[3]
649
+ video = torch.nn.functional.interpolate(video, size=(H_*args.upscale, W_*args.upscale), mode=args.upscale_mode, align_corners=False)
650
+ __frame_transform = transforms.Compose(
651
+ [transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)] # -1, 1
652
+ )
653
+ video = torch.stack([__frame_transform(f) for f in video], dim=0)
654
+ video = video.unsqueeze(0)
655
+ # [B, C, F, H, W]
656
+ video = video.permute(0, 2, 1, 3, 4).contiguous()
657
+
658
+ _B, _C, _F, _H, _W = video.shape
659
+ time_chunks = make_temporal_chunks(_F, args.chunk_len, overlap_t)
660
+ spatial_tiles = make_spatial_tiles(_H, _W, args.tile_size_hw, overlap_hw)
661
+
662
+ output_video = torch.zeros_like(video)
663
+ write_count = torch.zeros_like(video, dtype=torch.int)
664
+
665
+ print(f"Process video: {video_name} | Prompt: {prompt} | Frame: {_F} (ori: {original_shape[0]}; pad: {pad_f}) | Target Resolution: {_H}, {_W} (ori: {original_shape[1]*args.upscale}, {original_shape[2]*args.upscale}; pad: {pad_h}, {pad_w}) | Chunk Num: {len(time_chunks)*len(spatial_tiles)}")
666
+
667
+ for t_start, t_end in time_chunks:
668
+ for h_start, h_end, w_start, w_end in spatial_tiles:
669
+ video_chunk = video[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
670
+ # print(f"video_chunk: {video_chunk.shape} | t: {t_start}:{t_end} | h: {h_start}:{h_end} | w: {w_start}:{w_end}")
671
+
672
+ # [B, C, F, H, W]
673
+ _video_generate = process_video(
674
+ pipe=pipe,
675
+ video=video_chunk,
676
+ prompt=prompt,
677
+ noise_step=args.noise_step,
678
+ sr_noise_step=args.sr_noise_step,
679
+ )
680
+
681
+ region = get_valid_tile_region(
682
+ t_start, t_end, h_start, h_end, w_start, w_end,
683
+ video_shape=video.shape,
684
+ overlap_t=overlap_t,
685
+ overlap_h=overlap_hw[0],
686
+ overlap_w=overlap_hw[1],
687
+ )
688
+ output_video[:, :, region["out_t_start"]:region["out_t_end"],
689
+ region["out_h_start"]:region["out_h_end"],
690
+ region["out_w_start"]:region["out_w_end"]] = \
691
+ _video_generate[:, :, region["valid_t_start"]:region["valid_t_end"],
692
+ region["valid_h_start"]:region["valid_h_end"],
693
+ region["valid_w_start"]:region["valid_w_end"]]
694
+ write_count[:, :, region["out_t_start"]:region["out_t_end"],
695
+ region["out_h_start"]:region["out_h_end"],
696
+ region["out_w_start"]:region["out_w_end"]] += 1
697
+
698
+ video_generate = output_video
699
+
700
+ if (write_count == 0).any():
701
+ print("Error: Lack of write in region !!!")
702
+ exit()
703
+ if (write_count > 1).any():
704
+ print("Error: Write count > 1 in region !!!")
705
+ exit()
706
+
707
+ video_generate = remove_padding_and_extra_frames(video_generate, pad_f, pad_h*4, pad_w*4)
708
+ file_name = os.path.basename(video_path)
709
+ output_path = os.path.join(args.output_path, file_name)
710
+
711
+ if metrics_models is not None:
712
+ # [1, C, F, H, W] -> [F, C, H, W]
713
+ pred_frames = video_generate[0]
714
+ pred_frames = pred_frames.permute(1, 0, 2, 3).contiguous()
715
+ if args.gt_dir is not None:
716
+ gt_frames = load_sequence(os.path.join(args.gt_dir, file_name))
717
+ else:
718
+ gt_frames = None
719
+ compute_metrics(pred_frames, gt_frames, metrics_models, metric_accumulator, file_name)
720
+
721
+ if args.png_save:
722
+ # Save as PNG sequence
723
+ output_dir = output_path.rsplit('.', 1)[0] # Remove extension
724
+ save_frames_as_png(video_generate, output_dir, fps=args.fps)
725
+ else:
726
+ output_path = output_path.replace('.mkv', '.mp4')
727
+ save_video_with_imageio(video_generate, output_path, fps=args.fps, format=args.save_format)
728
+ else:
729
+ print(f"Warning: {video_name} not found in {args.input_dir}")
730
+
731
+ if metrics_models is not None:
732
+ print("\n=== Overall Average Metrics ===")
733
+ count = len(next(iter(metric_accumulator.values())))
734
+ overall_avg = {metric: 0 for metric in metrics_list}
735
+ out_name = 'metrics_'
736
+ for metric in metrics_list:
737
+ out_name += f"{metric}_"
738
+ scores = metric_accumulator[metric]
739
+ if scores:
740
+ avg = sum(scores) / len(scores)
741
+ overall_avg[metric] = avg
742
+ print(f"{metric.upper()}: {avg:.4f}")
743
+
744
+ out_name = out_name.rstrip('_') + '.json'
745
+ out_path = os.path.join(args.output_path, out_name)
746
+ output = {
747
+ "per_sample": metric_accumulator,
748
+ "average": overall_avg,
749
+ "count": count
750
+ }
751
+ with open(out_path, 'w') as f:
752
+ json.dump(output, f, indent=2)
753
+
754
+ print("All videos processed.")
pretrained_models/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Place pretrained models here.
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=1.1.1
2
+ transformers>=4.46.2
3
+ numpy==1.26.0
4
+ torch>=2.5.0
5
+ torchvision>=0.20.0
6
+ sentencepiece>=0.2.0
7
+ SwissArmyTransformer>=0.4.12
8
+ gradio>=5.5.0
9
+ imageio>=2.35.1
10
+ imageio-ffmpeg>=0.5.1
11
+ openai>=1.54.0
12
+ moviepy>=2.0.0
13
+ scikit-video>=1.1.11
14
+ pydantic>=2.10.3
15
+ wandb
16
+ peft
17
+ opencv-python
18
+ decord
19
+ av
20
+ torchdiffeq