hangg-sai Aaryaman Vasishta commited on
Commit
a342aa8
·
0 Parent(s):

Initial commit

Browse files

Co-authored-by: Aaryaman Vasishta <[email protected]>

.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ assets/** filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .envrc
2
+ .venv/
3
+ .gradio/
4
+ work_dirs*
5
+ logs*
6
+ pull_changes.sh
7
+
8
+ # Byte-compiled files
9
+ __pycache__/
10
+ *.py[cod]
11
+
12
+ # Virtual environments
13
+ env/
14
+ venv/
15
+ ENV/
16
+ .VENV/
17
+
18
+ # Distribution files
19
+ build/
20
+ dist/
21
+ *.egg-info/
22
+
23
+ # Logs and temporary files
24
+ *.log
25
+ *.tmp
26
+ *.bak
27
+ *.swp
28
+
29
+ # IDE files
30
+ .idea/
31
+ .vscode/
32
+ *.sublime-workspace
33
+ *.sublime-project
34
+
35
+ # OS files
36
+ .DS_Store
37
+ Thumbs.db
38
+
39
+ # Testing and coverage
40
+ htmlcov/
41
+ .coverage
42
+ *.cover
43
+ *.py,cover
44
+ .cache/
45
+
46
+ # Jupyter Notebook checkpoints
47
+ .ipynb_checkpoints/
48
+
49
+ # Pre-commit hooks
50
+ .pre-commit-config.yaml~
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "third_party/dust3r"]
2
+ path = third_party/dust3r
3
+ url = https://github.com/jensenstability/dust3r
LICENSE ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Stability AI Non-Commercial License Agreement
2
+ Last Updated: February 20, 2025
3
+
4
+ I. INTRODUCTION
5
+
6
+ This Stability AI Non-Commercial License Agreement (the “Agreement”) applies to any individual person or entity
7
+ (“You”, “Your” or “Licensee”) that uses or distributes any portion or element of the Stability AI Materials or
8
+ Derivative Works thereof for any Research & Non-Commercial use. Capitalized terms not otherwise defined herein
9
+ are defined in Section IV below.
10
+
11
+ This Agreement is intended to allow research and non-commercial uses of the Model free of charge.
12
+
13
+ By clicking “I Accept” or by using or distributing or using any portion or element of the Stability Materials
14
+ or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement.
15
+
16
+ If You are acting on behalf of a company, organization, or other entity, then “You” includes you and that entity,
17
+ and You agree that You:
18
+ (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and
19
+ (ii) You agree to the terms of this Agreement on that entity’s behalf.
20
+
21
+ ---
22
+
23
+ II. RESEARCH & NON-COMMERCIAL USE LICENSE
24
+
25
+ Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable,
26
+ non-sublicensable, revocable, and royalty-free limited license under Stability AI’s intellectual property or other
27
+ rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create
28
+ Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose.
29
+
30
+ - **“Research Purpose”** means academic or scientific advancement, and in each case, is not primarily intended
31
+ for commercial advantage or monetary compensation to You or others.
32
+ - **“Non-Commercial Purpose”** means any purpose other than a Research Purpose that is not primarily intended
33
+ for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist)
34
+ or evaluation and testing.
35
+
36
+ ---
37
+
38
+ III. GENERAL TERMS
39
+
40
+ Your Research or Non-Commercial license under this Agreement is subject to the following terms.
41
+
42
+ ### a. Distribution & Attribution
43
+ If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product
44
+ or service that uses any portion of them, You shall:
45
+ 1. Provide a copy of this Agreement to that third party.
46
+ 2. Retain the following attribution notice within a **"Notice"** text file distributed as a part of such copies:
47
+
48
+ **"This Stability AI Model is licensed under the Stability AI Non-Commercial License,
49
+ Copyright © Stability AI Ltd. All Rights Reserved."**
50
+
51
+ 3. Prominently display **“Powered by Stability AI”** on a related website, user interface, blog post,
52
+ about page, or product documentation.
53
+ 4. If You create a Derivative Work, You may add your own attribution notice(s) to the **"Notice"** text file
54
+ included with that Derivative Work, provided that You clearly indicate which attributions apply to the
55
+ Stability AI Materials and state in the **"Notice"** text file that You changed the Stability AI Materials
56
+ and how it was modified.
57
+
58
+ ### b. Use Restrictions
59
+ Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability
60
+ AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control
61
+ Laws and equivalent regulations) and adhere to the Documentation and Stability AI’s AUP, which is hereby
62
+ incorporated by reference.
63
+
64
+ Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the
65
+ Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model
66
+ (excluding the Model or Derivative Works).
67
+
68
+ ### c. Intellectual Property
69
+
70
+ #### (i) Trademark License
71
+ No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials
72
+ or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of
73
+ its Affiliates, except as required under Section IV(a) herein.
74
+
75
+ #### (ii) Ownership of Derivative Works
76
+ As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI’s
77
+ ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
78
+
79
+ #### (iii) Ownership of Outputs
80
+ As between You and Stability AI, You own any outputs generated from the Model or Derivative Works to the extent
81
+ permitted by applicable law.
82
+
83
+ #### (iv) Disputes
84
+ If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works, or
86
+ associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual
87
+ property or other rights owned or licensable by You, then any licenses granted to You under this Agreement
88
+ shall terminate as of the date such litigation or claim is filed or instituted.
89
+
90
+ You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out
91
+ of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of
92
+ this Agreement.
93
+
94
+ #### (v) Feedback
95
+ From time to time, You may provide Stability AI with verbal and/or written suggestions, comments, or other
96
+ feedback related to Stability AI’s existing or prospective technology, products, or services (collectively,
97
+ “Feedback”).
98
+
99
+ You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant
100
+ Stability AI a **perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive,
101
+ worldwide right and license** to exploit the Feedback in any manner without restriction.
102
+
103
+ Your Feedback is provided **“AS IS”** and You make no warranties whatsoever about any Feedback.
104
+
105
+ ---
106
+
107
+ IV. DEFINITIONS
108
+
109
+ - **“Affiliate(s)”** means any entity that directly or indirectly controls, is controlled by, or is under common
110
+ control with the subject entity. For purposes of this definition, “control” means direct or indirect ownership
111
+ or control of more than 50% of the voting interests of the subject entity.
112
+ - **“AUP”** means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may
113
+ be updated from time to time.
114
+ - **"Derivative Work(s)"** means:
115
+ (a) Any derivative work of the Stability AI Materials as recognized by U.S. copyright laws.
116
+ (b) Any modifications to a Model, and any other model created which is based on or derived from the Model or
117
+ the Model’s output, including **fine-tune** and **low-rank adaptation** models derived from a Model or
118
+ a Model’s output, but does not include the output of any Model.
119
+ - **“Model”** means Stability AI’s Stable Virtual Camera model.
120
+ - **"Stability AI" or "we"** means Stability AI Ltd. and its Affiliates.
121
+ - **"Software"** means Stability AI’s proprietary software made available under this Agreement now or in the future.
122
+ - **“Stability AI Materials”** means, collectively, Stability’s proprietary Model, Software, and Documentation
123
+ (and any portion or combination thereof) made available under this Agreement.
124
+ - **“Trade Control Laws”** means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
README.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Stable Virtual Camera
3
+ emoji: ⚡
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 5.17.0
8
+ app_file: demo_gr.py
9
+ pinned: false
10
+ ---
11
+
12
+ - **Project Page**: [https://stable-virtual-camera.github.io/](https://stable-virtual-camera.github.io/)
13
+ - **Paper**: [https://stable-virtual-camera.github.io/assets/paper.pdf](http://https://stable-virtual-camera.github.io/assets/paper.pdf)
14
+ - **Blog**: [https://stability.ai/news/introducing-stable-virtual-camera-multi-view-video-generation-with-3d-camera-control](https://stability.ai/news/introducing-stable-virtual-camera-multi-view-video-generation-with-3d-camera-control)
15
+ - **Code**: [https://github.com/Stability-AI/stable-virtual-camera](https://github.com/Stability-AI/stable-virtual-camera)
16
+ - **Model Card**: [https://huggingface.co/stabilityai/stable-virtual-camera](https://huggingface.co/stabilityai/stable-virtual-camera)
17
+ - **Video**: [https://www.youtube.com/channel/UCLLlVDcS7nNenT_zzO3OPxQ](http://https://www.youtube.com/channel/UCLLlVDcS7nNenT_zzO3OPxQ)
assets/advance/backyard-7_0.jpg ADDED

Git LFS Details

  • SHA256: 102c0b5ce669e41b9a1ac7e7d8096b5bdd848faaed6e0227e927dde6e9f8ffe5
  • Pointer size: 130 Bytes
  • Size of remote file: 67.7 kB
assets/advance/backyard-7_1.jpg ADDED

Git LFS Details

  • SHA256: 256919ceb20bfcc9bd4c03baec1bf06a96c5cc28ce34549999b92df2cc3a61c0
  • Pointer size: 130 Bytes
  • Size of remote file: 59.7 kB
assets/advance/backyard-7_2.jpg ADDED

Git LFS Details

  • SHA256: 26a280e4814feba832c961025c1f1c1bd9f248b96611625323ecc5f9940325f4
  • Pointer size: 130 Bytes
  • Size of remote file: 70.4 kB
assets/advance/backyard-7_3.jpg ADDED

Git LFS Details

  • SHA256: 9fe7e28c41968d1f31c70cc64c69b885e404d5637195f68c1d7659cab75c6403
  • Pointer size: 130 Bytes
  • Size of remote file: 71.1 kB
assets/advance/backyard-7_4.jpg ADDED

Git LFS Details

  • SHA256: f5ea4fe7da9a22bc5c4936d1fae0f366344b2ed38b0aea95b8cb9389e895c9fe
  • Pointer size: 130 Bytes
  • Size of remote file: 77.2 kB
assets/advance/backyard-7_5.jpg ADDED

Git LFS Details

  • SHA256: a186074f378e1c8da814c61d53b99b037267a23fa9e72ce8a3c7d1bab9da1bf0
  • Pointer size: 130 Bytes
  • Size of remote file: 81.4 kB
assets/advance/backyard-7_6.jpg ADDED

Git LFS Details

  • SHA256: 1040d5738f1039dafcaac28f5c0e180d45e5e8453c18739966c109a47138282e
  • Pointer size: 130 Bytes
  • Size of remote file: 76.4 kB
assets/advance/blue-car.jpg ADDED
assets/advance/garden-4_0.jpg ADDED

Git LFS Details

  • SHA256: 38fbe78f699fc84a1f4268ef8bacef9ddacfd32e9eb8fbcb605e46cfd52b988e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
assets/advance/garden-4_1.jpg ADDED

Git LFS Details

  • SHA256: 1975effeffc9b2011a28f6eb04d1b0bd2f37f765c194249c95e6b3783d698a42
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
assets/advance/garden-4_2.jpg ADDED

Git LFS Details

  • SHA256: 4112ff5f2ceaa3b469bb402853e7cde10396f858e5a2ceba93b095e1e3d8d335
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
assets/advance/garden-4_3.jpg ADDED

Git LFS Details

  • SHA256: a750b648c389f78f2f6b26d78f753eace13a41d355f725850c2667f864f709cd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
assets/advance/telebooth-2_0.jpg ADDED

Git LFS Details

  • SHA256: 8822927954b3dda7c40ebc1311a5e97bae5875d5c2ee06331de5c39ef69b36c9
  • Pointer size: 130 Bytes
  • Size of remote file: 78.5 kB
assets/advance/telebooth-2_1.jpg ADDED

Git LFS Details

  • SHA256: b6718ec1dcb569ac36b5aa69154d505653985e8045aacf25c3e0d3df9798a2f5
  • Pointer size: 130 Bytes
  • Size of remote file: 78 kB
assets/advance/vgg-lab-4_0.png ADDED

Git LFS Details

  • SHA256: d1442eb509af02273cf7168f5212b3221142df4db99991b38395f42f8b239960
  • Pointer size: 131 Bytes
  • Size of remote file: 412 kB
assets/advance/vgg-lab-4_1.png ADDED

Git LFS Details

  • SHA256: c2bb10b9574247ceb0948aa00afea588f001f0271f51908b8132d63587fc43d0
  • Pointer size: 131 Bytes
  • Size of remote file: 443 kB
assets/advance/vgg-lab-4_2.png ADDED

Git LFS Details

  • SHA256: 7fa884bb6d783fd9385bd38042f3461f430bec8311e7b2171474b6a906538030
  • Pointer size: 131 Bytes
  • Size of remote file: 410 kB
assets/advance/vgg-lab-4_3.png ADDED

Git LFS Details

  • SHA256: 99469f816604c92c9c27a7cff119cb3649d3dfa4c41dcef89525b7b3cbd885a4
  • Pointer size: 131 Bytes
  • Size of remote file: 475 kB
assets/basic/blue-car.jpg ADDED

Git LFS Details

  • SHA256: 0cf493d0f738830223949fd24bb3ab0a1c078804fdb744efa95a1fdcfcfb5332
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB
assets/basic/hilly-countryside.jpg ADDED

Git LFS Details

  • SHA256: 4ae3b8cb5d989b62ceaf4930afea55790048657fa459f383f8bd809b3bdcfca0
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
assets/basic/lily-dragon.png ADDED

Git LFS Details

  • SHA256: c545057ee2feeced73566f708311bf758350ef0ded844d7bd438e48fca7f5bd2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.57 MB
assets/basic/llff-room.jpg ADDED

Git LFS Details

  • SHA256: 3529ea8355bb564784390a1a3ba5f0d2e71c4bbaed652dd2c778049766eedebf
  • Pointer size: 130 Bytes
  • Size of remote file: 35.1 kB
assets/basic/mountain-lake.jpg ADDED

Git LFS Details

  • SHA256: 5c6ef051d69e4e08ab29508d8d4b5171384935719030e88e706584db2f409c3b
  • Pointer size: 130 Bytes
  • Size of remote file: 55.6 kB
assets/basic/vasedeck.jpg ADDED

Git LFS Details

  • SHA256: 334fde31dd688bd8343ac0fba1354b4aa4779c09de03c7cb20a0386f660049cd
  • Pointer size: 130 Bytes
  • Size of remote file: 61.5 kB
assets/basic/vgg-lab-4_0.png ADDED
demo_gr.py ADDED
@@ -0,0 +1,1238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ import os.path as osp
5
+ import queue
6
+ import secrets
7
+ import threading
8
+ import time
9
+ from datetime import datetime
10
+ from glob import glob
11
+ from pathlib import Path
12
+ from typing import Literal
13
+
14
+ import gradio as gr
15
+ import httpx
16
+ import imageio.v3 as iio
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import tyro
21
+ import viser
22
+ import viser.transforms as vt
23
+ from einops import rearrange
24
+ from gradio import networking
25
+ from gradio.context import LocalContext
26
+ from gradio.tunneling import CERTIFICATE_PATH, Tunnel
27
+
28
+ from seva.eval import (
29
+ IS_TORCH_NIGHTLY,
30
+ chunk_input_and_test,
31
+ create_transforms_simple,
32
+ infer_prior_stats,
33
+ run_one_scene,
34
+ transform_img_and_K,
35
+ )
36
+ from seva.geometry import (
37
+ DEFAULT_FOV_RAD,
38
+ get_default_intrinsics,
39
+ get_preset_pose_fov,
40
+ normalize_scene,
41
+ )
42
+ from seva.gui import define_gui
43
+ from seva.model import SGMWrapper
44
+ from seva.modules.autoencoder import AutoEncoder
45
+ from seva.modules.conditioner import CLIPConditioner
46
+ from seva.modules.preprocessor import Dust3rPipeline
47
+ from seva.sampling import DDPMDiscretization, DiscreteDenoiser
48
+ from seva.utils import load_model
49
+
50
+ device = "cuda:0"
51
+
52
+
53
+ # Constants.
54
+ WORK_DIR = "work_dirs/demo_gr"
55
+ MAX_SESSIONS = 1
56
+ ADVANCE_EXAMPLE_MAP = [
57
+ (
58
+ "assets/advance/blue-car.jpg",
59
+ ["assets/advance/blue-car.jpg"],
60
+ ),
61
+ (
62
+ "assets/advance/garden-4_0.jpg",
63
+ [
64
+ "assets/advance/garden-4_0.jpg",
65
+ "assets/advance/garden-4_1.jpg",
66
+ "assets/advance/garden-4_2.jpg",
67
+ "assets/advance/garden-4_3.jpg",
68
+ ],
69
+ ),
70
+ (
71
+ "assets/advance/vgg-lab-4_0.png",
72
+ [
73
+ "assets/advance/vgg-lab-4_0.png",
74
+ "assets/advance/vgg-lab-4_1.png",
75
+ "assets/advance/vgg-lab-4_2.png",
76
+ "assets/advance/vgg-lab-4_3.png",
77
+ ],
78
+ ),
79
+ (
80
+ "assets/advance/telebooth-2_0.jpg",
81
+ [
82
+ "assets/advance/telebooth-2_0.jpg",
83
+ "assets/advance/telebooth-2_1.jpg",
84
+ ],
85
+ ),
86
+ (
87
+ "assets/advance/backyard-7_0.jpg",
88
+ [
89
+ "assets/advance/backyard-7_0.jpg",
90
+ "assets/advance/backyard-7_1.jpg",
91
+ "assets/advance/backyard-7_2.jpg",
92
+ "assets/advance/backyard-7_3.jpg",
93
+ "assets/advance/backyard-7_4.jpg",
94
+ "assets/advance/backyard-7_5.jpg",
95
+ "assets/advance/backyard-7_6.jpg",
96
+ ],
97
+ ),
98
+ ]
99
+
100
+ if IS_TORCH_NIGHTLY:
101
+ COMPILE = True
102
+ os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
103
+ os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
104
+ else:
105
+ COMPILE = False
106
+
107
+ # Shared global variables across sessions.
108
+ DUST3R = Dust3rPipeline(device=device) # type: ignore
109
+ MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device)
110
+ AE = AutoEncoder(chunk_size=1).to(device)
111
+ CONDITIONER = CLIPConditioner().to(device)
112
+ DISCRETIZATION = DDPMDiscretization()
113
+ DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device)
114
+ VERSION_DICT = {
115
+ "H": 576,
116
+ "W": 576,
117
+ "T": 21,
118
+ "C": 4,
119
+ "f": 8,
120
+ "options": {},
121
+ }
122
+ SERVERS = {}
123
+ ABORT_EVENTS = {}
124
+
125
+ if COMPILE:
126
+ MODEL = torch.compile(MODEL)
127
+ CONDITIONER = torch.compile(CONDITIONER)
128
+ AE = torch.compile(AE)
129
+
130
+
131
+ class SevaRenderer(object):
132
+ def __init__(self, server: viser.ViserServer):
133
+ self.server = server
134
+ self.gui_state = None
135
+
136
+ def preprocess(
137
+ self, input_img_path_or_tuples: list[tuple[str, None]] | str
138
+ ) -> tuple[dict, dict, dict]:
139
+ # Simply hardcode these such that aspect ratio is always kept and
140
+ # shorter side is resized to 576. This is only to make GUI option fewer
141
+ # though, changing it still works.
142
+ shorter: int = 576
143
+ # Has to be 64 multiple for the network.
144
+ shorter = round(shorter / 64) * 64
145
+
146
+ if isinstance(input_img_path_or_tuples, str):
147
+ # Assume `Basic` demo mode: just hardcode the camera parameters and ignore points.
148
+ input_imgs = torch.as_tensor(
149
+ iio.imread(input_img_path_or_tuples) / 255.0, dtype=torch.float32
150
+ )[None, ..., :3]
151
+ input_imgs = transform_img_and_K(
152
+ input_imgs.permute(0, 3, 1, 2),
153
+ shorter,
154
+ K=None,
155
+ size_stride=64,
156
+ )[0].permute(0, 2, 3, 1)
157
+ input_Ks = get_default_intrinsics(
158
+ aspect_ratio=input_imgs.shape[2] / input_imgs.shape[1]
159
+ )
160
+ input_c2ws = torch.eye(4)[None]
161
+ # Simulate a small time interval such that gradio can update
162
+ # propgress properly.
163
+ time.sleep(0.1)
164
+ return (
165
+ {
166
+ "input_imgs": input_imgs,
167
+ "input_Ks": input_Ks,
168
+ "input_c2ws": input_c2ws,
169
+ "input_wh": (input_imgs.shape[2], input_imgs.shape[1]),
170
+ "points": [np.zeros((0, 3))],
171
+ "point_colors": [np.zeros((0, 3))],
172
+ "scene_scale": 1.0,
173
+ },
174
+ gr.update(visible=False),
175
+ gr.update(),
176
+ )
177
+ else:
178
+ # Assume `Advance` demo mode: use dust3r to extract camera parameters and points.
179
+ img_paths = [p for (p, _) in input_img_path_or_tuples]
180
+ (
181
+ input_imgs,
182
+ input_Ks,
183
+ input_c2ws,
184
+ points,
185
+ point_colors,
186
+ ) = DUST3R.infer_cameras_and_points(img_paths)
187
+ num_inputs = len(img_paths)
188
+ if num_inputs == 1:
189
+ input_imgs, input_Ks, input_c2ws, points, point_colors = (
190
+ input_imgs[:1],
191
+ input_Ks[:1],
192
+ input_c2ws[:1],
193
+ points[:1],
194
+ point_colors[:1],
195
+ )
196
+ input_imgs = [img[..., :3] for img in input_imgs]
197
+ # Normalize the scene.
198
+ point_chunks = [p.shape[0] for p in points]
199
+ point_indices = np.cumsum(point_chunks)[:-1]
200
+ input_c2ws, points, _ = normalize_scene( # type: ignore
201
+ input_c2ws,
202
+ np.concatenate(points, 0),
203
+ camera_center_method="poses",
204
+ )
205
+ points = np.split(points, point_indices, 0)
206
+ # Scale camera and points for viewport visualization.
207
+ scene_scale = np.median(
208
+ np.ptp(np.concatenate([input_c2ws[:, :3, 3], *points], 0), -1)
209
+ )
210
+ input_c2ws[:, :3, 3] /= scene_scale
211
+ points = [point / scene_scale for point in points]
212
+ input_imgs = [
213
+ torch.as_tensor(img / 255.0, dtype=torch.float32) for img in input_imgs
214
+ ]
215
+ input_Ks = torch.as_tensor(input_Ks)
216
+ input_c2ws = torch.as_tensor(input_c2ws)
217
+ new_input_imgs, new_input_Ks = [], []
218
+ for img, K in zip(input_imgs, input_Ks):
219
+ img = rearrange(img, "h w c -> 1 c h w")
220
+ # If you don't want to keep aspect ratio and want to always center crop, use this:
221
+ # img, K = transform_img_and_K(img, (shorter, shorter), K=K[None])
222
+ img, K = transform_img_and_K(img, shorter, K=K[None], size_stride=64)
223
+ assert isinstance(K, torch.Tensor)
224
+ K = K / K.new_tensor([img.shape[-1], img.shape[-2], 1])[:, None]
225
+ new_input_imgs.append(img)
226
+ new_input_Ks.append(K)
227
+ input_imgs = torch.cat(new_input_imgs, 0)
228
+ input_imgs = rearrange(input_imgs, "b c h w -> b h w c")[..., :3]
229
+ input_Ks = torch.cat(new_input_Ks, 0)
230
+ return (
231
+ {
232
+ "input_imgs": input_imgs,
233
+ "input_Ks": input_Ks,
234
+ "input_c2ws": input_c2ws,
235
+ "input_wh": (input_imgs.shape[2], input_imgs.shape[1]),
236
+ "points": points,
237
+ "point_colors": point_colors,
238
+ "scene_scale": scene_scale,
239
+ },
240
+ gr.update(visible=False),
241
+ gr.update()
242
+ if num_inputs <= 10
243
+ else gr.update(choices=["interp"], value="interp"),
244
+ )
245
+
246
+ def visualize_scene(self, preprocessed: dict):
247
+ server = self.server
248
+ server.scene.reset()
249
+ server.gui.reset()
250
+ set_bkgd_color(server)
251
+
252
+ (
253
+ input_imgs,
254
+ input_Ks,
255
+ input_c2ws,
256
+ input_wh,
257
+ points,
258
+ point_colors,
259
+ scene_scale,
260
+ ) = (
261
+ preprocessed["input_imgs"],
262
+ preprocessed["input_Ks"],
263
+ preprocessed["input_c2ws"],
264
+ preprocessed["input_wh"],
265
+ preprocessed["points"],
266
+ preprocessed["point_colors"],
267
+ preprocessed["scene_scale"],
268
+ )
269
+ W, H = input_wh
270
+
271
+ server.scene.set_up_direction(-input_c2ws[..., :3, 1].mean(0).numpy())
272
+
273
+ # Use first image as default fov.
274
+ assert input_imgs[0].shape[:2] == (H, W)
275
+ if H > W:
276
+ init_fov = 2 * np.arctan(1 / (2 * input_Ks[0, 0, 0].item()))
277
+ else:
278
+ init_fov = 2 * np.arctan(1 / (2 * input_Ks[0, 1, 1].item()))
279
+ init_fov_deg = float(init_fov / np.pi * 180.0)
280
+
281
+ frustum_nodes, pcd_nodes = [], []
282
+ for i in range(len(input_imgs)):
283
+ K = input_Ks[i]
284
+ frustum = server.scene.add_camera_frustum(
285
+ f"/scene_assets/cameras/{i}",
286
+ fov=2 * np.arctan(1 / (2 * K[1, 1].item())),
287
+ aspect=W / H,
288
+ scale=0.1 * scene_scale,
289
+ image=(input_imgs[i].numpy() * 255.0).astype(np.uint8),
290
+ wxyz=vt.SO3.from_matrix(input_c2ws[i, :3, :3].numpy()).wxyz,
291
+ position=input_c2ws[i, :3, 3].numpy(),
292
+ )
293
+
294
+ def get_handler(frustum):
295
+ def handler(event: viser.GuiEvent) -> None:
296
+ assert event.client_id is not None
297
+ client = server.get_clients()[event.client_id]
298
+ with client.atomic():
299
+ client.camera.position = frustum.position
300
+ client.camera.wxyz = frustum.wxyz
301
+ # Set look_at as the projected origin onto the
302
+ # frustum's forward direction.
303
+ look_direction = vt.SO3(frustum.wxyz).as_matrix()[:, 2]
304
+ position_origin = -frustum.position
305
+ client.camera.look_at = (
306
+ frustum.position
307
+ + np.dot(look_direction, position_origin)
308
+ / np.linalg.norm(position_origin)
309
+ * look_direction
310
+ )
311
+
312
+ return handler
313
+
314
+ frustum.on_click(get_handler(frustum)) # type: ignore
315
+ frustum_nodes.append(frustum)
316
+
317
+ pcd = server.scene.add_point_cloud(
318
+ f"/scene_assets/points/{i}",
319
+ points[i],
320
+ point_colors[i],
321
+ point_size=0.01 * scene_scale,
322
+ point_shape="circle",
323
+ )
324
+ pcd_nodes.append(pcd)
325
+
326
+ with server.gui.add_folder("Scene scale", expand_by_default=False, order=200):
327
+ camera_scale_slider = server.gui.add_slider(
328
+ "Log camera scale", initial_value=0.0, min=-2.0, max=2.0, step=0.1
329
+ )
330
+
331
+ @camera_scale_slider.on_update
332
+ def _(_) -> None:
333
+ for i in range(len(frustum_nodes)):
334
+ frustum_nodes[i].scale = (
335
+ 0.1 * scene_scale * 10**camera_scale_slider.value
336
+ )
337
+
338
+ point_scale_slider = server.gui.add_slider(
339
+ "Log point scale", initial_value=0.0, min=-2.0, max=2.0, step=0.1
340
+ )
341
+
342
+ @point_scale_slider.on_update
343
+ def _(_) -> None:
344
+ for i in range(len(pcd_nodes)):
345
+ pcd_nodes[i].point_size = (
346
+ 0.01 * scene_scale * 10**point_scale_slider.value
347
+ )
348
+
349
+ self.gui_state = define_gui(
350
+ server,
351
+ init_fov=init_fov_deg,
352
+ img_wh=input_wh,
353
+ scene_scale=scene_scale,
354
+ )
355
+
356
+ def get_target_c2ws_and_Ks_from_gui(self, preprocessed: dict):
357
+ input_wh = preprocessed["input_wh"]
358
+ W, H = input_wh
359
+ gui_state = self.gui_state
360
+ assert gui_state is not None and gui_state.camera_traj_list is not None
361
+ target_c2ws, target_Ks = [], []
362
+ for item in gui_state.camera_traj_list:
363
+ target_c2ws.append(item["w2c"])
364
+ assert item["img_wh"] == input_wh
365
+ K = np.array(item["K"]).reshape(3, 3) / np.array([W, H, 1])[:, None]
366
+ target_Ks.append(K)
367
+ target_c2ws = torch.as_tensor(
368
+ np.linalg.inv(np.array(target_c2ws).reshape(-1, 4, 4))
369
+ )
370
+ target_Ks = torch.as_tensor(np.array(target_Ks).reshape(-1, 3, 3))
371
+ return target_c2ws, target_Ks
372
+
373
+ def get_target_c2ws_and_Ks_from_preset(
374
+ self,
375
+ preprocessed: dict,
376
+ preset_traj: Literal[
377
+ "orbit",
378
+ "spiral",
379
+ "lemniscate",
380
+ "zoom-in",
381
+ "zoom-out",
382
+ "dolly zoom-in",
383
+ "dolly zoom-out",
384
+ "move-forward",
385
+ "move-backward",
386
+ "move-up",
387
+ "move-down",
388
+ "move-left",
389
+ "move-right",
390
+ ],
391
+ num_frames: int,
392
+ zoom_factor: float | None,
393
+ ):
394
+ img_wh = preprocessed["input_wh"]
395
+ start_c2w = preprocessed["input_c2ws"][0]
396
+ start_w2c = torch.linalg.inv(start_c2w)
397
+ look_at = torch.tensor([0, 0, 10])
398
+ start_fov = DEFAULT_FOV_RAD
399
+ target_c2ws, target_fovs = get_preset_pose_fov(
400
+ preset_traj,
401
+ num_frames,
402
+ start_w2c,
403
+ look_at,
404
+ -start_c2w[:3, 1],
405
+ start_fov,
406
+ spiral_radii=[1.0, 1.0, 0.5],
407
+ zoom_factor=zoom_factor,
408
+ )
409
+ target_c2ws = torch.as_tensor(target_c2ws)
410
+ target_fovs = torch.as_tensor(target_fovs)
411
+ target_Ks = get_default_intrinsics(
412
+ target_fovs, # type: ignore
413
+ aspect_ratio=img_wh[0] / img_wh[1],
414
+ )
415
+ return target_c2ws, target_Ks
416
+
417
+ def export_output_data(self, preprocessed: dict, output_dir: str):
418
+ input_imgs, input_Ks, input_c2ws, input_wh = (
419
+ preprocessed["input_imgs"],
420
+ preprocessed["input_Ks"],
421
+ preprocessed["input_c2ws"],
422
+ preprocessed["input_wh"],
423
+ )
424
+ target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_gui(preprocessed)
425
+
426
+ num_inputs = len(input_imgs)
427
+ num_targets = len(target_c2ws)
428
+
429
+ input_imgs = (input_imgs.cpu().numpy() * 255.0).astype(np.uint8)
430
+ input_c2ws = input_c2ws.cpu().numpy()
431
+ input_Ks = input_Ks.cpu().numpy()
432
+ target_c2ws = target_c2ws.cpu().numpy()
433
+ target_Ks = target_Ks.cpu().numpy()
434
+ img_whs = np.array(input_wh)[None].repeat(len(input_imgs) + len(target_Ks), 0)
435
+
436
+ os.makedirs(output_dir, exist_ok=True)
437
+ img_paths = []
438
+ for i, img in enumerate(input_imgs):
439
+ iio.imwrite(img_path := osp.join(output_dir, f"{i:03d}.png"), img)
440
+ img_paths.append(img_path)
441
+ for i in range(num_targets):
442
+ iio.imwrite(
443
+ img_path := osp.join(output_dir, f"{i + num_inputs:03d}.png"),
444
+ np.zeros((input_wh[1], input_wh[0], 3), dtype=np.uint8),
445
+ )
446
+ img_paths.append(img_path)
447
+
448
+ # Convert from OpenCV to OpenGL camera format.
449
+ all_c2ws = np.concatenate([input_c2ws, target_c2ws])
450
+ all_Ks = np.concatenate([input_Ks, target_Ks])
451
+ all_c2ws = all_c2ws @ np.diag([1, -1, -1, 1])
452
+ create_transforms_simple(output_dir, img_paths, img_whs, all_c2ws, all_Ks)
453
+ split_dict = {
454
+ "train_ids": list(range(num_inputs)),
455
+ "test_ids": list(range(num_inputs, num_inputs + num_targets)),
456
+ }
457
+ with open(
458
+ osp.join(output_dir, f"train_test_split_{num_inputs}.json"), "w"
459
+ ) as f:
460
+ json.dump(split_dict, f, indent=4)
461
+ gr.Info(f"Output data saved to {output_dir}", duration=1)
462
+
463
+ def render(
464
+ self,
465
+ preprocessed: dict,
466
+ session_hash: str,
467
+ seed: int,
468
+ chunk_strategy: str,
469
+ cfg: float,
470
+ preset_traj: Literal[
471
+ "orbit",
472
+ "spiral",
473
+ "lemniscate",
474
+ "zoom-in",
475
+ "zoom-out",
476
+ "dolly zoom-in",
477
+ "dolly zoom-out",
478
+ "move-forward",
479
+ "move-backward",
480
+ "move-up",
481
+ "move-down",
482
+ "move-left",
483
+ "move-right",
484
+ ]
485
+ | None,
486
+ num_frames: int | None,
487
+ zoom_factor: float | None,
488
+ camera_scale: float,
489
+ ):
490
+ render_name = datetime.now().strftime("%Y%m%d_%H%M%S")
491
+ render_dir = osp.join(WORK_DIR, render_name)
492
+
493
+ input_imgs, input_Ks, input_c2ws, (W, H) = (
494
+ preprocessed["input_imgs"],
495
+ preprocessed["input_Ks"],
496
+ preprocessed["input_c2ws"],
497
+ preprocessed["input_wh"],
498
+ )
499
+ num_inputs = len(input_imgs)
500
+ if preset_traj is None:
501
+ target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_gui(preprocessed)
502
+ else:
503
+ assert num_frames is not None
504
+ assert num_inputs == 1
505
+ input_c2ws = torch.eye(4)[None].to(dtype=input_c2ws.dtype)
506
+ target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_preset(
507
+ preprocessed, preset_traj, num_frames, zoom_factor
508
+ )
509
+ all_c2ws = torch.cat([input_c2ws, target_c2ws], 0)
510
+ all_Ks = (
511
+ torch.cat([input_Ks, target_Ks], 0)
512
+ * input_Ks.new_tensor([W, H, 1])[:, None]
513
+ )
514
+ num_targets = len(target_c2ws)
515
+ input_indices = list(range(num_inputs))
516
+ target_indices = np.arange(num_inputs, num_inputs + num_targets).tolist()
517
+ # Get anchor cameras.
518
+ T = VERSION_DICT["T"]
519
+ version_dict = copy.deepcopy(VERSION_DICT)
520
+ num_anchors = infer_prior_stats(
521
+ T,
522
+ num_inputs,
523
+ num_total_frames=num_targets,
524
+ version_dict=version_dict,
525
+ )
526
+ # infer_prior_stats modifies T in-place.
527
+ T = version_dict["T"]
528
+ assert isinstance(num_anchors, int)
529
+ anchor_indices = np.linspace(
530
+ num_inputs,
531
+ num_inputs + num_targets - 1,
532
+ num_anchors,
533
+ ).tolist()
534
+ anchor_c2ws = all_c2ws[[round(ind) for ind in anchor_indices]]
535
+ anchor_Ks = all_Ks[[round(ind) for ind in anchor_indices]]
536
+ # Create image conditioning.
537
+ all_imgs_np = (
538
+ F.pad(input_imgs, (0, 0, 0, 0, 0, 0, 0, num_targets), value=0.0).numpy()
539
+ * 255.0
540
+ ).astype(np.uint8)
541
+ image_cond = {
542
+ "img": all_imgs_np,
543
+ "input_indices": input_indices,
544
+ "prior_indices": anchor_indices,
545
+ }
546
+ # Create camera conditioning (K is unnormalized).
547
+ camera_cond = {
548
+ "c2w": all_c2ws,
549
+ "K": all_Ks,
550
+ "input_indices": list(range(num_inputs + num_targets)),
551
+ }
552
+ # Run rendering.
553
+ num_steps = 50
554
+ options_ori = VERSION_DICT["options"]
555
+ options = copy.deepcopy(options_ori)
556
+ options["chunk_strategy"] = chunk_strategy
557
+ options["video_save_fps"] = 30.0
558
+ options["beta_linear_start"] = 5e-6
559
+ options["log_snr_shift"] = 2.4
560
+ options["guider_types"] = [1, 2]
561
+ options["cfg"] = [
562
+ float(cfg),
563
+ 3.0 if num_inputs >= 9 else 2.0,
564
+ ] # We define semi-dense-view regime to have 9 input views.
565
+ options["camera_scale"] = camera_scale
566
+ options["num_steps"] = num_steps
567
+ options["cfg_min"] = 1.2
568
+ options["encoding_t"] = 1
569
+ options["decoding_t"] = 1
570
+ assert session_hash in ABORT_EVENTS
571
+ abort_event = ABORT_EVENTS[session_hash]
572
+ abort_event.clear()
573
+ options["abort_event"] = abort_event
574
+ task = "img2trajvid"
575
+ # Get number of first pass chunks.
576
+ T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
577
+ chunk_strategy_first_pass = options.get(
578
+ "chunk_strategy_first_pass", "gt-nearest"
579
+ )
580
+ num_chunks_0 = len(
581
+ chunk_input_and_test(
582
+ T_first_pass,
583
+ input_c2ws,
584
+ anchor_c2ws,
585
+ input_indices,
586
+ image_cond["prior_indices"],
587
+ options={**options, "sampler_verbose": False},
588
+ task=task,
589
+ chunk_strategy=chunk_strategy_first_pass,
590
+ gt_input_inds=list(range(input_c2ws.shape[0])),
591
+ )[1]
592
+ )
593
+ # Get number of second pass chunks.
594
+ anchor_argsort = np.argsort(input_indices + anchor_indices).tolist()
595
+ anchor_indices = np.array(input_indices + anchor_indices)[
596
+ anchor_argsort
597
+ ].tolist()
598
+ gt_input_inds = [anchor_argsort.index(i) for i in range(input_c2ws.shape[0])]
599
+ anchor_c2ws_second_pass = torch.cat([input_c2ws, anchor_c2ws], dim=0)[
600
+ anchor_argsort
601
+ ]
602
+ T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
603
+ chunk_strategy = options.get("chunk_strategy", "nearest")
604
+ num_chunks_1 = len(
605
+ chunk_input_and_test(
606
+ T_second_pass,
607
+ anchor_c2ws_second_pass,
608
+ target_c2ws,
609
+ anchor_indices,
610
+ target_indices,
611
+ options={**options, "sampler_verbose": False},
612
+ task=task,
613
+ chunk_strategy=chunk_strategy,
614
+ gt_input_inds=gt_input_inds,
615
+ )[1]
616
+ )
617
+ second_pass_pbar = gr.Progress().tqdm(
618
+ iterable=None,
619
+ desc="Second pass sampling",
620
+ total=num_chunks_1 * num_steps,
621
+ )
622
+ first_pass_pbar = gr.Progress().tqdm(
623
+ iterable=None,
624
+ desc="First pass sampling",
625
+ total=num_chunks_0 * num_steps,
626
+ )
627
+ video_path_generator = run_one_scene(
628
+ task=task,
629
+ version_dict={
630
+ "H": H,
631
+ "W": W,
632
+ "T": T,
633
+ "C": VERSION_DICT["C"],
634
+ "f": VERSION_DICT["f"],
635
+ "options": options,
636
+ },
637
+ model=MODEL,
638
+ ae=AE,
639
+ conditioner=CONDITIONER,
640
+ denoiser=DENOISER,
641
+ image_cond=image_cond,
642
+ camera_cond=camera_cond,
643
+ save_path=render_dir,
644
+ use_traj_prior=True,
645
+ traj_prior_c2ws=anchor_c2ws,
646
+ traj_prior_Ks=anchor_Ks,
647
+ seed=seed,
648
+ gradio=True,
649
+ first_pass_pbar=first_pass_pbar,
650
+ second_pass_pbar=second_pass_pbar,
651
+ abort_event=abort_event,
652
+ )
653
+ output_queue = queue.Queue()
654
+
655
+ blocks = LocalContext.blocks.get()
656
+ event_id = LocalContext.event_id.get()
657
+
658
+ def worker():
659
+ # gradio doesn't support threading with progress intentionally, so
660
+ # we need to hack this.
661
+ LocalContext.blocks.set(blocks)
662
+ LocalContext.event_id.set(event_id)
663
+ for i, video_path in enumerate(video_path_generator):
664
+ if i == 0:
665
+ output_queue.put(
666
+ (
667
+ video_path,
668
+ gr.update(),
669
+ gr.update(),
670
+ gr.update(),
671
+ )
672
+ )
673
+ elif i == 1:
674
+ output_queue.put(
675
+ (
676
+ video_path,
677
+ gr.update(visible=True),
678
+ gr.update(visible=False),
679
+ gr.update(visible=False),
680
+ )
681
+ )
682
+ else:
683
+ gr.Error("More than two passes during rendering.")
684
+
685
+ thread = threading.Thread(target=worker, daemon=True)
686
+ thread.start()
687
+
688
+ while thread.is_alive() or not output_queue.empty():
689
+ if abort_event.is_set():
690
+ thread.join()
691
+ abort_event.clear()
692
+ yield (
693
+ gr.update(),
694
+ gr.update(visible=True),
695
+ gr.update(visible=False),
696
+ gr.update(visible=False),
697
+ )
698
+ time.sleep(0.1)
699
+ while not output_queue.empty():
700
+ yield output_queue.get()
701
+
702
+
703
+ # This is basically a copy of the original `networking.setup_tunnel` function,
704
+ # but it also returns the tunnel object for proper cleanup.
705
+ def setup_tunnel(
706
+ local_host: str, local_port: int, share_token: str, share_server_address: str | None
707
+ ) -> tuple[str, Tunnel]:
708
+ share_server_address = (
709
+ networking.GRADIO_SHARE_SERVER_ADDRESS
710
+ if share_server_address is None
711
+ else share_server_address
712
+ )
713
+ if share_server_address is None:
714
+ try:
715
+ response = httpx.get(networking.GRADIO_API_SERVER, timeout=30)
716
+ payload = response.json()[0]
717
+ remote_host, remote_port = payload["host"], int(payload["port"])
718
+ certificate = payload["root_ca"]
719
+ Path(CERTIFICATE_PATH).parent.mkdir(parents=True, exist_ok=True)
720
+ with open(CERTIFICATE_PATH, "w") as f:
721
+ f.write(certificate)
722
+ except Exception as e:
723
+ raise RuntimeError(
724
+ "Could not get share link from Gradio API Server."
725
+ ) from e
726
+ else:
727
+ remote_host, remote_port = share_server_address.split(":")
728
+ remote_port = int(remote_port)
729
+ tunnel = Tunnel(remote_host, remote_port, local_host, local_port, share_token)
730
+ address = tunnel.start_tunnel()
731
+ return address, tunnel
732
+
733
+
734
+ def set_bkgd_color(server: viser.ViserServer | viser.ClientHandle):
735
+ server.scene.set_background_image(np.array([[[39, 39, 42]]], dtype=np.uint8))
736
+
737
+
738
+ def start_server_and_abort_event(request: gr.Request):
739
+ server = viser.ViserServer()
740
+
741
+ @server.on_client_connect
742
+ def _(client: viser.ClientHandle):
743
+ # Force dark mode that blends well with gradio's dark theme.
744
+ client.gui.configure_theme(
745
+ dark_mode=True,
746
+ show_share_button=False,
747
+ control_layout="collapsible",
748
+ )
749
+ set_bkgd_color(client)
750
+
751
+ print(f"Starting server {server.get_port()}")
752
+ server_url, tunnel = setup_tunnel(
753
+ local_host=server.get_host(),
754
+ local_port=server.get_port(),
755
+ share_token=secrets.token_urlsafe(32),
756
+ share_server_address=None,
757
+ )
758
+ SERVERS[request.session_hash] = (server, tunnel)
759
+ if server_url is None:
760
+ raise gr.Error(
761
+ "Failed to get a viewport URL. Please check your network connection."
762
+ )
763
+ # Give it enough time to start.
764
+ time.sleep(1)
765
+
766
+ ABORT_EVENTS[request.session_hash] = threading.Event()
767
+
768
+ return (
769
+ SevaRenderer(server),
770
+ gr.HTML(
771
+ f'<iframe src="{server_url}" style="display: block; margin: auto; width: 100%; height: min(60vh, 600px);" frameborder="0"></iframe>',
772
+ container=True,
773
+ ),
774
+ request.session_hash,
775
+ )
776
+
777
+
778
+ def stop_server_and_abort_event(request: gr.Request):
779
+ if request.session_hash in SERVERS:
780
+ print(f"Stopping server {request.session_hash}")
781
+ server, tunnel = SERVERS.pop(request.session_hash)
782
+ server.stop()
783
+ tunnel.kill()
784
+
785
+ if request.session_hash in ABORT_EVENTS:
786
+ print(f"Setting abort event {request.session_hash}")
787
+ ABORT_EVENTS[request.session_hash].set()
788
+ # Give it enough time to abort jobs.
789
+ time.sleep(5)
790
+ ABORT_EVENTS.pop(request.session_hash)
791
+
792
+
793
+ def set_abort_event(request: gr.Request):
794
+ if request.session_hash in ABORT_EVENTS:
795
+ print(f"Setting abort event {request.session_hash}")
796
+ ABORT_EVENTS[request.session_hash].set()
797
+
798
+
799
+ def get_advance_examples(selection: gr.SelectData):
800
+ index = selection.index
801
+ return (
802
+ gr.Gallery(ADVANCE_EXAMPLE_MAP[index][1], visible=True),
803
+ gr.update(visible=True),
804
+ gr.update(visible=True),
805
+ gr.Gallery(visible=False),
806
+ )
807
+
808
+
809
+ def get_preamble():
810
+ gr.Markdown("""
811
+ # Stable Virtual Camera
812
+ <span style="display: flex; flex-wrap: wrap; gap: 5px;">
813
+ <a href="https://stable-virtual-camera.github.io"><img src="https://img.shields.io/badge/%F0%9F%8F%A0%20Project%20Page-gray.svg"></a>
814
+ <a href="https://stable-virtual-camera.github.io/pdf/paper.pdf"><img src="https://img.shields.io/badge/%F0%9F%93%84%20Paper-gray.svg"></a>
815
+ <a href="https://stability.ai/news/introducing-stable-virtual-camera-multi-view-video-generation-with-3d-camera-control"><img src="https://img.shields.io/badge/%F0%9F%93%83%20Blog-Stability%20AI-orange.svg"></a>
816
+ <a href="https://huggingface.co/stabilityai/stable-virtual-camera"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a>
817
+ <a href="https://huggingface.co/spaces/stabilityai/stable-virtual-camera"><img src="https://img.shields.io/badge/%F0%9F%9A%80%20Gradio%20Demo-Huggingface-orange"></a>
818
+ <a href="https://www.youtube.com/channel/UCLLlVDcS7nNenT_zzO3OPxQ"><img src="https://img.shields.io/badge/%F0%9F%8E%AC%20Video-YouTube-orange"></a>
819
+ </span>
820
+
821
+ Welcome to the demo of <strong>Stable Virtual Camera (Seva)</strong>! Given any number of input views and their cameras, this demo will allow you to generate novel views of a scene at any target camera of interest.
822
+
823
+ We provide two ways to use our demo (selected by the tab below, documented [here](https://github.com/Stability-AI/stable-virtual-camera/blob/main/docs/GR_USAGE.md)):
824
+ 1. **[Basic](https://github.com/user-attachments/assets/4d965fa6-d8eb-452c-b773-6e09c88ca705)**: Given a single image, you can generate a video following one of our preset camera trajectories.
825
+ 2. **[Advanced](https://github.com/user-attachments/assets/dcec1be0-bd10-441e-879c-d1c2b63091ba)**: Given any number of input images, you can generate a video following any camera trajectory of your choice by our key-frame-based interface.
826
+
827
+ > This is a research preview and comes with a few [limitations](https://stable-virtual-camera.github.io/#limitations):
828
+ > - Limited quality in certain subjects due to training data, including humans, animals, and dynamic textures.
829
+ > - Limited quality in some highly ambiguous scenes and camera trajectories, including extreme views and collision into objects.
830
+ """)
831
+
832
+
833
+ # Make sure that gradio uses dark theme.
834
+ _APP_JS = """
835
+ function refresh() {
836
+ const url = new URL(window.location);
837
+ if (url.searchParams.get('__theme') !== 'dark') {
838
+ url.searchParams.set('__theme', 'dark');
839
+ }
840
+ }
841
+ """
842
+
843
+
844
+ def main(server_port: int | None = None, share: bool = True):
845
+ with gr.Blocks(js=_APP_JS) as app:
846
+ renderer = gr.State()
847
+ session_hash = gr.State()
848
+ _ = get_preamble()
849
+ with gr.Tabs():
850
+ with gr.Tab("Basic"):
851
+ render_btn = gr.Button("Render video", interactive=False, render=False)
852
+ with gr.Row():
853
+ with gr.Column():
854
+ with gr.Group():
855
+ preprocess_btn = gr.Button("Preprocess images")
856
+ preprocess_progress = gr.Textbox(
857
+ label="",
858
+ visible=False,
859
+ interactive=False,
860
+ )
861
+ with gr.Group():
862
+ input_imgs = gr.Image(
863
+ type="filepath",
864
+ label="Input",
865
+ height=200,
866
+ )
867
+ _ = gr.Examples(
868
+ examples=sorted(glob("assets/basic/*")),
869
+ inputs=[input_imgs],
870
+ label="Example",
871
+ )
872
+ chunk_strategy = gr.Dropdown(
873
+ ["interp", "interp-gt"],
874
+ label="Chunk strategy",
875
+ render=False,
876
+ )
877
+ preprocessed = gr.State()
878
+ preprocess_btn.click(
879
+ lambda r, *args: [
880
+ *r.preprocess(*args),
881
+ gr.update(interactive=True),
882
+ ],
883
+ inputs=[renderer, input_imgs],
884
+ outputs=[
885
+ preprocessed,
886
+ preprocess_progress,
887
+ chunk_strategy,
888
+ render_btn,
889
+ ],
890
+ show_progress_on=[preprocess_progress],
891
+ concurrency_limit=1,
892
+ concurrency_id="gpu_queue",
893
+ )
894
+ preprocess_btn.click(
895
+ lambda: gr.update(visible=True),
896
+ outputs=[preprocess_progress],
897
+ )
898
+ with gr.Row():
899
+ preset_traj = gr.Dropdown(
900
+ choices=[
901
+ "orbit",
902
+ "spiral",
903
+ "lemniscate",
904
+ "zoom-in",
905
+ "zoom-out",
906
+ "dolly zoom-in",
907
+ "dolly zoom-out",
908
+ "move-forward",
909
+ "move-backward",
910
+ "move-up",
911
+ "move-down",
912
+ "move-left",
913
+ "move-right",
914
+ ],
915
+ label="Preset trajectory",
916
+ value="orbit",
917
+ )
918
+ num_frames = gr.Slider(30, 150, 80, label="#Frames")
919
+ zoom_factor = gr.Slider(
920
+ step=0.01, label="Zoom factor", visible=False
921
+ )
922
+ with gr.Row():
923
+ seed = gr.Number(value=23, label="Random seed")
924
+ chunk_strategy.render()
925
+ cfg = gr.Slider(1.0, 7.0, value=4.0, label="CFG value")
926
+ with gr.Row():
927
+ camera_scale = gr.Slider(
928
+ 0.1,
929
+ 15.0,
930
+ value=2.0,
931
+ label="Camera scale",
932
+ )
933
+
934
+ def default_cfg_preset_traj(traj):
935
+ # These are just some hand-tuned values that we
936
+ # found work the best.
937
+ if traj in ["zoom-out", "move-down"]:
938
+ value = 5.0
939
+ elif traj in [
940
+ "orbit",
941
+ "dolly zoom-out",
942
+ "move-backward",
943
+ "move-up",
944
+ "move-left",
945
+ "move-right",
946
+ ]:
947
+ value = 4.0
948
+ else:
949
+ value = 3.0
950
+ return value
951
+
952
+ preset_traj.change(
953
+ default_cfg_preset_traj,
954
+ inputs=[preset_traj],
955
+ outputs=[cfg],
956
+ )
957
+ preset_traj.change(
958
+ lambda traj: gr.update(
959
+ value=(
960
+ 10.0 if "dolly" in traj or "pan" in traj else 2.0
961
+ )
962
+ ),
963
+ inputs=[preset_traj],
964
+ outputs=[camera_scale],
965
+ )
966
+
967
+ def zoom_factor_preset_traj(traj):
968
+ visible = traj in [
969
+ "zoom-in",
970
+ "zoom-out",
971
+ "dolly zoom-in",
972
+ "dolly zoom-out",
973
+ ]
974
+ is_zoomin = traj.endswith("zoom-in")
975
+ if is_zoomin:
976
+ minimum = 0.1
977
+ maximum = 0.5
978
+ value = 0.28
979
+ else:
980
+ minimum = 1.2
981
+ maximum = 3
982
+ value = 1.5
983
+ return gr.update(
984
+ visible=visible,
985
+ minimum=minimum,
986
+ maximum=maximum,
987
+ value=value,
988
+ )
989
+
990
+ preset_traj.change(
991
+ zoom_factor_preset_traj,
992
+ inputs=[preset_traj],
993
+ outputs=[zoom_factor],
994
+ )
995
+ with gr.Column():
996
+ with gr.Group():
997
+ abort_btn = gr.Button("Abort rendering", visible=False)
998
+ render_btn.render()
999
+ render_progress = gr.Textbox(
1000
+ label="", visible=False, interactive=False
1001
+ )
1002
+ output_video = gr.Video(
1003
+ label="Output", interactive=False, autoplay=True, loop=True
1004
+ )
1005
+ render_btn.click(
1006
+ lambda r, *args: (yield from r.render(*args)),
1007
+ inputs=[
1008
+ renderer,
1009
+ preprocessed,
1010
+ session_hash,
1011
+ seed,
1012
+ chunk_strategy,
1013
+ cfg,
1014
+ preset_traj,
1015
+ num_frames,
1016
+ zoom_factor,
1017
+ camera_scale,
1018
+ ],
1019
+ outputs=[
1020
+ output_video,
1021
+ render_btn,
1022
+ abort_btn,
1023
+ render_progress,
1024
+ ],
1025
+ show_progress_on=[render_progress],
1026
+ concurrency_id="gpu_queue",
1027
+ )
1028
+ render_btn.click(
1029
+ lambda: [
1030
+ gr.update(visible=False),
1031
+ gr.update(visible=True),
1032
+ gr.update(visible=True),
1033
+ ],
1034
+ outputs=[render_btn, abort_btn, render_progress],
1035
+ )
1036
+ abort_btn.click(set_abort_event)
1037
+ with gr.Tab("Advanced"):
1038
+ render_btn = gr.Button("Render video", interactive=False, render=False)
1039
+ viewport = gr.HTML(container=True, render=False)
1040
+ gr.Timer(0.1).tick(
1041
+ lambda renderer: gr.update(
1042
+ interactive=renderer is not None
1043
+ and renderer.gui_state is not None
1044
+ and renderer.gui_state.camera_traj_list is not None
1045
+ ),
1046
+ inputs=[renderer],
1047
+ outputs=[render_btn],
1048
+ )
1049
+ with gr.Row():
1050
+ viewport.render()
1051
+ with gr.Row():
1052
+ with gr.Column():
1053
+ with gr.Group():
1054
+ preprocess_btn = gr.Button("Preprocess images")
1055
+ preprocess_progress = gr.Textbox(
1056
+ label="",
1057
+ visible=False,
1058
+ interactive=False,
1059
+ )
1060
+ with gr.Group():
1061
+ input_imgs = gr.Gallery(
1062
+ interactive=True,
1063
+ label="Input",
1064
+ columns=4,
1065
+ height=200,
1066
+ )
1067
+ # Define example images (gradio doesn't support variable length
1068
+ # examples so we need to hack it).
1069
+ example_imgs = gr.Gallery(
1070
+ [e[0] for e in ADVANCE_EXAMPLE_MAP],
1071
+ allow_preview=False,
1072
+ preview=False,
1073
+ label="Example",
1074
+ columns=20,
1075
+ rows=1,
1076
+ height=115,
1077
+ )
1078
+ example_imgs_expander = gr.Gallery(
1079
+ visible=False,
1080
+ interactive=False,
1081
+ label="Example",
1082
+ preview=True,
1083
+ columns=20,
1084
+ rows=1,
1085
+ )
1086
+ chunk_strategy = gr.Dropdown(
1087
+ ["interp-gt", "interp"],
1088
+ label="Chunk strategy",
1089
+ value="interp-gt",
1090
+ render=False,
1091
+ )
1092
+ with gr.Row():
1093
+ example_imgs_backer = gr.Button(
1094
+ "Go back", visible=False
1095
+ )
1096
+ example_imgs_confirmer = gr.Button(
1097
+ "Confirm", visible=False
1098
+ )
1099
+ example_imgs.select(
1100
+ get_advance_examples,
1101
+ outputs=[
1102
+ example_imgs_expander,
1103
+ example_imgs_confirmer,
1104
+ example_imgs_backer,
1105
+ example_imgs,
1106
+ ],
1107
+ )
1108
+ example_imgs_confirmer.click(
1109
+ lambda x: (
1110
+ x,
1111
+ gr.update(visible=False),
1112
+ gr.update(visible=False),
1113
+ gr.update(visible=False),
1114
+ gr.update(visible=True),
1115
+ ),
1116
+ inputs=[example_imgs_expander],
1117
+ outputs=[
1118
+ input_imgs,
1119
+ example_imgs_expander,
1120
+ example_imgs_confirmer,
1121
+ example_imgs_backer,
1122
+ example_imgs,
1123
+ ],
1124
+ )
1125
+ example_imgs_backer.click(
1126
+ lambda: (
1127
+ gr.update(visible=False),
1128
+ gr.update(visible=False),
1129
+ gr.update(visible=False),
1130
+ gr.update(visible=True),
1131
+ ),
1132
+ outputs=[
1133
+ example_imgs_expander,
1134
+ example_imgs_confirmer,
1135
+ example_imgs_backer,
1136
+ example_imgs,
1137
+ ],
1138
+ )
1139
+ preprocessed = gr.State()
1140
+ preprocess_btn.click(
1141
+ lambda r, *args: r.preprocess(*args),
1142
+ inputs=[renderer, input_imgs],
1143
+ outputs=[
1144
+ preprocessed,
1145
+ preprocess_progress,
1146
+ chunk_strategy,
1147
+ ],
1148
+ show_progress_on=[preprocess_progress],
1149
+ concurrency_id="gpu_queue",
1150
+ )
1151
+ preprocess_btn.click(
1152
+ lambda: gr.update(visible=True),
1153
+ outputs=[preprocess_progress],
1154
+ )
1155
+ preprocessed.change(
1156
+ lambda r, *args: r.visualize_scene(*args),
1157
+ inputs=[renderer, preprocessed],
1158
+ )
1159
+ with gr.Row():
1160
+ seed = gr.Number(value=23, label="Random seed")
1161
+ chunk_strategy.render()
1162
+ cfg = gr.Slider(1.0, 7.0, value=3.0, label="CFG value")
1163
+ with gr.Row():
1164
+ camera_scale = gr.Slider(
1165
+ 0.1,
1166
+ 15.0,
1167
+ value=2.0,
1168
+ label="Camera scale (useful for single-view input)",
1169
+ )
1170
+ with gr.Group():
1171
+ output_data_dir = gr.Textbox(label="Output data directory")
1172
+ output_data_btn = gr.Button("Export output data")
1173
+ output_data_btn.click(
1174
+ lambda r, *args: r.export_output_data(*args),
1175
+ inputs=[renderer, preprocessed, output_data_dir],
1176
+ )
1177
+ with gr.Column():
1178
+ with gr.Group():
1179
+ abort_btn = gr.Button("Abort rendering", visible=False)
1180
+ render_btn.render()
1181
+ render_progress = gr.Textbox(
1182
+ label="", visible=False, interactive=False
1183
+ )
1184
+ output_video = gr.Video(
1185
+ label="Output", interactive=False, autoplay=True, loop=True
1186
+ )
1187
+ render_btn.click(
1188
+ lambda r, *args: (yield from r.render(*args)),
1189
+ inputs=[
1190
+ renderer,
1191
+ preprocessed,
1192
+ session_hash,
1193
+ seed,
1194
+ chunk_strategy,
1195
+ cfg,
1196
+ gr.State(),
1197
+ gr.State(),
1198
+ gr.State(),
1199
+ camera_scale,
1200
+ ],
1201
+ outputs=[
1202
+ output_video,
1203
+ render_btn,
1204
+ abort_btn,
1205
+ render_progress,
1206
+ ],
1207
+ show_progress_on=[render_progress],
1208
+ concurrency_id="gpu_queue",
1209
+ )
1210
+ render_btn.click(
1211
+ lambda: [
1212
+ gr.update(visible=False),
1213
+ gr.update(visible=True),
1214
+ gr.update(visible=True),
1215
+ ],
1216
+ outputs=[render_btn, abort_btn, render_progress],
1217
+ )
1218
+ abort_btn.click(set_abort_event)
1219
+
1220
+ # Register the session initialization and cleanup functions.
1221
+ app.load(
1222
+ start_server_and_abort_event,
1223
+ outputs=[renderer, viewport, session_hash],
1224
+ )
1225
+ app.unload(stop_server_and_abort_event)
1226
+
1227
+ app.queue(max_size=5).launch(
1228
+ share=share,
1229
+ server_port=server_port,
1230
+ show_error=True,
1231
+ allowed_paths=[WORK_DIR],
1232
+ # Badget rendering will be broken otherwise.
1233
+ ssr_mode=False,
1234
+ )
1235
+
1236
+
1237
+ if __name__ == "__main__":
1238
+ tyro.cli(main)
requirements.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/nightly/cu124
2
+ torch==2.7.0.dev20250218+cu124
3
+ torchvision==0.22.0.dev20250219+cu124
4
+ roma
5
+ gradio==5.17.0
6
+ matplotlib
7
+ tqdm
8
+ opencv-python
9
+ scipy
10
+ einops
11
+ trimesh
12
+ tensorboard
13
+ git+https://github.com/jensenz-sai/pycolmap@543266bc316df2fe407b3a33d454b310b1641042
14
+ pyglet<2
15
+ huggingface-hub[torch]>=0.22
16
+ pillow-heif # add heif/heic image support
17
+ pyrender # for rendering depths in scannetpp
18
+ kapture # for visloc data loading
19
+ kapture-localization
20
+ numpy==1.24.4
21
+ numpy-quaternion
22
+ pycolmap # for pnp
23
+ poselib # for pnp
24
+ viser
25
+ tyro
26
+ ninja
27
+ colorama
28
+ pytorch-lightning
29
+ splines
30
+ diffusers
31
+ kornia
32
+ open-clip-torch
33
+ accelerate
34
+ pyav
35
+ imageio[ffmpeg]
seva/__init__.py ADDED
File without changes
seva/data_io.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ from glob import glob
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import cv2
8
+ import imageio.v3 as iio
9
+ import numpy as np
10
+ import torch
11
+
12
+ from seva.geometry import (
13
+ align_principle_axes,
14
+ similarity_from_cameras,
15
+ transform_cameras,
16
+ transform_points,
17
+ )
18
+
19
+
20
+ def _get_rel_paths(path_dir: str) -> List[str]:
21
+ """Recursively get relative paths of files in a directory."""
22
+ paths = []
23
+ for dp, _, fn in os.walk(path_dir):
24
+ for f in fn:
25
+ paths.append(os.path.relpath(os.path.join(dp, f), path_dir))
26
+ return paths
27
+
28
+
29
+ class BaseParser(object):
30
+ def __init__(
31
+ self,
32
+ data_dir: str,
33
+ factor: int = 1,
34
+ normalize: bool = False,
35
+ test_every: Optional[int] = 8,
36
+ ):
37
+ self.data_dir = data_dir
38
+ self.factor = factor
39
+ self.normalize = normalize
40
+ self.test_every = test_every
41
+
42
+ self.image_names: List[str] = [] # (num_images,)
43
+ self.image_paths: List[str] = [] # (num_images,)
44
+ self.camtoworlds: np.ndarray = np.zeros((0, 4, 4)) # (num_images, 4, 4)
45
+ self.camera_ids: List[int] = [] # (num_images,)
46
+ self.Ks_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> K
47
+ self.params_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> params
48
+ self.imsize_dict: Dict[
49
+ int, Tuple[int, int]
50
+ ] = {} # Dict of camera_id -> (width, height)
51
+ self.points: np.ndarray = np.zeros((0, 3)) # (num_points, 3)
52
+ self.points_err: np.ndarray = np.zeros((0,)) # (num_points,)
53
+ self.points_rgb: np.ndarray = np.zeros((0, 3)) # (num_points, 3)
54
+ self.point_indices: Dict[str, np.ndarray] = {} # Dict of image_name -> (M,)
55
+ self.transform: np.ndarray = np.zeros((4, 4)) # (4, 4)
56
+
57
+ self.mapx_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> (H, W)
58
+ self.mapy_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> (H, W)
59
+ self.roi_undist_dict: Dict[int, Tuple[int, int, int, int]] = (
60
+ dict()
61
+ ) # Dict of camera_id -> (x, y, w, h)
62
+ self.scene_scale: float = 1.0
63
+
64
+
65
+ class DirectParser(BaseParser):
66
+ def __init__(
67
+ self,
68
+ imgs: List[np.ndarray],
69
+ c2ws: np.ndarray,
70
+ Ks: np.ndarray,
71
+ points: Optional[np.ndarray] = None,
72
+ points_rgb: Optional[np.ndarray] = None, # uint8
73
+ mono_disps: Optional[List[np.ndarray]] = None,
74
+ normalize: bool = False,
75
+ test_every: Optional[int] = None,
76
+ ):
77
+ super().__init__("", 1, normalize, test_every)
78
+
79
+ self.image_names = [f"{i:06d}" for i in range(len(imgs))]
80
+ self.image_paths = ["null" for _ in range(len(imgs))]
81
+ self.camtoworlds = c2ws
82
+ self.camera_ids = [i for i in range(len(imgs))]
83
+ self.Ks_dict = {i: K for i, K in enumerate(Ks)}
84
+ self.imsize_dict = {
85
+ i: (img.shape[1], img.shape[0]) for i, img in enumerate(imgs)
86
+ }
87
+ if points is not None:
88
+ self.points = points
89
+ assert points_rgb is not None
90
+ self.points_rgb = points_rgb
91
+ self.points_err = np.zeros((len(points),))
92
+
93
+ self.imgs = imgs
94
+ self.mono_disps = mono_disps
95
+
96
+ # Normalize the world space.
97
+ if normalize:
98
+ T1 = similarity_from_cameras(self.camtoworlds)
99
+ self.camtoworlds = transform_cameras(T1, self.camtoworlds)
100
+
101
+ if points is not None:
102
+ self.points = transform_points(T1, self.points)
103
+ T2 = align_principle_axes(self.points)
104
+ self.camtoworlds = transform_cameras(T2, self.camtoworlds)
105
+ self.points = transform_points(T2, self.points)
106
+ else:
107
+ T2 = np.eye(4)
108
+
109
+ self.transform = T2 @ T1
110
+ else:
111
+ self.transform = np.eye(4)
112
+
113
+ # size of the scene measured by cameras
114
+ camera_locations = self.camtoworlds[:, :3, 3]
115
+ scene_center = np.mean(camera_locations, axis=0)
116
+ dists = np.linalg.norm(camera_locations - scene_center, axis=1)
117
+ self.scene_scale = np.max(dists)
118
+
119
+
120
+ class COLMAPParser(BaseParser):
121
+ """COLMAP parser."""
122
+
123
+ def __init__(
124
+ self,
125
+ data_dir: str,
126
+ factor: int = 1,
127
+ normalize: bool = False,
128
+ test_every: Optional[int] = 8,
129
+ image_folder: str = "images",
130
+ colmap_folder: str = "sparse/0",
131
+ ):
132
+ super().__init__(data_dir, factor, normalize, test_every)
133
+
134
+ colmap_dir = os.path.join(data_dir, colmap_folder)
135
+ assert os.path.exists(
136
+ colmap_dir
137
+ ), f"COLMAP directory {colmap_dir} does not exist."
138
+
139
+ try:
140
+ from pycolmap import SceneManager
141
+ except ImportError:
142
+ raise ImportError(
143
+ "Please install pycolmap to use the data parsers: "
144
+ " `pip install git+https://github.com/jensenz-sai/pycolmap.git@543266bc316df2fe407b3a33d454b310b1641042`"
145
+ )
146
+
147
+ manager = SceneManager(colmap_dir)
148
+ manager.load_cameras()
149
+ manager.load_images()
150
+ manager.load_points3D()
151
+
152
+ # Extract extrinsic matrices in world-to-camera format.
153
+ imdata = manager.images
154
+ w2c_mats = []
155
+ camera_ids = []
156
+ Ks_dict = dict()
157
+ params_dict = dict()
158
+ imsize_dict = dict() # width, height
159
+ bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
160
+ for k in imdata:
161
+ im = imdata[k]
162
+ rot = im.R()
163
+ trans = im.tvec.reshape(3, 1)
164
+ w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0)
165
+ w2c_mats.append(w2c)
166
+
167
+ # support different camera intrinsics
168
+ camera_id = im.camera_id
169
+ camera_ids.append(camera_id)
170
+
171
+ # camera intrinsics
172
+ cam = manager.cameras[camera_id]
173
+ fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy
174
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
175
+ K[:2, :] /= factor
176
+ Ks_dict[camera_id] = K
177
+
178
+ # Get distortion parameters.
179
+ type_ = cam.camera_type
180
+ if type_ == 0 or type_ == "SIMPLE_PINHOLE":
181
+ params = np.empty(0, dtype=np.float32)
182
+ camtype = "perspective"
183
+ elif type_ == 1 or type_ == "PINHOLE":
184
+ params = np.empty(0, dtype=np.float32)
185
+ camtype = "perspective"
186
+ if type_ == 2 or type_ == "SIMPLE_RADIAL":
187
+ params = np.array([cam.k1, 0.0, 0.0, 0.0], dtype=np.float32)
188
+ camtype = "perspective"
189
+ elif type_ == 3 or type_ == "RADIAL":
190
+ params = np.array([cam.k1, cam.k2, 0.0, 0.0], dtype=np.float32)
191
+ camtype = "perspective"
192
+ elif type_ == 4 or type_ == "OPENCV":
193
+ params = np.array([cam.k1, cam.k2, cam.p1, cam.p2], dtype=np.float32)
194
+ camtype = "perspective"
195
+ elif type_ == 5 or type_ == "OPENCV_FISHEYE":
196
+ params = np.array([cam.k1, cam.k2, cam.k3, cam.k4], dtype=np.float32)
197
+ camtype = "fisheye"
198
+ assert (
199
+ camtype == "perspective" # type: ignore
200
+ ), f"Only support perspective camera model, got {type_}"
201
+
202
+ params_dict[camera_id] = params # type: ignore
203
+
204
+ # image size
205
+ imsize_dict[camera_id] = (cam.width // factor, cam.height // factor)
206
+
207
+ print(
208
+ f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras."
209
+ )
210
+
211
+ if len(imdata) == 0:
212
+ raise ValueError("No images found in COLMAP.")
213
+ if not (type_ == 0 or type_ == 1): # type: ignore
214
+ print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.")
215
+
216
+ w2c_mats = np.stack(w2c_mats, axis=0)
217
+
218
+ # Convert extrinsics to camera-to-world.
219
+ camtoworlds = np.linalg.inv(w2c_mats)
220
+
221
+ # Image names from COLMAP. No need for permuting the poses according to
222
+ # image names anymore.
223
+ image_names = [imdata[k].name for k in imdata]
224
+
225
+ # Previous Nerf results were generated with images sorted by filename,
226
+ # ensure metrics are reported on the same test set.
227
+ inds = np.argsort(image_names)
228
+ image_names = [image_names[i] for i in inds]
229
+ camtoworlds = camtoworlds[inds]
230
+ camera_ids = [camera_ids[i] for i in inds]
231
+
232
+ # Load images.
233
+ if factor > 1:
234
+ image_dir_suffix = f"_{factor}"
235
+ else:
236
+ image_dir_suffix = ""
237
+ colmap_image_dir = os.path.join(data_dir, image_folder)
238
+ image_dir = os.path.join(data_dir, image_folder + image_dir_suffix)
239
+ for d in [image_dir, colmap_image_dir]:
240
+ if not os.path.exists(d):
241
+ raise ValueError(f"Image folder {d} does not exist.")
242
+
243
+ # Downsampled images may have different names vs images used for COLMAP,
244
+ # so we need to map between the two sorted lists of files.
245
+ colmap_files = sorted(_get_rel_paths(colmap_image_dir))
246
+ image_files = sorted(_get_rel_paths(image_dir))
247
+ colmap_to_image = dict(zip(colmap_files, image_files))
248
+ image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names]
249
+
250
+ # 3D points and {image_name -> [point_idx]}
251
+ points = manager.points3D.astype(np.float32) # type: ignore
252
+ points_err = manager.point3D_errors.astype(np.float32) # type: ignore
253
+ points_rgb = manager.point3D_colors.astype(np.uint8) # type: ignore
254
+ point_indices = dict()
255
+
256
+ image_id_to_name = {v: k for k, v in manager.name_to_image_id.items()}
257
+ for point_id, data in manager.point3D_id_to_images.items():
258
+ for image_id, _ in data:
259
+ image_name = image_id_to_name[image_id]
260
+ point_idx = manager.point3D_id_to_point3D_idx[point_id]
261
+ point_indices.setdefault(image_name, []).append(point_idx)
262
+ point_indices = {
263
+ k: np.array(v).astype(np.int32) for k, v in point_indices.items()
264
+ }
265
+
266
+ # Normalize the world space.
267
+ if normalize:
268
+ T1 = similarity_from_cameras(camtoworlds)
269
+ camtoworlds = transform_cameras(T1, camtoworlds)
270
+ points = transform_points(T1, points)
271
+
272
+ T2 = align_principle_axes(points)
273
+ camtoworlds = transform_cameras(T2, camtoworlds)
274
+ points = transform_points(T2, points)
275
+
276
+ transform = T2 @ T1
277
+ else:
278
+ transform = np.eye(4)
279
+
280
+ self.image_names = image_names # List[str], (num_images,)
281
+ self.image_paths = image_paths # List[str], (num_images,)
282
+ self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4)
283
+ self.camera_ids = camera_ids # List[int], (num_images,)
284
+ self.Ks_dict = Ks_dict # Dict of camera_id -> K
285
+ self.params_dict = params_dict # Dict of camera_id -> params
286
+ self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height)
287
+ self.points = points # np.ndarray, (num_points, 3)
288
+ self.points_err = points_err # np.ndarray, (num_points,)
289
+ self.points_rgb = points_rgb # np.ndarray, (num_points, 3)
290
+ self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,]
291
+ self.transform = transform # np.ndarray, (4, 4)
292
+
293
+ # undistortion
294
+ self.mapx_dict = dict()
295
+ self.mapy_dict = dict()
296
+ self.roi_undist_dict = dict()
297
+ for camera_id in self.params_dict.keys():
298
+ params = self.params_dict[camera_id]
299
+ if len(params) == 0:
300
+ continue # no distortion
301
+ assert camera_id in self.Ks_dict, f"Missing K for camera {camera_id}"
302
+ assert (
303
+ camera_id in self.params_dict
304
+ ), f"Missing params for camera {camera_id}"
305
+ K = self.Ks_dict[camera_id]
306
+ width, height = self.imsize_dict[camera_id]
307
+ K_undist, roi_undist = cv2.getOptimalNewCameraMatrix(
308
+ K, params, (width, height), 0
309
+ )
310
+ mapx, mapy = cv2.initUndistortRectifyMap(
311
+ K,
312
+ params,
313
+ None,
314
+ K_undist,
315
+ (width, height),
316
+ cv2.CV_32FC1, # type: ignore
317
+ )
318
+ self.Ks_dict[camera_id] = K_undist
319
+ self.mapx_dict[camera_id] = mapx
320
+ self.mapy_dict[camera_id] = mapy
321
+ self.roi_undist_dict[camera_id] = roi_undist # type: ignore
322
+
323
+ # size of the scene measured by cameras
324
+ camera_locations = camtoworlds[:, :3, 3]
325
+ scene_center = np.mean(camera_locations, axis=0)
326
+ dists = np.linalg.norm(camera_locations - scene_center, axis=1)
327
+ self.scene_scale = np.max(dists)
328
+
329
+
330
+ class ReconfusionParser(BaseParser):
331
+ def __init__(self, data_dir: str, normalize: bool = False):
332
+ super().__init__(data_dir, 1, normalize, test_every=None)
333
+
334
+ def get_num(p):
335
+ return p.split("_")[-1].removesuffix(".json")
336
+
337
+ splits_per_num_input_frames = {}
338
+ num_input_frames = [
339
+ int(get_num(p)) if get_num(p).isdigit() else get_num(p)
340
+ for p in sorted(glob(osp.join(data_dir, "train_test_split_*.json")))
341
+ ]
342
+ for num_input_frames in num_input_frames:
343
+ with open(
344
+ osp.join(
345
+ data_dir,
346
+ f"train_test_split_{num_input_frames}.json",
347
+ )
348
+ ) as f:
349
+ splits_per_num_input_frames[num_input_frames] = json.load(f)
350
+ self.splits_per_num_input_frames = splits_per_num_input_frames
351
+
352
+ with open(osp.join(data_dir, "transforms.json")) as f:
353
+ metadata = json.load(f)
354
+
355
+ image_names, image_paths, camtoworlds = [], [], []
356
+ for frame in metadata["frames"]:
357
+ if frame["file_path"] is None:
358
+ image_path = image_name = None
359
+ else:
360
+ image_path = osp.join(data_dir, frame["file_path"])
361
+ image_name = osp.basename(image_path)
362
+ image_paths.append(image_path)
363
+ image_names.append(image_name)
364
+ camtoworld = np.array(frame["transform_matrix"])
365
+ if "applied_transform" in metadata:
366
+ applied_transform = np.concatenate(
367
+ [metadata["applied_transform"], [[0, 0, 0, 1]]], axis=0
368
+ )
369
+ camtoworld = applied_transform @ camtoworld
370
+ camtoworlds.append(camtoworld)
371
+ camtoworlds = np.array(camtoworlds)
372
+ camtoworlds[:, :, [1, 2]] *= -1
373
+
374
+ # Normalize the world space.
375
+ if normalize:
376
+ T1 = similarity_from_cameras(camtoworlds)
377
+ camtoworlds = transform_cameras(T1, camtoworlds)
378
+ self.transform = T1
379
+ else:
380
+ self.transform = np.eye(4)
381
+
382
+ self.image_names = image_names
383
+ self.image_paths = image_paths
384
+ self.camtoworlds = camtoworlds
385
+ self.camera_ids = list(range(len(image_paths)))
386
+ self.Ks_dict = {
387
+ i: np.array(
388
+ [
389
+ [
390
+ metadata.get("fl_x", frame.get("fl_x", None)),
391
+ 0.0,
392
+ metadata.get("cx", frame.get("cx", None)),
393
+ ],
394
+ [
395
+ 0.0,
396
+ metadata.get("fl_y", frame.get("fl_y", None)),
397
+ metadata.get("cy", frame.get("cy", None)),
398
+ ],
399
+ [0.0, 0.0, 1.0],
400
+ ]
401
+ )
402
+ for i, frame in enumerate(metadata["frames"])
403
+ }
404
+ self.imsize_dict = {
405
+ i: (
406
+ metadata.get("w", frame.get("w", None)),
407
+ metadata.get("h", frame.get("h", None)),
408
+ )
409
+ for i, frame in enumerate(metadata["frames"])
410
+ }
411
+ # When num_input_frames is None, use all frames for both training and
412
+ # testing.
413
+ # self.splits_per_num_input_frames[None] = {
414
+ # "train_ids": list(range(len(image_paths))),
415
+ # "test_ids": list(range(len(image_paths))),
416
+ # }
417
+
418
+ # size of the scene measured by cameras
419
+ camera_locations = camtoworlds[:, :3, 3]
420
+ scene_center = np.mean(camera_locations, axis=0)
421
+ dists = np.linalg.norm(camera_locations - scene_center, axis=1)
422
+ self.scene_scale = np.max(dists)
423
+
424
+ self.bounds = None
425
+ if osp.exists(osp.join(data_dir, "bounds.npy")):
426
+ self.bounds = np.load(osp.join(data_dir, "bounds.npy"))
427
+ scaling = np.linalg.norm(self.transform[0, :3])
428
+ self.bounds = self.bounds / scaling
429
+
430
+
431
+ class Dataset(torch.utils.data.Dataset):
432
+ """A simple dataset class."""
433
+
434
+ def __init__(
435
+ self,
436
+ parser: BaseParser,
437
+ split: str = "train",
438
+ num_input_frames: Optional[int] = None,
439
+ patch_size: Optional[int] = None,
440
+ load_depths: bool = False,
441
+ load_mono_disps: bool = False,
442
+ ):
443
+ self.parser = parser
444
+ self.split = split
445
+ self.num_input_frames = num_input_frames
446
+ self.patch_size = patch_size
447
+ self.load_depths = load_depths
448
+ self.load_mono_disps = load_mono_disps
449
+ if load_mono_disps:
450
+ assert isinstance(parser, DirectParser)
451
+ assert parser.mono_disps is not None
452
+ if isinstance(parser, ReconfusionParser):
453
+ ids_per_split = parser.splits_per_num_input_frames[num_input_frames]
454
+ self.indices = ids_per_split[
455
+ "train_ids" if split == "train" else "test_ids"
456
+ ]
457
+ else:
458
+ indices = np.arange(len(self.parser.image_names))
459
+ if split == "train":
460
+ self.indices = (
461
+ indices[indices % self.parser.test_every != 0]
462
+ if self.parser.test_every is not None
463
+ else indices
464
+ )
465
+ else:
466
+ self.indices = (
467
+ indices[indices % self.parser.test_every == 0]
468
+ if self.parser.test_every is not None
469
+ else indices
470
+ )
471
+
472
+ def __len__(self):
473
+ return len(self.indices)
474
+
475
+ def __getitem__(self, item: int) -> Dict[str, Any]:
476
+ index = self.indices[item]
477
+ if isinstance(self.parser, DirectParser):
478
+ image = self.parser.imgs[index]
479
+ else:
480
+ image = iio.imread(self.parser.image_paths[index])[..., :3]
481
+ camera_id = self.parser.camera_ids[index]
482
+ K = self.parser.Ks_dict[camera_id].copy() # undistorted K
483
+ params = self.parser.params_dict.get(camera_id, None)
484
+ camtoworlds = self.parser.camtoworlds[index]
485
+
486
+ x, y, w, h = 0, 0, image.shape[1], image.shape[0]
487
+ if params is not None and len(params) > 0:
488
+ # Images are distorted. Undistort them.
489
+ mapx, mapy = (
490
+ self.parser.mapx_dict[camera_id],
491
+ self.parser.mapy_dict[camera_id],
492
+ )
493
+ image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR)
494
+ x, y, w, h = self.parser.roi_undist_dict[camera_id]
495
+ image = image[y : y + h, x : x + w]
496
+
497
+ if self.patch_size is not None:
498
+ # Random crop.
499
+ h, w = image.shape[:2]
500
+ x = np.random.randint(0, max(w - self.patch_size, 1))
501
+ y = np.random.randint(0, max(h - self.patch_size, 1))
502
+ image = image[y : y + self.patch_size, x : x + self.patch_size]
503
+ K[0, 2] -= x
504
+ K[1, 2] -= y
505
+
506
+ data = {
507
+ "K": torch.from_numpy(K).float(),
508
+ "camtoworld": torch.from_numpy(camtoworlds).float(),
509
+ "image": torch.from_numpy(image).float(),
510
+ "image_id": item, # the index of the image in the dataset
511
+ }
512
+
513
+ if self.load_depths:
514
+ # projected points to image plane to get depths
515
+ worldtocams = np.linalg.inv(camtoworlds)
516
+ image_name = self.parser.image_names[index]
517
+ point_indices = self.parser.point_indices[image_name]
518
+ points_world = self.parser.points[point_indices]
519
+ points_cam = (worldtocams[:3, :3] @ points_world.T + worldtocams[:3, 3:4]).T
520
+ points_proj = (K @ points_cam.T).T
521
+ points = points_proj[:, :2] / points_proj[:, 2:3] # (M, 2)
522
+ depths = points_cam[:, 2] # (M,)
523
+ if self.patch_size is not None:
524
+ points[:, 0] -= x
525
+ points[:, 1] -= y
526
+ # filter out points outside the image
527
+ selector = (
528
+ (points[:, 0] >= 0)
529
+ & (points[:, 0] < image.shape[1])
530
+ & (points[:, 1] >= 0)
531
+ & (points[:, 1] < image.shape[0])
532
+ & (depths > 0)
533
+ )
534
+ points = points[selector]
535
+ depths = depths[selector]
536
+ data["points"] = torch.from_numpy(points).float()
537
+ data["depths"] = torch.from_numpy(depths).float()
538
+ if self.load_mono_disps:
539
+ data["mono_disps"] = torch.from_numpy(self.parser.mono_disps[index]).float() # type: ignore
540
+
541
+ return data
542
+
543
+
544
+ def get_parser(parser_type: str, **kwargs) -> BaseParser:
545
+ if parser_type == "colmap":
546
+ parser = COLMAPParser(**kwargs)
547
+ elif parser_type == "direct":
548
+ parser = DirectParser(**kwargs)
549
+ elif parser_type == "reconfusion":
550
+ parser = ReconfusionParser(**kwargs)
551
+ else:
552
+ raise ValueError(f"Unknown parser type: {parser_type}")
553
+ return parser
seva/eval.py ADDED
@@ -0,0 +1,1988 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import json
3
+ import math
4
+ import os
5
+ import re
6
+ import threading
7
+ from typing import List, Literal, Optional, Tuple, Union
8
+
9
+ import gradio as gr
10
+ from colorama import Fore, Style, init
11
+
12
+ init(autoreset=True)
13
+
14
+ import imageio.v3 as iio
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import torchvision.transforms.functional as TF
19
+ from einops import repeat
20
+ from PIL import Image
21
+ from tqdm.auto import tqdm
22
+
23
+ from seva.geometry import get_camera_dist, get_plucker_coordinates, to_hom_pose
24
+ from seva.sampling import (
25
+ EulerEDMSampler,
26
+ MultiviewCFG,
27
+ MultiviewTemporalCFG,
28
+ VanillaCFG,
29
+ )
30
+ from seva.utils import seed_everything
31
+
32
+ try:
33
+ # Check if version string contains 'dev' or 'nightly'
34
+ version = torch.__version__
35
+ IS_TORCH_NIGHTLY = "dev" in version
36
+ if IS_TORCH_NIGHTLY:
37
+ torch._dynamo.config.cache_size_limit = 128 # type: ignore[assignment]
38
+ torch._dynamo.config.accumulated_cache_size_limit = 1024 # type: ignore[assignment]
39
+ torch._dynamo.config.force_parameter_static_shapes = False # type: ignore[assignment]
40
+ except Exception:
41
+ IS_TORCH_NIGHTLY = False
42
+
43
+
44
+ def pad_indices(
45
+ input_indices: List[int],
46
+ test_indices: List[int],
47
+ T: int,
48
+ padding_mode: Literal["first", "last", "none"] = "last",
49
+ ):
50
+ assert padding_mode in ["last", "none"], "`first` padding is not supported yet."
51
+ if padding_mode == "last":
52
+ padded_indices = [
53
+ i for i in range(T) if i not in (input_indices + test_indices)
54
+ ]
55
+ else:
56
+ padded_indices = []
57
+ input_selects = list(range(len(input_indices)))
58
+ test_selects = list(range(len(test_indices)))
59
+ if max(input_indices) > max(test_indices):
60
+ # last elem from input
61
+ input_selects += [input_selects[-1]] * len(padded_indices)
62
+ input_indices = input_indices + padded_indices
63
+ sorted_inds = np.argsort(input_indices)
64
+ input_indices = [input_indices[ind] for ind in sorted_inds]
65
+ input_selects = [input_selects[ind] for ind in sorted_inds]
66
+ else:
67
+ # last elem from test
68
+ test_selects += [test_selects[-1]] * len(padded_indices)
69
+ test_indices = test_indices + padded_indices
70
+ sorted_inds = np.argsort(test_indices)
71
+ test_indices = [test_indices[ind] for ind in sorted_inds]
72
+ test_selects = [test_selects[ind] for ind in sorted_inds]
73
+
74
+ if padding_mode == "last":
75
+ input_maps = np.array([-1] * T)
76
+ test_maps = np.array([-1] * T)
77
+ else:
78
+ input_maps = np.array([-1] * (len(input_indices) + len(test_indices)))
79
+ test_maps = np.array([-1] * (len(input_indices) + len(test_indices)))
80
+ input_maps[input_indices] = input_selects
81
+ test_maps[test_indices] = test_selects
82
+ return input_indices, test_indices, input_maps, test_maps
83
+
84
+
85
+ def assemble(
86
+ input,
87
+ test,
88
+ input_maps,
89
+ test_maps,
90
+ ):
91
+ T = len(input_maps)
92
+ assembled = torch.zeros_like(test[-1:]).repeat_interleave(T, dim=0)
93
+ assembled[input_maps != -1] = input[input_maps[input_maps != -1]]
94
+ assembled[test_maps != -1] = test[test_maps[test_maps != -1]]
95
+ assert np.logical_xor(input_maps != -1, test_maps != -1).all()
96
+ return assembled
97
+
98
+
99
+ def get_resizing_factor(
100
+ target_shape: Tuple[int, int], # H, W
101
+ current_shape: Tuple[int, int], # H, W
102
+ cover_target: bool = True,
103
+ # If True, the output shape will fully cover the target shape.
104
+ # If No, the target shape will fully cover the output shape.
105
+ ) -> float:
106
+ r_bound = target_shape[1] / target_shape[0]
107
+ aspect_r = current_shape[1] / current_shape[0]
108
+ if r_bound >= 1.0:
109
+ if cover_target:
110
+ if aspect_r >= r_bound:
111
+ factor = min(target_shape) / min(current_shape)
112
+ elif aspect_r < 1.0:
113
+ factor = max(target_shape) / min(current_shape)
114
+ else:
115
+ factor = max(target_shape) / max(current_shape)
116
+ else:
117
+ if aspect_r >= r_bound:
118
+ factor = max(target_shape) / max(current_shape)
119
+ elif aspect_r < 1.0:
120
+ factor = min(target_shape) / max(current_shape)
121
+ else:
122
+ factor = min(target_shape) / min(current_shape)
123
+ else:
124
+ if cover_target:
125
+ if aspect_r <= r_bound:
126
+ factor = min(target_shape) / min(current_shape)
127
+ elif aspect_r > 1.0:
128
+ factor = max(target_shape) / min(current_shape)
129
+ else:
130
+ factor = max(target_shape) / max(current_shape)
131
+ else:
132
+ if aspect_r <= r_bound:
133
+ factor = max(target_shape) / max(current_shape)
134
+ elif aspect_r > 1.0:
135
+ factor = min(target_shape) / max(current_shape)
136
+ else:
137
+ factor = min(target_shape) / min(current_shape)
138
+ return factor
139
+
140
+
141
+ def get_unique_embedder_keys_from_conditioner(conditioner):
142
+ keys = [x.input_key for x in conditioner.embedders if x.input_key is not None]
143
+ keys = [item for sublist in keys for item in sublist] # Flatten list
144
+ return set(keys)
145
+
146
+
147
+ def get_wh_with_fixed_shortest_side(w, h, size):
148
+ # size is smaller or equal to zero, we return original w h
149
+ if size is None or size <= 0:
150
+ return w, h
151
+ if w < h:
152
+ new_w = size
153
+ new_h = int(size * h / w)
154
+ else:
155
+ new_h = size
156
+ new_w = int(size * w / h)
157
+ return new_w, new_h
158
+
159
+
160
+ def load_img_and_K(
161
+ image_path_or_size: Union[str, torch.Size],
162
+ size: Optional[Union[int, Tuple[int, int]]],
163
+ scale: float = 1.0,
164
+ center: Tuple[float, float] = (0.5, 0.5),
165
+ K: torch.Tensor | None = None,
166
+ size_stride: int = 1,
167
+ center_crop: bool = False,
168
+ image_as_tensor: bool = True,
169
+ context_rgb: np.ndarray | None = None,
170
+ device: str = "cuda",
171
+ ):
172
+ if isinstance(image_path_or_size, torch.Size):
173
+ image = Image.new("RGBA", image_path_or_size[::-1])
174
+ else:
175
+ image = Image.open(image_path_or_size).convert("RGBA")
176
+
177
+ w, h = image.size
178
+ if size is None:
179
+ size = (w, h)
180
+
181
+ image = np.array(image).astype(np.float32) / 255
182
+ if image.shape[-1] == 4:
183
+ rgb, alpha = image[:, :, :3], image[:, :, 3:]
184
+ if context_rgb is not None:
185
+ image = rgb * alpha + context_rgb * (1 - alpha)
186
+ else:
187
+ image = rgb * alpha + (1 - alpha)
188
+ image = image.transpose(2, 0, 1)
189
+ image = torch.from_numpy(image).to(dtype=torch.float32)
190
+ image = image.unsqueeze(0)
191
+
192
+ if isinstance(size, (tuple, list)):
193
+ # => if size is a tuple or list, we first rescale to fully cover the `size`
194
+ # area and then crop the `size` area from the rescale image
195
+ W, H = size
196
+ else:
197
+ # => if size is int, we rescale the image to fit the shortest side to size
198
+ # => if size is None, no rescaling is applied
199
+ W, H = get_wh_with_fixed_shortest_side(w, h, size)
200
+ W, H = (
201
+ math.floor(W / size_stride + 0.5) * size_stride,
202
+ math.floor(H / size_stride + 0.5) * size_stride,
203
+ )
204
+
205
+ rfs = get_resizing_factor((math.floor(H * scale), math.floor(W * scale)), (h, w))
206
+ resize_size = rh, rw = [int(np.ceil(rfs * s)) for s in (h, w)]
207
+ image = torch.nn.functional.interpolate(
208
+ image, resize_size, mode="area", antialias=False
209
+ )
210
+ if scale < 1.0:
211
+ pw = math.ceil((W - resize_size[1]) * 0.5)
212
+ ph = math.ceil((H - resize_size[0]) * 0.5)
213
+ image = F.pad(image, (pw, pw, ph, ph), "constant", 1.0)
214
+
215
+ cy_center = int(center[1] * image.shape[-2])
216
+ cx_center = int(center[0] * image.shape[-1])
217
+ if center_crop:
218
+ side = min(H, W)
219
+ ct = max(0, cy_center - side // 2)
220
+ cl = max(0, cx_center - side // 2)
221
+ ct = min(ct, image.shape[-2] - side)
222
+ cl = min(cl, image.shape[-1] - side)
223
+ image = TF.crop(image, top=ct, left=cl, height=side, width=side)
224
+ else:
225
+ ct = max(0, cy_center - H // 2)
226
+ cl = max(0, cx_center - W // 2)
227
+ ct = min(ct, image.shape[-2] - H)
228
+ cl = min(cl, image.shape[-1] - W)
229
+ image = TF.crop(image, top=ct, left=cl, height=H, width=W)
230
+
231
+ if K is not None:
232
+ K = K.clone()
233
+ if torch.all(K[:2, -1] >= 0) and torch.all(K[:2, -1] <= 1):
234
+ K[:2] *= K.new_tensor([rw, rh])[:, None] # normalized K
235
+ else:
236
+ K[:2] *= K.new_tensor([rw / w, rh / h])[:, None] # unnormalized K
237
+ K[:2, 2] -= K.new_tensor([cl, ct])
238
+
239
+ if image_as_tensor:
240
+ # tensor of shape (1, 3, H, W) with values ranging from (-1, 1)
241
+ image = image.to(device) * 2.0 - 1.0
242
+ else:
243
+ # PIL Image with values ranging from (0, 255)
244
+ image = image.permute(0, 2, 3, 1).numpy()[0]
245
+ image = Image.fromarray((image * 255).astype(np.uint8))
246
+ return image, K
247
+
248
+
249
+ def transform_img_and_K(
250
+ image: torch.Tensor,
251
+ size: Union[int, Tuple[int, int]],
252
+ scale: float = 1.0,
253
+ center: Tuple[float, float] = (0.5, 0.5),
254
+ K: torch.Tensor | None = None,
255
+ size_stride: int = 1,
256
+ mode: str = "crop",
257
+ ):
258
+ assert mode in [
259
+ "crop",
260
+ "pad",
261
+ "stretch",
262
+ ], f"mode should be one of ['crop', 'pad', 'stretch'], got {mode}"
263
+
264
+ h, w = image.shape[-2:]
265
+ if isinstance(size, (tuple, list)):
266
+ # => if size is a tuple or list, we first rescale to fully cover the `size`
267
+ # area and then crop the `size` area from the rescale image
268
+ W, H = size
269
+ else:
270
+ # => if size is int, we rescale the image to fit the shortest side to size
271
+ # => if size is None, no rescaling is applied
272
+ W, H = get_wh_with_fixed_shortest_side(w, h, size)
273
+ W, H = (
274
+ math.floor(W / size_stride + 0.5) * size_stride,
275
+ math.floor(H / size_stride + 0.5) * size_stride,
276
+ )
277
+
278
+ if mode == "stretch":
279
+ rh, rw = H, W
280
+ else:
281
+ rfs = get_resizing_factor(
282
+ (H, W),
283
+ (h, w),
284
+ cover_target=mode != "pad",
285
+ )
286
+ (rh, rw) = [int(np.ceil(rfs * s)) for s in (h, w)]
287
+
288
+ rh, rw = int(rh / scale), int(rw / scale)
289
+ image = torch.nn.functional.interpolate(
290
+ image, (rh, rw), mode="area", antialias=False
291
+ )
292
+
293
+ cy_center = int(center[1] * image.shape[-2])
294
+ cx_center = int(center[0] * image.shape[-1])
295
+ if mode != "pad":
296
+ ct = max(0, cy_center - H // 2)
297
+ cl = max(0, cx_center - W // 2)
298
+ ct = min(ct, image.shape[-2] - H)
299
+ cl = min(cl, image.shape[-1] - W)
300
+ image = TF.crop(image, top=ct, left=cl, height=H, width=W)
301
+ pl, pt = 0, 0
302
+ else:
303
+ pt = max(0, H // 2 - cy_center)
304
+ pl = max(0, W // 2 - cx_center)
305
+ pb = max(0, H - pt - image.shape[-2])
306
+ pr = max(0, W - pl - image.shape[-1])
307
+ image = TF.pad(
308
+ image,
309
+ [pl, pt, pr, pb],
310
+ )
311
+ cl, ct = 0, 0
312
+
313
+ if K is not None:
314
+ K = K.clone()
315
+ # K[:, :2, 2] += K.new_tensor([pl, pt])
316
+ if torch.all(K[:, :2, -1] >= 0) and torch.all(K[:, :2, -1] <= 1):
317
+ K[:, :2] *= K.new_tensor([rw, rh])[None, :, None] # normalized K
318
+ else:
319
+ K[:, :2] *= K.new_tensor([rw / w, rh / h])[None, :, None] # unnormalized K
320
+ K[:, :2, 2] += K.new_tensor([pl - cl, pt - ct])
321
+
322
+ return image, K
323
+
324
+
325
+ lowvram_mode = False
326
+
327
+
328
+ def set_lowvram_mode(mode):
329
+ global lowvram_mode
330
+ lowvram_mode = mode
331
+
332
+
333
+ def load_model(model, device: str = "cuda"):
334
+ model.to(device)
335
+
336
+
337
+ def unload_model(model):
338
+ global lowvram_mode
339
+ if lowvram_mode:
340
+ model.cpu()
341
+ torch.cuda.empty_cache()
342
+
343
+
344
+ def infer_prior_stats(
345
+ T,
346
+ num_input_frames,
347
+ num_total_frames,
348
+ version_dict,
349
+ ):
350
+ options = version_dict["options"]
351
+ chunk_strategy = options.get("chunk_strategy", "nearest")
352
+ T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
353
+ T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
354
+ # get traj_prior_c2ws for 2-pass sampling
355
+ if chunk_strategy.startswith("interp"):
356
+ # Start and end have alreay taken up two slots
357
+ # +1 means we need X + 1 prior frames to bound X times forwards for all test frames
358
+
359
+ # Tuning up `num_prior_frames_ratio` is helpful when you observe sudden jump in the
360
+ # generated frames due to insufficient prior frames. This option is effective for
361
+ # complicated trajectory and when `interp` strategy is used (usually semi-dense-view
362
+ # regime). Recommended range is [1.0 (default), 1.5].
363
+ if num_input_frames >= options.get("num_input_semi_dense", 9):
364
+ num_prior_frames = (
365
+ math.ceil(
366
+ num_total_frames
367
+ / (T_second_pass - 2)
368
+ * options.get("num_prior_frames_ratio", 1.0)
369
+ )
370
+ + 1
371
+ )
372
+
373
+ if num_prior_frames + num_input_frames < T_first_pass:
374
+ num_prior_frames = T_first_pass - num_input_frames
375
+
376
+ num_prior_frames = max(
377
+ num_prior_frames,
378
+ options.get("num_prior_frames", 0),
379
+ )
380
+
381
+ T_first_pass = num_prior_frames + num_input_frames
382
+
383
+ if "gt" in chunk_strategy:
384
+ T_second_pass = T_second_pass + num_input_frames
385
+
386
+ # Dynamically update context window length.
387
+ version_dict["T"] = [T_first_pass, T_second_pass]
388
+
389
+ else:
390
+ num_prior_frames = (
391
+ math.ceil(
392
+ num_total_frames
393
+ / (
394
+ T_second_pass
395
+ - 2
396
+ - (num_input_frames if "gt" in chunk_strategy else 0)
397
+ )
398
+ * options.get("num_prior_frames_ratio", 1.0)
399
+ )
400
+ + 1
401
+ )
402
+
403
+ if num_prior_frames + num_input_frames < T_first_pass:
404
+ num_prior_frames = T_first_pass - num_input_frames
405
+
406
+ num_prior_frames = max(
407
+ num_prior_frames,
408
+ options.get("num_prior_frames", 0),
409
+ )
410
+ else:
411
+ num_prior_frames = max(
412
+ T_first_pass - num_input_frames,
413
+ options.get("num_prior_frames", 0),
414
+ )
415
+
416
+ if num_input_frames >= options.get("num_input_semi_dense", 9):
417
+ T_first_pass = num_prior_frames + num_input_frames
418
+
419
+ # Dynamically update context window length.
420
+ version_dict["T"] = [T_first_pass, T_second_pass]
421
+
422
+ return num_prior_frames
423
+
424
+
425
+ def infer_prior_inds(
426
+ c2ws,
427
+ num_prior_frames,
428
+ input_frame_indices,
429
+ options,
430
+ ):
431
+ chunk_strategy = options.get("chunk_strategy", "nearest")
432
+ if chunk_strategy.startswith("interp"):
433
+ prior_frame_indices = np.array(
434
+ [i for i in range(c2ws.shape[0]) if i not in input_frame_indices]
435
+ )
436
+ prior_frame_indices = prior_frame_indices[
437
+ np.ceil(
438
+ np.linspace(
439
+ 0, prior_frame_indices.shape[0] - 1, num_prior_frames, endpoint=True
440
+ )
441
+ ).astype(int)
442
+ ] # having a ceil here is actually safer for corner case
443
+ else:
444
+ prior_frame_indices = []
445
+ while len(prior_frame_indices) < num_prior_frames:
446
+ closest_distance = np.abs(
447
+ np.arange(c2ws.shape[0])[None]
448
+ - np.concatenate(
449
+ [np.array(input_frame_indices), np.array(prior_frame_indices)]
450
+ )[:, None]
451
+ ).min(0)
452
+ prior_frame_indices.append(np.argsort(closest_distance)[-1])
453
+ return np.sort(prior_frame_indices)
454
+
455
+
456
+ def compute_relative_inds(
457
+ source_inds,
458
+ target_inds,
459
+ ):
460
+ assert len(source_inds) > 2
461
+ # compute relative indices of target_inds within source_inds
462
+ relative_inds = []
463
+ for ind in target_inds:
464
+ if ind in source_inds:
465
+ relative_ind = int(np.where(source_inds == ind)[0][0])
466
+ elif ind < source_inds[0]:
467
+ # extrapolate
468
+ relative_ind = -((source_inds[0] - ind) / (source_inds[1] - source_inds[0]))
469
+ elif ind > source_inds[-1]:
470
+ # extrapolate
471
+ relative_ind = len(source_inds) + (
472
+ (ind - source_inds[-1]) / (source_inds[-1] - source_inds[-2])
473
+ )
474
+ else:
475
+ # interpolate
476
+ lower_inds = source_inds[source_inds < ind]
477
+ upper_inds = source_inds[source_inds > ind]
478
+ if len(lower_inds) > 0 and len(upper_inds) > 0:
479
+ lower_ind = lower_inds[-1]
480
+ upper_ind = upper_inds[0]
481
+ relative_lower_ind = int(np.where(source_inds == lower_ind)[0][0])
482
+ relative_upper_ind = int(np.where(source_inds == upper_ind)[0][0])
483
+ relative_ind = relative_lower_ind + (ind - lower_ind) / (
484
+ upper_ind - lower_ind
485
+ ) * (relative_upper_ind - relative_lower_ind)
486
+ else:
487
+ # Out of range
488
+ relative_inds.append(float("nan")) # Or some other placeholder
489
+ relative_inds.append(relative_ind)
490
+ return relative_inds
491
+
492
+
493
+ def find_nearest_source_inds(
494
+ source_c2ws,
495
+ target_c2ws,
496
+ nearest_num=1,
497
+ mode="translation",
498
+ ):
499
+ dists = get_camera_dist(source_c2ws, target_c2ws, mode=mode).cpu().numpy()
500
+ sorted_inds = np.argsort(dists, axis=0).T
501
+ return sorted_inds[:, :nearest_num]
502
+
503
+
504
+ def chunk_input_and_test(
505
+ T,
506
+ input_c2ws,
507
+ test_c2ws,
508
+ input_ords, # orders
509
+ test_ords, # orders
510
+ options,
511
+ task: str = "img2img",
512
+ chunk_strategy: str = "gt",
513
+ gt_input_inds: list = [],
514
+ ):
515
+ M, N = input_c2ws.shape[0], test_c2ws.shape[0]
516
+
517
+ chunks = []
518
+ if chunk_strategy.startswith("gt"):
519
+ assert len(gt_input_inds) < T, (
520
+ f"Number of gt input frames {len(gt_input_inds)} should be "
521
+ f"less than {T} when `gt` chunking strategy is used."
522
+ )
523
+ assert (
524
+ list(range(M)) == gt_input_inds
525
+ ), "All input_c2ws should be gt when `gt` chunking strategy is used."
526
+
527
+ # LEGACY CHUNKING STRATEGY
528
+ # num_test_per_chunk = T - len(gt_input_inds)
529
+ # test_inds_per_chunk = [i for i in range(T) if i not in gt_input_inds]
530
+ # for i in range(0, test_c2ws.shape[0], num_test_per_chunk):
531
+ # chunk = ["NULL"] * T
532
+ # for j, k in enumerate(gt_input_inds):
533
+ # chunk[k] = f"!{j:03d}"
534
+ # for j, k in enumerate(
535
+ # test_inds_per_chunk[: test_c2ws[i : i + num_test_per_chunk].shape[0]]
536
+ # ):
537
+ # chunk[k] = f">{i + j:03d}"
538
+ # chunks.append(chunk)
539
+
540
+ num_test_seen = 0
541
+ while num_test_seen < N:
542
+ chunk = [f"!{i:03d}" for i in gt_input_inds]
543
+ if chunk_strategy != "gt" and num_test_seen > 0:
544
+ pseudo_num_ratio = options.get("pseudo_num_ratio", 0.33)
545
+ if (N - num_test_seen) >= math.floor(
546
+ (T - len(gt_input_inds)) * pseudo_num_ratio
547
+ ):
548
+ pseudo_num = math.ceil((T - len(gt_input_inds)) * pseudo_num_ratio)
549
+ else:
550
+ pseudo_num = (T - len(gt_input_inds)) - (N - num_test_seen)
551
+ pseudo_num = min(pseudo_num, options.get("pseudo_num_max", 10000))
552
+
553
+ if "ltr" in chunk_strategy:
554
+ chunk.extend(
555
+ [
556
+ f"!{i + len(gt_input_inds):03d}"
557
+ for i in range(num_test_seen - pseudo_num, num_test_seen)
558
+ ]
559
+ )
560
+ elif "nearest" in chunk_strategy:
561
+ source_inds = np.concatenate(
562
+ [
563
+ find_nearest_source_inds(
564
+ test_c2ws[:num_test_seen],
565
+ test_c2ws[num_test_seen:],
566
+ nearest_num=1, # pseudo_num,
567
+ mode="rotation",
568
+ ),
569
+ find_nearest_source_inds(
570
+ test_c2ws[:num_test_seen],
571
+ test_c2ws[num_test_seen:],
572
+ nearest_num=1, # pseudo_num,
573
+ mode="translation",
574
+ ),
575
+ ],
576
+ axis=1,
577
+ )
578
+ ####### [HACK ALERT] keep running until pseudo num is stablized ########
579
+ temp_pseudo_num = pseudo_num
580
+ while True:
581
+ nearest_source_inds = np.concatenate(
582
+ [
583
+ np.sort(
584
+ [
585
+ ind
586
+ for (ind, _) in collections.Counter(
587
+ [
588
+ item
589
+ for item in source_inds[
590
+ : T
591
+ - len(gt_input_inds)
592
+ - temp_pseudo_num
593
+ ]
594
+ .flatten()
595
+ .tolist()
596
+ if item
597
+ != (
598
+ num_test_seen - 1
599
+ ) # exclude the last one here
600
+ ]
601
+ ).most_common(pseudo_num - 1)
602
+ ],
603
+ ).astype(int),
604
+ [num_test_seen - 1], # always keep the last one
605
+ ]
606
+ )
607
+ if len(nearest_source_inds) >= temp_pseudo_num:
608
+ break # stablized
609
+ else:
610
+ temp_pseudo_num = len(nearest_source_inds)
611
+ pseudo_num = len(nearest_source_inds)
612
+ ########################################################################
613
+ chunk.extend(
614
+ [f"!{i + len(gt_input_inds):03d}" for i in nearest_source_inds]
615
+ )
616
+ else:
617
+ raise NotImplementedError(
618
+ f"Chunking strategy {chunk_strategy} for the first pass is not implemented."
619
+ )
620
+
621
+ chunk.extend(
622
+ [
623
+ f">{i:03d}"
624
+ for i in range(
625
+ num_test_seen,
626
+ min(num_test_seen + T - len(gt_input_inds) - pseudo_num, N),
627
+ )
628
+ ]
629
+ )
630
+ else:
631
+ chunk.extend(
632
+ [
633
+ f">{i:03d}"
634
+ for i in range(
635
+ num_test_seen,
636
+ min(num_test_seen + T - len(gt_input_inds), N),
637
+ )
638
+ ]
639
+ )
640
+
641
+ num_test_seen += sum([1 for c in chunk if c.startswith(">")])
642
+ if len(chunk) < T:
643
+ chunk.extend(["NULL"] * (T - len(chunk)))
644
+ chunks.append(chunk)
645
+
646
+ elif chunk_strategy.startswith("nearest"):
647
+ input_imgs = np.array([f"!{i:03d}" for i in range(M)])
648
+ test_imgs = np.array([f">{i:03d}" for i in range(N)])
649
+
650
+ match = re.match(r"^nearest-(\d+)$", chunk_strategy)
651
+ if match:
652
+ nearest_num = int(match.group(1))
653
+ assert (
654
+ nearest_num < T
655
+ ), f"Nearest number of {nearest_num} should be less than {T}."
656
+ source_inds = find_nearest_source_inds(
657
+ input_c2ws,
658
+ test_c2ws,
659
+ nearest_num=nearest_num,
660
+ mode="translation", # during the second pass, consider translation only is enough
661
+ )
662
+
663
+ for i in range(0, N, T - nearest_num):
664
+ nearest_source_inds = np.sort(
665
+ [
666
+ ind
667
+ for (ind, _) in collections.Counter(
668
+ source_inds[i : i + T - nearest_num].flatten().tolist()
669
+ ).most_common(nearest_num)
670
+ ]
671
+ )
672
+ chunk = (
673
+ input_imgs[nearest_source_inds].tolist()
674
+ + test_imgs[i : i + T - nearest_num].tolist()
675
+ )
676
+ chunks.append(chunk + ["NULL"] * (T - len(chunk)))
677
+
678
+ else:
679
+ # do not always condition on gt cond frames
680
+ if "gt" not in chunk_strategy:
681
+ gt_input_inds = []
682
+
683
+ source_inds = find_nearest_source_inds(
684
+ input_c2ws,
685
+ test_c2ws,
686
+ nearest_num=1,
687
+ mode="translation", # during the second pass, consider translation only is enough
688
+ )[:, 0]
689
+
690
+ test_inds_per_input = {}
691
+ for test_idx, input_idx in enumerate(source_inds):
692
+ if input_idx not in test_inds_per_input:
693
+ test_inds_per_input[input_idx] = []
694
+ test_inds_per_input[input_idx].append(test_idx)
695
+
696
+ num_test_seen = 0
697
+ chunk = input_imgs[gt_input_inds].tolist()
698
+ candidate_input_inds = sorted(list(test_inds_per_input.keys()))
699
+
700
+ while num_test_seen < N:
701
+ input_idx = candidate_input_inds[0]
702
+ test_inds = test_inds_per_input[input_idx]
703
+ input_is_cond = input_idx in gt_input_inds
704
+ prefix_inds = [] if input_is_cond else [input_idx]
705
+
706
+ if len(chunk) == T - len(prefix_inds) or not candidate_input_inds:
707
+ if chunk:
708
+ chunk += ["NULL"] * (T - len(chunk))
709
+ chunks.append(chunk)
710
+ chunk = input_imgs[gt_input_inds].tolist()
711
+ if num_test_seen >= N:
712
+ break
713
+ continue
714
+
715
+ candidate_chunk = (
716
+ input_imgs[prefix_inds].tolist() + test_imgs[test_inds].tolist()
717
+ )
718
+
719
+ space_left = T - len(chunk)
720
+ if len(candidate_chunk) <= space_left:
721
+ chunk.extend(candidate_chunk)
722
+ num_test_seen += len(test_inds)
723
+ candidate_input_inds.pop(0)
724
+ else:
725
+ chunk.extend(candidate_chunk[:space_left])
726
+ num_input_idx = 0 if input_is_cond else 1
727
+ num_test_seen += space_left - num_input_idx
728
+ test_inds_per_input[input_idx] = test_inds[
729
+ space_left - num_input_idx :
730
+ ]
731
+
732
+ if len(chunk) == T:
733
+ chunks.append(chunk)
734
+ chunk = input_imgs[gt_input_inds].tolist()
735
+
736
+ if chunk and chunk != input_imgs[gt_input_inds].tolist():
737
+ chunks.append(chunk + ["NULL"] * (T - len(chunk)))
738
+
739
+ elif chunk_strategy.startswith("interp"):
740
+ # `interp` chunk requires ordering info
741
+ assert input_ords is not None and test_ords is not None, (
742
+ "When using `interp` chunking strategy, ordering of input "
743
+ "and test frames should be provided."
744
+ )
745
+
746
+ # if chunk_strategy is `interp*`` and task is `img2trajvid*`, we will not
747
+ # use input views since their order info within target views is unknown
748
+ if "img2trajvid" in task:
749
+ assert (
750
+ list(range(len(gt_input_inds))) == gt_input_inds
751
+ ), "`img2trajvid` task should put `gt_input_inds` in start."
752
+ input_c2ws = input_c2ws[
753
+ [ind for ind in range(M) if ind not in gt_input_inds]
754
+ ]
755
+ input_ords = [
756
+ input_ords[ind] for ind in range(M) if ind not in gt_input_inds
757
+ ]
758
+ M = input_c2ws.shape[0]
759
+
760
+ input_ords = [0] + input_ords # this is a hack accounting for test views
761
+ # before the first input view
762
+ input_ords[-1] += 0.01 # this is a hack ensuring last test stop is included
763
+ # in the last forward when input_ords[-1] == test_ords[-1]
764
+ input_ords = np.array(input_ords)[:, None]
765
+ input_ords_ = np.concatenate([input_ords[1:], np.full((1, 1), np.inf)])
766
+ test_ords = np.array(test_ords)[None]
767
+
768
+ in_stop_ranges = np.logical_and(
769
+ np.repeat(input_ords, N, axis=1) <= np.repeat(test_ords, M + 1, axis=0),
770
+ np.repeat(input_ords_, N, axis=1) > np.repeat(test_ords, M + 1, axis=0),
771
+ ) # (M, N)
772
+ assert (in_stop_ranges.sum(1) <= T - 2).all(), (
773
+ "More input frames need to be sampled during the first pass to ensure "
774
+ f"#test frames during each forard in the second pass will not exceed {T - 2}."
775
+ )
776
+ if input_ords[1, 0] <= test_ords[0, 0]:
777
+ assert not in_stop_ranges[0].any()
778
+ if input_ords[-1, 0] >= test_ords[0, -1]:
779
+ assert not in_stop_ranges[-1].any()
780
+
781
+ gt_chunk = (
782
+ [f"!{i:03d}" for i in gt_input_inds] if "gt" in chunk_strategy else []
783
+ )
784
+ chunk = gt_chunk + []
785
+ # any test views before the first input views
786
+ if in_stop_ranges[0].any():
787
+ for j, in_range in enumerate(in_stop_ranges[0]):
788
+ if in_range:
789
+ chunk.append(f">{j:03d}")
790
+ in_stop_ranges = in_stop_ranges[1:]
791
+
792
+ i = 0
793
+ base_i = len(gt_input_inds) if "img2trajvid" in task else 0
794
+ chunk.append(f"!{i + base_i:03d}")
795
+ while i < len(in_stop_ranges):
796
+ in_stop_range = in_stop_ranges[i]
797
+ if not in_stop_range.any():
798
+ i += 1
799
+ continue
800
+
801
+ input_left = i + 1 < M
802
+ space_left = T - len(chunk)
803
+ if sum(in_stop_range) + input_left <= space_left:
804
+ for j, in_range in enumerate(in_stop_range):
805
+ if in_range:
806
+ chunk.append(f">{j:03d}")
807
+ i += 1
808
+ if input_left:
809
+ chunk.append(f"!{i + base_i:03d}")
810
+
811
+ else:
812
+ chunk += ["NULL"] * space_left
813
+ chunks.append(chunk)
814
+ chunk = gt_chunk + [f"!{i + base_i:03d}"]
815
+
816
+ if len(chunk) > 1:
817
+ chunk += ["NULL"] * (T - len(chunk))
818
+ chunks.append(chunk)
819
+
820
+ else:
821
+ raise NotImplementedError
822
+
823
+ (
824
+ input_inds_per_chunk,
825
+ input_sels_per_chunk,
826
+ test_inds_per_chunk,
827
+ test_sels_per_chunk,
828
+ ) = (
829
+ [],
830
+ [],
831
+ [],
832
+ [],
833
+ )
834
+ for chunk in chunks:
835
+ input_inds = [
836
+ int(img.removeprefix("!")) for img in chunk if img.startswith("!")
837
+ ]
838
+ input_sels = [chunk.index(img) for img in chunk if img.startswith("!")]
839
+ test_inds = [int(img.removeprefix(">")) for img in chunk if img.startswith(">")]
840
+ test_sels = [chunk.index(img) for img in chunk if img.startswith(">")]
841
+ input_inds_per_chunk.append(input_inds)
842
+ input_sels_per_chunk.append(input_sels)
843
+ test_inds_per_chunk.append(test_inds)
844
+ test_sels_per_chunk.append(test_sels)
845
+
846
+ if options.get("sampler_verbose", True):
847
+
848
+ def colorize(item):
849
+ if item.startswith("!"):
850
+ return f"{Fore.RED}{item}{Style.RESET_ALL}" # Red for items starting with '!'
851
+ elif item.startswith(">"):
852
+ return f"{Fore.GREEN}{item}{Style.RESET_ALL}" # Green for items starting with '>'
853
+ return item # Default color if neither '!' nor '>'
854
+
855
+ print("\nchunks:")
856
+ for chunk in chunks:
857
+ print(", ".join(colorize(item) for item in chunk))
858
+
859
+ return (
860
+ chunks,
861
+ input_inds_per_chunk, # ordering of input in raw sequence
862
+ input_sels_per_chunk, # ordering of input in one-forward sequence of length T
863
+ test_inds_per_chunk, # ordering of test in raw sequence
864
+ test_sels_per_chunk, # oredering of test in one-forward sequence of length T
865
+ )
866
+
867
+
868
+ def is_k_in_dict(d, k):
869
+ return any(map(lambda x: x.startswith(k), d.keys()))
870
+
871
+
872
+ def get_k_from_dict(d, k):
873
+ media_d = {}
874
+ for key, value in d.items():
875
+ if key == k:
876
+ return value
877
+ if key.startswith(k):
878
+ media = key.split("/")[-1]
879
+ if media == "raw":
880
+ return value
881
+ media_d[media] = value
882
+ if len(media_d) == 0:
883
+ return torch.tensor([])
884
+ assert (
885
+ len(media_d) == 1
886
+ ), f"multiple media found in {d} for key {k}: {media_d.keys()}"
887
+ return media_d[media]
888
+
889
+
890
+ def update_kv_for_dict(d, k, v):
891
+ for key in d.keys():
892
+ if key.startswith(k):
893
+ d[key] = v
894
+ return d
895
+
896
+
897
+ def extend_dict(ds, d):
898
+ for key in d.keys():
899
+ if key in ds:
900
+ ds[key] = torch.cat([ds[key], d[key]], 0)
901
+ else:
902
+ ds[key] = d[key]
903
+ return ds
904
+
905
+
906
+ def replace_or_include_input_for_dict(
907
+ samples,
908
+ test_indices,
909
+ imgs,
910
+ c2w,
911
+ K,
912
+ ):
913
+ samples_new = {}
914
+ for sample, value in samples.items():
915
+ if "rgb" in sample:
916
+ imgs[test_indices] = (
917
+ value[test_indices] if value.shape[0] == imgs.shape[0] else value
918
+ ).to(device=imgs.device, dtype=imgs.dtype)
919
+ samples_new[sample] = imgs
920
+ elif "c2w" in sample:
921
+ c2w[test_indices] = (
922
+ value[test_indices] if value.shape[0] == c2w.shape[0] else value
923
+ ).to(device=c2w.device, dtype=c2w.dtype)
924
+ samples_new[sample] = c2w
925
+ elif "intrinsics" in sample:
926
+ K[test_indices] = (
927
+ value[test_indices] if value.shape[0] == K.shape[0] else value
928
+ ).to(device=K.device, dtype=K.dtype)
929
+ samples_new[sample] = K
930
+ else:
931
+ samples_new[sample] = value
932
+ return samples_new
933
+
934
+
935
+ def decode_output(
936
+ samples,
937
+ T,
938
+ indices=None,
939
+ ):
940
+ # decode model output into dict if it is not
941
+ if isinstance(samples, dict):
942
+ # model with postprocessor and outputs dict
943
+ for sample, value in samples.items():
944
+ if isinstance(value, torch.Tensor):
945
+ value = value.detach().cpu()
946
+ elif isinstance(value, np.ndarray):
947
+ value = torch.from_numpy(value)
948
+ else:
949
+ value = torch.tensor(value)
950
+
951
+ if indices is not None and value.shape[0] == T:
952
+ value = value[indices]
953
+ samples[sample] = value
954
+ else:
955
+ # model without postprocessor and outputs tensor (rgb)
956
+ samples = samples.detach().cpu()
957
+
958
+ if indices is not None and samples.shape[0] == T:
959
+ samples = samples[indices]
960
+ samples = {"samples-rgb/image": samples}
961
+
962
+ return samples
963
+
964
+
965
+ def save_output(
966
+ samples,
967
+ save_path,
968
+ video_save_fps=2,
969
+ ):
970
+ os.makedirs(save_path, exist_ok=True)
971
+ for sample in samples:
972
+ media_type = "video"
973
+ if "/" in sample:
974
+ sample_, media_type = sample.split("/")
975
+ else:
976
+ sample_ = sample
977
+
978
+ value = samples[sample]
979
+ if isinstance(value, torch.Tensor):
980
+ value = value.detach().cpu()
981
+ elif isinstance(value, np.ndarray):
982
+ value = torch.from_numpy(value)
983
+ else:
984
+ value = torch.tensor(value)
985
+
986
+ if media_type == "image":
987
+ value = (value.permute(0, 2, 3, 1) + 1) / 2.0
988
+ value = (value * 255).clamp(0, 255).to(torch.uint8)
989
+ iio.imwrite(
990
+ os.path.join(save_path, f"{sample_}.mp4")
991
+ if sample_
992
+ else f"{save_path}.mp4",
993
+ value,
994
+ fps=video_save_fps,
995
+ macro_block_size=1,
996
+ ffmpeg_log_level="error",
997
+ )
998
+ os.makedirs(os.path.join(save_path, sample_), exist_ok=True)
999
+ for i, s in enumerate(value):
1000
+ iio.imwrite(
1001
+ os.path.join(save_path, sample_, f"{i:03d}.png"),
1002
+ s,
1003
+ )
1004
+ elif media_type == "video":
1005
+ value = (value.permute(0, 2, 3, 1) + 1) / 2.0
1006
+ value = (value * 255).clamp(0, 255).to(torch.uint8)
1007
+ iio.imwrite(
1008
+ os.path.join(save_path, f"{sample_}.mp4"),
1009
+ value,
1010
+ fps=video_save_fps,
1011
+ macro_block_size=1,
1012
+ ffmpeg_log_level="error",
1013
+ )
1014
+ elif media_type == "raw":
1015
+ torch.save(
1016
+ value,
1017
+ os.path.join(save_path, f"{sample_}.pt"),
1018
+ )
1019
+ else:
1020
+ pass
1021
+
1022
+
1023
+ def create_transforms_simple(save_path, img_paths, img_whs, c2ws, Ks):
1024
+ import os.path as osp
1025
+
1026
+ out_frames = []
1027
+ for img_path, img_wh, c2w, K in zip(img_paths, img_whs, c2ws, Ks):
1028
+ out_frame = {
1029
+ "fl_x": K[0][0].item(),
1030
+ "fl_y": K[1][1].item(),
1031
+ "cx": K[0][2].item(),
1032
+ "cy": K[1][2].item(),
1033
+ "w": img_wh[0].item(),
1034
+ "h": img_wh[1].item(),
1035
+ "file_path": f"./{osp.relpath(img_path, start=save_path)}"
1036
+ if img_path is not None
1037
+ else None,
1038
+ "transform_matrix": c2w.tolist(),
1039
+ }
1040
+ out_frames.append(out_frame)
1041
+ out = {
1042
+ # "camera_model": "PINHOLE",
1043
+ "orientation_override": "none",
1044
+ "frames": out_frames,
1045
+ }
1046
+ with open(osp.join(save_path, "transforms.json"), "w") as of:
1047
+ json.dump(out, of, indent=5)
1048
+
1049
+
1050
+ class GradioTrackedSampler(EulerEDMSampler):
1051
+ """
1052
+ A thin wrapper around the EulerEDMSampler that allows tracking progress and
1053
+ aborting sampling for gradio demo.
1054
+ """
1055
+
1056
+ def __init__(self, abort_event: threading.Event, *args, **kwargs):
1057
+ super().__init__(*args, **kwargs)
1058
+ self.abort_event = abort_event
1059
+
1060
+ def __call__( # type: ignore
1061
+ self,
1062
+ denoiser,
1063
+ x: torch.Tensor,
1064
+ scale: float | torch.Tensor,
1065
+ cond: dict,
1066
+ uc: dict | None = None,
1067
+ num_steps: int | None = None,
1068
+ verbose: bool = True,
1069
+ global_pbar: gr.Progress | None = None,
1070
+ **guider_kwargs,
1071
+ ) -> torch.Tensor | None:
1072
+ uc = cond if uc is None else uc
1073
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
1074
+ x,
1075
+ cond,
1076
+ uc,
1077
+ num_steps,
1078
+ )
1079
+ for i in self.get_sigma_gen(num_sigmas, verbose=verbose):
1080
+ gamma = (
1081
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
1082
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
1083
+ else 0.0
1084
+ )
1085
+ x = self.sampler_step(
1086
+ s_in * sigmas[i],
1087
+ s_in * sigmas[i + 1],
1088
+ denoiser,
1089
+ x,
1090
+ scale,
1091
+ cond,
1092
+ uc,
1093
+ gamma,
1094
+ **guider_kwargs,
1095
+ )
1096
+ # Allow tracking progress in gradio demo.
1097
+ if global_pbar is not None:
1098
+ global_pbar.update()
1099
+ # Allow aborting sampling in gradio demo.
1100
+ if self.abort_event.is_set():
1101
+ return None
1102
+ return x
1103
+
1104
+
1105
+ def create_samplers(
1106
+ guider_types: int | list[int],
1107
+ discretization,
1108
+ num_frames: list[int] | None,
1109
+ num_steps: int,
1110
+ cfg_min: float = 1.0,
1111
+ device: str | torch.device = "cuda",
1112
+ abort_event: threading.Event | None = None,
1113
+ ):
1114
+ guider_mapping = {
1115
+ 0: VanillaCFG,
1116
+ 1: MultiviewCFG,
1117
+ 2: MultiviewTemporalCFG,
1118
+ }
1119
+ samplers = []
1120
+ if not isinstance(guider_types, (list, tuple)):
1121
+ guider_types = [guider_types]
1122
+ for i, guider_type in enumerate(guider_types):
1123
+ if guider_type not in guider_mapping:
1124
+ raise ValueError(
1125
+ f"Invalid guider type {guider_type}. Must be one of {list(guider_mapping.keys())}"
1126
+ )
1127
+ guider_cls = guider_mapping[guider_type]
1128
+ guider_args = ()
1129
+ if guider_type > 0:
1130
+ guider_args += (cfg_min,)
1131
+ if guider_type == 2:
1132
+ assert num_frames is not None
1133
+ guider_args = (num_frames[i], cfg_min)
1134
+ guider = guider_cls(*guider_args)
1135
+
1136
+ if abort_event is not None:
1137
+ sampler = GradioTrackedSampler(
1138
+ abort_event,
1139
+ discretization=discretization,
1140
+ guider=guider,
1141
+ num_steps=num_steps,
1142
+ s_churn=0.0,
1143
+ s_tmin=0.0,
1144
+ s_tmax=999.0,
1145
+ s_noise=1.0,
1146
+ verbose=True,
1147
+ device=device,
1148
+ )
1149
+ else:
1150
+ sampler = EulerEDMSampler(
1151
+ discretization=discretization,
1152
+ guider=guider,
1153
+ num_steps=num_steps,
1154
+ s_churn=0.0,
1155
+ s_tmin=0.0,
1156
+ s_tmax=999.0,
1157
+ s_noise=1.0,
1158
+ verbose=True,
1159
+ device=device,
1160
+ )
1161
+ samplers.append(sampler)
1162
+ return samplers
1163
+
1164
+
1165
+ def get_value_dict(
1166
+ curr_imgs,
1167
+ curr_imgs_clip,
1168
+ curr_input_frame_indices,
1169
+ curr_c2ws,
1170
+ curr_Ks,
1171
+ curr_input_camera_indices,
1172
+ all_c2ws,
1173
+ camera_scale=2.0,
1174
+ ):
1175
+ assert sorted(curr_input_camera_indices) == sorted(
1176
+ range(len(curr_input_camera_indices))
1177
+ )
1178
+ H, W, T, F = curr_imgs.shape[-2], curr_imgs.shape[-1], len(curr_imgs), 8
1179
+
1180
+ value_dict = {}
1181
+ value_dict["cond_frames_without_noise"] = curr_imgs_clip[curr_input_frame_indices]
1182
+ value_dict["cond_frames"] = curr_imgs + 0.0 * torch.randn_like(curr_imgs)
1183
+ value_dict["cond_frames_mask"] = torch.zeros(T, dtype=torch.bool)
1184
+ value_dict["cond_frames_mask"][curr_input_frame_indices] = True
1185
+ value_dict["cond_aug"] = 0.0
1186
+
1187
+ c2w = to_hom_pose(curr_c2ws.float())
1188
+ w2c = torch.linalg.inv(c2w)
1189
+
1190
+ # camera centering
1191
+ ref_c2ws = all_c2ws
1192
+ camera_dist_2med = torch.norm(
1193
+ ref_c2ws[:, :3, 3] - ref_c2ws[:, :3, 3].median(0, keepdim=True).values,
1194
+ dim=-1,
1195
+ )
1196
+ valid_mask = camera_dist_2med <= torch.clamp(
1197
+ torch.quantile(camera_dist_2med, 0.97) * 10,
1198
+ max=1e6,
1199
+ )
1200
+ c2w[:, :3, 3] -= ref_c2ws[valid_mask, :3, 3].mean(0, keepdim=True)
1201
+ w2c = torch.linalg.inv(c2w)
1202
+
1203
+ # camera normalization
1204
+ camera_dists = c2w[:, :3, 3].clone()
1205
+ translation_scaling_factor = (
1206
+ camera_scale
1207
+ if torch.isclose(
1208
+ torch.norm(camera_dists[0]),
1209
+ torch.zeros(1),
1210
+ atol=1e-5,
1211
+ ).any()
1212
+ else (camera_scale / torch.norm(camera_dists[0]))
1213
+ )
1214
+ w2c[:, :3, 3] *= translation_scaling_factor
1215
+ c2w[:, :3, 3] *= translation_scaling_factor
1216
+ value_dict["plucker_coordinate"], _ = get_plucker_coordinates(
1217
+ extrinsics_src=w2c[0],
1218
+ extrinsics=w2c,
1219
+ intrinsics=curr_Ks.float().clone(),
1220
+ mode="plucker",
1221
+ rel_zero_translation=True,
1222
+ target_size=(H // F, W // F),
1223
+ return_grid_cam=True,
1224
+ )
1225
+
1226
+ value_dict["c2w"] = c2w
1227
+ value_dict["K"] = curr_Ks
1228
+ value_dict["camera_mask"] = torch.zeros(T, dtype=torch.bool)
1229
+ value_dict["camera_mask"][curr_input_camera_indices] = True
1230
+
1231
+ return value_dict
1232
+
1233
+
1234
+ def do_sample(
1235
+ model,
1236
+ ae,
1237
+ conditioner,
1238
+ denoiser,
1239
+ sampler,
1240
+ value_dict,
1241
+ H,
1242
+ W,
1243
+ C,
1244
+ F,
1245
+ T,
1246
+ cfg,
1247
+ encoding_t=1,
1248
+ decoding_t=1,
1249
+ verbose=True,
1250
+ global_pbar=None,
1251
+ **_,
1252
+ ):
1253
+ imgs = value_dict["cond_frames"].to("cuda")
1254
+ input_masks = value_dict["cond_frames_mask"].to("cuda")
1255
+ pluckers = value_dict["plucker_coordinate"].to("cuda")
1256
+
1257
+ num_samples = [1, T]
1258
+ with torch.inference_mode(), torch.autocast("cuda"):
1259
+ load_model(ae)
1260
+ load_model(conditioner)
1261
+ latents = torch.nn.functional.pad(
1262
+ ae.encode(imgs[input_masks], encoding_t), (0, 0, 0, 0, 0, 1), value=1.0
1263
+ )
1264
+ c_crossattn = repeat(conditioner(imgs[input_masks]).mean(0), "d -> n 1 d", n=T)
1265
+ uc_crossattn = torch.zeros_like(c_crossattn)
1266
+ c_replace = latents.new_zeros(T, *latents.shape[1:])
1267
+ c_replace[input_masks] = latents
1268
+ uc_replace = torch.zeros_like(c_replace)
1269
+ c_concat = torch.cat(
1270
+ [
1271
+ repeat(
1272
+ input_masks,
1273
+ "n -> n 1 h w",
1274
+ h=pluckers.shape[2],
1275
+ w=pluckers.shape[3],
1276
+ ),
1277
+ pluckers,
1278
+ ],
1279
+ 1,
1280
+ )
1281
+ uc_concat = torch.cat(
1282
+ [pluckers.new_zeros(T, 1, *pluckers.shape[-2:]), pluckers], 1
1283
+ )
1284
+ c_dense_vector = pluckers
1285
+ uc_dense_vector = c_dense_vector
1286
+ # TODO(hangg): concat and dense are problematic.
1287
+ c = {
1288
+ "crossattn": c_crossattn,
1289
+ "replace": c_replace,
1290
+ "concat": c_concat,
1291
+ "dense_vector": c_dense_vector,
1292
+ }
1293
+ uc = {
1294
+ "crossattn": uc_crossattn,
1295
+ "replace": uc_replace,
1296
+ "concat": uc_concat,
1297
+ "dense_vector": uc_dense_vector,
1298
+ }
1299
+ unload_model(ae)
1300
+ unload_model(conditioner)
1301
+
1302
+ additional_model_inputs = {"num_frames": T}
1303
+ additional_sampler_inputs = {
1304
+ "c2w": value_dict["c2w"].to("cuda"),
1305
+ "K": value_dict["K"].to("cuda"),
1306
+ "input_frame_mask": value_dict["cond_frames_mask"].to("cuda"),
1307
+ }
1308
+ if global_pbar is not None:
1309
+ additional_sampler_inputs["global_pbar"] = global_pbar
1310
+
1311
+ shape = (math.prod(num_samples), C, H // F, W // F)
1312
+ randn = torch.randn(shape).to("cuda")
1313
+
1314
+ load_model(model)
1315
+ samples_z = sampler(
1316
+ lambda input, sigma, c: denoiser(
1317
+ model,
1318
+ input,
1319
+ sigma,
1320
+ c,
1321
+ **additional_model_inputs,
1322
+ ),
1323
+ randn,
1324
+ scale=cfg,
1325
+ cond=c,
1326
+ uc=uc,
1327
+ verbose=verbose,
1328
+ **additional_sampler_inputs,
1329
+ )
1330
+ if samples_z is None:
1331
+ return
1332
+ unload_model(model)
1333
+
1334
+ load_model(ae)
1335
+ samples = ae.decode(samples_z, decoding_t)
1336
+ unload_model(ae)
1337
+
1338
+ return samples
1339
+
1340
+
1341
+ def run_one_scene(
1342
+ task,
1343
+ version_dict,
1344
+ model,
1345
+ ae,
1346
+ conditioner,
1347
+ denoiser,
1348
+ image_cond,
1349
+ camera_cond,
1350
+ save_path,
1351
+ use_traj_prior,
1352
+ traj_prior_Ks,
1353
+ traj_prior_c2ws,
1354
+ seed=23,
1355
+ gradio=False,
1356
+ abort_event=None,
1357
+ first_pass_pbar=None,
1358
+ second_pass_pbar=None,
1359
+ ):
1360
+ H, W, T, C, F, options = (
1361
+ version_dict["H"],
1362
+ version_dict["W"],
1363
+ version_dict["T"],
1364
+ version_dict["C"],
1365
+ version_dict["f"],
1366
+ version_dict["options"],
1367
+ )
1368
+
1369
+ if isinstance(image_cond, str):
1370
+ image_cond = {"img": [image_cond]}
1371
+ imgs_clip, imgs, img_size = [], [], None
1372
+ for i, (img, K) in enumerate(zip(image_cond["img"], camera_cond["K"])):
1373
+ if isinstance(img, str) or img is None:
1374
+ img, K = load_img_and_K(img or img_size, None, K=K, device="cpu") # type: ignore
1375
+ img_size = img.shape[-2:]
1376
+ if options.get("L_short", -1) == -1:
1377
+ img, K = transform_img_and_K(
1378
+ img,
1379
+ (W, H),
1380
+ K=K[None],
1381
+ mode=(
1382
+ options.get("transform_input", "crop")
1383
+ if i in image_cond["input_indices"]
1384
+ else options.get("transform_target", "crop")
1385
+ ),
1386
+ scale=(
1387
+ 1.0
1388
+ if i in image_cond["input_indices"]
1389
+ else options.get("transform_scale", 1.0)
1390
+ ),
1391
+ )
1392
+ else:
1393
+ downsample = 3
1394
+ assert options["L_short"] % F * 2**downsample == 0, (
1395
+ "Short side of the image should be divisible by "
1396
+ f"F*2**{downsample}={F * 2**downsample}."
1397
+ )
1398
+ img, K = transform_img_and_K(
1399
+ img,
1400
+ options["L_short"],
1401
+ K=K[None],
1402
+ size_stride=F * 2**downsample,
1403
+ mode=(
1404
+ options.get("transform_input", "crop")
1405
+ if i in image_cond["input_indices"]
1406
+ else options.get("transform_target", "crop")
1407
+ ),
1408
+ scale=(
1409
+ 1.0
1410
+ if i in image_cond["input_indices"]
1411
+ else options.get("transform_scale", 1.0)
1412
+ ),
1413
+ )
1414
+ version_dict["W"] = W = img.shape[-1]
1415
+ version_dict["H"] = H = img.shape[-2]
1416
+ K = K[0]
1417
+ K[0] /= W
1418
+ K[1] /= H
1419
+ camera_cond["K"][i] = K
1420
+ img_clip = img
1421
+ elif isinstance(img, np.ndarray):
1422
+ img_size = torch.Size(img.shape[:2])
1423
+ img = torch.as_tensor(img).permute(2, 0, 1)
1424
+ img = img.unsqueeze(0)
1425
+ img = img / 255.0 * 2.0 - 1.0
1426
+ if not gradio:
1427
+ img, K = transform_img_and_K(img, (W, H), K=K[None])
1428
+ assert K is not None
1429
+ K = K[0]
1430
+ K[0] /= W
1431
+ K[1] /= H
1432
+ camera_cond["K"][i] = K
1433
+ img_clip = img
1434
+ else:
1435
+ assert (
1436
+ False
1437
+ ), f"Variable `img` got {type(img)} type which is not supported!!!"
1438
+ imgs_clip.append(img_clip)
1439
+ imgs.append(img)
1440
+ imgs_clip = torch.cat(imgs_clip, dim=0)
1441
+ imgs = torch.cat(imgs, dim=0)
1442
+
1443
+ if traj_prior_Ks is not None:
1444
+ assert img_size is not None
1445
+ for i, prior_k in enumerate(traj_prior_Ks):
1446
+ img, prior_k = load_img_and_K(img_size, None, K=prior_k, device="cpu") # type: ignore
1447
+ img, prior_k = transform_img_and_K(
1448
+ img,
1449
+ (W, H),
1450
+ K=prior_k[None],
1451
+ mode=options.get(
1452
+ "transform_target", "crop"
1453
+ ), # mode for prior is always same as target
1454
+ scale=options.get(
1455
+ "transform_scale", 1.0
1456
+ ), # scale for prior is always same as target
1457
+ )
1458
+ prior_k = prior_k[0]
1459
+ prior_k[0] /= W
1460
+ prior_k[1] /= H
1461
+ traj_prior_Ks[i] = prior_k
1462
+
1463
+ options["num_frames"] = T
1464
+ discretization = denoiser.discretization
1465
+ torch.cuda.empty_cache()
1466
+
1467
+ seed_everything(seed)
1468
+
1469
+ # Get Data
1470
+ input_indices = image_cond["input_indices"]
1471
+ input_imgs = imgs[input_indices]
1472
+ input_imgs_clip = imgs_clip[input_indices]
1473
+ input_c2ws = camera_cond["c2w"][input_indices]
1474
+ input_Ks = camera_cond["K"][input_indices]
1475
+
1476
+ test_indices = [i for i in range(len(imgs)) if i not in input_indices]
1477
+ test_imgs = imgs[test_indices]
1478
+ test_imgs_clip = imgs_clip[test_indices]
1479
+ test_c2ws = camera_cond["c2w"][test_indices]
1480
+ test_Ks = camera_cond["K"][test_indices]
1481
+
1482
+ if options.get("save_input", True):
1483
+ save_output(
1484
+ {"/image": input_imgs},
1485
+ save_path=os.path.join(save_path, "input"),
1486
+ video_save_fps=2,
1487
+ )
1488
+
1489
+ if not use_traj_prior:
1490
+ chunk_strategy = options.get("chunk_strategy", "gt")
1491
+
1492
+ (
1493
+ _,
1494
+ input_inds_per_chunk,
1495
+ input_sels_per_chunk,
1496
+ test_inds_per_chunk,
1497
+ test_sels_per_chunk,
1498
+ ) = chunk_input_and_test(
1499
+ T,
1500
+ input_c2ws,
1501
+ test_c2ws,
1502
+ input_indices,
1503
+ test_indices,
1504
+ options=options,
1505
+ task=task,
1506
+ chunk_strategy=chunk_strategy,
1507
+ gt_input_inds=list(range(input_c2ws.shape[0])),
1508
+ )
1509
+ print(
1510
+ f"One pass - chunking with `{chunk_strategy}` strategy: total "
1511
+ f"{len(input_inds_per_chunk)} forward(s) ..."
1512
+ )
1513
+
1514
+ all_samples = {}
1515
+ all_test_inds = []
1516
+ for i, (
1517
+ chunk_input_inds,
1518
+ chunk_input_sels,
1519
+ chunk_test_inds,
1520
+ chunk_test_sels,
1521
+ ) in tqdm(
1522
+ enumerate(
1523
+ zip(
1524
+ input_inds_per_chunk,
1525
+ input_sels_per_chunk,
1526
+ test_inds_per_chunk,
1527
+ test_sels_per_chunk,
1528
+ )
1529
+ ),
1530
+ total=len(input_inds_per_chunk),
1531
+ leave=False,
1532
+ ):
1533
+ (
1534
+ curr_input_sels,
1535
+ curr_test_sels,
1536
+ curr_input_maps,
1537
+ curr_test_maps,
1538
+ ) = pad_indices(
1539
+ chunk_input_sels,
1540
+ chunk_test_sels,
1541
+ T=T,
1542
+ padding_mode=options.get("t_padding_mode", "last"),
1543
+ )
1544
+ curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [
1545
+ assemble(
1546
+ input=x[chunk_input_inds],
1547
+ test=y[chunk_test_inds],
1548
+ input_maps=curr_input_maps,
1549
+ test_maps=curr_test_maps,
1550
+ )
1551
+ for x, y in zip(
1552
+ [
1553
+ torch.cat(
1554
+ [
1555
+ input_imgs,
1556
+ get_k_from_dict(all_samples, "samples-rgb").to(
1557
+ input_imgs.device
1558
+ ),
1559
+ ],
1560
+ dim=0,
1561
+ ),
1562
+ torch.cat(
1563
+ [
1564
+ input_imgs_clip,
1565
+ get_k_from_dict(all_samples, "samples-rgb").to(
1566
+ input_imgs.device
1567
+ ),
1568
+ ],
1569
+ dim=0,
1570
+ ),
1571
+ torch.cat([input_c2ws, test_c2ws[all_test_inds]], dim=0),
1572
+ torch.cat([input_Ks, test_Ks[all_test_inds]], dim=0),
1573
+ ], # procedually append generated prior views to the input views
1574
+ [test_imgs, test_imgs_clip, test_c2ws, test_Ks],
1575
+ )
1576
+ ]
1577
+ value_dict = get_value_dict(
1578
+ curr_imgs.to("cuda"),
1579
+ curr_imgs_clip.to("cuda"),
1580
+ curr_input_sels
1581
+ + [
1582
+ sel
1583
+ for (ind, sel) in zip(
1584
+ np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]],
1585
+ curr_test_sels,
1586
+ )
1587
+ if test_indices[ind] in image_cond["input_indices"]
1588
+ ],
1589
+ curr_c2ws,
1590
+ curr_Ks,
1591
+ curr_input_sels
1592
+ + [
1593
+ sel
1594
+ for (ind, sel) in zip(
1595
+ np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]],
1596
+ curr_test_sels,
1597
+ )
1598
+ if test_indices[ind] in camera_cond["input_indices"]
1599
+ ],
1600
+ all_c2ws=camera_cond["c2w"],
1601
+ )
1602
+ samplers = create_samplers(
1603
+ options["guider_types"],
1604
+ discretization,
1605
+ [len(curr_imgs)],
1606
+ options["num_steps"],
1607
+ options["cfg_min"],
1608
+ abort_event=abort_event,
1609
+ )
1610
+ assert len(samplers) == 1
1611
+ samples = do_sample(
1612
+ model,
1613
+ ae,
1614
+ conditioner,
1615
+ denoiser,
1616
+ samplers[0],
1617
+ value_dict,
1618
+ H,
1619
+ W,
1620
+ C,
1621
+ F,
1622
+ T=len(curr_imgs),
1623
+ cfg=(
1624
+ options["cfg"][0]
1625
+ if isinstance(options["cfg"], (list, tuple))
1626
+ else options["cfg"]
1627
+ ),
1628
+ **{k: options[k] for k in options if k not in ["cfg", "T"]},
1629
+ )
1630
+ samples = decode_output(
1631
+ samples, len(curr_imgs), chunk_test_sels
1632
+ ) # decode into dict
1633
+ if options.get("save_first_pass", False):
1634
+ save_output(
1635
+ replace_or_include_input_for_dict(
1636
+ samples,
1637
+ chunk_test_sels,
1638
+ curr_imgs,
1639
+ curr_c2ws,
1640
+ curr_Ks,
1641
+ ),
1642
+ save_path=os.path.join(save_path, "first-pass", f"forward_{i}"),
1643
+ video_save_fps=2,
1644
+ )
1645
+ extend_dict(all_samples, samples)
1646
+ all_test_inds.extend(chunk_test_inds)
1647
+ else:
1648
+ assert traj_prior_c2ws is not None, (
1649
+ "`traj_prior_c2ws` should be set when using 2-pass sampling. One "
1650
+ "potential reason is that the amount of input frames is larger than "
1651
+ "T. Set `num_prior_frames` manually to overwrite the infered stats."
1652
+ )
1653
+ traj_prior_c2ws = torch.as_tensor(
1654
+ traj_prior_c2ws,
1655
+ device=input_c2ws.device,
1656
+ dtype=input_c2ws.dtype,
1657
+ )
1658
+
1659
+ if traj_prior_Ks is None:
1660
+ traj_prior_Ks = test_Ks[:1].repeat_interleave(
1661
+ traj_prior_c2ws.shape[0], dim=0
1662
+ )
1663
+
1664
+ traj_prior_imgs = imgs.new_zeros(traj_prior_c2ws.shape[0], *imgs.shape[1:])
1665
+ traj_prior_imgs_clip = imgs_clip.new_zeros(
1666
+ traj_prior_c2ws.shape[0], *imgs_clip.shape[1:]
1667
+ )
1668
+
1669
+ # ---------------------------------- first pass ----------------------------------
1670
+ T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
1671
+ T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
1672
+ chunk_strategy_first_pass = options.get(
1673
+ "chunk_strategy_first_pass", "gt-nearest"
1674
+ )
1675
+ (
1676
+ _,
1677
+ input_inds_per_chunk,
1678
+ input_sels_per_chunk,
1679
+ prior_inds_per_chunk,
1680
+ prior_sels_per_chunk,
1681
+ ) = chunk_input_and_test(
1682
+ T_first_pass,
1683
+ input_c2ws,
1684
+ traj_prior_c2ws,
1685
+ input_indices,
1686
+ image_cond["prior_indices"],
1687
+ options=options,
1688
+ task=task,
1689
+ chunk_strategy=chunk_strategy_first_pass,
1690
+ gt_input_inds=list(range(input_c2ws.shape[0])),
1691
+ )
1692
+ print(
1693
+ f"Two passes (first) - chunking with `{chunk_strategy_first_pass}` strategy: total "
1694
+ f"{len(input_inds_per_chunk)} forward(s) ..."
1695
+ )
1696
+
1697
+ all_samples = {}
1698
+ all_prior_inds = []
1699
+ for i, (
1700
+ chunk_input_inds,
1701
+ chunk_input_sels,
1702
+ chunk_prior_inds,
1703
+ chunk_prior_sels,
1704
+ ) in tqdm(
1705
+ enumerate(
1706
+ zip(
1707
+ input_inds_per_chunk,
1708
+ input_sels_per_chunk,
1709
+ prior_inds_per_chunk,
1710
+ prior_sels_per_chunk,
1711
+ )
1712
+ ),
1713
+ total=len(input_inds_per_chunk),
1714
+ leave=False,
1715
+ ):
1716
+ (
1717
+ curr_input_sels,
1718
+ curr_prior_sels,
1719
+ curr_input_maps,
1720
+ curr_prior_maps,
1721
+ ) = pad_indices(
1722
+ chunk_input_sels,
1723
+ chunk_prior_sels,
1724
+ T=T_first_pass,
1725
+ padding_mode=options.get("t_padding_mode", "last"),
1726
+ )
1727
+ curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [
1728
+ assemble(
1729
+ input=x[chunk_input_inds],
1730
+ test=y[chunk_prior_inds],
1731
+ input_maps=curr_input_maps,
1732
+ test_maps=curr_prior_maps,
1733
+ )
1734
+ for x, y in zip(
1735
+ [
1736
+ torch.cat(
1737
+ [
1738
+ input_imgs,
1739
+ get_k_from_dict(all_samples, "samples-rgb").to(
1740
+ input_imgs.device
1741
+ ),
1742
+ ],
1743
+ dim=0,
1744
+ ),
1745
+ torch.cat(
1746
+ [
1747
+ input_imgs_clip,
1748
+ get_k_from_dict(all_samples, "samples-rgb").to(
1749
+ input_imgs.device
1750
+ ),
1751
+ ],
1752
+ dim=0,
1753
+ ),
1754
+ torch.cat([input_c2ws, traj_prior_c2ws[all_prior_inds]], dim=0),
1755
+ torch.cat([input_Ks, traj_prior_Ks[all_prior_inds]], dim=0),
1756
+ ], # procedually append generated prior views to the input views
1757
+ [
1758
+ traj_prior_imgs,
1759
+ traj_prior_imgs_clip,
1760
+ traj_prior_c2ws,
1761
+ traj_prior_Ks,
1762
+ ],
1763
+ )
1764
+ ]
1765
+ value_dict = get_value_dict(
1766
+ curr_imgs.to("cuda"),
1767
+ curr_imgs_clip.to("cuda"),
1768
+ curr_input_sels,
1769
+ curr_c2ws,
1770
+ curr_Ks,
1771
+ list(range(T_first_pass)),
1772
+ all_c2ws=camera_cond["c2w"], # traj_prior_c2ws,
1773
+ )
1774
+ samplers = create_samplers(
1775
+ options["guider_types"],
1776
+ discretization,
1777
+ [T_first_pass, T_second_pass],
1778
+ options["num_steps"],
1779
+ options["cfg_min"],
1780
+ abort_event=abort_event,
1781
+ )
1782
+ samples = do_sample(
1783
+ model,
1784
+ ae,
1785
+ conditioner,
1786
+ denoiser,
1787
+ (
1788
+ samplers[1]
1789
+ if len(samplers) > 1
1790
+ and options.get("ltr_first_pass", False)
1791
+ and chunk_strategy_first_pass != "gt"
1792
+ and i > 0
1793
+ else samplers[0]
1794
+ ),
1795
+ value_dict,
1796
+ H,
1797
+ W,
1798
+ C,
1799
+ F,
1800
+ cfg=(
1801
+ options["cfg"][0]
1802
+ if isinstance(options["cfg"], (list, tuple))
1803
+ else options["cfg"]
1804
+ ),
1805
+ T=T_first_pass,
1806
+ global_pbar=first_pass_pbar,
1807
+ **{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]},
1808
+ )
1809
+ if samples is None:
1810
+ return
1811
+ samples = decode_output(
1812
+ samples, T_first_pass, chunk_prior_sels
1813
+ ) # decode into dict
1814
+ extend_dict(all_samples, samples)
1815
+ all_prior_inds.extend(chunk_prior_inds)
1816
+
1817
+ if options.get("save_first_pass", True):
1818
+ save_output(
1819
+ all_samples,
1820
+ save_path=os.path.join(save_path, "first-pass"),
1821
+ video_save_fps=5,
1822
+ )
1823
+ video_path_0 = os.path.join(save_path, "first-pass", "samples-rgb.mp4")
1824
+ yield video_path_0
1825
+
1826
+ # ---------------------------------- second pass ----------------------------------
1827
+ prior_indices = image_cond["prior_indices"]
1828
+ assert (
1829
+ prior_indices is not None
1830
+ ), "`prior_frame_indices` needs to be set if using 2-pass sampling."
1831
+ prior_argsort = np.argsort(input_indices + prior_indices).tolist()
1832
+ prior_indices = np.array(input_indices + prior_indices)[prior_argsort].tolist()
1833
+ gt_input_inds = [prior_argsort.index(i) for i in range(input_c2ws.shape[0])]
1834
+
1835
+ traj_prior_imgs = torch.cat(
1836
+ [input_imgs, get_k_from_dict(all_samples, "samples-rgb")], dim=0
1837
+ )[prior_argsort]
1838
+ traj_prior_imgs_clip = torch.cat(
1839
+ [
1840
+ input_imgs_clip,
1841
+ get_k_from_dict(all_samples, "samples-rgb"),
1842
+ ],
1843
+ dim=0,
1844
+ )[prior_argsort]
1845
+ traj_prior_c2ws = torch.cat([input_c2ws, traj_prior_c2ws], dim=0)[prior_argsort]
1846
+ traj_prior_Ks = torch.cat([input_Ks, traj_prior_Ks], dim=0)[prior_argsort]
1847
+
1848
+ update_kv_for_dict(all_samples, "samples-rgb", traj_prior_imgs)
1849
+ update_kv_for_dict(all_samples, "samples-c2ws", traj_prior_c2ws)
1850
+ update_kv_for_dict(all_samples, "samples-intrinsics", traj_prior_Ks)
1851
+
1852
+ chunk_strategy = options.get("chunk_strategy", "nearest")
1853
+ (
1854
+ _,
1855
+ prior_inds_per_chunk,
1856
+ prior_sels_per_chunk,
1857
+ test_inds_per_chunk,
1858
+ test_sels_per_chunk,
1859
+ ) = chunk_input_and_test(
1860
+ T_second_pass,
1861
+ traj_prior_c2ws,
1862
+ test_c2ws,
1863
+ prior_indices,
1864
+ test_indices,
1865
+ options=options,
1866
+ task=task,
1867
+ chunk_strategy=chunk_strategy,
1868
+ gt_input_inds=gt_input_inds,
1869
+ )
1870
+ print(
1871
+ f"Two passes (second) - chunking with `{chunk_strategy}` strategy: total "
1872
+ f"{len(prior_inds_per_chunk)} forward(s) ..."
1873
+ )
1874
+
1875
+ all_samples = {}
1876
+ all_test_inds = []
1877
+ for i, (
1878
+ chunk_prior_inds,
1879
+ chunk_prior_sels,
1880
+ chunk_test_inds,
1881
+ chunk_test_sels,
1882
+ ) in tqdm(
1883
+ enumerate(
1884
+ zip(
1885
+ prior_inds_per_chunk,
1886
+ prior_sels_per_chunk,
1887
+ test_inds_per_chunk,
1888
+ test_sels_per_chunk,
1889
+ )
1890
+ ),
1891
+ total=len(prior_inds_per_chunk),
1892
+ leave=False,
1893
+ ):
1894
+ (
1895
+ curr_prior_sels,
1896
+ curr_test_sels,
1897
+ curr_prior_maps,
1898
+ curr_test_maps,
1899
+ ) = pad_indices(
1900
+ chunk_prior_sels,
1901
+ chunk_test_sels,
1902
+ T=T_second_pass,
1903
+ padding_mode="last",
1904
+ )
1905
+ curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [
1906
+ assemble(
1907
+ input=x[chunk_prior_inds],
1908
+ test=y[chunk_test_inds],
1909
+ input_maps=curr_prior_maps,
1910
+ test_maps=curr_test_maps,
1911
+ )
1912
+ for x, y in zip(
1913
+ [
1914
+ traj_prior_imgs,
1915
+ traj_prior_imgs_clip,
1916
+ traj_prior_c2ws,
1917
+ traj_prior_Ks,
1918
+ ],
1919
+ [test_imgs, test_imgs_clip, test_c2ws, test_Ks],
1920
+ )
1921
+ ]
1922
+ value_dict = get_value_dict(
1923
+ curr_imgs.to("cuda"),
1924
+ curr_imgs_clip.to("cuda"),
1925
+ curr_prior_sels,
1926
+ curr_c2ws,
1927
+ curr_Ks,
1928
+ list(range(T_second_pass)),
1929
+ all_c2ws=camera_cond["c2w"], # test_c2ws,
1930
+ )
1931
+ samples = do_sample(
1932
+ model,
1933
+ ae,
1934
+ conditioner,
1935
+ denoiser,
1936
+ samplers[1] if len(samplers) > 1 else samplers[0],
1937
+ value_dict,
1938
+ H,
1939
+ W,
1940
+ C,
1941
+ F,
1942
+ T=T_second_pass,
1943
+ cfg=(
1944
+ options["cfg"][1]
1945
+ if isinstance(options["cfg"], (list, tuple))
1946
+ and len(options["cfg"]) > 1
1947
+ else options["cfg"]
1948
+ ),
1949
+ global_pbar=second_pass_pbar,
1950
+ **{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]},
1951
+ )
1952
+ if samples is None:
1953
+ return
1954
+ samples = decode_output(
1955
+ samples, T_second_pass, chunk_test_sels
1956
+ ) # decode into dict
1957
+ if options.get("save_second_pass", False):
1958
+ save_output(
1959
+ replace_or_include_input_for_dict(
1960
+ samples,
1961
+ chunk_test_sels,
1962
+ curr_imgs,
1963
+ curr_c2ws,
1964
+ curr_Ks,
1965
+ ),
1966
+ save_path=os.path.join(save_path, "second-pass", f"forward_{i}"),
1967
+ video_save_fps=2,
1968
+ )
1969
+ extend_dict(all_samples, samples)
1970
+ all_test_inds.extend(chunk_test_inds)
1971
+ all_samples = {
1972
+ key: value[np.argsort(all_test_inds)] for key, value in all_samples.items()
1973
+ }
1974
+ save_output(
1975
+ replace_or_include_input_for_dict(
1976
+ all_samples,
1977
+ test_indices,
1978
+ imgs.clone(),
1979
+ camera_cond["c2w"].clone(),
1980
+ camera_cond["K"].clone(),
1981
+ )
1982
+ if options.get("replace_or_include_input", False)
1983
+ else all_samples,
1984
+ save_path=save_path,
1985
+ video_save_fps=options.get("video_save_fps", 2),
1986
+ )
1987
+ video_path_1 = os.path.join(save_path, "samples-rgb.mp4")
1988
+ yield video_path_1
seva/geometry.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ import numpy as np
4
+ import roma
5
+ import scipy.interpolate
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ DEFAULT_FOV_RAD = 0.9424777960769379 # 54 degrees by default
10
+
11
+
12
+ def get_camera_dist(
13
+ source_c2ws: torch.Tensor, # N x 3 x 4
14
+ target_c2ws: torch.Tensor, # M x 3 x 4
15
+ mode: str = "translation",
16
+ ):
17
+ if mode == "rotation":
18
+ dists = torch.acos(
19
+ (
20
+ (
21
+ torch.matmul(
22
+ source_c2ws[:, None, :3, :3],
23
+ target_c2ws[None, :, :3, :3].transpose(-1, -2),
24
+ )
25
+ .diagonal(offset=0, dim1=-2, dim2=-1)
26
+ .sum(-1)
27
+ - 1
28
+ )
29
+ / 2
30
+ ).clamp(-1, 1)
31
+ ) * (180 / torch.pi)
32
+ elif mode == "translation":
33
+ dists = torch.norm(
34
+ source_c2ws[:, None, :3, 3] - target_c2ws[None, :, :3, 3], dim=-1
35
+ )
36
+ else:
37
+ raise NotImplementedError(
38
+ f"Mode {mode} is not implemented for finding nearest source indices."
39
+ )
40
+ return dists
41
+
42
+
43
+ def to_hom(X):
44
+ # get homogeneous coordinates of the input
45
+ X_hom = torch.cat([X, torch.ones_like(X[..., :1])], dim=-1)
46
+ return X_hom
47
+
48
+
49
+ def to_hom_pose(pose):
50
+ # get homogeneous coordinates of the input pose
51
+ if pose.shape[-2:] == (3, 4):
52
+ pose_hom = torch.eye(4, device=pose.device)[None].repeat(pose.shape[0], 1, 1)
53
+ pose_hom[:, :3, :] = pose
54
+ return pose_hom
55
+ return pose
56
+
57
+
58
+ def get_default_intrinsics(
59
+ fov_rad=DEFAULT_FOV_RAD,
60
+ aspect_ratio=1.0,
61
+ ):
62
+ if not isinstance(fov_rad, torch.Tensor):
63
+ fov_rad = torch.tensor(
64
+ [fov_rad] if isinstance(fov_rad, (int, float)) else fov_rad
65
+ )
66
+ if aspect_ratio >= 1.0: # W >= H
67
+ focal_x = 0.5 / torch.tan(0.5 * fov_rad)
68
+ focal_y = focal_x * aspect_ratio
69
+ else: # W < H
70
+ focal_y = 0.5 / torch.tan(0.5 * fov_rad)
71
+ focal_x = focal_y / aspect_ratio
72
+ intrinsics = focal_x.new_zeros((focal_x.shape[0], 3, 3))
73
+ intrinsics[:, torch.eye(3, device=focal_x.device, dtype=bool)] = torch.stack(
74
+ [focal_x, focal_y, torch.ones_like(focal_x)], dim=-1
75
+ )
76
+ intrinsics[:, :, -1] = torch.tensor(
77
+ [0.5, 0.5, 1.0], device=focal_x.device, dtype=focal_x.dtype
78
+ )
79
+ return intrinsics
80
+
81
+
82
+ def get_image_grid(img_h, img_w):
83
+ # add 0.5 is VERY important especially when your img_h and img_w
84
+ # is not very large (e.g., 72)!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
85
+ y_range = torch.arange(img_h, dtype=torch.float32).add_(0.5)
86
+ x_range = torch.arange(img_w, dtype=torch.float32).add_(0.5)
87
+ Y, X = torch.meshgrid(y_range, x_range, indexing="ij") # [H,W]
88
+ xy_grid = torch.stack([X, Y], dim=-1).view(-1, 2) # [HW,2]
89
+ return to_hom(xy_grid) # [HW,3]
90
+
91
+
92
+ def img2cam(X, cam_intr):
93
+ return X @ cam_intr.inverse().transpose(-1, -2)
94
+
95
+
96
+ def cam2world(X, pose):
97
+ X_hom = to_hom(X)
98
+ pose_inv = torch.linalg.inv(to_hom_pose(pose))[..., :3, :4]
99
+ return X_hom @ pose_inv.transpose(-1, -2)
100
+
101
+
102
+ def get_center_and_ray(
103
+ img_h, img_w, pose, intr, zero_center_for_debugging=False
104
+ ): # [HW,2]
105
+ # given the intrinsic/extrinsic matrices, get the camera center and ray directions]
106
+ # assert(opt.camera.model=="perspective")
107
+
108
+ # compute center and ray
109
+ grid_img = get_image_grid(img_h, img_w) # [HW,3]
110
+ grid_3D_cam = img2cam(grid_img.to(intr.device), intr.float()) # [B,HW,3]
111
+ center_3D_cam = torch.zeros_like(grid_3D_cam) # [B,HW,3]
112
+
113
+ # transform from camera to world coordinates
114
+ grid_3D = cam2world(grid_3D_cam, pose) # [B,HW,3]
115
+ center_3D = cam2world(center_3D_cam, pose) # [B,HW,3]
116
+ ray = grid_3D - center_3D # [B,HW,3]
117
+
118
+ return center_3D_cam if zero_center_for_debugging else center_3D, ray, grid_3D_cam
119
+
120
+
121
+ def get_plucker_coordinates(
122
+ extrinsics_src,
123
+ extrinsics,
124
+ intrinsics=None,
125
+ fov_rad=DEFAULT_FOV_RAD,
126
+ mode="plucker",
127
+ rel_zero_translation=True,
128
+ zero_center_for_debugging=False,
129
+ target_size=[72, 72], # 576-size image
130
+ return_grid_cam=False, # save for later use if want restore
131
+ ):
132
+ if intrinsics is None:
133
+ intrinsics = get_default_intrinsics(fov_rad).to(extrinsics.device)
134
+ else:
135
+ # for some data preprocessed in the early stage (e.g., MVI and CO3D),
136
+ # intrinsics are expressed in raw pixel space (e.g., 576x576) instead
137
+ # of normalized image coordinates
138
+ if not (
139
+ torch.all(intrinsics[:, :2, -1] >= 0)
140
+ and torch.all(intrinsics[:, :2, -1] <= 1)
141
+ ):
142
+ intrinsics[:, :2] /= intrinsics.new_tensor(target_size).view(1, -1, 1) * 8
143
+ # you should ensure the intrisics are expressed in
144
+ # resolution-independent normalized image coordinates just performing a
145
+ # very simple verification here checking if principal points are
146
+ # between 0 and 1
147
+ assert (
148
+ torch.all(intrinsics[:, :2, -1] >= 0)
149
+ and torch.all(intrinsics[:, :2, -1] <= 1)
150
+ ), "Intrinsics should be expressed in resolution-independent normalized image coordinates."
151
+
152
+ c2w_src = torch.linalg.inv(extrinsics_src)
153
+ if not rel_zero_translation:
154
+ c2w_src[:3, 3] = c2w_src[3, :3] = 0.0
155
+ # transform coordinates from the source camera's coordinate system to the coordinate system of the respective camera
156
+ extrinsics_rel = torch.einsum(
157
+ "vnm,vmp->vnp", extrinsics, c2w_src[None].repeat(extrinsics.shape[0], 1, 1)
158
+ )
159
+
160
+ intrinsics[:, :2] *= extrinsics.new_tensor(
161
+ [
162
+ target_size[1], # w
163
+ target_size[0], # h
164
+ ]
165
+ ).view(1, -1, 1)
166
+ centers, rays, grid_cam = get_center_and_ray(
167
+ img_h=target_size[0],
168
+ img_w=target_size[1],
169
+ pose=extrinsics_rel[:, :3, :],
170
+ intr=intrinsics,
171
+ zero_center_for_debugging=zero_center_for_debugging,
172
+ )
173
+
174
+ if mode == "plucker" or "v1" in mode:
175
+ rays = torch.nn.functional.normalize(rays, dim=-1)
176
+ plucker = torch.cat((rays, torch.cross(centers, rays, dim=-1)), dim=-1)
177
+ else:
178
+ raise ValueError(f"Unknown Plucker coordinate mode: {mode}")
179
+
180
+ plucker = plucker.permute(0, 2, 1).reshape(plucker.shape[0], -1, *target_size)
181
+ if return_grid_cam:
182
+ return plucker, grid_cam.reshape(-1, *target_size, 3)
183
+ return plucker
184
+
185
+
186
+ def rt_to_mat4(
187
+ R: torch.Tensor, t: torch.Tensor, s: torch.Tensor | None = None
188
+ ) -> torch.Tensor:
189
+ """
190
+ Args:
191
+ R (torch.Tensor): (..., 3, 3).
192
+ t (torch.Tensor): (..., 3).
193
+ s (torch.Tensor): (...,).
194
+
195
+ Returns:
196
+ torch.Tensor: (..., 4, 4)
197
+ """
198
+ mat34 = torch.cat([R, t[..., None]], dim=-1)
199
+ if s is None:
200
+ bottom = (
201
+ mat34.new_tensor([[0.0, 0.0, 0.0, 1.0]])
202
+ .reshape((1,) * (mat34.dim() - 2) + (1, 4))
203
+ .expand(mat34.shape[:-2] + (1, 4))
204
+ )
205
+ else:
206
+ bottom = F.pad(1.0 / s[..., None, None], (3, 0), value=0.0)
207
+ mat4 = torch.cat([mat34, bottom], dim=-2)
208
+ return mat4
209
+
210
+
211
+ def get_preset_pose_fov(
212
+ option: Literal[
213
+ "orbit",
214
+ "spiral",
215
+ "lemniscate",
216
+ "zoom-in",
217
+ "zoom-out",
218
+ "dolly zoom-in",
219
+ "dolly zoom-out",
220
+ "move-forward",
221
+ "move-backward",
222
+ "move-up",
223
+ "move-down",
224
+ "move-left",
225
+ "move-right",
226
+ "roll",
227
+ ],
228
+ num_frames: int,
229
+ start_w2c: torch.Tensor,
230
+ look_at: torch.Tensor,
231
+ up_direction: torch.Tensor | None = None,
232
+ fov: float = DEFAULT_FOV_RAD,
233
+ spiral_radii: list[float] = [0.5, 0.5, 0.2],
234
+ zoom_factor: float | None = None,
235
+ ):
236
+ poses = fovs = None
237
+ if option == "orbit":
238
+ poses = torch.linalg.inv(
239
+ get_arc_horizontal_w2cs(
240
+ start_w2c,
241
+ look_at,
242
+ up_direction,
243
+ num_frames=num_frames,
244
+ endpoint=False,
245
+ )
246
+ ).numpy()
247
+ fovs = np.full((num_frames,), fov)
248
+ elif option == "spiral":
249
+ poses = generate_spiral_path(
250
+ torch.linalg.inv(start_w2c)[None].numpy() @ np.diagflat([1, -1, -1, 1]),
251
+ np.array([1, 5]),
252
+ n_frames=num_frames,
253
+ n_rots=2,
254
+ zrate=0.5,
255
+ radii=spiral_radii,
256
+ endpoint=False,
257
+ ) @ np.diagflat([1, -1, -1, 1])
258
+ poses = np.concatenate(
259
+ [
260
+ poses,
261
+ np.array([0.0, 0.0, 0.0, 1.0])[None, None].repeat(len(poses), 0),
262
+ ],
263
+ 1,
264
+ )
265
+ # We want the spiral trajectory to always start from start_w2c. Thus we
266
+ # apply the relative pose to get the final trajectory.
267
+ poses = (
268
+ np.linalg.inv(start_w2c.numpy())[None] @ np.linalg.inv(poses[:1]) @ poses
269
+ )
270
+ fovs = np.full((num_frames,), fov)
271
+ elif option == "lemniscate":
272
+ poses = torch.linalg.inv(
273
+ get_lemniscate_w2cs(
274
+ start_w2c,
275
+ look_at,
276
+ up_direction,
277
+ num_frames,
278
+ degree=60.0,
279
+ endpoint=False,
280
+ )
281
+ ).numpy()
282
+ fovs = np.full((num_frames,), fov)
283
+ elif option == "roll":
284
+ poses = torch.linalg.inv(
285
+ get_roll_w2cs(
286
+ start_w2c,
287
+ look_at,
288
+ None,
289
+ num_frames,
290
+ degree=360.0,
291
+ endpoint=False,
292
+ )
293
+ ).numpy()
294
+ fovs = np.full((num_frames,), fov)
295
+ elif option in [
296
+ "dolly zoom-in",
297
+ "dolly zoom-out",
298
+ "zoom-in",
299
+ "zoom-out",
300
+ ]:
301
+ if option.startswith("dolly"):
302
+ direction = "backward" if option == "dolly zoom-in" else "forward"
303
+ poses = torch.linalg.inv(
304
+ get_moving_w2cs(
305
+ start_w2c,
306
+ look_at,
307
+ up_direction,
308
+ num_frames,
309
+ endpoint=True,
310
+ direction=direction,
311
+ )
312
+ ).numpy()
313
+ else:
314
+ poses = torch.linalg.inv(start_w2c)[None].repeat(num_frames, 1, 1).numpy()
315
+ fov_rad_start = fov
316
+ if zoom_factor is None:
317
+ zoom_factor = 0.28 if option.endswith("zoom-in") else 1.5
318
+ fov_rad_end = zoom_factor * fov
319
+ fovs = (
320
+ np.linspace(0, 1, num_frames) * (fov_rad_end - fov_rad_start)
321
+ + fov_rad_start
322
+ )
323
+ elif option in [
324
+ "move-forward",
325
+ "move-backward",
326
+ "move-up",
327
+ "move-down",
328
+ "move-left",
329
+ "move-right",
330
+ ]:
331
+ poses = torch.linalg.inv(
332
+ get_moving_w2cs(
333
+ start_w2c,
334
+ look_at,
335
+ up_direction,
336
+ num_frames,
337
+ endpoint=True,
338
+ direction=option.removeprefix("move-"),
339
+ )
340
+ ).numpy()
341
+ fovs = np.full((num_frames,), fov)
342
+ else:
343
+ raise ValueError(f"Unknown preset option {option}.")
344
+
345
+ return poses, fovs
346
+
347
+
348
+ def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor:
349
+ """Triangulate a set of rays to find a single lookat point.
350
+
351
+ Args:
352
+ origins (torch.Tensor): A (N, 3) array of ray origins.
353
+ viewdirs (torch.Tensor): A (N, 3) array of ray view directions.
354
+
355
+ Returns:
356
+ torch.Tensor: A (3,) lookat point.
357
+ """
358
+
359
+ viewdirs = torch.nn.functional.normalize(viewdirs, dim=-1)
360
+ eye = torch.eye(3, device=origins.device, dtype=origins.dtype)[None]
361
+ # Calculate projection matrix I - rr^T
362
+ I_min_cov = eye - (viewdirs[..., None] * viewdirs[..., None, :])
363
+ # Compute sum of projections
364
+ sum_proj = I_min_cov.matmul(origins[..., None]).sum(dim=-3)
365
+ # Solve for the intersection point using least squares
366
+ lookat = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
367
+ # Check NaNs.
368
+ assert not torch.any(torch.isnan(lookat))
369
+ return lookat
370
+
371
+
372
+ def get_lookat_w2cs(
373
+ positions: torch.Tensor,
374
+ lookat: torch.Tensor,
375
+ up: torch.Tensor,
376
+ face_off: bool = False,
377
+ ):
378
+ """
379
+ Args:
380
+ positions: (N, 3) tensor of camera positions
381
+ lookat: (3,) tensor of lookat point
382
+ up: (3,) or (N, 3) tensor of up vector
383
+
384
+ Returns:
385
+ w2cs: (N, 3, 3) tensor of world to camera rotation matrices
386
+ """
387
+ forward_vectors = F.normalize(lookat - positions, dim=-1)
388
+ if face_off:
389
+ forward_vectors = -forward_vectors
390
+ if up.dim() == 1:
391
+ up = up[None]
392
+ right_vectors = F.normalize(torch.cross(forward_vectors, up, dim=-1), dim=-1)
393
+ down_vectors = F.normalize(
394
+ torch.cross(forward_vectors, right_vectors, dim=-1), dim=-1
395
+ )
396
+ Rs = torch.stack([right_vectors, down_vectors, forward_vectors], dim=-1)
397
+ w2cs = torch.linalg.inv(rt_to_mat4(Rs, positions))
398
+ return w2cs
399
+
400
+
401
+ def get_arc_horizontal_w2cs(
402
+ ref_w2c: torch.Tensor,
403
+ lookat: torch.Tensor,
404
+ up: torch.Tensor | None,
405
+ num_frames: int,
406
+ clockwise: bool = True,
407
+ face_off: bool = False,
408
+ endpoint: bool = False,
409
+ degree: float = 360.0,
410
+ ref_up_shift: float = 0.0,
411
+ ref_radius_scale: float = 1.0,
412
+ **_,
413
+ ) -> torch.Tensor:
414
+ ref_c2w = torch.linalg.inv(ref_w2c)
415
+ ref_position = ref_c2w[:3, 3]
416
+ if up is None:
417
+ up = -ref_c2w[:3, 1]
418
+ assert up is not None
419
+ ref_position += up * ref_up_shift
420
+ ref_position *= ref_radius_scale
421
+ thetas = (
422
+ torch.linspace(0.0, torch.pi * degree / 180, num_frames, device=ref_w2c.device)
423
+ if endpoint
424
+ else torch.linspace(
425
+ 0.0, torch.pi * degree / 180, num_frames + 1, device=ref_w2c.device
426
+ )[:-1]
427
+ )
428
+ if not clockwise:
429
+ thetas = -thetas
430
+ positions = (
431
+ torch.einsum(
432
+ "nij,j->ni",
433
+ roma.rotvec_to_rotmat(thetas[:, None] * up[None]),
434
+ ref_position - lookat,
435
+ )
436
+ + lookat
437
+ )
438
+ return get_lookat_w2cs(positions, lookat, up, face_off=face_off)
439
+
440
+
441
+ def get_lemniscate_w2cs(
442
+ ref_w2c: torch.Tensor,
443
+ lookat: torch.Tensor,
444
+ up: torch.Tensor | None,
445
+ num_frames: int,
446
+ degree: float,
447
+ endpoint: bool = False,
448
+ **_,
449
+ ) -> torch.Tensor:
450
+ ref_c2w = torch.linalg.inv(ref_w2c)
451
+ a = torch.linalg.norm(ref_c2w[:3, 3] - lookat) * np.tan(degree / 360 * np.pi)
452
+ # Lemniscate curve in camera space. Starting at the origin.
453
+ thetas = (
454
+ torch.linspace(0, 2 * torch.pi, num_frames, device=ref_w2c.device)
455
+ if endpoint
456
+ else torch.linspace(0, 2 * torch.pi, num_frames + 1, device=ref_w2c.device)[:-1]
457
+ ) + torch.pi / 2
458
+ positions = torch.stack(
459
+ [
460
+ a * torch.cos(thetas) / (1 + torch.sin(thetas) ** 2),
461
+ a * torch.cos(thetas) * torch.sin(thetas) / (1 + torch.sin(thetas) ** 2),
462
+ torch.zeros(num_frames, device=ref_w2c.device),
463
+ ],
464
+ dim=-1,
465
+ )
466
+ # Transform to world space.
467
+ positions = torch.einsum(
468
+ "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0)
469
+ )
470
+ if up is None:
471
+ up = -ref_c2w[:3, 1]
472
+ assert up is not None
473
+ return get_lookat_w2cs(positions, lookat, up)
474
+
475
+
476
+ def get_moving_w2cs(
477
+ ref_w2c: torch.Tensor,
478
+ lookat: torch.Tensor,
479
+ up: torch.Tensor | None,
480
+ num_frames: int,
481
+ endpoint: bool = False,
482
+ direction: str = "forward",
483
+ tilt_xy: torch.Tensor = None,
484
+ ):
485
+ """
486
+ Args:
487
+ ref_w2c: (4, 4) tensor of the reference wolrd-to-camera matrix
488
+ lookat: (3,) tensor of lookat point
489
+ up: (3,) tensor of up vector
490
+
491
+ Returns:
492
+ w2cs: (N, 3, 3) tensor of world to camera rotation matrices
493
+ """
494
+ ref_c2w = torch.linalg.inv(ref_w2c)
495
+ ref_position = ref_c2w[:3, -1]
496
+ if up is None:
497
+ up = -ref_c2w[:3, 1]
498
+
499
+ direction_vectors = {
500
+ "forward": (lookat - ref_position).clone(),
501
+ "backward": -(lookat - ref_position).clone(),
502
+ "up": up.clone(),
503
+ "down": -up.clone(),
504
+ "right": torch.cross((lookat - ref_position), up, dim=0),
505
+ "left": -torch.cross((lookat - ref_position), up, dim=0),
506
+ }
507
+ if direction not in direction_vectors:
508
+ raise ValueError(
509
+ f"Invalid direction: {direction}. Must be one of {list(direction_vectors.keys())}"
510
+ )
511
+
512
+ positions = ref_position + (
513
+ F.normalize(direction_vectors[direction], dim=0)
514
+ * (
515
+ torch.linspace(0, 0.99, num_frames, device=ref_w2c.device)
516
+ if endpoint
517
+ else torch.linspace(0, 1, num_frames + 1, device=ref_w2c.device)[:-1]
518
+ )[:, None]
519
+ )
520
+
521
+ if tilt_xy is not None:
522
+ positions[:, :2] += tilt_xy
523
+
524
+ return get_lookat_w2cs(positions, lookat, up)
525
+
526
+
527
+ def get_roll_w2cs(
528
+ ref_w2c: torch.Tensor,
529
+ lookat: torch.Tensor,
530
+ up: torch.Tensor | None,
531
+ num_frames: int,
532
+ endpoint: bool = False,
533
+ degree: float = 360.0,
534
+ **_,
535
+ ) -> torch.Tensor:
536
+ ref_c2w = torch.linalg.inv(ref_w2c)
537
+ ref_position = ref_c2w[:3, 3]
538
+ if up is None:
539
+ up = -ref_c2w[:3, 1] # Infer the up vector from the reference.
540
+
541
+ # Create vertical angles
542
+ thetas = (
543
+ torch.linspace(0.0, torch.pi * degree / 180, num_frames, device=ref_w2c.device)
544
+ if endpoint
545
+ else torch.linspace(
546
+ 0.0, torch.pi * degree / 180, num_frames + 1, device=ref_w2c.device
547
+ )[:-1]
548
+ )[:, None]
549
+
550
+ lookat_vector = F.normalize(lookat[None].float(), dim=-1)
551
+ up = up[None]
552
+ up = (
553
+ up * torch.cos(thetas)
554
+ + torch.cross(lookat_vector, up) * torch.sin(thetas)
555
+ + lookat_vector
556
+ * torch.einsum("ij,ij->i", lookat_vector, up)[:, None]
557
+ * (1 - torch.cos(thetas))
558
+ )
559
+
560
+ # Normalize the camera orientation
561
+ return get_lookat_w2cs(ref_position[None].repeat(num_frames, 1), lookat, up)
562
+
563
+
564
+ def normalize(x):
565
+ """Normalization helper function."""
566
+ return x / np.linalg.norm(x)
567
+
568
+
569
+ def viewmatrix(lookdir, up, position, subtract_position=False):
570
+ """Construct lookat view matrix."""
571
+ vec2 = normalize((lookdir - position) if subtract_position else lookdir)
572
+ vec0 = normalize(np.cross(up, vec2))
573
+ vec1 = normalize(np.cross(vec2, vec0))
574
+ m = np.stack([vec0, vec1, vec2, position], axis=1)
575
+ return m
576
+
577
+
578
+ def poses_avg(poses):
579
+ """New pose using average position, z-axis, and up vector of input poses."""
580
+ position = poses[:, :3, 3].mean(0)
581
+ z_axis = poses[:, :3, 2].mean(0)
582
+ up = poses[:, :3, 1].mean(0)
583
+ cam2world = viewmatrix(z_axis, up, position)
584
+ return cam2world
585
+
586
+
587
+ def generate_spiral_path(
588
+ poses, bounds, n_frames=120, n_rots=2, zrate=0.5, endpoint=False, radii=None
589
+ ):
590
+ """Calculates a forward facing spiral path for rendering."""
591
+ # Find a reasonable 'focus depth' for this dataset as a weighted average
592
+ # of near and far bounds in disparity space.
593
+ close_depth, inf_depth = bounds.min() * 0.9, bounds.max() * 5.0
594
+ dt = 0.75
595
+ focal = 1 / ((1 - dt) / close_depth + dt / inf_depth)
596
+
597
+ # Get radii for spiral path using 90th percentile of camera positions.
598
+ positions = poses[:, :3, 3]
599
+ if radii is None:
600
+ radii = np.percentile(np.abs(positions), 90, 0)
601
+ radii = np.concatenate([radii, [1.0]])
602
+
603
+ # Generate poses for spiral path.
604
+ render_poses = []
605
+ cam2world = poses_avg(poses)
606
+ up = poses[:, :3, 1].mean(0)
607
+ for theta in np.linspace(0.0, 2.0 * np.pi * n_rots, n_frames, endpoint=endpoint):
608
+ t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]
609
+ position = cam2world @ t
610
+ lookat = cam2world @ [0, 0, -focal, 1.0]
611
+ z_axis = position - lookat
612
+ render_poses.append(viewmatrix(z_axis, up, position))
613
+ render_poses = np.stack(render_poses, axis=0)
614
+ return render_poses
615
+
616
+
617
+ def generate_interpolated_path(
618
+ poses: np.ndarray,
619
+ n_interp: int,
620
+ spline_degree: int = 5,
621
+ smoothness: float = 0.03,
622
+ rot_weight: float = 0.1,
623
+ endpoint: bool = False,
624
+ ):
625
+ """Creates a smooth spline path between input keyframe camera poses.
626
+
627
+ Spline is calculated with poses in format (position, lookat-point, up-point).
628
+
629
+ Args:
630
+ poses: (n, 3, 4) array of input pose keyframes.
631
+ n_interp: returned path will have n_interp * (n - 1) total poses.
632
+ spline_degree: polynomial degree of B-spline.
633
+ smoothness: parameter for spline smoothing, 0 forces exact interpolation.
634
+ rot_weight: relative weighting of rotation/translation in spline solve.
635
+
636
+ Returns:
637
+ Array of new camera poses with shape (n_interp * (n - 1), 3, 4).
638
+ """
639
+
640
+ def poses_to_points(poses, dist):
641
+ """Converts from pose matrices to (position, lookat, up) format."""
642
+ pos = poses[:, :3, -1]
643
+ lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
644
+ up = poses[:, :3, -1] + dist * poses[:, :3, 1]
645
+ return np.stack([pos, lookat, up], 1)
646
+
647
+ def points_to_poses(points):
648
+ """Converts from (position, lookat, up) format to pose matrices."""
649
+ return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points])
650
+
651
+ def interp(points, n, k, s):
652
+ """Runs multidimensional B-spline interpolation on the input points."""
653
+ sh = points.shape
654
+ pts = np.reshape(points, (sh[0], -1))
655
+ k = min(k, sh[0] - 1)
656
+ tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s)
657
+ u = np.linspace(0, 1, n, endpoint=endpoint)
658
+ new_points = np.array(scipy.interpolate.splev(u, tck))
659
+ new_points = np.reshape(new_points.T, (n, sh[1], sh[2]))
660
+ return new_points
661
+
662
+ points = poses_to_points(poses, dist=rot_weight)
663
+ new_points = interp(
664
+ points, n_interp * (points.shape[0] - 1), k=spline_degree, s=smoothness
665
+ )
666
+ return points_to_poses(new_points)
667
+
668
+
669
+ def similarity_from_cameras(c2w, strict_scaling=False, center_method="focus"):
670
+ """
671
+ reference: nerf-factory
672
+ Get a similarity transform to normalize dataset
673
+ from c2w (OpenCV convention) cameras
674
+ :param c2w: (N, 4)
675
+ :return T (4,4) , scale (float)
676
+ """
677
+ t = c2w[:, :3, 3]
678
+ R = c2w[:, :3, :3]
679
+
680
+ # (1) Rotate the world so that z+ is the up axis
681
+ # we estimate the up axis by averaging the camera up axes
682
+ ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1)
683
+ world_up = np.mean(ups, axis=0)
684
+ world_up /= np.linalg.norm(world_up)
685
+
686
+ up_camspace = np.array([0.0, -1.0, 0.0])
687
+ c = (up_camspace * world_up).sum()
688
+ cross = np.cross(world_up, up_camspace)
689
+ skew = np.array(
690
+ [
691
+ [0.0, -cross[2], cross[1]],
692
+ [cross[2], 0.0, -cross[0]],
693
+ [-cross[1], cross[0], 0.0],
694
+ ]
695
+ )
696
+ if c > -1:
697
+ R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c)
698
+ else:
699
+ # In the unlikely case the original data has y+ up axis,
700
+ # rotate 180-deg about x axis
701
+ R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
702
+
703
+ # R_align = np.eye(3) # DEBUG
704
+ R = R_align @ R
705
+ fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1)
706
+ t = (R_align @ t[..., None])[..., 0]
707
+
708
+ # (2) Recenter the scene.
709
+ if center_method == "focus":
710
+ # find the closest point to the origin for each camera's center ray
711
+ nearest = t + (fwds * -t).sum(-1)[:, None] * fwds
712
+ translate = -np.median(nearest, axis=0)
713
+ elif center_method == "poses":
714
+ # use center of the camera positions
715
+ translate = -np.median(t, axis=0)
716
+ else:
717
+ raise ValueError(f"Unknown center_method {center_method}")
718
+
719
+ transform = np.eye(4)
720
+ transform[:3, 3] = translate
721
+ transform[:3, :3] = R_align
722
+
723
+ # (3) Rescale the scene using camera distances
724
+ scale_fn = np.max if strict_scaling else np.median
725
+ inv_scale = scale_fn(np.linalg.norm(t + translate, axis=-1))
726
+ if inv_scale == 0:
727
+ inv_scale = 1.0
728
+ scale = 1.0 / inv_scale
729
+ transform[:3, :] *= scale
730
+
731
+ return transform
732
+
733
+
734
+ def align_principle_axes(point_cloud):
735
+ # Compute centroid
736
+ centroid = np.median(point_cloud, axis=0)
737
+
738
+ # Translate point cloud to centroid
739
+ translated_point_cloud = point_cloud - centroid
740
+
741
+ # Compute covariance matrix
742
+ covariance_matrix = np.cov(translated_point_cloud, rowvar=False)
743
+
744
+ # Compute eigenvectors and eigenvalues
745
+ eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix)
746
+
747
+ # Sort eigenvectors by eigenvalues (descending order) so that the z-axis
748
+ # is the principal axis with the smallest eigenvalue.
749
+ sort_indices = eigenvalues.argsort()[::-1]
750
+ eigenvectors = eigenvectors[:, sort_indices]
751
+
752
+ # Check orientation of eigenvectors. If the determinant of the eigenvectors is
753
+ # negative, then we need to flip the sign of one of the eigenvectors.
754
+ if np.linalg.det(eigenvectors) < 0:
755
+ eigenvectors[:, 0] *= -1
756
+
757
+ # Create rotation matrix
758
+ rotation_matrix = eigenvectors.T
759
+
760
+ # Create SE(3) matrix (4x4 transformation matrix)
761
+ transform = np.eye(4)
762
+ transform[:3, :3] = rotation_matrix
763
+ transform[:3, 3] = -rotation_matrix @ centroid
764
+
765
+ return transform
766
+
767
+
768
+ def transform_points(matrix, points):
769
+ """Transform points using a SE(4) matrix.
770
+
771
+ Args:
772
+ matrix: 4x4 SE(4) matrix
773
+ points: Nx3 array of points
774
+
775
+ Returns:
776
+ Nx3 array of transformed points
777
+ """
778
+ assert matrix.shape == (4, 4)
779
+ assert len(points.shape) == 2 and points.shape[1] == 3
780
+ return points @ matrix[:3, :3].T + matrix[:3, 3]
781
+
782
+
783
+ def transform_cameras(matrix, camtoworlds):
784
+ """Transform cameras using a SE(4) matrix.
785
+
786
+ Args:
787
+ matrix: 4x4 SE(4) matrix
788
+ camtoworlds: Nx4x4 array of camera-to-world matrices
789
+
790
+ Returns:
791
+ Nx4x4 array of transformed camera-to-world matrices
792
+ """
793
+ assert matrix.shape == (4, 4)
794
+ assert len(camtoworlds.shape) == 3 and camtoworlds.shape[1:] == (4, 4)
795
+ camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, matrix)
796
+ scaling = np.linalg.norm(camtoworlds[:, 0, :3], axis=1)
797
+ camtoworlds[:, :3, :3] = camtoworlds[:, :3, :3] / scaling[:, None, None]
798
+ return camtoworlds
799
+
800
+
801
+ def normalize_scene(camtoworlds, points=None, camera_center_method="focus"):
802
+ T1 = similarity_from_cameras(camtoworlds, center_method=camera_center_method)
803
+ camtoworlds = transform_cameras(T1, camtoworlds)
804
+ if points is not None:
805
+ points = transform_points(T1, points)
806
+ T2 = align_principle_axes(points)
807
+ camtoworlds = transform_cameras(T2, camtoworlds)
808
+ points = transform_points(T2, points)
809
+ return camtoworlds, points, T2 @ T1
810
+ else:
811
+ return camtoworlds, T1
seva/gui.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import colorsys
2
+ import dataclasses
3
+ import threading
4
+ import time
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import scipy
9
+ import splines
10
+ import splines.quaternion
11
+ import torch
12
+ import viser
13
+ import viser.transforms as vt
14
+
15
+ from seva.geometry import get_preset_pose_fov
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class Keyframe(object):
20
+ position: np.ndarray
21
+ wxyz: np.ndarray
22
+ override_fov_enabled: bool
23
+ override_fov_rad: float
24
+ aspect: float
25
+ override_transition_enabled: bool
26
+ override_transition_sec: float | None
27
+
28
+ @staticmethod
29
+ def from_camera(camera: viser.CameraHandle, aspect: float) -> "Keyframe":
30
+ return Keyframe(
31
+ camera.position,
32
+ camera.wxyz,
33
+ override_fov_enabled=False,
34
+ override_fov_rad=camera.fov,
35
+ aspect=aspect,
36
+ override_transition_enabled=False,
37
+ override_transition_sec=None,
38
+ )
39
+
40
+ @staticmethod
41
+ def from_se3(se3: vt.SE3, fov: float, aspect: float) -> "Keyframe":
42
+ return Keyframe(
43
+ se3.translation(),
44
+ se3.rotation().wxyz,
45
+ override_fov_enabled=False,
46
+ override_fov_rad=fov,
47
+ aspect=aspect,
48
+ override_transition_enabled=False,
49
+ override_transition_sec=None,
50
+ )
51
+
52
+
53
+ class CameraTrajectory(object):
54
+ def __init__(
55
+ self,
56
+ server: viser.ViserServer,
57
+ duration_element: viser.GuiInputHandle[float],
58
+ scene_scale: float,
59
+ scene_node_prefix: str = "/",
60
+ ):
61
+ self._server = server
62
+ self._keyframes: dict[int, tuple[Keyframe, viser.CameraFrustumHandle]] = {}
63
+ self._keyframe_counter: int = 0
64
+ self._spline_nodes: list[viser.SceneNodeHandle] = []
65
+ self._camera_edit_panel: viser.Gui3dContainerHandle | None = None
66
+
67
+ self._orientation_spline: splines.quaternion.KochanekBartels | None = None
68
+ self._position_spline: splines.KochanekBartels | None = None
69
+ self._fov_spline: splines.KochanekBartels | None = None
70
+
71
+ self._keyframes_visible: bool = True
72
+
73
+ self._duration_element = duration_element
74
+ self._scene_node_prefix = scene_node_prefix
75
+
76
+ self.scene_scale = scene_scale
77
+ # These parameters should be overridden externally.
78
+ self.loop: bool = False
79
+ self.framerate: float = 30.0
80
+ self.tension: float = 0.0 # Tension / alpha term.
81
+ self.default_fov: float = 0.0
82
+ self.default_transition_sec: float = 0.0
83
+ self.show_spline: bool = True
84
+
85
+ def set_keyframes_visible(self, visible: bool) -> None:
86
+ self._keyframes_visible = visible
87
+ for keyframe in self._keyframes.values():
88
+ keyframe[1].visible = visible
89
+
90
+ def add_camera(self, keyframe: Keyframe, keyframe_index: int | None = None) -> None:
91
+ """Add a new camera, or replace an old one if `keyframe_index` is passed in."""
92
+ server = self._server
93
+
94
+ # Add a keyframe if we aren't replacing an existing one.
95
+ if keyframe_index is None:
96
+ keyframe_index = self._keyframe_counter
97
+ self._keyframe_counter += 1
98
+
99
+ print(
100
+ f"{keyframe.wxyz=} {keyframe.position=} {keyframe_index=} {keyframe.aspect=}"
101
+ )
102
+ frustum_handle = server.scene.add_camera_frustum(
103
+ str(Path(self._scene_node_prefix) / f"cameras/{keyframe_index}"),
104
+ fov=(
105
+ keyframe.override_fov_rad
106
+ if keyframe.override_fov_enabled
107
+ else self.default_fov
108
+ ),
109
+ aspect=keyframe.aspect,
110
+ scale=0.1 * self.scene_scale,
111
+ color=(200, 10, 30),
112
+ wxyz=keyframe.wxyz,
113
+ position=keyframe.position,
114
+ visible=self._keyframes_visible,
115
+ )
116
+ self._server.scene.add_icosphere(
117
+ str(Path(self._scene_node_prefix) / f"cameras/{keyframe_index}/sphere"),
118
+ radius=0.03,
119
+ color=(200, 10, 30),
120
+ )
121
+
122
+ @frustum_handle.on_click
123
+ def _(_) -> None:
124
+ if self._camera_edit_panel is not None:
125
+ self._camera_edit_panel.remove()
126
+ self._camera_edit_panel = None
127
+
128
+ with server.scene.add_3d_gui_container(
129
+ "/camera_edit_panel",
130
+ position=keyframe.position,
131
+ ) as camera_edit_panel:
132
+ self._camera_edit_panel = camera_edit_panel
133
+ override_fov = server.gui.add_checkbox(
134
+ "Override FOV", initial_value=keyframe.override_fov_enabled
135
+ )
136
+ override_fov_degrees = server.gui.add_slider(
137
+ "Override FOV (degrees)",
138
+ 5.0,
139
+ 175.0,
140
+ step=0.1,
141
+ initial_value=keyframe.override_fov_rad * 180.0 / np.pi,
142
+ disabled=not keyframe.override_fov_enabled,
143
+ )
144
+ delete_button = server.gui.add_button(
145
+ "Delete", color="red", icon=viser.Icon.TRASH
146
+ )
147
+ go_to_button = server.gui.add_button("Go to")
148
+ close_button = server.gui.add_button("Close")
149
+
150
+ @override_fov.on_update
151
+ def _(_) -> None:
152
+ keyframe.override_fov_enabled = override_fov.value
153
+ override_fov_degrees.disabled = not override_fov.value
154
+ self.add_camera(keyframe, keyframe_index)
155
+
156
+ @override_fov_degrees.on_update
157
+ def _(_) -> None:
158
+ keyframe.override_fov_rad = override_fov_degrees.value / 180.0 * np.pi
159
+ self.add_camera(keyframe, keyframe_index)
160
+
161
+ @delete_button.on_click
162
+ def _(event: viser.GuiEvent) -> None:
163
+ assert event.client is not None
164
+ with event.client.gui.add_modal("Confirm") as modal:
165
+ event.client.gui.add_markdown("Delete keyframe?")
166
+ confirm_button = event.client.gui.add_button(
167
+ "Yes", color="red", icon=viser.Icon.TRASH
168
+ )
169
+ exit_button = event.client.gui.add_button("Cancel")
170
+
171
+ @confirm_button.on_click
172
+ def _(_) -> None:
173
+ assert camera_edit_panel is not None
174
+
175
+ keyframe_id = None
176
+ for i, keyframe_tuple in self._keyframes.items():
177
+ if keyframe_tuple[1] is frustum_handle:
178
+ keyframe_id = i
179
+ break
180
+ assert keyframe_id is not None
181
+
182
+ self._keyframes.pop(keyframe_id)
183
+ frustum_handle.remove()
184
+ camera_edit_panel.remove()
185
+ self._camera_edit_panel = None
186
+ modal.close()
187
+ self.update_spline()
188
+
189
+ @exit_button.on_click
190
+ def _(_) -> None:
191
+ modal.close()
192
+
193
+ @go_to_button.on_click
194
+ def _(event: viser.GuiEvent) -> None:
195
+ assert event.client is not None
196
+ client = event.client
197
+ T_world_current = vt.SE3.from_rotation_and_translation(
198
+ vt.SO3(client.camera.wxyz), client.camera.position
199
+ )
200
+ T_world_target = vt.SE3.from_rotation_and_translation(
201
+ vt.SO3(keyframe.wxyz), keyframe.position
202
+ ) @ vt.SE3.from_translation(np.array([0.0, 0.0, -0.5]))
203
+
204
+ T_current_target = T_world_current.inverse() @ T_world_target
205
+
206
+ for j in range(10):
207
+ T_world_set = T_world_current @ vt.SE3.exp(
208
+ T_current_target.log() * j / 9.0
209
+ )
210
+
211
+ # Important bit: we atomically set both the orientation and
212
+ # the position of the camera.
213
+ with client.atomic():
214
+ client.camera.wxyz = T_world_set.rotation().wxyz
215
+ client.camera.position = T_world_set.translation()
216
+ time.sleep(1.0 / 30.0)
217
+
218
+ @close_button.on_click
219
+ def _(_) -> None:
220
+ assert camera_edit_panel is not None
221
+ camera_edit_panel.remove()
222
+ self._camera_edit_panel = None
223
+
224
+ self._keyframes[keyframe_index] = (keyframe, frustum_handle)
225
+
226
+ def update_aspect(self, aspect: float) -> None:
227
+ for keyframe_index, frame in self._keyframes.items():
228
+ frame = dataclasses.replace(frame[0], aspect=aspect)
229
+ self.add_camera(frame, keyframe_index=keyframe_index)
230
+
231
+ def get_aspect(self) -> float:
232
+ """Get W/H aspect ratio, which is shared across all keyframes."""
233
+ assert len(self._keyframes) > 0
234
+ return next(iter(self._keyframes.values()))[0].aspect
235
+
236
+ def reset(self) -> None:
237
+ for frame in self._keyframes.values():
238
+ print(f"removing {frame[1]}")
239
+ frame[1].remove()
240
+ self._keyframes.clear()
241
+ self.update_spline()
242
+ print("camera traj reset")
243
+
244
+ def spline_t_from_t_sec(self, time: np.ndarray) -> np.ndarray:
245
+ """From a time value in seconds, compute a t value for our geometric
246
+ spline interpolation. An increment of 1 for the latter will move the
247
+ camera forward by one keyframe.
248
+
249
+ We use a PCHIP spline here to guarantee monotonicity.
250
+ """
251
+ transition_times_cumsum = self.compute_transition_times_cumsum()
252
+ spline_indices = np.arange(transition_times_cumsum.shape[0])
253
+
254
+ if self.loop:
255
+ # In the case of a loop, we pad the spline to match the start/end
256
+ # slopes.
257
+ interpolator = scipy.interpolate.PchipInterpolator(
258
+ x=np.concatenate(
259
+ [
260
+ [-(transition_times_cumsum[-1] - transition_times_cumsum[-2])],
261
+ transition_times_cumsum,
262
+ transition_times_cumsum[-1:] + transition_times_cumsum[1:2],
263
+ ],
264
+ axis=0,
265
+ ),
266
+ y=np.concatenate(
267
+ [[-1], spline_indices, [spline_indices[-1] + 1]], # type: ignore
268
+ axis=0,
269
+ ),
270
+ )
271
+ else:
272
+ interpolator = scipy.interpolate.PchipInterpolator(
273
+ x=transition_times_cumsum, y=spline_indices
274
+ )
275
+
276
+ # Clip to account for floating point error.
277
+ return np.clip(interpolator(time), 0, spline_indices[-1])
278
+
279
+ def interpolate_pose_and_fov_rad(
280
+ self, normalized_t: float
281
+ ) -> tuple[vt.SE3, float] | None:
282
+ if len(self._keyframes) < 2:
283
+ return None
284
+
285
+ self._fov_spline = splines.KochanekBartels(
286
+ [
287
+ (
288
+ keyframe[0].override_fov_rad
289
+ if keyframe[0].override_fov_enabled
290
+ else self.default_fov
291
+ )
292
+ for keyframe in self._keyframes.values()
293
+ ],
294
+ tcb=(self.tension, 0.0, 0.0),
295
+ endconditions="closed" if self.loop else "natural",
296
+ )
297
+
298
+ assert self._orientation_spline is not None
299
+ assert self._position_spline is not None
300
+ assert self._fov_spline is not None
301
+
302
+ max_t = self.compute_duration()
303
+ t = max_t * normalized_t
304
+ spline_t = float(self.spline_t_from_t_sec(np.array(t)))
305
+
306
+ quat = self._orientation_spline.evaluate(spline_t)
307
+ assert isinstance(quat, splines.quaternion.UnitQuaternion)
308
+ return (
309
+ vt.SE3.from_rotation_and_translation(
310
+ vt.SO3(np.array([quat.scalar, *quat.vector])),
311
+ self._position_spline.evaluate(spline_t),
312
+ ),
313
+ float(self._fov_spline.evaluate(spline_t)),
314
+ )
315
+
316
+ def update_spline(self) -> None:
317
+ num_frames = int(self.compute_duration() * self.framerate)
318
+ keyframes = list(self._keyframes.values())
319
+
320
+ if num_frames <= 0 or not self.show_spline or len(keyframes) < 2:
321
+ for node in self._spline_nodes:
322
+ node.remove()
323
+ self._spline_nodes.clear()
324
+ return
325
+
326
+ transition_times_cumsum = self.compute_transition_times_cumsum()
327
+
328
+ self._orientation_spline = splines.quaternion.KochanekBartels(
329
+ [
330
+ splines.quaternion.UnitQuaternion.from_unit_xyzw(
331
+ np.roll(keyframe[0].wxyz, shift=-1)
332
+ )
333
+ for keyframe in keyframes
334
+ ],
335
+ tcb=(self.tension, 0.0, 0.0),
336
+ endconditions="closed" if self.loop else "natural",
337
+ )
338
+ self._position_spline = splines.KochanekBartels(
339
+ [keyframe[0].position for keyframe in keyframes],
340
+ tcb=(self.tension, 0.0, 0.0),
341
+ endconditions="closed" if self.loop else "natural",
342
+ )
343
+
344
+ # Update visualized spline.
345
+ points_array = self._position_spline.evaluate(
346
+ self.spline_t_from_t_sec(
347
+ np.linspace(0, transition_times_cumsum[-1], num_frames)
348
+ )
349
+ )
350
+ colors_array = np.array(
351
+ [
352
+ colorsys.hls_to_rgb(h, 0.5, 1.0)
353
+ for h in np.linspace(0.0, 1.0, len(points_array))
354
+ ]
355
+ )
356
+
357
+ # Clear prior spline nodes.
358
+ for node in self._spline_nodes:
359
+ node.remove()
360
+ self._spline_nodes.clear()
361
+
362
+ self._spline_nodes.append(
363
+ self._server.scene.add_spline_catmull_rom(
364
+ str(Path(self._scene_node_prefix) / "camera_spline"),
365
+ positions=points_array,
366
+ color=(220, 220, 220),
367
+ closed=self.loop,
368
+ line_width=1.0,
369
+ segments=points_array.shape[0] + 1,
370
+ )
371
+ )
372
+ self._spline_nodes.append(
373
+ self._server.scene.add_point_cloud(
374
+ str(Path(self._scene_node_prefix) / "camera_spline/points"),
375
+ points=points_array,
376
+ colors=colors_array,
377
+ point_size=0.04,
378
+ )
379
+ )
380
+
381
+ def make_transition_handle(i: int) -> None:
382
+ assert self._position_spline is not None
383
+ transition_pos = self._position_spline.evaluate(
384
+ float(
385
+ self.spline_t_from_t_sec(
386
+ (transition_times_cumsum[i] + transition_times_cumsum[i + 1])
387
+ / 2.0,
388
+ )
389
+ )
390
+ )
391
+ transition_sphere = self._server.scene.add_icosphere(
392
+ str(Path(self._scene_node_prefix) / f"camera_spline/transition_{i}"),
393
+ radius=0.04,
394
+ color=(255, 0, 0),
395
+ position=transition_pos,
396
+ )
397
+ self._spline_nodes.append(transition_sphere)
398
+
399
+ @transition_sphere.on_click
400
+ def _(_) -> None:
401
+ server = self._server
402
+
403
+ if self._camera_edit_panel is not None:
404
+ self._camera_edit_panel.remove()
405
+ self._camera_edit_panel = None
406
+
407
+ keyframe_index = (i + 1) % len(self._keyframes)
408
+ keyframe = keyframes[keyframe_index][0]
409
+
410
+ with server.scene.add_3d_gui_container(
411
+ "/camera_edit_panel",
412
+ position=transition_pos,
413
+ ) as camera_edit_panel:
414
+ self._camera_edit_panel = camera_edit_panel
415
+ override_transition_enabled = server.gui.add_checkbox(
416
+ "Override transition",
417
+ initial_value=keyframe.override_transition_enabled,
418
+ )
419
+ override_transition_sec = server.gui.add_number(
420
+ "Override transition (sec)",
421
+ initial_value=(
422
+ keyframe.override_transition_sec
423
+ if keyframe.override_transition_sec is not None
424
+ else self.default_transition_sec
425
+ ),
426
+ min=0.001,
427
+ max=30.0,
428
+ step=0.001,
429
+ disabled=not override_transition_enabled.value,
430
+ )
431
+ close_button = server.gui.add_button("Close")
432
+
433
+ @override_transition_enabled.on_update
434
+ def _(_) -> None:
435
+ keyframe.override_transition_enabled = (
436
+ override_transition_enabled.value
437
+ )
438
+ override_transition_sec.disabled = (
439
+ not override_transition_enabled.value
440
+ )
441
+ self._duration_element.value = self.compute_duration()
442
+
443
+ @override_transition_sec.on_update
444
+ def _(_) -> None:
445
+ keyframe.override_transition_sec = override_transition_sec.value
446
+ self._duration_element.value = self.compute_duration()
447
+
448
+ @close_button.on_click
449
+ def _(_) -> None:
450
+ assert camera_edit_panel is not None
451
+ camera_edit_panel.remove()
452
+ self._camera_edit_panel = None
453
+
454
+ (num_transitions_plus_1,) = transition_times_cumsum.shape
455
+ for i in range(num_transitions_plus_1 - 1):
456
+ make_transition_handle(i)
457
+
458
+ def compute_duration(self) -> float:
459
+ """Compute the total duration of the trajectory."""
460
+ total = 0.0
461
+ for i, (keyframe, frustum) in enumerate(self._keyframes.values()):
462
+ if i == 0 and not self.loop:
463
+ continue
464
+ del frustum
465
+ total += (
466
+ keyframe.override_transition_sec
467
+ if keyframe.override_transition_enabled
468
+ and keyframe.override_transition_sec is not None
469
+ else self.default_transition_sec
470
+ )
471
+ return total
472
+
473
+ def compute_transition_times_cumsum(self) -> np.ndarray:
474
+ """Compute the total duration of the trajectory."""
475
+ total = 0.0
476
+ out = [0.0]
477
+ for i, (keyframe, frustum) in enumerate(self._keyframes.values()):
478
+ if i == 0:
479
+ continue
480
+ del frustum
481
+ total += (
482
+ keyframe.override_transition_sec
483
+ if keyframe.override_transition_enabled
484
+ and keyframe.override_transition_sec is not None
485
+ else self.default_transition_sec
486
+ )
487
+ out.append(total)
488
+
489
+ if self.loop:
490
+ keyframe = next(iter(self._keyframes.values()))[0]
491
+ total += (
492
+ keyframe.override_transition_sec
493
+ if keyframe.override_transition_enabled
494
+ and keyframe.override_transition_sec is not None
495
+ else self.default_transition_sec
496
+ )
497
+ out.append(total)
498
+
499
+ return np.array(out)
500
+
501
+
502
+ @dataclasses.dataclass
503
+ class GuiState:
504
+ preview_render: bool
505
+ preview_fov: float
506
+ preview_aspect: float
507
+ camera_traj_list: list | None
508
+ active_input_index: int
509
+
510
+
511
+ def define_gui(
512
+ server: viser.ViserServer,
513
+ init_fov: float = 75.0,
514
+ img_wh: tuple[int, int] = (576, 576),
515
+ **kwargs,
516
+ ) -> GuiState:
517
+ gui_state = GuiState(
518
+ preview_render=False,
519
+ preview_fov=0.0,
520
+ preview_aspect=1.0,
521
+ camera_traj_list=None,
522
+ active_input_index=0,
523
+ )
524
+
525
+ with server.gui.add_folder(
526
+ "Preset camera trajectories", order=99, expand_by_default=False
527
+ ):
528
+ preset_traj_dropdown = server.gui.add_dropdown(
529
+ "Options",
530
+ [
531
+ "orbit",
532
+ "spiral",
533
+ "lemniscate",
534
+ "zoom-out",
535
+ "dolly zoom-out",
536
+ ],
537
+ initial_value="orbit",
538
+ hint="Select a preset camera trajectory.",
539
+ )
540
+ preset_duration_num = server.gui.add_number(
541
+ "Duration (sec)",
542
+ min=1.0,
543
+ max=60.0,
544
+ step=0.5,
545
+ initial_value=2.0,
546
+ )
547
+ preset_submit_button = server.gui.add_button(
548
+ "Submit",
549
+ icon=viser.Icon.PICK,
550
+ hint="Add a new keyframe at the current pose.",
551
+ )
552
+
553
+ @preset_submit_button.on_click
554
+ def _(event: viser.GuiEvent) -> None:
555
+ camera_traj.reset()
556
+ gui_state.camera_traj_list = None
557
+
558
+ duration = preset_duration_num.value
559
+ fps = framerate_number.value
560
+ num_frames = int(duration * fps)
561
+ transition_sec = duration / num_frames
562
+ transition_sec_number.value = transition_sec
563
+ assert event.client_id is not None
564
+ transition_sec_number.disabled = True
565
+ loop_checkbox.disabled = True
566
+ add_keyframe_button.disabled = True
567
+
568
+ camera = server.get_clients()[event.client_id].camera
569
+ start_w2c = torch.linalg.inv(
570
+ torch.as_tensor(
571
+ vt.SE3.from_rotation_and_translation(
572
+ vt.SO3(camera.wxyz), camera.position
573
+ ).as_matrix(),
574
+ dtype=torch.float32,
575
+ )
576
+ )
577
+ look_at = torch.as_tensor(camera.look_at, dtype=torch.float32)
578
+ up_direction = torch.as_tensor(camera.up_direction, dtype=torch.float32)
579
+ poses, fovs = get_preset_pose_fov(
580
+ option=preset_traj_dropdown.value, # type: ignore
581
+ num_frames=num_frames,
582
+ start_w2c=start_w2c,
583
+ look_at=look_at,
584
+ up_direction=up_direction,
585
+ fov=camera.fov,
586
+ )
587
+ assert poses is not None and fovs is not None
588
+ for pose, fov in zip(poses, fovs):
589
+ camera_traj.add_camera(
590
+ Keyframe.from_se3(
591
+ vt.SE3.from_matrix(pose),
592
+ fov=fov,
593
+ aspect=img_wh[0] / img_wh[1],
594
+ )
595
+ )
596
+
597
+ duration_number.value = camera_traj.compute_duration()
598
+ camera_traj.update_spline()
599
+
600
+ with server.gui.add_folder("Advanced", expand_by_default=False, order=100):
601
+ transition_sec_number = server.gui.add_number(
602
+ "Transition (sec)",
603
+ min=0.001,
604
+ max=30.0,
605
+ step=0.001,
606
+ initial_value=1.5,
607
+ hint="Time in seconds between each keyframe, which can also be overridden on a per-transition basis.",
608
+ )
609
+ framerate_number = server.gui.add_number(
610
+ "FPS", min=0.1, max=240.0, step=1e-2, initial_value=30.0
611
+ )
612
+ framerate_buttons = server.gui.add_button_group("", ("24", "30", "60"))
613
+ duration_number = server.gui.add_number(
614
+ "Duration (sec)",
615
+ min=0.0,
616
+ max=1e8,
617
+ step=0.001,
618
+ initial_value=0.0,
619
+ disabled=True,
620
+ )
621
+
622
+ @framerate_buttons.on_click
623
+ def _(_) -> None:
624
+ framerate_number.value = float(framerate_buttons.value)
625
+
626
+ fov_degree_slider = server.gui.add_slider(
627
+ "FOV",
628
+ initial_value=init_fov,
629
+ min=0.1,
630
+ max=175.0,
631
+ step=0.01,
632
+ hint="Field-of-view for rendering, which can also be overridden on a per-keyframe basis.",
633
+ )
634
+
635
+ @fov_degree_slider.on_update
636
+ def _(_) -> None:
637
+ fov_radians = fov_degree_slider.value / 180.0 * np.pi
638
+ for client in server.get_clients().values():
639
+ client.camera.fov = fov_radians
640
+ camera_traj.default_fov = fov_radians
641
+
642
+ # Updating the aspect ratio will also re-render the camera frustums.
643
+ # Could rethink this.
644
+ camera_traj.update_aspect(img_wh[0] / img_wh[1])
645
+ compute_and_update_preview_camera_state()
646
+
647
+ scene_node_prefix = "/render_assets"
648
+ base_scene_node = server.scene.add_frame(scene_node_prefix, show_axes=False)
649
+ add_keyframe_button = server.gui.add_button(
650
+ "Add keyframe",
651
+ icon=viser.Icon.PLUS,
652
+ hint="Add a new keyframe at the current pose.",
653
+ )
654
+
655
+ @add_keyframe_button.on_click
656
+ def _(event: viser.GuiEvent) -> None:
657
+ assert event.client_id is not None
658
+ camera = server.get_clients()[event.client_id].camera
659
+ pose = vt.SE3.from_rotation_and_translation(
660
+ vt.SO3(camera.wxyz), camera.position
661
+ )
662
+ print(f"client {event.client_id} at {camera.position} {camera.wxyz}")
663
+ print(f"camera pose {pose.as_matrix()}")
664
+
665
+ # Add this camera to the trajectory.
666
+ camera_traj.add_camera(
667
+ Keyframe.from_camera(
668
+ camera,
669
+ aspect=img_wh[0] / img_wh[1],
670
+ ),
671
+ )
672
+ duration_number.value = camera_traj.compute_duration()
673
+ camera_traj.update_spline()
674
+
675
+ clear_keyframes_button = server.gui.add_button(
676
+ "Clear keyframes",
677
+ icon=viser.Icon.TRASH,
678
+ hint="Remove all keyframes from the render trajectory.",
679
+ )
680
+
681
+ @clear_keyframes_button.on_click
682
+ def _(event: viser.GuiEvent) -> None:
683
+ assert event.client_id is not None
684
+ client = server.get_clients()[event.client_id]
685
+ with client.atomic(), client.gui.add_modal("Confirm") as modal:
686
+ client.gui.add_markdown("Clear all keyframes?")
687
+ confirm_button = client.gui.add_button(
688
+ "Yes", color="red", icon=viser.Icon.TRASH
689
+ )
690
+ exit_button = client.gui.add_button("Cancel")
691
+
692
+ @confirm_button.on_click
693
+ def _(_) -> None:
694
+ camera_traj.reset()
695
+ modal.close()
696
+
697
+ duration_number.value = camera_traj.compute_duration()
698
+ add_keyframe_button.disabled = False
699
+ transition_sec_number.disabled = False
700
+ transition_sec_number.value = 1.5
701
+ loop_checkbox.disabled = False
702
+
703
+ nonlocal gui_state
704
+ gui_state.camera_traj_list = None
705
+
706
+ @exit_button.on_click
707
+ def _(_) -> None:
708
+ modal.close()
709
+
710
+ play_button = server.gui.add_button("Play", icon=viser.Icon.PLAYER_PLAY)
711
+ pause_button = server.gui.add_button(
712
+ "Pause", icon=viser.Icon.PLAYER_PAUSE, visible=False
713
+ )
714
+
715
+ # Poll the play button to see if we should be playing endlessly.
716
+ def play() -> None:
717
+ while True:
718
+ while not play_button.visible:
719
+ max_frame = int(framerate_number.value * duration_number.value)
720
+ if max_frame > 0:
721
+ assert preview_frame_slider is not None
722
+ preview_frame_slider.value = (
723
+ preview_frame_slider.value + 1
724
+ ) % max_frame
725
+ time.sleep(1.0 / framerate_number.value)
726
+ time.sleep(0.1)
727
+
728
+ threading.Thread(target=play).start()
729
+
730
+ # Play the camera trajectory when the play button is pressed.
731
+ @play_button.on_click
732
+ def _(_) -> None:
733
+ play_button.visible = False
734
+ pause_button.visible = True
735
+
736
+ # Play the camera trajectory when the play button is pressed.
737
+ @pause_button.on_click
738
+ def _(_) -> None:
739
+ play_button.visible = True
740
+ pause_button.visible = False
741
+
742
+ preview_render_button = server.gui.add_button(
743
+ "Preview render",
744
+ hint="Show a preview of the render in the viewport.",
745
+ icon=viser.Icon.CAMERA_CHECK,
746
+ )
747
+ preview_render_stop_button = server.gui.add_button(
748
+ "Exit render preview",
749
+ color="red",
750
+ icon=viser.Icon.CAMERA_CANCEL,
751
+ visible=False,
752
+ )
753
+
754
+ @preview_render_button.on_click
755
+ def _(_) -> None:
756
+ gui_state.preview_render = True
757
+ preview_render_button.visible = False
758
+ preview_render_stop_button.visible = True
759
+ play_button.visible = False
760
+ pause_button.visible = True
761
+ preset_submit_button.disabled = True
762
+
763
+ maybe_pose_and_fov_rad = compute_and_update_preview_camera_state()
764
+ if maybe_pose_and_fov_rad is None:
765
+ remove_preview_camera()
766
+ return
767
+ pose, fov = maybe_pose_and_fov_rad
768
+ del fov
769
+
770
+ # Hide all render assets when we're previewing the render.
771
+ nonlocal base_scene_node
772
+ base_scene_node.visible = False
773
+
774
+ # Back up and then set camera poses.
775
+ for client in server.get_clients().values():
776
+ camera_pose_backup_from_id[client.client_id] = (
777
+ client.camera.position,
778
+ client.camera.look_at,
779
+ client.camera.up_direction,
780
+ )
781
+ with client.atomic():
782
+ client.camera.wxyz = pose.rotation().wxyz
783
+ client.camera.position = pose.translation()
784
+
785
+ def stop_preview_render() -> None:
786
+ gui_state.preview_render = False
787
+ preview_render_button.visible = True
788
+ preview_render_stop_button.visible = False
789
+ play_button.visible = True
790
+ pause_button.visible = False
791
+ preset_submit_button.disabled = False
792
+
793
+ # Revert camera poses.
794
+ for client in server.get_clients().values():
795
+ if client.client_id not in camera_pose_backup_from_id:
796
+ continue
797
+ cam_position, cam_look_at, cam_up = camera_pose_backup_from_id.pop(
798
+ client.client_id
799
+ )
800
+ with client.atomic():
801
+ client.camera.position = cam_position
802
+ client.camera.look_at = cam_look_at
803
+ client.camera.up_direction = cam_up
804
+ client.flush()
805
+
806
+ # Un-hide render assets.
807
+ nonlocal base_scene_node
808
+ base_scene_node.visible = True
809
+ remove_preview_camera()
810
+
811
+ @preview_render_stop_button.on_click
812
+ def _(_) -> None:
813
+ stop_preview_render()
814
+
815
+ def get_max_frame_index() -> int:
816
+ return max(1, int(framerate_number.value * duration_number.value) - 1)
817
+
818
+ def add_preview_frame_slider() -> viser.GuiInputHandle[int] | None:
819
+ """Helper for creating the current frame # slider. This is removed and
820
+ re-added anytime the `max` value changes."""
821
+
822
+ preview_frame_slider = server.gui.add_slider(
823
+ "Preview frame",
824
+ min=0,
825
+ max=get_max_frame_index(),
826
+ step=1,
827
+ initial_value=0,
828
+ order=set_traj_button.order + 0.01,
829
+ disabled=get_max_frame_index() == 1,
830
+ )
831
+ play_button.disabled = preview_frame_slider.disabled
832
+ preview_render_button.disabled = preview_frame_slider.disabled
833
+ set_traj_button.disabled = preview_frame_slider.disabled
834
+
835
+ @preview_frame_slider.on_update
836
+ def _(_) -> None:
837
+ nonlocal preview_camera_handle
838
+ maybe_pose_and_fov_rad = compute_and_update_preview_camera_state()
839
+ if maybe_pose_and_fov_rad is None:
840
+ return
841
+ pose, fov_rad = maybe_pose_and_fov_rad
842
+
843
+ preview_camera_handle = server.scene.add_camera_frustum(
844
+ str(Path(scene_node_prefix) / "preview_camera"),
845
+ fov=fov_rad,
846
+ aspect=img_wh[0] / img_wh[1],
847
+ scale=0.35,
848
+ wxyz=pose.rotation().wxyz,
849
+ position=pose.translation(),
850
+ color=(10, 200, 30),
851
+ )
852
+ if gui_state.preview_render:
853
+ for client in server.get_clients().values():
854
+ with client.atomic():
855
+ client.camera.wxyz = pose.rotation().wxyz
856
+ client.camera.position = pose.translation()
857
+
858
+ return preview_frame_slider
859
+
860
+ set_traj_button = server.gui.add_button(
861
+ "Set camera trajectory",
862
+ color="green",
863
+ icon=viser.Icon.CHECK,
864
+ hint="Save the camera trajectory for rendering.",
865
+ )
866
+
867
+ @set_traj_button.on_click
868
+ def _(event: viser.GuiEvent) -> None:
869
+ assert event.client is not None
870
+ num_frames = int(framerate_number.value * duration_number.value)
871
+
872
+ def get_intrinsics(W, H, fov_rad):
873
+ focal = 0.5 * H / np.tan(0.5 * fov_rad)
874
+ return np.array(
875
+ [[focal, 0.0, 0.5 * W], [0.0, focal, 0.5 * H], [0.0, 0.0, 1.0]]
876
+ )
877
+
878
+ camera_traj_list = []
879
+ for i in range(num_frames):
880
+ maybe_pose_and_fov_rad = camera_traj.interpolate_pose_and_fov_rad(
881
+ i / num_frames
882
+ )
883
+ if maybe_pose_and_fov_rad is None:
884
+ return
885
+ pose, fov_rad = maybe_pose_and_fov_rad
886
+ H = img_wh[1]
887
+ W = img_wh[0]
888
+ K = get_intrinsics(W, H, fov_rad)
889
+ w2c = pose.inverse().as_matrix()
890
+ camera_traj_list.append(
891
+ {
892
+ "w2c": w2c.flatten().tolist(),
893
+ "K": K.flatten().tolist(),
894
+ "img_wh": (W, H),
895
+ }
896
+ )
897
+ nonlocal gui_state
898
+ gui_state.camera_traj_list = camera_traj_list
899
+ print(f"Get camera_traj_list: {gui_state.camera_traj_list}")
900
+
901
+ stop_preview_render()
902
+
903
+ preview_frame_slider = add_preview_frame_slider()
904
+
905
+ loop_checkbox = server.gui.add_checkbox(
906
+ "Loop", False, hint="Add a segment between the first and last keyframes."
907
+ )
908
+
909
+ @loop_checkbox.on_update
910
+ def _(_) -> None:
911
+ camera_traj.loop = loop_checkbox.value
912
+ duration_number.value = camera_traj.compute_duration()
913
+
914
+ @transition_sec_number.on_update
915
+ def _(_) -> None:
916
+ camera_traj.default_transition_sec = transition_sec_number.value
917
+ duration_number.value = camera_traj.compute_duration()
918
+
919
+ preview_camera_handle: viser.SceneNodeHandle | None = None
920
+
921
+ def remove_preview_camera() -> None:
922
+ nonlocal preview_camera_handle
923
+ if preview_camera_handle is not None:
924
+ preview_camera_handle.remove()
925
+ preview_camera_handle = None
926
+
927
+ def compute_and_update_preview_camera_state() -> tuple[vt.SE3, float] | None:
928
+ """Update the render tab state with the current preview camera pose.
929
+ Returns current camera pose + FOV if available."""
930
+
931
+ if preview_frame_slider is None:
932
+ return None
933
+ maybe_pose_and_fov_rad = camera_traj.interpolate_pose_and_fov_rad(
934
+ preview_frame_slider.value / get_max_frame_index()
935
+ )
936
+ if maybe_pose_and_fov_rad is None:
937
+ remove_preview_camera()
938
+ return None
939
+ pose, fov_rad = maybe_pose_and_fov_rad
940
+ gui_state.preview_fov = fov_rad
941
+ gui_state.preview_aspect = camera_traj.get_aspect()
942
+ return pose, fov_rad
943
+
944
+ # We back up the camera poses before and after we start previewing renders.
945
+ camera_pose_backup_from_id: dict[int, tuple] = {}
946
+
947
+ # Update the # of frames.
948
+ @duration_number.on_update
949
+ @framerate_number.on_update
950
+ def _(_) -> None:
951
+ remove_preview_camera() # Will be re-added when slider is updated.
952
+
953
+ nonlocal preview_frame_slider
954
+ old = preview_frame_slider
955
+ assert old is not None
956
+
957
+ preview_frame_slider = add_preview_frame_slider()
958
+ if preview_frame_slider is not None:
959
+ old.remove()
960
+ else:
961
+ preview_frame_slider = old
962
+
963
+ camera_traj.framerate = framerate_number.value
964
+ camera_traj.update_spline()
965
+
966
+ camera_traj = CameraTrajectory(
967
+ server,
968
+ duration_number,
969
+ scene_node_prefix=scene_node_prefix,
970
+ **kwargs,
971
+ )
972
+ camera_traj.default_fov = fov_degree_slider.value / 180.0 * np.pi
973
+ camera_traj.default_transition_sec = transition_sec_number.value
974
+
975
+ return gui_state
seva/model.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from seva.modules.layers import (
7
+ Downsample,
8
+ GroupNorm32,
9
+ ResBlock,
10
+ TimestepEmbedSequential,
11
+ Upsample,
12
+ timestep_embedding,
13
+ )
14
+ from seva.modules.transformer import MultiviewTransformer
15
+
16
+
17
+ @dataclass
18
+ class SevaParams(object):
19
+ in_channels: int = 11
20
+ model_channels: int = 320
21
+ out_channels: int = 4
22
+ num_frames: int = 21
23
+ num_res_blocks: int = 2
24
+ attention_resolutions: list[int] = field(default_factory=lambda: [4, 2, 1])
25
+ channel_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
26
+ num_head_channels: int = 64
27
+ transformer_depth: list[int] = field(default_factory=lambda: [1, 1, 1, 1])
28
+ context_dim: int = 1024
29
+ dense_in_channels: int = 6
30
+ dropout: float = 0.0
31
+ unflatten_names: list[str] = field(
32
+ default_factory=lambda: ["middle_ds8", "output_ds4", "output_ds2"]
33
+ )
34
+
35
+ def __post_init__(self):
36
+ assert len(self.channel_mult) == len(self.transformer_depth)
37
+
38
+
39
+ class Seva(nn.Module):
40
+ def __init__(self, params: SevaParams) -> None:
41
+ super().__init__()
42
+ self.params = params
43
+ self.model_channels = params.model_channels
44
+ self.out_channels = params.out_channels
45
+ self.num_head_channels = params.num_head_channels
46
+
47
+ time_embed_dim = params.model_channels * 4
48
+ self.time_embed = nn.Sequential(
49
+ nn.Linear(params.model_channels, time_embed_dim),
50
+ nn.SiLU(),
51
+ nn.Linear(time_embed_dim, time_embed_dim),
52
+ )
53
+
54
+ self.input_blocks = nn.ModuleList(
55
+ [
56
+ TimestepEmbedSequential(
57
+ nn.Conv2d(params.in_channels, params.model_channels, 3, padding=1)
58
+ )
59
+ ]
60
+ )
61
+ self._feature_size = params.model_channels
62
+ input_block_chans = [params.model_channels]
63
+ ch = params.model_channels
64
+ ds = 1
65
+ for level, mult in enumerate(params.channel_mult):
66
+ for _ in range(params.num_res_blocks):
67
+ input_layers: list[ResBlock | MultiviewTransformer | Downsample] = [
68
+ ResBlock(
69
+ channels=ch,
70
+ emb_channels=time_embed_dim,
71
+ out_channels=mult * params.model_channels,
72
+ dense_in_channels=params.dense_in_channels,
73
+ dropout=params.dropout,
74
+ )
75
+ ]
76
+ ch = mult * params.model_channels
77
+ if ds in params.attention_resolutions:
78
+ num_heads = ch // params.num_head_channels
79
+ dim_head = params.num_head_channels
80
+ input_layers.append(
81
+ MultiviewTransformer(
82
+ ch,
83
+ num_heads,
84
+ dim_head,
85
+ name=f"input_ds{ds}",
86
+ depth=params.transformer_depth[level],
87
+ context_dim=params.context_dim,
88
+ unflatten_names=params.unflatten_names,
89
+ )
90
+ )
91
+ self.input_blocks.append(TimestepEmbedSequential(*input_layers))
92
+ self._feature_size += ch
93
+ input_block_chans.append(ch)
94
+ if level != len(params.channel_mult) - 1:
95
+ ds *= 2
96
+ out_ch = ch
97
+ self.input_blocks.append(
98
+ TimestepEmbedSequential(Downsample(ch, out_channels=out_ch))
99
+ )
100
+ ch = out_ch
101
+ input_block_chans.append(ch)
102
+ self._feature_size += ch
103
+
104
+ num_heads = ch // params.num_head_channels
105
+ dim_head = params.num_head_channels
106
+
107
+ self.middle_block = TimestepEmbedSequential(
108
+ ResBlock(
109
+ channels=ch,
110
+ emb_channels=time_embed_dim,
111
+ out_channels=None,
112
+ dense_in_channels=params.dense_in_channels,
113
+ dropout=params.dropout,
114
+ ),
115
+ MultiviewTransformer(
116
+ ch,
117
+ num_heads,
118
+ dim_head,
119
+ name=f"middle_ds{ds}",
120
+ depth=params.transformer_depth[-1],
121
+ context_dim=params.context_dim,
122
+ unflatten_names=params.unflatten_names,
123
+ ),
124
+ ResBlock(
125
+ channels=ch,
126
+ emb_channels=time_embed_dim,
127
+ out_channels=None,
128
+ dense_in_channels=params.dense_in_channels,
129
+ dropout=params.dropout,
130
+ ),
131
+ )
132
+ self._feature_size += ch
133
+
134
+ self.output_blocks = nn.ModuleList([])
135
+ for level, mult in list(enumerate(params.channel_mult))[::-1]:
136
+ for i in range(params.num_res_blocks + 1):
137
+ ich = input_block_chans.pop()
138
+ output_layers: list[ResBlock | MultiviewTransformer | Upsample] = [
139
+ ResBlock(
140
+ channels=ch + ich,
141
+ emb_channels=time_embed_dim,
142
+ out_channels=params.model_channels * mult,
143
+ dense_in_channels=params.dense_in_channels,
144
+ dropout=params.dropout,
145
+ )
146
+ ]
147
+ ch = params.model_channels * mult
148
+ if ds in params.attention_resolutions:
149
+ num_heads = ch // params.num_head_channels
150
+ dim_head = params.num_head_channels
151
+
152
+ output_layers.append(
153
+ MultiviewTransformer(
154
+ ch,
155
+ num_heads,
156
+ dim_head,
157
+ name=f"output_ds{ds}",
158
+ depth=params.transformer_depth[level],
159
+ context_dim=params.context_dim,
160
+ unflatten_names=params.unflatten_names,
161
+ )
162
+ )
163
+ if level and i == params.num_res_blocks:
164
+ out_ch = ch
165
+ ds //= 2
166
+ output_layers.append(Upsample(ch, out_ch))
167
+ self.output_blocks.append(TimestepEmbedSequential(*output_layers))
168
+ self._feature_size += ch
169
+
170
+ self.out = nn.Sequential(
171
+ GroupNorm32(32, ch),
172
+ nn.SiLU(),
173
+ nn.Conv2d(self.model_channels, params.out_channels, 3, padding=1),
174
+ )
175
+
176
+ def forward(
177
+ self,
178
+ x: torch.Tensor,
179
+ t: torch.Tensor,
180
+ y: torch.Tensor,
181
+ dense_y: torch.Tensor,
182
+ num_frames: int | None = None,
183
+ ) -> torch.Tensor:
184
+ num_frames = num_frames or self.params.num_frames
185
+ t_emb = timestep_embedding(t, self.model_channels)
186
+ t_emb = self.time_embed(t_emb)
187
+
188
+ hs = []
189
+ h = x
190
+ for module in self.input_blocks:
191
+ h = module(
192
+ h,
193
+ emb=t_emb,
194
+ context=y,
195
+ dense_emb=dense_y,
196
+ num_frames=num_frames,
197
+ )
198
+ hs.append(h)
199
+ h = self.middle_block(
200
+ h,
201
+ emb=t_emb,
202
+ context=y,
203
+ dense_emb=dense_y,
204
+ num_frames=num_frames,
205
+ )
206
+ for module in self.output_blocks:
207
+ h = torch.cat([h, hs.pop()], dim=1)
208
+ h = module(
209
+ h,
210
+ emb=t_emb,
211
+ context=y,
212
+ dense_emb=dense_y,
213
+ num_frames=num_frames,
214
+ )
215
+ h = h.type(x.dtype)
216
+ return self.out(h)
217
+
218
+
219
+ class SGMWrapper(nn.Module):
220
+ def __init__(self, module: Seva):
221
+ super().__init__()
222
+ self.module = module
223
+
224
+ def forward(
225
+ self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
226
+ ) -> torch.Tensor:
227
+ x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
228
+ return self.module(
229
+ x,
230
+ t=t,
231
+ y=c["crossattn"],
232
+ dense_y=c["dense_vector"],
233
+ **kwargs,
234
+ )
seva/modules/__init__.py ADDED
File without changes
seva/modules/autoencoder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.models import AutoencoderKL # type: ignore
3
+ from torch import nn
4
+
5
+
6
+ class AutoEncoder(nn.Module):
7
+ scale_factor: float = 0.18215
8
+ downsample: int = 8
9
+
10
+ def __init__(self, chunk_size: int | None = None):
11
+ super().__init__()
12
+ self.module = AutoencoderKL.from_pretrained(
13
+ "stabilityai/stable-diffusion-2-1-base",
14
+ subfolder="vae",
15
+ force_download=False,
16
+ low_cpu_mem_usage=False,
17
+ )
18
+ self.module.eval().requires_grad_(False) # type: ignore
19
+ self.chunk_size = chunk_size
20
+
21
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
22
+ return (
23
+ self.module.encode(x).latent_dist.mean # type: ignore
24
+ * self.scale_factor
25
+ )
26
+
27
+ def encode(self, x: torch.Tensor, chunk_size: int | None = None) -> torch.Tensor:
28
+ chunk_size = chunk_size or self.chunk_size
29
+ if chunk_size is not None:
30
+ return torch.cat(
31
+ [self._encode(x_chunk) for x_chunk in x.split(chunk_size)],
32
+ dim=0,
33
+ )
34
+ else:
35
+ return self._encode(x)
36
+
37
+ def _decode(self, z: torch.Tensor) -> torch.Tensor:
38
+ return self.module.decode(z / self.scale_factor).sample # type: ignore
39
+
40
+ def decode(self, z: torch.Tensor, chunk_size: int | None = None) -> torch.Tensor:
41
+ chunk_size = chunk_size or self.chunk_size
42
+ if chunk_size is not None:
43
+ return torch.cat(
44
+ [self._decode(z_chunk) for z_chunk in z.split(chunk_size)],
45
+ dim=0,
46
+ )
47
+ else:
48
+ return self._decode(z)
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ return self.decode(self.encode(x))
seva/modules/conditioner.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kornia
2
+ import open_clip
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class CLIPConditioner(nn.Module):
8
+ mean: torch.Tensor
9
+ std: torch.Tensor
10
+
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.module = open_clip.create_model_and_transforms(
14
+ "ViT-H-14", pretrained="laion2b_s32b_b79k"
15
+ )[0]
16
+ self.module.eval().requires_grad_(False) # type: ignore
17
+ self.register_buffer(
18
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
19
+ )
20
+ self.register_buffer(
21
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
22
+ )
23
+
24
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
25
+ x = kornia.geometry.resize(
26
+ x,
27
+ (224, 224),
28
+ interpolation="bicubic",
29
+ align_corners=True,
30
+ antialias=True,
31
+ )
32
+ x = (x + 1.0) / 2.0
33
+ x = kornia.enhance.normalize(x, self.mean, self.std)
34
+ return x
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ x = self.preprocess(x)
38
+ x = self.module.encode_image(x)
39
+ return x
seva/modules/layers.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import repeat
6
+ from torch import nn
7
+
8
+ from .transformer import MultiviewTransformer
9
+
10
+
11
+ def timestep_embedding(
12
+ timesteps: torch.Tensor,
13
+ dim: int,
14
+ max_period: int = 10000,
15
+ repeat_only: bool = False,
16
+ ) -> torch.Tensor:
17
+ if not repeat_only:
18
+ half = dim // 2
19
+ freqs = torch.exp(
20
+ -math.log(max_period)
21
+ * torch.arange(start=0, end=half, dtype=torch.float32)
22
+ / half
23
+ ).to(device=timesteps.device)
24
+ args = timesteps[:, None].float() * freqs[None]
25
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
26
+ if dim % 2:
27
+ embedding = torch.cat(
28
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
29
+ )
30
+ else:
31
+ embedding = repeat(timesteps, "b -> b d", d=dim)
32
+ return embedding
33
+
34
+
35
+ class Upsample(nn.Module):
36
+ def __init__(self, channels: int, out_channels: int | None = None):
37
+ super().__init__()
38
+ self.channels = channels
39
+ self.out_channels = out_channels or channels
40
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, 1, 1)
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ assert x.shape[1] == self.channels
44
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
45
+ x = self.conv(x)
46
+ return x
47
+
48
+
49
+ class Downsample(nn.Module):
50
+ def __init__(self, channels: int, out_channels: int | None = None):
51
+ super().__init__()
52
+ self.channels = channels
53
+ self.out_channels = out_channels or channels
54
+ self.op = nn.Conv2d(self.channels, self.out_channels, 3, 2, 1)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ assert x.shape[1] == self.channels
58
+ return self.op(x)
59
+
60
+
61
+ class GroupNorm32(nn.GroupNorm):
62
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
63
+ return super().forward(input.float()).type(input.dtype)
64
+
65
+
66
+ class TimestepEmbedSequential(nn.Sequential):
67
+ def forward( # type: ignore[override]
68
+ self,
69
+ x: torch.Tensor,
70
+ emb: torch.Tensor,
71
+ context: torch.Tensor,
72
+ dense_emb: torch.Tensor,
73
+ num_frames: int,
74
+ ) -> torch.Tensor:
75
+ for layer in self:
76
+ if isinstance(layer, MultiviewTransformer):
77
+ assert num_frames is not None
78
+ x = layer(x, context, num_frames)
79
+ elif isinstance(layer, ResBlock):
80
+ x = layer(x, emb, dense_emb)
81
+ else:
82
+ x = layer(x)
83
+ return x
84
+
85
+
86
+ class ResBlock(nn.Module):
87
+ def __init__(
88
+ self,
89
+ channels: int,
90
+ emb_channels: int,
91
+ out_channels: int | None,
92
+ dense_in_channels: int,
93
+ dropout: float,
94
+ ):
95
+ super().__init__()
96
+ out_channels = out_channels or channels
97
+
98
+ self.in_layers = nn.Sequential(
99
+ GroupNorm32(32, channels),
100
+ nn.SiLU(),
101
+ nn.Conv2d(channels, out_channels, 3, 1, 1),
102
+ )
103
+ self.emb_layers = nn.Sequential(
104
+ nn.SiLU(), nn.Linear(emb_channels, out_channels)
105
+ )
106
+ self.dense_emb_layers = nn.Sequential(
107
+ nn.Conv2d(dense_in_channels, 2 * channels, 1, 1, 0)
108
+ )
109
+ self.out_layers = nn.Sequential(
110
+ GroupNorm32(32, out_channels),
111
+ nn.SiLU(),
112
+ nn.Dropout(dropout),
113
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
114
+ )
115
+ if out_channels == channels:
116
+ self.skip_connection = nn.Identity()
117
+ else:
118
+ self.skip_connection = nn.Conv2d(channels, out_channels, 1, 1, 0)
119
+
120
+ def forward(
121
+ self, x: torch.Tensor, emb: torch.Tensor, dense_emb: torch.Tensor
122
+ ) -> torch.Tensor:
123
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
124
+ h = in_rest(x)
125
+ dense = self.dense_emb_layers(
126
+ F.interpolate(
127
+ dense_emb, size=h.shape[2:], mode="bilinear", align_corners=True
128
+ )
129
+ ).type(h.dtype)
130
+ dense_scale, dense_shift = torch.chunk(dense, 2, dim=1)
131
+ h = h * (1 + dense_scale) + dense_shift
132
+ h = in_conv(h)
133
+ emb_out = self.emb_layers(emb).type(h.dtype)
134
+ # TODO(hangg): Optimize this?
135
+ while len(emb_out.shape) < len(h.shape):
136
+ emb_out = emb_out[..., None]
137
+ h = h + emb_out
138
+ h = self.out_layers(h)
139
+ h = self.skip_connection(x) + h
140
+ return h
seva/modules/preprocessor.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import os
3
+ import os.path as osp
4
+ import sys
5
+ from typing import cast
6
+
7
+ import imageio.v3 as iio
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class Dust3rPipeline(object):
13
+ def __init__(self, device: str | torch.device = "cuda"):
14
+ submodule_path = osp.realpath(
15
+ osp.join(osp.dirname(__file__), "../../third_party/dust3r/")
16
+ )
17
+ if submodule_path not in sys.path:
18
+ sys.path.insert(0, submodule_path)
19
+ try:
20
+ with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
21
+ from dust3r.cloud_opt import ( # type: ignore[import]
22
+ GlobalAlignerMode,
23
+ global_aligner,
24
+ )
25
+ from dust3r.image_pairs import make_pairs # type: ignore[import]
26
+ from dust3r.inference import inference # type: ignore[import]
27
+ from dust3r.model import AsymmetricCroCo3DStereo # type: ignore[import]
28
+ from dust3r.utils.image import load_images # type: ignore[import]
29
+ except ImportError:
30
+ raise ImportError(
31
+ "Missing required submodule: 'dust3r'. Please ensure that all submodules are properly set up.\n\n"
32
+ "To initialize them, run the following command in the project root:\n"
33
+ " git submodule update --init --recursive"
34
+ )
35
+
36
+ self.device = torch.device(device)
37
+ self.model = AsymmetricCroCo3DStereo.from_pretrained(
38
+ "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
39
+ ).to(self.device)
40
+
41
+ self._GlobalAlignerMode = GlobalAlignerMode
42
+ self._global_aligner = global_aligner
43
+ self._make_pairs = make_pairs
44
+ self._inference = inference
45
+ self._load_images = load_images
46
+
47
+ def infer_cameras_and_points(
48
+ self,
49
+ img_paths: list[str],
50
+ Ks: list[list] = None,
51
+ c2ws: list[list] = None,
52
+ batch_size: int = 16,
53
+ schedule: str = "cosine",
54
+ lr: float = 0.01,
55
+ niter: int = 500,
56
+ min_conf_thr: int = 3,
57
+ ) -> tuple[
58
+ list[np.ndarray], np.ndarray, np.ndarray, list[np.ndarray], list[np.ndarray]
59
+ ]:
60
+ num_img = len(img_paths)
61
+ if num_img == 1:
62
+ print("Only one image found, duplicating it to create a stereo pair.")
63
+ img_paths = img_paths * 2
64
+
65
+ images = self._load_images(img_paths, size=512)
66
+ pairs = self._make_pairs(
67
+ images,
68
+ scene_graph="complete",
69
+ prefilter=None,
70
+ symmetrize=True,
71
+ )
72
+ output = self._inference(pairs, self.model, self.device, batch_size=batch_size)
73
+
74
+ ori_imgs = [iio.imread(p) for p in img_paths]
75
+ ori_img_whs = np.array([img.shape[1::-1] for img in ori_imgs])
76
+ img_whs = np.concatenate([image["true_shape"][:, ::-1] for image in images], 0)
77
+
78
+ scene = self._global_aligner(
79
+ output,
80
+ device=self.device,
81
+ mode=self._GlobalAlignerMode.PointCloudOptimizer,
82
+ same_focals=True,
83
+ optimize_pp=False, # True,
84
+ min_conf_thr=min_conf_thr,
85
+ )
86
+
87
+ # if Ks is not None:
88
+ # scene.preset_focal(
89
+ # torch.tensor([[K[0, 0], K[1, 1]] for K in Ks])
90
+ # )
91
+
92
+ if c2ws is not None:
93
+ scene.preset_pose(c2ws)
94
+
95
+ _ = scene.compute_global_alignment(
96
+ init="msp", niter=niter, schedule=schedule, lr=lr
97
+ )
98
+
99
+ imgs = cast(list, scene.imgs)
100
+ Ks = scene.get_intrinsics().detach().cpu().numpy().copy()
101
+ c2ws = scene.get_im_poses().detach().cpu().numpy() # type: ignore
102
+ pts3d = [x.detach().cpu().numpy() for x in scene.get_pts3d()] # type: ignore
103
+ if num_img > 1:
104
+ masks = [x.detach().cpu().numpy() for x in scene.get_masks()]
105
+ points = [p[m] for p, m in zip(pts3d, masks)]
106
+ point_colors = [img[m] for img, m in zip(imgs, masks)]
107
+ else:
108
+ points = [p.reshape(-1, 3) for p in pts3d]
109
+ point_colors = [img.reshape(-1, 3) for img in imgs]
110
+
111
+ # Convert back to the original image size.
112
+ imgs = ori_imgs
113
+ Ks[:, :2, -1] *= ori_img_whs / img_whs
114
+ Ks[:, :2, :2] *= (ori_img_whs / img_whs).mean(axis=1, keepdims=True)[..., None]
115
+
116
+ return imgs, Ks, c2ws, points, point_colors
seva/modules/transformer.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from einops import rearrange, repeat
4
+ from torch import nn
5
+ from torch.nn.attention import SDPBackend, sdpa_kernel
6
+
7
+
8
+ class GEGLU(nn.Module):
9
+ def __init__(self, dim_in: int, dim_out: int):
10
+ super().__init__()
11
+ self.proj = nn.Linear(dim_in, dim_out * 2)
12
+
13
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
14
+ x, gate = self.proj(x).chunk(2, dim=-1)
15
+ return x * F.gelu(gate)
16
+
17
+
18
+ class FeedForward(nn.Module):
19
+ def __init__(
20
+ self,
21
+ dim: int,
22
+ dim_out: int | None = None,
23
+ mult: int = 4,
24
+ dropout: float = 0.0,
25
+ ):
26
+ super().__init__()
27
+ inner_dim = int(dim * mult)
28
+ dim_out = dim_out or dim
29
+ self.net = nn.Sequential(
30
+ GEGLU(dim, inner_dim), nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
31
+ )
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ return self.net(x)
35
+
36
+
37
+ class Attention(nn.Module):
38
+ def __init__(
39
+ self,
40
+ query_dim: int,
41
+ context_dim: int | None = None,
42
+ heads: int = 8,
43
+ dim_head: int = 64,
44
+ dropout: float = 0.0,
45
+ ):
46
+ super().__init__()
47
+ self.heads = heads
48
+ self.dim_head = dim_head
49
+ inner_dim = dim_head * heads
50
+ context_dim = context_dim or query_dim
51
+
52
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
53
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
54
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
55
+ self.to_out = nn.Sequential(
56
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
57
+ )
58
+
59
+ def forward(
60
+ self, x: torch.Tensor, context: torch.Tensor | None = None
61
+ ) -> torch.Tensor:
62
+ q = self.to_q(x)
63
+ context = context if context is not None else x
64
+ k = self.to_k(context)
65
+ v = self.to_v(context)
66
+ q, k, v = map(
67
+ lambda t: rearrange(t, "b l (h d) -> b h l d", h=self.heads),
68
+ (q, k, v),
69
+ )
70
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
71
+ out = F.scaled_dot_product_attention(q, k, v)
72
+ out = rearrange(out, "b h l d -> b l (h d)")
73
+ out = self.to_out(out)
74
+ return out
75
+
76
+
77
+ class TransformerBlock(nn.Module):
78
+ def __init__(
79
+ self,
80
+ dim: int,
81
+ n_heads: int,
82
+ d_head: int,
83
+ context_dim: int,
84
+ dropout: float = 0.0,
85
+ ):
86
+ super().__init__()
87
+ self.attn1 = Attention(
88
+ query_dim=dim,
89
+ context_dim=None,
90
+ heads=n_heads,
91
+ dim_head=d_head,
92
+ dropout=dropout,
93
+ )
94
+ self.ff = FeedForward(dim, dropout=dropout)
95
+ self.attn2 = Attention(
96
+ query_dim=dim,
97
+ context_dim=context_dim,
98
+ heads=n_heads,
99
+ dim_head=d_head,
100
+ dropout=dropout,
101
+ )
102
+ self.norm1 = nn.LayerNorm(dim)
103
+ self.norm2 = nn.LayerNorm(dim)
104
+ self.norm3 = nn.LayerNorm(dim)
105
+
106
+ def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
107
+ x = self.attn1(self.norm1(x)) + x
108
+ x = self.attn2(self.norm2(x), context=context) + x
109
+ x = self.ff(self.norm3(x)) + x
110
+ return x
111
+
112
+
113
+ class TransformerBlockTimeMix(nn.Module):
114
+ def __init__(
115
+ self,
116
+ dim: int,
117
+ n_heads: int,
118
+ d_head: int,
119
+ context_dim: int,
120
+ dropout: float = 0.0,
121
+ ):
122
+ super().__init__()
123
+ inner_dim = n_heads * d_head
124
+ self.norm_in = nn.LayerNorm(dim)
125
+ self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout)
126
+ self.attn1 = Attention(
127
+ query_dim=inner_dim,
128
+ context_dim=None,
129
+ heads=n_heads,
130
+ dim_head=d_head,
131
+ dropout=dropout,
132
+ )
133
+ self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout)
134
+ self.attn2 = Attention(
135
+ query_dim=inner_dim,
136
+ context_dim=context_dim,
137
+ heads=n_heads,
138
+ dim_head=d_head,
139
+ dropout=dropout,
140
+ )
141
+ self.norm1 = nn.LayerNorm(inner_dim)
142
+ self.norm2 = nn.LayerNorm(inner_dim)
143
+ self.norm3 = nn.LayerNorm(inner_dim)
144
+
145
+ def forward(
146
+ self, x: torch.Tensor, context: torch.Tensor, num_frames: int
147
+ ) -> torch.Tensor:
148
+ _, s, _ = x.shape
149
+ x = rearrange(x, "(b t) s c -> (b s) t c", t=num_frames)
150
+ x = self.ff_in(self.norm_in(x)) + x
151
+ x = self.attn1(self.norm1(x), context=None) + x
152
+ x = self.attn2(self.norm2(x), context=context) + x
153
+ x = self.ff(self.norm3(x))
154
+ x = rearrange(x, "(b s) t c -> (b t) s c", s=s)
155
+ return x
156
+
157
+
158
+ class SkipConnect(nn.Module):
159
+ def __init__(self):
160
+ super().__init__()
161
+
162
+ def forward(
163
+ self, x_spatial: torch.Tensor, x_temporal: torch.Tensor
164
+ ) -> torch.Tensor:
165
+ return x_spatial + x_temporal
166
+
167
+
168
+ class MultiviewTransformer(nn.Module):
169
+ def __init__(
170
+ self,
171
+ in_channels: int,
172
+ n_heads: int,
173
+ d_head: int,
174
+ name: str,
175
+ unflatten_names: list[str] = [],
176
+ depth: int = 1,
177
+ context_dim: int = 1024,
178
+ dropout: float = 0.0,
179
+ ):
180
+ super().__init__()
181
+ self.in_channels = in_channels
182
+ self.name = name
183
+ self.unflatten_names = unflatten_names
184
+
185
+ inner_dim = n_heads * d_head
186
+ self.norm = nn.GroupNorm(32, in_channels, eps=1e-6)
187
+ self.proj_in = nn.Linear(in_channels, inner_dim)
188
+ self.transformer_blocks = nn.ModuleList(
189
+ [
190
+ TransformerBlock(
191
+ inner_dim,
192
+ n_heads,
193
+ d_head,
194
+ context_dim=context_dim,
195
+ dropout=dropout,
196
+ )
197
+ for _ in range(depth)
198
+ ]
199
+ )
200
+ self.proj_out = nn.Linear(inner_dim, in_channels)
201
+ self.time_mixer = SkipConnect()
202
+ self.time_mix_blocks = nn.ModuleList(
203
+ [
204
+ TransformerBlockTimeMix(
205
+ inner_dim,
206
+ n_heads,
207
+ d_head,
208
+ context_dim=context_dim,
209
+ dropout=dropout,
210
+ )
211
+ for _ in range(depth)
212
+ ]
213
+ )
214
+
215
+ def forward(
216
+ self, x: torch.Tensor, context: torch.Tensor, num_frames: int
217
+ ) -> torch.Tensor:
218
+ assert context.ndim == 3
219
+ _, _, h, w = x.shape
220
+ x_in = x
221
+
222
+ time_context = context
223
+ time_context_first_timestep = time_context[::num_frames]
224
+ time_context = repeat(
225
+ time_context_first_timestep, "b ... -> (b n) ...", n=h * w
226
+ )
227
+
228
+ if self.name in self.unflatten_names:
229
+ context = context[::num_frames]
230
+
231
+ x = self.norm(x)
232
+ x = rearrange(x, "b c h w -> b (h w) c")
233
+ x = self.proj_in(x)
234
+
235
+ for block, mix_block in zip(self.transformer_blocks, self.time_mix_blocks):
236
+ if self.name in self.unflatten_names:
237
+ x = rearrange(x, "(b t) (h w) c -> b (t h w) c", t=num_frames, h=h, w=w)
238
+ x = block(x, context=context)
239
+ if self.name in self.unflatten_names:
240
+ x = rearrange(x, "b (t h w) c -> (b t) (h w) c", t=num_frames, h=h, w=w)
241
+ x_mix = mix_block(x, context=time_context, num_frames=num_frames)
242
+ x = self.time_mixer(x_spatial=x, x_temporal=x_mix)
243
+
244
+ x = self.proj_out(x)
245
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
246
+ out = x + x_in
247
+ return out
seva/sampling.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+ from tqdm import tqdm
6
+
7
+ from seva.geometry import get_camera_dist
8
+
9
+
10
+ def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
11
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
12
+ dims_to_append = target_dims - x.ndim
13
+ if dims_to_append < 0:
14
+ raise ValueError(
15
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
16
+ )
17
+ return x[(...,) + (None,) * dims_to_append]
18
+
19
+
20
+ def append_zero(x: torch.Tensor) -> torch.Tensor:
21
+ return torch.cat([x, x.new_zeros([1])])
22
+
23
+
24
+ def to_d(x: torch.Tensor, sigma: torch.Tensor, denoised: torch.Tensor) -> torch.Tensor:
25
+ return (x - denoised) / append_dims(sigma, x.ndim)
26
+
27
+
28
+ def make_betas(
29
+ num_timesteps: int, linear_start: float = 1e-4, linear_end: float = 2e-2
30
+ ) -> np.ndarray:
31
+ betas = (
32
+ torch.linspace(
33
+ linear_start**0.5, linear_end**0.5, num_timesteps, dtype=torch.float64
34
+ )
35
+ ** 2
36
+ )
37
+ return betas.numpy()
38
+
39
+
40
+ def generate_roughly_equally_spaced_steps(
41
+ num_substeps: int, max_step: int
42
+ ) -> np.ndarray:
43
+ return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
44
+
45
+
46
+ class EpsScaling(object):
47
+ def __call__(
48
+ self, sigma: torch.Tensor
49
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
50
+ c_skip = torch.ones_like(sigma, device=sigma.device)
51
+ c_out = -sigma
52
+ c_in = 1 / (sigma**2 + 1.0) ** 0.5
53
+ c_noise = sigma.clone()
54
+ return c_skip, c_out, c_in, c_noise
55
+
56
+
57
+ class DDPMDiscretization(object):
58
+ def __init__(
59
+ self,
60
+ linear_start: float = 5e-06,
61
+ linear_end: float = 0.012,
62
+ num_timesteps: int = 1000,
63
+ log_snr_shift: float | None = 2.4,
64
+ ):
65
+ self.num_timesteps = num_timesteps
66
+
67
+ betas = make_betas(
68
+ num_timesteps,
69
+ linear_start=linear_start,
70
+ linear_end=linear_end,
71
+ )
72
+ self.log_snr_shift = log_snr_shift
73
+
74
+ alphas = 1.0 - betas # first alpha here is on data side
75
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
76
+
77
+ def get_sigmas(self, n: int, device: str | torch.device = "cpu") -> torch.Tensor:
78
+ if n < self.num_timesteps:
79
+ timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
80
+ alphas_cumprod = self.alphas_cumprod[timesteps]
81
+ elif n == self.num_timesteps:
82
+ alphas_cumprod = self.alphas_cumprod
83
+ else:
84
+ raise ValueError(f"Expected n <= {self.num_timesteps}, but got n = {n}.")
85
+
86
+ sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
87
+ if self.log_snr_shift is not None:
88
+ sigmas = sigmas * np.exp(self.log_snr_shift)
89
+ return torch.flip(
90
+ torch.tensor(sigmas, dtype=torch.float32, device=device), (0,)
91
+ )
92
+
93
+ def __call__(
94
+ self,
95
+ n: int,
96
+ do_append_zero: bool = True,
97
+ flip: bool = False,
98
+ device: str | torch.device = "cpu",
99
+ ) -> torch.Tensor:
100
+ sigmas = self.get_sigmas(n, device=device)
101
+ sigmas = append_zero(sigmas) if do_append_zero else sigmas
102
+ return sigmas if not flip else torch.flip(sigmas, (0,))
103
+
104
+
105
+ class DiscreteDenoiser(object):
106
+ sigmas: torch.Tensor
107
+
108
+ def __init__(
109
+ self,
110
+ discretization: DDPMDiscretization,
111
+ num_idx: int = 1000,
112
+ device: str | torch.device = "cpu",
113
+ ):
114
+ self.scaling = EpsScaling()
115
+ self.discretization = discretization
116
+ self.num_idx = num_idx
117
+ self.device = device
118
+
119
+ self.register_sigmas()
120
+
121
+ def register_sigmas(self):
122
+ self.sigmas = self.discretization(
123
+ self.num_idx, do_append_zero=False, flip=True, device=self.device
124
+ )
125
+
126
+ def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:
127
+ dists = sigma - self.sigmas[:, None]
128
+ return dists.abs().argmin(dim=0).view(sigma.shape)
129
+
130
+ def idx_to_sigma(self, idx: torch.Tensor | int) -> torch.Tensor:
131
+ return self.sigmas[idx]
132
+
133
+ def __call__(
134
+ self,
135
+ network: nn.Module,
136
+ input: torch.Tensor,
137
+ sigma: torch.Tensor,
138
+ cond: dict,
139
+ **additional_model_inputs,
140
+ ) -> torch.Tensor:
141
+ sigma = self.idx_to_sigma(self.sigma_to_idx(sigma))
142
+ sigma_shape = sigma.shape
143
+ sigma = append_dims(sigma, input.ndim)
144
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma)
145
+ c_noise = self.sigma_to_idx(c_noise.reshape(sigma_shape))
146
+ if "replace" in cond:
147
+ x, mask = cond.pop("replace").split((input.shape[1], 1), dim=1)
148
+ input = input * (1 - mask) + x * mask
149
+ return (
150
+ network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
151
+ + input * c_skip
152
+ )
153
+
154
+
155
+ class ConstantScaleRule(object):
156
+ def __call__(self, scale: float | torch.Tensor) -> float | torch.Tensor:
157
+ return scale
158
+
159
+
160
+ class MultiviewScaleRule(object):
161
+ def __init__(self, min_scale: float = 1.0):
162
+ self.min_scale = min_scale
163
+
164
+ def __call__(
165
+ self,
166
+ scale: float | torch.Tensor,
167
+ c2w: torch.Tensor,
168
+ K: torch.Tensor,
169
+ input_frame_mask: torch.Tensor,
170
+ ) -> torch.Tensor:
171
+ c2w_input = c2w[input_frame_mask]
172
+ rotation_diff = get_camera_dist(c2w, c2w_input, mode="rotation").min(-1).values
173
+ translation_diff = (
174
+ get_camera_dist(c2w, c2w_input, mode="translation").min(-1).values
175
+ )
176
+ K_diff = (
177
+ ((K[:, None] - K[input_frame_mask][None]).flatten(-2) == 0).all(-1).any(-1)
178
+ )
179
+ close_frame = (rotation_diff < 10.0) & (translation_diff < 1e-5) & K_diff
180
+ if isinstance(scale, torch.Tensor):
181
+ scale = scale.clone()
182
+ scale[close_frame] = self.min_scale
183
+ elif isinstance(scale, float):
184
+ scale = torch.where(close_frame, self.min_scale, scale)
185
+ else:
186
+ raise ValueError(f"Invalid scale type {type(scale)}.")
187
+ return scale
188
+
189
+
190
+ class ConstantScaleSchedule(object):
191
+ def __call__(
192
+ self, sigma: float | torch.Tensor, scale: float | torch.Tensor
193
+ ) -> float | torch.Tensor:
194
+ if isinstance(sigma, float):
195
+ return scale
196
+ elif isinstance(sigma, torch.Tensor):
197
+ if len(sigma.shape) == 1 and isinstance(scale, torch.Tensor):
198
+ sigma = append_dims(sigma, scale.ndim)
199
+ return scale * torch.ones_like(sigma)
200
+ else:
201
+ raise ValueError(f"Invalid sigma type {type(sigma)}.")
202
+
203
+
204
+ class ConstantGuidance(object):
205
+ def __call__(
206
+ self,
207
+ uncond: torch.Tensor,
208
+ cond: torch.Tensor,
209
+ scale: float | torch.Tensor,
210
+ ) -> torch.Tensor:
211
+ if isinstance(scale, torch.Tensor) and len(scale.shape) == 1:
212
+ scale = append_dims(scale, cond.ndim)
213
+ return uncond + scale * (cond - uncond)
214
+
215
+
216
+ class VanillaCFG(object):
217
+ def __init__(self):
218
+ self.scale_rule = ConstantScaleRule()
219
+ self.scale_schedule = ConstantScaleSchedule()
220
+ self.guidance = ConstantGuidance()
221
+
222
+ def __call__(
223
+ self, x: torch.Tensor, sigma: float | torch.Tensor, scale: float | torch.Tensor
224
+ ) -> torch.Tensor:
225
+ x_u, x_c = x.chunk(2)
226
+ scale = self.scale_rule(scale)
227
+ scale_value = self.scale_schedule(sigma, scale)
228
+ x_pred = self.guidance(x_u, x_c, scale_value)
229
+ return x_pred
230
+
231
+ def prepare_inputs(
232
+ self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
233
+ ) -> tuple[torch.Tensor, torch.Tensor, dict]:
234
+ c_out = dict()
235
+
236
+ for k in c:
237
+ if k in ["vector", "crossattn", "concat", "replace", "dense_vector"]:
238
+ c_out[k] = torch.cat((uc[k], c[k]), 0)
239
+ else:
240
+ assert c[k] == uc[k]
241
+ c_out[k] = c[k]
242
+ return torch.cat([x] * 2), torch.cat([s] * 2), c_out
243
+
244
+
245
+ class MultiviewCFG(VanillaCFG):
246
+ def __init__(self, cfg_min: float = 1.0):
247
+ self.scale_min = cfg_min
248
+ self.scale_rule = MultiviewScaleRule(min_scale=cfg_min)
249
+ self.scale_schedule = ConstantScaleSchedule()
250
+ self.guidance = ConstantGuidance()
251
+
252
+ def __call__( # type: ignore
253
+ self,
254
+ x: torch.Tensor,
255
+ sigma: float | torch.Tensor,
256
+ scale: float | torch.Tensor,
257
+ c2w: torch.Tensor,
258
+ K: torch.Tensor,
259
+ input_frame_mask: torch.Tensor,
260
+ ) -> torch.Tensor:
261
+ x_u, x_c = x.chunk(2)
262
+ scale = self.scale_rule(scale, c2w, K, input_frame_mask)
263
+ scale_value = self.scale_schedule(sigma, scale)
264
+ x_pred = self.guidance(x_u, x_c, scale_value)
265
+ return x_pred
266
+
267
+
268
+ class MultiviewTemporalCFG(MultiviewCFG):
269
+ def __init__(self, num_frames: int, cfg_min: float = 1.0):
270
+ super().__init__(cfg_min=cfg_min)
271
+
272
+ self.num_frames = num_frames
273
+ distance_matrix = (
274
+ torch.arange(num_frames)[None] - torch.arange(num_frames)[:, None]
275
+ ).abs()
276
+ self.distance_matrix = distance_matrix
277
+
278
+ def __call__(
279
+ self,
280
+ x: torch.Tensor,
281
+ sigma: float | torch.Tensor,
282
+ scale: float | torch.Tensor,
283
+ c2w: torch.Tensor,
284
+ K: torch.Tensor,
285
+ input_frame_mask: torch.Tensor,
286
+ ) -> torch.Tensor:
287
+ input_frame_mask = rearrange(
288
+ input_frame_mask, "(b t) ... -> b t ...", t=self.num_frames
289
+ )
290
+ min_distance = (
291
+ self.distance_matrix[None].to(x.device)
292
+ + (~input_frame_mask[:, None]) * self.num_frames
293
+ ).min(-1)[0]
294
+ min_distance = min_distance / min_distance.max(-1, keepdim=True)[0].clamp(min=1)
295
+ scale = min_distance * (scale - self.scale_min) + self.scale_min
296
+ scale = rearrange(scale, "b t ... -> (b t) ...")
297
+ scale = append_dims(scale, x.ndim)
298
+ return super().__call__(x, sigma, scale, c2w, K, input_frame_mask.flatten(0, 1))
299
+
300
+
301
+ class EulerEDMSampler(object):
302
+ def __init__(
303
+ self,
304
+ discretization: DDPMDiscretization,
305
+ guider: VanillaCFG | MultiviewCFG | MultiviewTemporalCFG,
306
+ num_steps: int | None = None,
307
+ verbose: bool = False,
308
+ device: str | torch.device = "cuda",
309
+ s_churn=0.0,
310
+ s_tmin=0.0,
311
+ s_tmax=float("inf"),
312
+ s_noise=1.0,
313
+ ):
314
+ self.num_steps = num_steps
315
+ self.discretization = discretization
316
+ self.guider = guider
317
+ self.verbose = verbose
318
+ self.device = device
319
+
320
+ self.s_churn = s_churn
321
+ self.s_tmin = s_tmin
322
+ self.s_tmax = s_tmax
323
+ self.s_noise = s_noise
324
+
325
+ def prepare_sampling_loop(
326
+ self, x: torch.Tensor, cond: dict, uc: dict, num_steps: int | None = None
327
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, dict, dict]:
328
+ num_steps = num_steps or self.num_steps
329
+ assert num_steps is not None, "num_steps must be specified"
330
+ sigmas = self.discretization(num_steps, device=self.device)
331
+ x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
332
+ num_sigmas = len(sigmas)
333
+ s_in = x.new_ones([x.shape[0]])
334
+ return x, s_in, sigmas, num_sigmas, cond, uc
335
+
336
+ def get_sigma_gen(self, num_sigmas: int, verbose: bool = True) -> range | tqdm:
337
+ sigma_generator = range(num_sigmas - 1)
338
+ if self.verbose and verbose:
339
+ sigma_generator = tqdm(
340
+ sigma_generator,
341
+ total=num_sigmas - 1,
342
+ desc="Sampling",
343
+ leave=False,
344
+ )
345
+ return sigma_generator
346
+
347
+ def sampler_step(
348
+ self,
349
+ sigma: torch.Tensor,
350
+ next_sigma: torch.Tensor,
351
+ denoiser,
352
+ x: torch.Tensor,
353
+ scale: float | torch.Tensor,
354
+ cond: dict,
355
+ uc: dict,
356
+ gamma: float = 0.0,
357
+ **guider_kwargs,
358
+ ) -> torch.Tensor:
359
+ sigma_hat = sigma * (gamma + 1.0) + 1e-6
360
+
361
+ eps = torch.randn_like(x) * self.s_noise
362
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
363
+
364
+ denoised = denoiser(*self.guider.prepare_inputs(x, sigma_hat, cond, uc))
365
+ denoised = self.guider(denoised, sigma_hat, scale, **guider_kwargs)
366
+ d = to_d(x, sigma_hat, denoised)
367
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
368
+ return x + dt * d
369
+
370
+ def __call__(
371
+ self,
372
+ denoiser,
373
+ x: torch.Tensor,
374
+ scale: float | torch.Tensor,
375
+ cond: dict,
376
+ uc: dict | None = None,
377
+ num_steps: int | None = None,
378
+ verbose: bool = True,
379
+ **guider_kwargs,
380
+ ) -> torch.Tensor:
381
+ uc = cond if uc is None else uc
382
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
383
+ x,
384
+ cond,
385
+ uc,
386
+ num_steps,
387
+ )
388
+ for i in self.get_sigma_gen(num_sigmas, verbose=verbose):
389
+ gamma = (
390
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
391
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
392
+ else 0.0
393
+ )
394
+ x = self.sampler_step(
395
+ s_in * sigmas[i],
396
+ s_in * sigmas[i + 1],
397
+ denoiser,
398
+ x,
399
+ scale,
400
+ cond,
401
+ uc,
402
+ gamma,
403
+ **guider_kwargs,
404
+ )
405
+ return x
seva/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import safetensors.torch
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ from seva.model import Seva, SevaParams
8
+
9
+
10
+ def seed_everything(seed: int = 0):
11
+ torch.manual_seed(seed)
12
+ torch.cuda.manual_seed(seed)
13
+ torch.cuda.manual_seed_all(seed)
14
+ torch.backends.cudnn.deterministic = True
15
+ torch.backends.cudnn.benchmark = False
16
+
17
+
18
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
19
+ if len(missing) > 0 and len(unexpected) > 0:
20
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
21
+ print("\n" + "-" * 79 + "\n")
22
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
23
+ elif len(missing) > 0:
24
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
25
+ elif len(unexpected) > 0:
26
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
27
+
28
+
29
+ def load_model(
30
+ pretrained_model_name_or_path: str = "stabilityai/stable-virtual-camera",
31
+ weight_name: str = "model.safetensors",
32
+ device: str | torch.device = "cuda",
33
+ verbose: bool = False,
34
+ ) -> Seva:
35
+ if os.path.isdir(pretrained_model_name_or_path):
36
+ weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
37
+ else:
38
+ weight_path = hf_hub_download(
39
+ repo_id=pretrained_model_name_or_path, filename=weight_name
40
+ )
41
+ _ = hf_hub_download(
42
+ repo_id=pretrained_model_name_or_path, filename="config.yaml"
43
+ )
44
+
45
+ state_dict = safetensors.torch.load_file(
46
+ weight_path,
47
+ device=str(device),
48
+ )
49
+
50
+ with torch.device("meta"):
51
+ model = Seva(SevaParams()).to(torch.bfloat16)
52
+
53
+ missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True)
54
+ if verbose:
55
+ print_load_warning(missing, unexpected)
56
+ return model
third_party/dust3r ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 44b87f5a466ec32435036e40125d0b87a5746c20