diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..d22cef17abd65c6b9edcc26dfb84d6ae5fe6c6ac 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.gif filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index 2875742637ffac46795e86bb22831c0fa692df28..c6a6fbbe6d144a2b7d177c2e60404e0e9e85f259 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@ sdk_version: 5.23.1
app_file: app.py
pinned: false
license: mit
-short_description: https://arxiv.org/abs/2501.08295
+short_description: "LayerAnimate: Layer-level Control for Animation"
---
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/__assets__/demos/demo_1/first_frame.jpg b/__assets__/demos/demo_1/first_frame.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f60be07dba8315293b9e2509e4ca8d4d6a9dfe2d
Binary files /dev/null and b/__assets__/demos/demo_1/first_frame.jpg differ
diff --git a/__assets__/demos/demo_1/layer_0.jpg b/__assets__/demos/demo_1/layer_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..21f5d8a1548f7a20adc91fa864cf2fc9a5a26286
Binary files /dev/null and b/__assets__/demos/demo_1/layer_0.jpg differ
diff --git a/__assets__/demos/demo_1/layer_1.jpg b/__assets__/demos/demo_1/layer_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0823270c27edfa02461a1c3da5a0d0e2c6f35ca5
Binary files /dev/null and b/__assets__/demos/demo_1/layer_1.jpg differ
diff --git a/__assets__/demos/demo_1/layer_2.jpg b/__assets__/demos/demo_1/layer_2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ffacabc31c6006179d9a007adee229885cda47fd
Binary files /dev/null and b/__assets__/demos/demo_1/layer_2.jpg differ
diff --git a/__assets__/demos/demo_1/sketch.mp4 b/__assets__/demos/demo_1/sketch.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..9e7bba620e6b14bb4fb97d31f0388874ed81e959
Binary files /dev/null and b/__assets__/demos/demo_1/sketch.mp4 differ
diff --git a/__assets__/demos/demo_1/trajectory.json b/__assets__/demos/demo_1/trajectory.json
new file mode 100644
index 0000000000000000000000000000000000000000..4697b21da16ddeff8d242750cdf5b95830916512
--- /dev/null
+++ b/__assets__/demos/demo_1/trajectory.json
@@ -0,0 +1,200 @@
+[
+ [
+ [
+ 111.87965393066406,
+ 204.28741455078125
+ ],
+ [
+ 83.42483520507812,
+ 204.21835327148438
+ ],
+ [
+ 52.417137145996094,
+ 205.34869384765625
+ ],
+ [
+ -10.01504135131836,
+ 205.83694458007812
+ ],
+ [
+ -33.109561920166016,
+ 206.53018188476562
+ ],
+ [
+ -86.02885437011719,
+ 205.10772705078125
+ ],
+ [
+ -119.59435272216797,
+ 204.4576873779297
+ ],
+ [
+ -168.70248413085938,
+ 210.6188201904297
+ ],
+ [
+ -185.9542999267578,
+ 211.16294860839844
+ ],
+ [
+ -206.82852172851562,
+ 207.50912475585938
+ ],
+ [
+ -232.2637939453125,
+ 208.35643005371094
+ ],
+ [
+ -177.6964111328125,
+ 205.50949096679688
+ ],
+ [
+ -231.19761657714844,
+ 203.8624267578125
+ ],
+ [
+ -276.06622314453125,
+ 208.6024169921875
+ ],
+ [
+ -285.68218994140625,
+ 210.30313110351562
+ ],
+ [
+ -235.0211639404297,
+ 207.910400390625
+ ]
+ ],
+ [
+ [
+ 130.59063720703125,
+ 131.48106384277344
+ ],
+ [
+ 101.31892395019531,
+ 131.62567138671875
+ ],
+ [
+ 69.3387451171875,
+ 132.40696716308594
+ ],
+ [
+ 6.821704864501953,
+ 133.10546875
+ ],
+ [
+ -21.6120548248291,
+ 132.92977905273438
+ ],
+ [
+ -83.36480712890625,
+ 132.2947998046875
+ ],
+ [
+ -111.29481506347656,
+ 131.91827392578125
+ ],
+ [
+ -168.74850463867188,
+ 138.11587524414062
+ ],
+ [
+ -198.75299072265625,
+ 139.32774353027344
+ ],
+ [
+ -253.08055114746094,
+ 136.65480041503906
+ ],
+ [
+ -278.3507080078125,
+ 136.42958068847656
+ ],
+ [
+ -312.9150390625,
+ 134.22898864746094
+ ],
+ [
+ -332.20989990234375,
+ 133.93161010742188
+ ],
+ [
+ -357.1211853027344,
+ 139.33224487304688
+ ],
+ [
+ -361.4031677246094,
+ 139.66172790527344
+ ],
+ [
+ -338.45501708984375,
+ 141.38809204101562
+ ]
+ ],
+ [
+ [
+ 308.344970703125,
+ 6.6701483726501465
+ ],
+ [
+ 278.66864013671875,
+ 7.116205215454102
+ ],
+ [
+ 247.65390014648438,
+ 7.756659507751465
+ ],
+ [
+ 184.76953125,
+ 8.749884605407715
+ ],
+ [
+ 154.9658203125,
+ 8.66163444519043
+ ],
+ [
+ 92.775146484375,
+ 7.572597503662109
+ ],
+ [
+ 63.20433044433594,
+ 7.524573802947998
+ ],
+ [
+ 1.4797935485839844,
+ 13.07353401184082
+ ],
+ [
+ -26.288057327270508,
+ 13.74260139465332
+ ],
+ [
+ -83.00379943847656,
+ 11.522849082946777
+ ],
+ [
+ -109.52509307861328,
+ 10.739717483520508
+ ],
+ [
+ -140.5462646484375,
+ 8.596296310424805
+ ],
+ [
+ -155.35394287109375,
+ 8.009984970092773
+ ],
+ [
+ -180.55775451660156,
+ 13.584362030029297
+ ],
+ [
+ -185.0371856689453,
+ 14.09956169128418
+ ],
+ [
+ -203.57778930664062,
+ 18.082473754882812
+ ]
+ ]
+]
\ No newline at end of file
diff --git a/__assets__/demos/demo_1/trajectory.npz b/__assets__/demos/demo_1/trajectory.npz
new file mode 100644
index 0000000000000000000000000000000000000000..c8bcac4a0e831ebc2e9cbf1c89946bfbd11b57a5
--- /dev/null
+++ b/__assets__/demos/demo_1/trajectory.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:232a68740a9d2828277e786d760cb2d7436f4617ae1d64d31a61888be0c65ea1
+size 994
diff --git a/__assets__/demos/demo_2/first_frame.jpg b/__assets__/demos/demo_2/first_frame.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5857abee5ff9bc0e87cee79aeb100f294421ff1f
Binary files /dev/null and b/__assets__/demos/demo_2/first_frame.jpg differ
diff --git a/__assets__/demos/demo_2/layer_0.jpg b/__assets__/demos/demo_2/layer_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5f45a856cf35a5be895cbcde5e8eeb70513ab038
Binary files /dev/null and b/__assets__/demos/demo_2/layer_0.jpg differ
diff --git a/__assets__/demos/demo_2/layer_1.jpg b/__assets__/demos/demo_2/layer_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3b661b6bc4e7225f5efc5fdc780cd64edd7ab488
Binary files /dev/null and b/__assets__/demos/demo_2/layer_1.jpg differ
diff --git a/__assets__/demos/demo_2/layer_2.jpg b/__assets__/demos/demo_2/layer_2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2412784cdf7fb67eb92cef80edb39bd3a4c5be68
Binary files /dev/null and b/__assets__/demos/demo_2/layer_2.jpg differ
diff --git a/__assets__/demos/demo_2/sketch.mp4 b/__assets__/demos/demo_2/sketch.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..4f1846212648bfde1bf520c33fe87459bd7ad5ad
Binary files /dev/null and b/__assets__/demos/demo_2/sketch.mp4 differ
diff --git a/__assets__/demos/demo_2/trajectory.json b/__assets__/demos/demo_2/trajectory.json
new file mode 100644
index 0000000000000000000000000000000000000000..4e2f69bc9803beacdde5d63bbf64d4900993d6bb
--- /dev/null
+++ b/__assets__/demos/demo_2/trajectory.json
@@ -0,0 +1,200 @@
+[
+ [
+ [
+ 158.21946716308594,
+ 245.89105224609375
+ ],
+ [
+ 148.94857788085938,
+ 246.4789276123047
+ ],
+ [
+ 137.88522338867188,
+ 247.1299285888672
+ ],
+ [
+ 128.4403839111328,
+ 247.8033905029297
+ ],
+ [
+ 127.84039306640625,
+ 246.24864196777344
+ ],
+ [
+ 127.06155395507812,
+ 244.60606384277344
+ ],
+ [
+ 126.77435302734375,
+ 243.17208862304688
+ ],
+ [
+ 126.42509460449219,
+ 243.04747009277344
+ ],
+ [
+ 125.61285400390625,
+ 242.14913940429688
+ ],
+ [
+ 125.40904235839844,
+ 242.65948486328125
+ ],
+ [
+ 125.03759765625,
+ 242.90908813476562
+ ],
+ [
+ 124.67877197265625,
+ 242.95994567871094
+ ],
+ [
+ 125.00759887695312,
+ 242.61265563964844
+ ],
+ [
+ 125.37916564941406,
+ 242.13555908203125
+ ],
+ [
+ 125.7420654296875,
+ 242.410888671875
+ ],
+ [
+ 125.54336547851562,
+ 242.98825073242188
+ ]
+ ],
+ [
+ [
+ 223.55435180664062,
+ 204.28741455078125
+ ],
+ [
+ 207.83377075195312,
+ 202.7445068359375
+ ],
+ [
+ 193.4696044921875,
+ 200.418701171875
+ ],
+ [
+ 178.7669677734375,
+ 199.83621215820312
+ ],
+ [
+ 178.14218139648438,
+ 200.34848022460938
+ ],
+ [
+ 176.58251953125,
+ 200.19627380371094
+ ],
+ [
+ 175.0523681640625,
+ 200.24407958984375
+ ],
+ [
+ 174.57379150390625,
+ 199.90940856933594
+ ],
+ [
+ 173.37542724609375,
+ 200.4640350341797
+ ],
+ [
+ 173.5262451171875,
+ 200.5198974609375
+ ],
+ [
+ 173.60935974121094,
+ 200.36471557617188
+ ],
+ [
+ 173.8643035888672,
+ 200.39389038085938
+ ],
+ [
+ 173.903076171875,
+ 200.2958984375
+ ],
+ [
+ 173.96859741210938,
+ 200.00491333007812
+ ],
+ [
+ 174.22422790527344,
+ 200.09921264648438
+ ],
+ [
+ 174.16683959960938,
+ 200.00193786621094
+ ]
+ ],
+ [
+ [
+ 232.88790893554688,
+ 261.492431640625
+ ],
+ [
+ 224.37376403808594,
+ 258.9049072265625
+ ],
+ [
+ 214.7504119873047,
+ 255.82171630859375
+ ],
+ [
+ 205.59695434570312,
+ 252.74368286132812
+ ],
+ [
+ 203.56024169921875,
+ 254.83567810058594
+ ],
+ [
+ 200.3128662109375,
+ 256.933349609375
+ ],
+ [
+ 197.56045532226562,
+ 258.17236328125
+ ],
+ [
+ 196.72007751464844,
+ 258.3282470703125
+ ],
+ [
+ 194.2041473388672,
+ 259.42486572265625
+ ],
+ [
+ 194.23858642578125,
+ 259.9649353027344
+ ],
+ [
+ 194.01547241210938,
+ 260.14569091796875
+ ],
+ [
+ 193.87156677246094,
+ 259.9699401855469
+ ],
+ [
+ 193.9617919921875,
+ 259.7339172363281
+ ],
+ [
+ 193.89659118652344,
+ 259.5014343261719
+ ],
+ [
+ 193.8680419921875,
+ 259.7557373046875
+ ],
+ [
+ 193.91842651367188,
+ 260.28717041015625
+ ]
+ ]
+]
\ No newline at end of file
diff --git a/__assets__/demos/demo_2/trajectory.npz b/__assets__/demos/demo_2/trajectory.npz
new file mode 100644
index 0000000000000000000000000000000000000000..29549b7078bebcb3a89535ee9ea87cfcba35f4a8
--- /dev/null
+++ b/__assets__/demos/demo_2/trajectory.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ba8194e3bd1376e10cb6c708d59603c406269b95bb1e266b20c7cfa66e248875
+size 972
diff --git a/__assets__/demos/demo_3/first_frame.jpg b/__assets__/demos/demo_3/first_frame.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..de487915b2b1aa3a7f996b729b8656a21ba089ae
Binary files /dev/null and b/__assets__/demos/demo_3/first_frame.jpg differ
diff --git a/__assets__/demos/demo_3/last_frame.jpg b/__assets__/demos/demo_3/last_frame.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..32b453a70e67b1ecdd1b2348240451911294efa9
Binary files /dev/null and b/__assets__/demos/demo_3/last_frame.jpg differ
diff --git a/__assets__/demos/demo_3/layer_0.jpg b/__assets__/demos/demo_3/layer_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1c8e02cf61ab095e8f14a28cae14af52957e5768
Binary files /dev/null and b/__assets__/demos/demo_3/layer_0.jpg differ
diff --git a/__assets__/demos/demo_3/layer_0_last.jpg b/__assets__/demos/demo_3/layer_0_last.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..da2bd681bcd3e6f6474056d8913d9b9d56d20694
Binary files /dev/null and b/__assets__/demos/demo_3/layer_0_last.jpg differ
diff --git a/__assets__/demos/demo_3/layer_1.jpg b/__assets__/demos/demo_3/layer_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d4fc6a6301ad78768ff49af2e02b5fd7ee3dad46
Binary files /dev/null and b/__assets__/demos/demo_3/layer_1.jpg differ
diff --git a/__assets__/demos/demo_3/layer_1_last.jpg b/__assets__/demos/demo_3/layer_1_last.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..354d9897e8c5f2a5430655b65be181ad1ddccc7d
Binary files /dev/null and b/__assets__/demos/demo_3/layer_1_last.jpg differ
diff --git a/__assets__/demos/demo_3/layer_2.jpg b/__assets__/demos/demo_3/layer_2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b19cd4af7438b4176f55969a2b5aa7b2be5471b3
Binary files /dev/null and b/__assets__/demos/demo_3/layer_2.jpg differ
diff --git a/__assets__/demos/demo_3/layer_2_last.jpg b/__assets__/demos/demo_3/layer_2_last.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9fdf0d9adb42c0f4ee382d90af69db18db6b3bff
Binary files /dev/null and b/__assets__/demos/demo_3/layer_2_last.jpg differ
diff --git a/__assets__/demos/demo_3/layer_3.jpg b/__assets__/demos/demo_3/layer_3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ccc74e9066ba66c1780b17c86b24bc67dac8804c
Binary files /dev/null and b/__assets__/demos/demo_3/layer_3.jpg differ
diff --git a/__assets__/demos/demo_3/layer_3_last.jpg b/__assets__/demos/demo_3/layer_3_last.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..490f072ddd159a88f4201bff394deec3a5474046
Binary files /dev/null and b/__assets__/demos/demo_3/layer_3_last.jpg differ
diff --git a/__assets__/demos/demo_3/sketch.mp4 b/__assets__/demos/demo_3/sketch.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..abe9e92d9b50435c25f2a18683e3d762f3985531
Binary files /dev/null and b/__assets__/demos/demo_3/sketch.mp4 differ
diff --git a/__assets__/demos/demo_3/trajectory.json b/__assets__/demos/demo_3/trajectory.json
new file mode 100644
index 0000000000000000000000000000000000000000..13c5156e5626f3e51a20ecc8e6652eb95f5d14e6
--- /dev/null
+++ b/__assets__/demos/demo_3/trajectory.json
@@ -0,0 +1,134 @@
+[
+ [
+ [
+ 49.66927719116211,
+ 126.28060150146484
+ ],
+ [
+ 53.070796966552734,
+ 140.00479125976562
+ ],
+ [
+ 58.86982345581055,
+ 157.8321533203125
+ ],
+ [
+ 69.01676177978516,
+ 175.84800720214844
+ ],
+ [
+ 76.01651000976562,
+ 197.62847900390625
+ ],
+ [
+ 93.34223937988281,
+ 232.17538452148438
+ ],
+ [
+ 96.88280487060547,
+ 246.68162536621094
+ ],
+ [
+ 105.09373474121094,
+ 265.91741943359375
+ ],
+ [
+ 122.41947174072266,
+ 300.46429443359375
+ ],
+ [
+ 139.74520874023438,
+ 335.0111999511719
+ ],
+ [
+ 157.07093811035156,
+ 369.55810546875
+ ],
+ [
+ 174.39666748046875,
+ 404.10498046875
+ ],
+ [
+ 191.722412109375,
+ 438.65185546875
+ ],
+ [
+ 209.0481414794922,
+ 473.19873046875
+ ],
+ [
+ 226.37387084960938,
+ 507.74560546875
+ ],
+ [
+ 243.6995849609375,
+ 542.29248046875
+ ]
+ ],
+ [
+ [
+ 56.677669525146484,
+ 69.07560729980469
+ ],
+ [
+ 66.92218780517578,
+ 90.37911224365234
+ ],
+ [
+ 79.62323760986328,
+ 116.14250183105469
+ ],
+ [
+ 91.2628173828125,
+ 141.8087921142578
+ ],
+ [
+ 103.7956771850586,
+ 167.58724975585938
+ ],
+ [
+ 117.59683227539062,
+ 195.22598266601562
+ ],
+ [
+ 127.79037475585938,
+ 221.12567138671875
+ ],
+ [
+ 140.4638671875,
+ 248.97164916992188
+ ],
+ [
+ 138.9651641845703,
+ 256.9488830566406
+ ],
+ [
+ 165.24566650390625,
+ 296.32525634765625
+ ],
+ [
+ 191.52615356445312,
+ 335.70166015625
+ ],
+ [
+ 217.806640625,
+ 375.07806396484375
+ ],
+ [
+ 244.08714294433594,
+ 414.4544372558594
+ ],
+ [
+ 270.3676452636719,
+ 453.830810546875
+ ],
+ [
+ 296.64813232421875,
+ 493.20721435546875
+ ],
+ [
+ 322.92864990234375,
+ 532.5836181640625
+ ]
+ ]
+]
\ No newline at end of file
diff --git a/__assets__/demos/demo_3/trajectory.npz b/__assets__/demos/demo_3/trajectory.npz
new file mode 100644
index 0000000000000000000000000000000000000000..8bc355659b58c39586a685f0aba3357db5414f13
--- /dev/null
+++ b/__assets__/demos/demo_3/trajectory.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1080b8523b361f2e4fb3f5591c88f50e44d176a404e5f62b04cfc2bfe8c2f5d
+size 857
diff --git a/__assets__/demos/demo_4/first_frame.jpg b/__assets__/demos/demo_4/first_frame.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..df907a7d9d867eadcf038fa8df0ff8c4f2a93d38
Binary files /dev/null and b/__assets__/demos/demo_4/first_frame.jpg differ
diff --git a/__assets__/demos/demo_4/layer_0.jpg b/__assets__/demos/demo_4/layer_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..72760831c505bff25b56e60145869420539b382e
Binary files /dev/null and b/__assets__/demos/demo_4/layer_0.jpg differ
diff --git a/__assets__/demos/demo_4/layer_1.jpg b/__assets__/demos/demo_4/layer_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4eb96548a0f0e746dd7cd6b76ca77e736372b570
Binary files /dev/null and b/__assets__/demos/demo_4/layer_1.jpg differ
diff --git a/__assets__/demos/demo_4/layer_2.jpg b/__assets__/demos/demo_4/layer_2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c435db9a6164cf342b1d06887ff2772504d59e55
Binary files /dev/null and b/__assets__/demos/demo_4/layer_2.jpg differ
diff --git a/__assets__/demos/demo_4/sketch.mp4 b/__assets__/demos/demo_4/sketch.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..84506e5ffc349f63ec1241dc15d6f6e9035e0675
Binary files /dev/null and b/__assets__/demos/demo_4/sketch.mp4 differ
diff --git a/__assets__/demos/demo_4/trajectory.json b/__assets__/demos/demo_4/trajectory.json
new file mode 100644
index 0000000000000000000000000000000000000000..79124288bf3f869fce0978bddcb7252e3297a060
--- /dev/null
+++ b/__assets__/demos/demo_4/trajectory.json
@@ -0,0 +1,200 @@
+[
+ [
+ [
+ 186.72357177734375,
+ 225.0892333984375
+ ],
+ [
+ 186.59104919433594,
+ 220.61599731445312
+ ],
+ [
+ 190.39842224121094,
+ 216.0291748046875
+ ],
+ [
+ 199.52769470214844,
+ 213.26031494140625
+ ],
+ [
+ 204.145263671875,
+ 214.56866455078125
+ ],
+ [
+ 209.41751098632812,
+ 214.23330688476562
+ ],
+ [
+ 211.30255126953125,
+ 216.12774658203125
+ ],
+ [
+ 215.53131103515625,
+ 215.55880737304688
+ ],
+ [
+ 211.28453063964844,
+ 215.3497314453125
+ ],
+ [
+ 205.66819763183594,
+ 210.34344482421875
+ ],
+ [
+ 208.09231567382812,
+ 197.720458984375
+ ],
+ [
+ 201.51205444335938,
+ 215.72598266601562
+ ],
+ [
+ 191.19480895996094,
+ 223.12850952148438
+ ],
+ [
+ 194.90512084960938,
+ 222.38108825683594
+ ],
+ [
+ 200.74607849121094,
+ 217.3187713623047
+ ],
+ [
+ 207.563720703125,
+ 235.63250732421875
+ ]
+ ],
+ [
+ [
+ 289.63397216796875,
+ 230.28970336914062
+ ],
+ [
+ 289.8543701171875,
+ 227.20205688476562
+ ],
+ [
+ 292.2384033203125,
+ 223.03854370117188
+ ],
+ [
+ 301.47711181640625,
+ 219.50289916992188
+ ],
+ [
+ 308.8260803222656,
+ 220.3004608154297
+ ],
+ [
+ 315.6751403808594,
+ 219.62095642089844
+ ],
+ [
+ 317.8089599609375,
+ 221.09295654296875
+ ],
+ [
+ 320.73956298828125,
+ 221.21011352539062
+ ],
+ [
+ 317.1898193359375,
+ 221.21250915527344
+ ],
+ [
+ 319.5433349609375,
+ 217.74606323242188
+ ],
+ [
+ 317.6147155761719,
+ 207.62603759765625
+ ],
+ [
+ 308.29156494140625,
+ 224.09878540039062
+ ],
+ [
+ 294.7052917480469,
+ 230.4814910888672
+ ],
+ [
+ 298.7985534667969,
+ 230.0016326904297
+ ],
+ [
+ 304.0728454589844,
+ 226.04998779296875
+ ],
+ [
+ 314.6731872558594,
+ 242.630126953125
+ ]
+ ],
+ [
+ [
+ 214.7900390625,
+ 230.28970336914062
+ ],
+ [
+ 214.2034912109375,
+ 226.12539672851562
+ ],
+ [
+ 216.921630859375,
+ 221.91062927246094
+ ],
+ [
+ 226.7117156982422,
+ 219.55148315429688
+ ],
+ [
+ 232.1102294921875,
+ 220.2542724609375
+ ],
+ [
+ 237.49270629882812,
+ 219.5577850341797
+ ],
+ [
+ 240.1033935546875,
+ 220.77169799804688
+ ],
+ [
+ 243.27154541015625,
+ 220.56069946289062
+ ],
+ [
+ 240.3792724609375,
+ 221.12344360351562
+ ],
+ [
+ 235.10897827148438,
+ 216.4136962890625
+ ],
+ [
+ 234.0819091796875,
+ 202.91900634765625
+ ],
+ [
+ 224.08642578125,
+ 220.4688720703125
+ ],
+ [
+ 212.40911865234375,
+ 227.7927703857422
+ ],
+ [
+ 218.22300720214844,
+ 226.47549438476562
+ ],
+ [
+ 225.32315063476562,
+ 221.8306884765625
+ ],
+ [
+ 234.59808349609375,
+ 239.94235229492188
+ ]
+ ]
+]
\ No newline at end of file
diff --git a/__assets__/demos/demo_4/trajectory.npz b/__assets__/demos/demo_4/trajectory.npz
new file mode 100644
index 0000000000000000000000000000000000000000..fcf141bee93a53a0ce17816a88c7772020042988
--- /dev/null
+++ b/__assets__/demos/demo_4/trajectory.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c2904e38cbc8820daaa5f88085bbfc33aa3cd8b9be7d9588e02d6cadcccf2fa
+size 973
diff --git a/__assets__/demos/demo_5/first_frame.jpg b/__assets__/demos/demo_5/first_frame.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..abda485d2674818e31b06c8d69a37527dda0f0ce
Binary files /dev/null and b/__assets__/demos/demo_5/first_frame.jpg differ
diff --git a/__assets__/demos/demo_5/layer_0.jpg b/__assets__/demos/demo_5/layer_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..95c55ae1c1fff7041cfeae5e572ef9efb54bae1e
Binary files /dev/null and b/__assets__/demos/demo_5/layer_0.jpg differ
diff --git a/__assets__/demos/demo_5/layer_1.jpg b/__assets__/demos/demo_5/layer_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1449431b4ac96c71ba94c6289f31391400ae5ab0
Binary files /dev/null and b/__assets__/demos/demo_5/layer_1.jpg differ
diff --git a/__assets__/demos/demo_5/sketch.mp4 b/__assets__/demos/demo_5/sketch.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..673f9c61ba5ecf5a8ca44f3bd38424ecb6a3795a
Binary files /dev/null and b/__assets__/demos/demo_5/sketch.mp4 differ
diff --git a/__assets__/demos/demo_5/trajectory.json b/__assets__/demos/demo_5/trajectory.json
new file mode 100644
index 0000000000000000000000000000000000000000..760eba5119d5bdf20fcc54a36d27ace3d152cfca
--- /dev/null
+++ b/__assets__/demos/demo_5/trajectory.json
@@ -0,0 +1,332 @@
+[
+ [
+ [
+ 494.2274169921875,
+ 22.271512985229492
+ ],
+ [
+ 499.44189453125,
+ 21.746015548706055
+ ],
+ [
+ 504.0919189453125,
+ 21.225364685058594
+ ],
+ [
+ 514.5880737304688,
+ 20.82619285583496
+ ],
+ [
+ 520.4939575195312,
+ 20.672199249267578
+ ],
+ [
+ 526.637451171875,
+ 20.305557250976562
+ ],
+ [
+ 534.9617919921875,
+ 20.358591079711914
+ ],
+ [
+ 539.2017211914062,
+ 20.12591552734375
+ ],
+ [
+ 543.9376220703125,
+ 20.107173919677734
+ ],
+ [
+ 549.5306396484375,
+ 19.739456176757812
+ ],
+ [
+ 553.4171142578125,
+ 20.842308044433594
+ ],
+ [
+ 554.49462890625,
+ 20.15322494506836
+ ],
+ [
+ 559.0555419921875,
+ 21.292396545410156
+ ],
+ [
+ 558.5130004882812,
+ 21.357444763183594
+ ],
+ [
+ 561.72607421875,
+ 20.114139556884766
+ ],
+ [
+ 560.4268798828125,
+ 21.73964500427246
+ ]
+ ],
+ [
+ [
+ 494.2274169921875,
+ 48.27378463745117
+ ],
+ [
+ 494.85711669921875,
+ 48.05669403076172
+ ],
+ [
+ 494.21563720703125,
+ 48.0822868347168
+ ],
+ [
+ 492.88446044921875,
+ 48.20854187011719
+ ],
+ [
+ 491.5914306640625,
+ 48.36796569824219
+ ],
+ [
+ 490.6370849609375,
+ 48.649070739746094
+ ],
+ [
+ 488.6202392578125,
+ 48.874202728271484
+ ],
+ [
+ 487.603271484375,
+ 49.16374969482422
+ ],
+ [
+ 486.469970703125,
+ 49.414939880371094
+ ],
+ [
+ 484.92120361328125,
+ 49.98759460449219
+ ],
+ [
+ 483.7000427246094,
+ 50.26809310913086
+ ],
+ [
+ 482.22125244140625,
+ 50.42219161987305
+ ],
+ [
+ 480.54931640625,
+ 50.766448974609375
+ ],
+ [
+ 479.24481201171875,
+ 51.03229522705078
+ ],
+ [
+ 478.1097106933594,
+ 51.489837646484375
+ ],
+ [
+ 476.470947265625,
+ 52.048194885253906
+ ]
+ ],
+ [
+ [
+ 64.8839111328125,
+ 287.4947204589844
+ ],
+ [
+ 81.71736145019531,
+ 288.09869384765625
+ ],
+ [
+ 100.02552795410156,
+ 288.89111328125
+ ],
+ [
+ 128.72686767578125,
+ 289.8943176269531
+ ],
+ [
+ 149.62322998046875,
+ 290.7263488769531
+ ],
+ [
+ 170.50192260742188,
+ 291.29925537109375
+ ],
+ [
+ 203.6192626953125,
+ 292.2691345214844
+ ],
+ [
+ 227.08547973632812,
+ 292.68035888671875
+ ],
+ [
+ 250.68621826171875,
+ 293.3591613769531
+ ],
+ [
+ 286.62176513671875,
+ 294.1515197753906
+ ],
+ [
+ 311.21240234375,
+ 294.3829650878906
+ ],
+ [
+ 335.68389892578125,
+ 294.7114562988281
+ ],
+ [
+ 373.18115234375,
+ 295.2404479980469
+ ],
+ [
+ 397.2961120605469,
+ 295.111572265625
+ ],
+ [
+ 422.346923828125,
+ 295.5068054199219
+ ],
+ [
+ 457.2431335449219,
+ 295.49383544921875
+ ]
+ ],
+ [
+ [
+ 64.8839111328125,
+ 235.4901580810547
+ ],
+ [
+ 61.33024597167969,
+ 235.5504150390625
+ ],
+ [
+ 57.36271667480469,
+ 235.6099090576172
+ ],
+ [
+ 50.592864990234375,
+ 235.9037322998047
+ ],
+ [
+ 46.184783935546875,
+ 235.94981384277344
+ ],
+ [
+ 42.2303466796875,
+ 235.8488006591797
+ ],
+ [
+ 35.333221435546875,
+ 235.73272705078125
+ ],
+ [
+ 29.864356994628906,
+ 236.13253784179688
+ ],
+ [
+ 24.596290588378906,
+ 236.366943359375
+ ],
+ [
+ 17.585124969482422,
+ 236.61953735351562
+ ],
+ [
+ 12.934989929199219,
+ 236.7737274169922
+ ],
+ [
+ 8.478790283203125,
+ 236.75421142578125
+ ],
+ [
+ 2.206012725830078,
+ 236.9993896484375
+ ],
+ [
+ -2.862123489379883,
+ 237.2617645263672
+ ],
+ [
+ -7.3507843017578125,
+ 237.2784423828125
+ ],
+ [
+ -12.782325744628906,
+ 237.2703094482422
+ ]
+ ],
+ [
+ [
+ 92.88457489013672,
+ 225.0892333984375
+ ],
+ [
+ 88.737548828125,
+ 225.09442138671875
+ ],
+ [
+ 84.08223724365234,
+ 225.36553955078125
+ ],
+ [
+ 76.90846252441406,
+ 225.7208251953125
+ ],
+ [
+ 72.26066589355469,
+ 225.9451141357422
+ ],
+ [
+ 67.7042465209961,
+ 226.13169860839844
+ ],
+ [
+ 60.917144775390625,
+ 226.32199096679688
+ ],
+ [
+ 55.98236083984375,
+ 226.5792236328125
+ ],
+ [
+ 51.30162811279297,
+ 226.9581298828125
+ ],
+ [
+ 44.654823303222656,
+ 227.06956481933594
+ ],
+ [
+ 40.06951904296875,
+ 227.15420532226562
+ ],
+ [
+ 35.59206771850586,
+ 227.13719177246094
+ ],
+ [
+ 29.056011199951172,
+ 227.17002868652344
+ ],
+ [
+ 24.805736541748047,
+ 227.24826049804688
+ ],
+ [
+ 20.537612915039062,
+ 227.34564208984375
+ ],
+ [
+ 14.309333801269531,
+ 227.30154418945312
+ ]
+ ]
+]
\ No newline at end of file
diff --git a/__assets__/demos/demo_5/trajectory.npz b/__assets__/demos/demo_5/trajectory.npz
new file mode 100644
index 0000000000000000000000000000000000000000..821d2643d46b2bf9b5b98c53a66d8fecd1fc32d6
--- /dev/null
+++ b/__assets__/demos/demo_5/trajectory.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2e9da4a1142e8210f0486ff1682fe7853e8714ecf813bfef4b9019efbc102f61
+size 1222
diff --git a/__assets__/figs/demos.gif b/__assets__/figs/demos.gif
new file mode 100644
index 0000000000000000000000000000000000000000..6a8bee693edf1c3fb75f64be6b28daa8c84cdc1d
--- /dev/null
+++ b/__assets__/figs/demos.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1fec782faeaf8433550a05a782216b449e84bb3e1c1db03cbcd2fbb25f5a0bc1
+size 10389569
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..06580f3e216d0a7622092291c62284618a60398a
--- /dev/null
+++ b/app.py
@@ -0,0 +1,651 @@
+import argparse
+import datetime
+import os
+import json
+
+import torch
+import torchvision.transforms as transforms
+from torchvision.transforms import functional as F
+
+import spaces
+from huggingface_hub import snapshot_download
+import gradio as gr
+
+from diffusers import DDIMScheduler
+
+from lvdm.models.unet import UNetModel
+from lvdm.models.autoencoder import AutoencoderKL, AutoencoderKL_Dualref
+from lvdm.models.condition import FrozenOpenCLIPEmbedder, FrozenOpenCLIPImageEmbedderV2, Resampler
+from lvdm.models.layer_controlnet import LayerControlNet
+from lvdm.pipelines.pipeline_animation import AnimationPipeline
+from lvdm.utils import generate_gaussian_heatmap, save_videos_grid, save_videos_with_traj
+
+from einops import rearrange
+import cv2
+import decord
+from PIL import Image
+import numpy as np
+from scipy.interpolate import PchipInterpolator
+
+SAVE_DIR = "outputs"
+LENGTH = 16
+WIDTH = 512
+HEIGHT = 320
+LAYER_CAPACITY = 4
+DEVICE = "cuda"
+
+os.makedirs("checkpoints", exist_ok=True)
+
+snapshot_download(
+ "Yuppie1204/LayerAnimate-Mix",
+ local_dir="checkpoints/LayerAnimate-Mix",
+)
+
+class LayerAnimate:
+
+ @spaces.GPU
+ def __init__(self):
+ self.savedir = SAVE_DIR
+ os.makedirs(self.savedir, exist_ok=True)
+
+ self.weight_dtype = torch.bfloat16
+ self.device = DEVICE
+ self.text_encoder = FrozenOpenCLIPEmbedder().eval()
+ self.image_encoder = FrozenOpenCLIPImageEmbedderV2().eval()
+
+ self.W = WIDTH
+ self.H = HEIGHT
+ self.L = LENGTH
+ self.layer_capacity = LAYER_CAPACITY
+
+ self.transforms = transforms.Compose([
+ transforms.Resize(min(self.H, self.W)),
+ transforms.CenterCrop((self.H, self.W)),
+ ])
+ self.pipeline = None
+ self.generator = None
+ # sample_grid is used to generate fixed trajectories to freeze static layers
+ self.sample_grid = np.meshgrid(np.linspace(0, self.W - 1, 10, dtype=int), np.linspace(0, self.H - 1, 10, dtype=int))
+ self.sample_grid = np.stack(self.sample_grid, axis=-1).reshape(-1, 1, 2)
+ self.sample_grid = np.repeat(self.sample_grid, self.L, axis=1) # [N, F, 2]
+
+ @spaces.GPU
+ def set_seed(self, seed):
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ self.generator = torch.Generator(self.device).manual_seed(seed)
+
+ @spaces.GPU
+ def set_model(self, pretrained_model_path):
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
+ image_projector = Resampler.from_pretrained(pretrained_model_path, subfolder="image_projector").eval()
+ vae, vae_dualref = None, None
+ if "I2V" or "Mix" in pretrained_model_path:
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").eval()
+ if "Interp" or "Mix" in pretrained_model_path:
+ vae_dualref = AutoencoderKL_Dualref.from_pretrained(pretrained_model_path, subfolder="vae_dualref").eval()
+ unet = UNetModel.from_pretrained(pretrained_model_path, subfolder="unet").eval()
+ layer_controlnet = LayerControlNet.from_pretrained(pretrained_model_path, subfolder="layer_controlnet").eval()
+
+ self.pipeline = AnimationPipeline(
+ vae=vae, vae_dualref=vae_dualref, text_encoder=self.text_encoder, image_encoder=self.image_encoder, image_projector=image_projector,
+ unet=unet, layer_controlnet=layer_controlnet, scheduler=scheduler
+ ).to(device=self.device, dtype=self.weight_dtype)
+ if "Interp" or "Mix" in pretrained_model_path:
+ self.pipeline.vae_dualref.decoder.to(dtype=torch.float32)
+ return pretrained_model_path
+
+ def upload_image(self, image):
+ image = self.transforms(image)
+ return image
+
+ def run(self, input_image, input_image_end, pretrained_model_path, seed,
+ prompt, n_prompt, num_inference_steps, guidance_scale,
+ *layer_args):
+ self.set_seed(seed)
+ global layer_tracking_points
+ args_layer_tracking_points = [layer_tracking_points[i].value for i in range(self.layer_capacity)]
+
+ args_layer_masks = layer_args[:self.layer_capacity]
+ args_layer_masks_end = layer_args[self.layer_capacity : 2 * self.layer_capacity]
+ args_layer_controls = layer_args[2 * self.layer_capacity : 3 * self.layer_capacity]
+ args_layer_scores = list(layer_args[3 * self.layer_capacity : 4 * self.layer_capacity])
+ args_layer_sketches = layer_args[4 * self.layer_capacity : 5 * self.layer_capacity]
+ args_layer_valids = layer_args[5 * self.layer_capacity : 6 * self.layer_capacity]
+ args_layer_statics = layer_args[6 * self.layer_capacity : 7 * self.layer_capacity]
+ for layer_idx in range(self.layer_capacity):
+ if args_layer_controls[layer_idx] != "score":
+ args_layer_scores[layer_idx] = -1
+ if args_layer_statics[layer_idx]:
+ args_layer_scores[layer_idx] = 0
+
+ mode = "i2v"
+ image1 = F.to_tensor(input_image) * 2 - 1
+ frame_tensor = image1[None].to(self.device) # [F, C, H, W]
+ if input_image_end is not None:
+ mode = "interpolate"
+ image2 = F.to_tensor(input_image_end) * 2 - 1
+ frame_tensor2 = image2[None].to(self.device)
+ frame_tensor = torch.cat([frame_tensor, frame_tensor2], dim=0)
+ frame_tensor = frame_tensor[None]
+
+ if mode == "interpolate":
+ layer_masks = torch.zeros((1, self.layer_capacity, 2, 1, self.H, self.W), dtype=torch.bool)
+ else:
+ layer_masks = torch.zeros((1, self.layer_capacity, 1, 1, self.H, self.W), dtype=torch.bool)
+ for layer_idx in range(self.layer_capacity):
+ if args_layer_masks[layer_idx] is not None:
+ mask = F.to_tensor(args_layer_masks[layer_idx]) > 0.5
+ layer_masks[0, layer_idx, 0] = mask
+ if args_layer_masks_end[layer_idx] is not None and mode == "interpolate":
+ mask = F.to_tensor(args_layer_masks_end[layer_idx]) > 0.5
+ layer_masks[0, layer_idx, 1] = mask
+ layer_masks = layer_masks.to(self.device)
+ layer_regions = layer_masks * frame_tensor[:, None]
+ layer_validity = torch.tensor([args_layer_valids], dtype=torch.bool, device=self.device)
+ motion_scores = torch.tensor([args_layer_scores], dtype=self.weight_dtype, device=self.device)
+ layer_static = torch.tensor([args_layer_statics], dtype=torch.bool, device=self.device)
+
+ sketch = torch.ones((1, self.layer_capacity, self.L, 3, self.H, self.W), dtype=self.weight_dtype)
+ for layer_idx in range(self.layer_capacity):
+ sketch_path = args_layer_sketches[layer_idx]
+ if sketch_path is not None:
+ video_reader = decord.VideoReader(sketch_path)
+ assert len(video_reader) == self.L, f"Input the length of sketch sequence should match the video length."
+ video_frames = video_reader.get_batch(range(self.L)).asnumpy()
+ sketch_values = [F.to_tensor(self.transforms(Image.fromarray(frame))) for frame in video_frames]
+ sketch_values = torch.stack(sketch_values) * 2 - 1
+ sketch[0, layer_idx] = sketch_values
+ sketch = sketch.to(self.device)
+
+ heatmap = torch.zeros((1, self.layer_capacity, self.L, 3, self.H, self.W), dtype=self.weight_dtype)
+ heatmap[:, :, :, 0] -= 1
+ trajectory = []
+ traj_layer_index = []
+ for layer_idx in range(self.layer_capacity):
+ tracking_points = args_layer_tracking_points[layer_idx]
+ if args_layer_statics[layer_idx]:
+ # generate pseudo trajectory for static layers
+ temp_layer_mask = layer_masks[0, layer_idx, 0, 0].cpu().numpy()
+ valid_flag = temp_layer_mask[self.sample_grid[:, 0, 1], self.sample_grid[:, 0, 0]]
+ valid_grid = self.sample_grid[valid_flag] # [F, N, 2]
+ trajectory.extend(list(valid_grid))
+ traj_layer_index.extend([layer_idx] * valid_grid.shape[0])
+ else:
+ for temp_track in tracking_points:
+ if len(temp_track) > 1:
+ x = [point[0] for point in temp_track]
+ y = [point[1] for point in temp_track]
+ t = np.linspace(0, 1, len(temp_track))
+ fx = PchipInterpolator(t, x)
+ fy = PchipInterpolator(t, y)
+ t_new = np.linspace(0, 1, self.L)
+ x_new = fx(t_new)
+ y_new = fy(t_new)
+ temp_traj = np.stack([x_new, y_new], axis=-1).astype(np.float32)
+ trajectory.append(temp_traj)
+ traj_layer_index.append(layer_idx)
+ elif len(temp_track) == 1:
+ trajectory.append(np.array(temp_track * self.L))
+ traj_layer_index.append(layer_idx)
+
+ trajectory = np.stack(trajectory)
+ trajectory = np.transpose(trajectory, (1, 0, 2))
+ traj_layer_index = np.array(traj_layer_index)
+ heatmap = generate_gaussian_heatmap(trajectory, self.W, self.H, traj_layer_index, self.layer_capacity, offset=True)
+ heatmap = rearrange(heatmap, "f n c h w -> (f n) c h w")
+ graymap, offset = heatmap[:, :1], heatmap[:, 1:]
+ graymap = graymap / 255.
+ rad = torch.sqrt(offset[:, 0:1]**2 + offset[:, 1:2]**2)
+ rad_max = torch.max(rad)
+ epsilon = 1e-5
+ offset = offset / (rad_max + epsilon)
+ graymap = graymap * 2 - 1
+ heatmap = torch.cat([graymap, offset], dim=1)
+ heatmap = rearrange(heatmap, '(f n) c h w -> n f c h w', n=self.layer_capacity)
+ heatmap = heatmap[None]
+ heatmap = heatmap.to(self.device)
+
+ sample = self.pipeline(
+ prompt,
+ self.L,
+ self.H,
+ self.W,
+ frame_tensor,
+ layer_masks = layer_masks,
+ layer_regions = layer_regions,
+ layer_static = layer_static,
+ motion_scores = motion_scores,
+ sketch = sketch,
+ trajectory = heatmap,
+ layer_validity = layer_validity,
+ num_inference_steps = num_inference_steps,
+ guidance_scale = guidance_scale,
+ guidance_rescale = 0.7,
+ negative_prompt = n_prompt,
+ num_videos_per_prompt = 1,
+ eta = 1.0,
+ generator = self.generator,
+ fps = 24,
+ mode = mode,
+ weight_dtype = self.weight_dtype,
+ output_type = "tensor",
+ ).videos
+ output_video_path = os.path.join(self.savedir, "video.mp4")
+ save_videos_grid(sample, output_video_path, fps=8)
+ output_video_traj_path = os.path.join(self.savedir, "video_with_traj.mp4")
+ vis_traj_flag = np.zeros(trajectory.shape[1], dtype=bool)
+ for traj_idx in range(trajectory.shape[1]):
+ if not args_layer_statics[traj_layer_index[traj_idx]]:
+ vis_traj_flag[traj_idx] = True
+ vis_traj = torch.from_numpy(trajectory[:, vis_traj_flag])
+ save_videos_with_traj(sample[0], vis_traj, os.path.join(self.savedir, f"video_with_traj.mp4"), fps=8, line_width=7, circle_radius=10)
+ return output_video_path, output_video_traj_path
+
+
+def update_layer_region(image, layer_mask):
+ if image is None or layer_mask is None:
+ return None, False
+ layer_mask_tensor = (F.to_tensor(layer_mask) > 0.5).float()
+ image = F.to_tensor(image)
+ layer_region = image * layer_mask_tensor
+ layer_region = F.to_pil_image(layer_region)
+ layer_region.putalpha(layer_mask)
+ return layer_region, True
+
+def control_layers(control_type):
+ if control_type == "score":
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
+ elif control_type == "trajectory":
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
+ else:
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
+
+def visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask):
+ first_mask_tensor = (F.to_tensor(first_mask) > 0.5).float()
+ first_frame = F.to_tensor(first_frame)
+ first_region = first_frame * first_mask_tensor
+ first_region = F.to_pil_image(first_region)
+ first_region.putalpha(first_mask)
+ transparent_background = first_region.convert('RGBA')
+
+ if last_frame is not None and last_mask is not None:
+ last_mask_tensor = (F.to_tensor(last_mask) > 0.5).float()
+ last_frame = F.to_tensor(last_frame)
+ last_region = last_frame * last_mask_tensor
+ last_region = F.to_pil_image(last_region)
+ last_region.putalpha(last_mask)
+ transparent_background_end = last_region.convert('RGBA')
+
+ width, height = transparent_background.size
+ transparent_layer = np.zeros((height, width, 4))
+ for track in tracking_points:
+ if len(track) > 1:
+ for i in range(len(track)-1):
+ start_point = np.array(track[i], dtype=np.int32)
+ end_point = np.array(track[i+1], dtype=np.int32)
+ vx = end_point[0] - start_point[0]
+ vy = end_point[1] - start_point[1]
+ arrow_length = max(np.sqrt(vx**2 + vy**2), 1)
+ if i == len(track)-2:
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
+ else:
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
+ elif len(track) == 1:
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
+ if last_frame is not None and last_mask is not None:
+ trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer)
+ else:
+ trajectory_map_end = None
+ return trajectory_map, trajectory_map_end
+
+def add_drag(layer_idx):
+ global layer_tracking_points
+ tracking_points = layer_tracking_points[layer_idx].value
+ tracking_points.append([])
+ return
+
+def delete_last_drag(layer_idx, first_frame, first_mask, last_frame, last_mask):
+ global layer_tracking_points
+ tracking_points = layer_tracking_points[layer_idx].value
+ tracking_points.pop()
+ trajectory_map, trajectory_map_end = visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask)
+ return trajectory_map, trajectory_map_end
+
+def delete_last_step(layer_idx, first_frame, first_mask, last_frame, last_mask):
+ global layer_tracking_points
+ tracking_points = layer_tracking_points[layer_idx].value
+ tracking_points[-1].pop()
+ trajectory_map, trajectory_map_end = visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask)
+ return trajectory_map, trajectory_map_end
+
+def add_tracking_points(layer_idx, first_frame, first_mask, last_frame, last_mask, evt: gr.SelectData): # SelectData is a subclass of EventData
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
+ global layer_tracking_points
+ tracking_points = layer_tracking_points[layer_idx].value
+ tracking_points[-1].append(evt.index)
+ trajectory_map, trajectory_map_end = visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask)
+ return trajectory_map, trajectory_map_end
+
+def reset_states(layer_idx, first_frame, first_mask, last_frame, last_mask):
+ global layer_tracking_points
+ layer_tracking_points[layer_idx].value = [[]]
+ tracking_points = layer_tracking_points[layer_idx].value
+ trajectory_map, trajectory_map_end = visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask)
+ return trajectory_map, trajectory_map_end
+
+def upload_tracking_points(tracking_path, layer_idx, first_frame, first_mask, last_frame, last_mask):
+ if tracking_path is None:
+ layer_region, _ = update_layer_region(first_frame, first_mask)
+ layer_region_end, _ = update_layer_region(last_frame, last_mask)
+ return layer_region, layer_region_end
+
+ global layer_tracking_points
+ with open(tracking_path, "r") as f:
+ tracking_points = json.load(f)
+ layer_tracking_points[layer_idx].value = tracking_points
+ trajectory_map, trajectory_map_end = visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask)
+ return trajectory_map, trajectory_map_end
+
+def reset_all_controls():
+ global layer_tracking_points
+ outputs = []
+ # Reset tracking points states
+ for layer_idx in range(LAYER_CAPACITY):
+ layer_tracking_points[layer_idx].value = [[]]
+
+ # Reset global components
+ outputs.extend([
+ "an anime scene.", # text prompt
+ "", # negative text prompt
+ 50, # inference steps
+ 7.5, # guidance scale
+ 42, # seed
+ None, # input image
+ None, # input image end
+ None, # output video
+ None, # output video with trajectory
+ ])
+ # Reset layer controls visibility
+ outputs.extend([None] * LAYER_CAPACITY) # layer masks
+ outputs.extend([None] * LAYER_CAPACITY) # layer masks end
+ outputs.extend([None] * LAYER_CAPACITY) # layer regions
+ outputs.extend([None] * LAYER_CAPACITY) # layer regions end
+ outputs.extend(["sketch"] * LAYER_CAPACITY) # layer controls
+ outputs.extend([gr.update(visible=False, value=-1) for _ in range(LAYER_CAPACITY)]) # layer score controls
+ outputs.extend([gr.update(visible=False) for _ in range(4 * LAYER_CAPACITY)]) # layer trajectory control 4 buttons
+ outputs.extend([gr.update(visible=False, value=None) for _ in range(LAYER_CAPACITY)]) # layer trajectory file
+ outputs.extend([None] * LAYER_CAPACITY) # layer sketch controls
+ outputs.extend([False] * LAYER_CAPACITY) # layer validity
+ outputs.extend([False] * LAYER_CAPACITY) # layer statics
+ return outputs
+
+if __name__ == "__main__":
+ with gr.Blocks() as demo:
+ gr.Markdown("""
LayerAnimate: Layer-level Control for Animation
""")
+
+ gr.Markdown("""Gradio Demo for LayerAnimate: Layer-level Control for Animation.
+ Github Repo can be found at https://github.com/IamCreateAI/LayerAnimate
+ The template is inspired by Framer.""")
+
+ gr.Image(label="LayerAnimate: Layer-level Control for Animation", value="__assets__/figs/demos.gif", height=540, width=960)
+
+ gr.Markdown("""## Usage:
+ 1. Select a pretrained model via the "Pretrained Model" dropdown of choices in the right column.
+ 2. Upload frames in the right column.
+ 1.1. Upload the first frame.
+ 1.2. (Optional) Upload the last frame.
+ 3. Input layer-level controls in the left column.
+ 2.1. Upload layer mask images for each layer, which can be obtained from many tools such as https://huggingface.co/spaces/yumyum2081/SAM2-Image-Predictor.
+ 2.2. Choose a control type from "motion score", "trajectory" and "sketch".
+ 2.3. For trajectory control, you can draw trajectories on layer regions.
+ 2.3.1. Click "Add New Trajectory" to add a new trajectory.
+ 2.3.2. Click "Reset" to reset all trajectories.
+ 2.3.3. Click "Delete Last Step" to delete the lastest clicked control point.
+ 2.3.4. Click "Delete Last Trajectory" to delete the whole lastest path.
+ 2.3.5. Or upload a trajectory file in json format, we provide examples below.
+ 2.4. For sketch control, you can upload a sketch video.
+ 4. We provide four layers for you to control, and it is not necessary to use all of them.
+ 5. Click "Run" button to generate videos.
+ 6. **Note: Remember to click "Clear" button to clear all the controls before switching to another example.**
+ """)
+
+ layeranimate = LayerAnimate()
+ layer_indices = [gr.Number(value=i, visible=False) for i in range(LAYER_CAPACITY)]
+ layer_tracking_points = [gr.State([[]]) for _ in range(LAYER_CAPACITY)]
+ layer_masks = []
+ layer_masks_end = []
+ layer_regions = []
+ layer_regions_end = []
+ layer_controls = []
+ layer_score_controls = []
+ layer_traj_controls = []
+ layer_traj_files = []
+ layer_sketch_controls = []
+ layer_statics = []
+ layer_valids = []
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ for layer_idx in range(LAYER_CAPACITY):
+ with gr.Accordion(label=f"Layer {layer_idx+1}", open=True if layer_idx == 0 else False):
+ gr.Markdown("""Layer Masks
""")
+ gr.Markdown("**Note**: Layer mask for the last frame is not required in I2V mode.")
+ with gr.Row():
+ with gr.Column():
+ layer_masks.append(gr.Image(
+ label="Layer Mask for First Frame",
+ height=320,
+ width=512,
+ image_mode="L",
+ type="pil",
+ ))
+
+ with gr.Column():
+ layer_masks_end.append(gr.Image(
+ label="Layer Mask for Last Frame",
+ height=320,
+ width=512,
+ image_mode="L",
+ type="pil",
+ ))
+ gr.Markdown("""Layer Regions
""")
+ with gr.Row():
+ with gr.Column():
+ layer_regions.append(gr.Image(
+ label="Layer Region for First Frame",
+ height=320,
+ width=512,
+ image_mode="RGBA",
+ type="pil",
+ # value=Image.new("RGBA", (512, 320), (255, 255, 255, 0)),
+ ))
+
+ with gr.Column():
+ layer_regions_end.append(gr.Image(
+ label="Layer Region for Last Frame",
+ height=320,
+ width=512,
+ image_mode="RGBA",
+ type="pil",
+ # value=Image.new("RGBA", (512, 320), (255, 255, 255, 0)),
+ ))
+ layer_controls.append(
+ gr.Radio(["score", "trajectory", "sketch"], label="Choose A Control Type", value="sketch")
+ )
+ layer_score_controls.append(
+ gr.Number(label="Motion Score", value=-1, visible=False)
+ )
+ layer_traj_controls.append(
+ [
+ gr.Button(value="Add New Trajectory", visible=False),
+ gr.Button(value="Reset", visible=False),
+ gr.Button(value="Delete Last Step", visible=False),
+ gr.Button(value="Delete Last Trajectory", visible=False),
+ ]
+ )
+ layer_traj_files.append(
+ gr.File(label="Trajectory File", visible=False)
+ )
+ layer_sketch_controls.append(
+ gr.Video(label="Sketch", height=320, width=512, visible=True)
+ )
+ layer_controls[layer_idx].change(
+ fn=control_layers,
+ inputs=layer_controls[layer_idx],
+ outputs=[layer_score_controls[layer_idx], *layer_traj_controls[layer_idx], layer_traj_files[layer_idx], layer_sketch_controls[layer_idx]]
+ )
+ with gr.Row():
+ layer_valids.append(gr.Checkbox(label="Valid", info="Is the layer valid?"))
+ layer_statics.append(gr.Checkbox(label="Static", info="Is the layer static?"))
+
+ with gr.Column(scale=1):
+ pretrained_model_path = gr.Dropdown(
+ label="Pretrained Model",
+ choices=[
+ "None",
+ "checkpoints/LayerAnimate-Mix",
+ ],
+ value="None",
+ )
+ text_prompt = gr.Textbox(label="Text Prompt", value="an anime scene.")
+ text_n_prompt = gr.Textbox(label="Negative Text Prompt", value="")
+ with gr.Row():
+ num_inference_steps = gr.Number(label="Inference Steps", value=50, minimum=1, maximum=1000)
+ guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
+ seed = gr.Number(label="Seed", value=42)
+ with gr.Row():
+ input_image = gr.Image(
+ label="First Frame",
+ height=320,
+ width=512,
+ type="pil",
+ )
+ input_image_end = gr.Image(
+ label="Last Frame",
+ height=320,
+ width=512,
+ type="pil",
+ )
+ run_button = gr.Button(value="Run")
+ with gr.Row():
+ output_video = gr.Video(
+ label="Output Video",
+ height=320,
+ width=512,
+ )
+ output_video_traj = gr.Video(
+ label="Output Video with Trajectory",
+ height=320,
+ width=512,
+ )
+ clear_button = gr.Button(value="Clear")
+
+ with gr.Row():
+ gr.Markdown("""
+ ## Citation
+ ```bibtex
+ @article{yang2025layeranimate,
+ author = {Yang, Yuxue and Fan, Lue and Lin, Zuzeng and Wang, Feng and Zhang, Zhaoxiang},
+ title = {LayerAnimate: Layer-level Control for Animation},
+ journal = {arXiv preprint arXiv:2501.08295},
+ year = {2025},
+ }
+ ```
+ """)
+
+ pretrained_model_path.input(layeranimate.set_model, pretrained_model_path, pretrained_model_path)
+ input_image.upload(layeranimate.upload_image, input_image, input_image)
+ input_image_end.upload(layeranimate.upload_image, input_image_end, input_image_end)
+ for i in range(LAYER_CAPACITY):
+ layer_masks[i].upload(layeranimate.upload_image, layer_masks[i], layer_masks[i])
+ layer_masks[i].change(update_layer_region, [input_image, layer_masks[i]], [layer_regions[i], layer_valids[i]])
+ layer_masks_end[i].upload(layeranimate.upload_image, layer_masks_end[i], layer_masks_end[i])
+ layer_masks_end[i].change(update_layer_region, [input_image_end, layer_masks_end[i]], [layer_regions_end[i], layer_valids[i]])
+ layer_traj_controls[i][0].click(add_drag, layer_indices[i], None)
+ layer_traj_controls[i][1].click(
+ reset_states,
+ [layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]],
+ [layer_regions[i], layer_regions_end[i]]
+ )
+ layer_traj_controls[i][2].click(
+ delete_last_step,
+ [layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]],
+ [layer_regions[i], layer_regions_end[i]]
+ )
+ layer_traj_controls[i][3].click(
+ delete_last_drag,
+ [layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]],
+ [layer_regions[i], layer_regions_end[i]]
+ )
+ layer_traj_files[i].change(
+ upload_tracking_points,
+ [layer_traj_files[i], layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]],
+ [layer_regions[i], layer_regions_end[i]]
+ )
+ layer_regions[i].select(
+ add_tracking_points,
+ [layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]],
+ [layer_regions[i], layer_regions_end[i]]
+ )
+ layer_regions_end[i].select(
+ add_tracking_points,
+ [layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]],
+ [layer_regions[i], layer_regions_end[i]]
+ )
+ run_button.click(
+ layeranimate.run,
+ [input_image, input_image_end, pretrained_model_path, seed, text_prompt, text_n_prompt, num_inference_steps, guidance_scale,
+ *layer_masks, *layer_masks_end, *layer_controls, *layer_score_controls, *layer_sketch_controls, *layer_valids, *layer_statics],
+ [output_video, output_video_traj]
+ )
+ clear_button.click(
+ reset_all_controls,
+ [],
+ [
+ text_prompt, text_n_prompt, num_inference_steps, guidance_scale, seed,
+ input_image, input_image_end, output_video, output_video_traj,
+ *layer_masks, *layer_masks_end, *layer_regions, *layer_regions_end,
+ *layer_controls, *layer_score_controls, *[button for temp_layer_controls in layer_traj_controls for button in temp_layer_controls], *layer_traj_files,
+ *layer_sketch_controls, *layer_valids, *layer_statics
+ ]
+ )
+ examples = gr.Examples(
+ examples=[
+ [
+ "__assets__/demos/demo_3/first_frame.jpg", "__assets__/demos/demo_3/last_frame.jpg",
+ "score", "__assets__/demos/demo_3/layer_0.jpg", "__assets__/demos/demo_3/layer_0_last.jpg", 0.4, None, None, True, False,
+ "score", "__assets__/demos/demo_3/layer_1.jpg", "__assets__/demos/demo_3/layer_1_last.jpg", 0.2, None, None, True, False,
+ "trajectory", "__assets__/demos/demo_3/layer_2.jpg", "__assets__/demos/demo_3/layer_2_last.jpg", -1, "__assets__/demos/demo_3/trajectory.json", None, True, False,
+ "sketch", "__assets__/demos/demo_3/layer_3.jpg", "__assets__/demos/demo_3/layer_3_last.jpg", -1, None, "__assets__/demos/demo_3/sketch.mp4", True, False,
+ 52
+ ],
+ [
+ "__assets__/demos/demo_4/first_frame.jpg", None,
+ "score", "__assets__/demos/demo_4/layer_0.jpg", None, 0.0, None, None, True, True,
+ "trajectory", "__assets__/demos/demo_4/layer_1.jpg", None, -1, "__assets__/demos/demo_4/trajectory.json", None, True, False,
+ "sketch", "__assets__/demos/demo_4/layer_2.jpg", None, -1, None, "__assets__/demos/demo_4/sketch.mp4", True, False,
+ "score", None, None, -1, None, None, False, False,
+ 42
+ ],
+ [
+ "__assets__/demos/demo_5/first_frame.jpg", None,
+ "sketch", "__assets__/demos/demo_5/layer_0.jpg", None, -1, None, "__assets__/demos/demo_5/sketch.mp4", True, False,
+ "trajectory", "__assets__/demos/demo_5/layer_1.jpg", None, -1, "__assets__/demos/demo_5/trajectory.json", None, True, False,
+ "score", None, None, -1, None, None, False, False,
+ "score", None, None, -1, None, None, False, False,
+ 47
+ ],
+ ],
+ inputs=[
+ input_image, input_image_end,
+ layer_controls[0], layer_masks[0], layer_masks_end[0], layer_score_controls[0], layer_traj_files[0], layer_sketch_controls[0], layer_valids[0], layer_statics[0],
+ layer_controls[1], layer_masks[1], layer_masks_end[1], layer_score_controls[1], layer_traj_files[1], layer_sketch_controls[1], layer_valids[1], layer_statics[1],
+ layer_controls[2], layer_masks[2], layer_masks_end[2], layer_score_controls[2], layer_traj_files[2], layer_sketch_controls[2], layer_valids[2], layer_statics[2],
+ layer_controls[3], layer_masks[3], layer_masks_end[3], layer_score_controls[3], layer_traj_files[3], layer_sketch_controls[3], layer_valids[3], layer_statics[3],
+ seed
+ ],
+ )
+ demo.launch()
\ No newline at end of file
diff --git a/lvdm/basics.py b/lvdm/basics.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdc35f8f848da706dbad8f043a475d761e8df289
--- /dev/null
+++ b/lvdm/basics.py
@@ -0,0 +1,100 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+import torch.nn as nn
+from .utils import instantiate_from_config
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def nonlinearity(type='silu'):
+ if type == 'silu':
+ return nn.SiLU()
+ elif type == 'leaky_relu':
+ return nn.LeakyReLU()
+
+
+class GroupNormSpecific(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def normalization(channels, num_groups=32):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNormSpecific(num_groups, channels)
+
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
\ No newline at end of file
diff --git a/lvdm/common.py b/lvdm/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..55a150b618e275f01d3a59ad9c7579176c4ea1b8
--- /dev/null
+++ b/lvdm/common.py
@@ -0,0 +1,94 @@
+import math
+from inspect import isfunction
+import torch
+from torch import nn
+import torch.distributed as dist
+
+
+def gather_data(data, return_np=True):
+ ''' gather data from multiple processes to one list '''
+ data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
+ dist.all_gather(data_list, data) # gather not supported with NCCL
+ if return_np:
+ data_list = [data.cpu().numpy() for data in data_list]
+ return data_list
+
+def autocast(f):
+ def do_autocast(*args, **kwargs):
+ with torch.cuda.amp.autocast(enabled=True,
+ dtype=torch.get_autocast_gpu_dtype(),
+ cache_enabled=torch.is_autocast_cache_enabled()):
+ return f(*args, **kwargs)
+ return do_autocast
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+def exists(val):
+ return val is not None
+
+def identity(*args, **kwargs):
+ return nn.Identity()
+
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+def isimage(x):
+ if not isinstance(x,torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+def shape_to_str(x):
+ shape_str = "x".join([str(x) for x in x.shape])
+ return shape_str
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+ckpt = torch.utils.checkpoint.checkpoint
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ return ckpt(func, *inputs, use_reentrant=False)
+ else:
+ return func(*inputs)
\ No newline at end of file
diff --git a/lvdm/models/autoencoder.py b/lvdm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f3ff37a5d32450cc995ba97638b60358b3f2ce9
--- /dev/null
+++ b/lvdm/models/autoencoder.py
@@ -0,0 +1,143 @@
+import os
+from functools import partial
+from dataclasses import dataclass
+
+import torch
+import numpy as np
+from einops import rearrange
+import torch.nn.functional as F
+from torch.utils.checkpoint import checkpoint
+from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models import ModelMixin
+from diffusers.utils import BaseOutput
+
+from ..modules.ae_modules import Encoder, Decoder
+from ..modules.ae_dualref_modules import VideoDecoder
+from ..utils import instantiate_from_config
+
+
+@dataclass
+class DecoderOutput(BaseOutput):
+ """
+ Output of decoding method.
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Decoded output sample of the model. Output of the last layer of the model.
+ """
+
+ sample: torch.FloatTensor
+
+
+@dataclass
+class AutoencoderKLOutput(BaseOutput):
+ """
+ Output of AutoencoderKL encoding method.
+
+ Args:
+ latent_dist (`DiagonalGaussianDistribution`):
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
+ """
+
+ latent_dist: "DiagonalGaussianDistribution"
+
+
+class AutoencoderKL(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(self,
+ ddconfig,
+ embed_dim,
+ image_key="image",
+ input_dim=4,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ self.input_dim = input_dim
+ self.use_checkpoint = use_checkpoint
+
+ def encode(self, x, return_hidden_states=False, **kwargs):
+ if return_hidden_states:
+ h, hidden = self.encoder(x, return_hidden_states)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return AutoencoderKLOutput(latent_dist=posterior), hidden
+ else:
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def decode(self, z, **kwargs):
+ if len(kwargs) == 0: ## use the original decoder in AutoencoderKL
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z, **kwargs) ##change for SVD decoder by adding **kwargs
+ return dec
+
+ def forward(self, input, sample_posterior=True, **additional_decode_kwargs):
+ input_tuple = (input, )
+ forward_temp = partial(self._forward, sample_posterior=sample_posterior, **additional_decode_kwargs)
+ return checkpoint(forward_temp, input_tuple, self.parameters(), self.use_checkpoint)
+
+
+ def _forward(self, input, sample_posterior=True, **additional_decode_kwargs):
+ posterior = self.encode(input)[0]
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, **additional_decode_kwargs)
+ ## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256])
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if x.dim() == 5 and self.input_dim == 4:
+ b,c,t,h,w = x.shape
+ self.b = b
+ self.t = t
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
+
+ return x
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+
+class AutoencoderKL_Dualref(AutoencoderKL):
+ @register_to_config
+ def __init__(self,
+ ddconfig,
+ embed_dim,
+ image_key="image",
+ input_dim=4,
+ use_checkpoint=False,
+ ):
+ super().__init__(ddconfig, embed_dim, image_key, input_dim, use_checkpoint)
+ self.decoder = VideoDecoder(**ddconfig)
+
+ def _forward(self, input, batch_size, sample_posterior=True, **additional_decode_kwargs):
+ posterior, hidden_states = self.encode(input, return_hidden_states=True)
+
+ hidden_states_first_last = []
+ ### use only the first and last hidden states
+ for hid in hidden_states:
+ hid = rearrange(hid, '(b t) c h w -> b c t h w', b=batch_size)
+ hid_new = torch.cat([hid[:, :, 0:1], hid[:, :, -1:]], dim=2)
+ hidden_states_first_last.append(hid_new)
+
+ if sample_posterior:
+ z = posterior[0].sample()
+ else:
+ z = posterior[0].mode()
+ dec = self.decode(z, ref_context=hidden_states_first_last, **additional_decode_kwargs)
+ ## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256])
+ return dec, posterior
\ No newline at end of file
diff --git a/lvdm/models/condition.py b/lvdm/models/condition.py
new file mode 100644
index 0000000000000000000000000000000000000000..79dfd51b82368af9488d3b9cacb7b86f66ddfd77
--- /dev/null
+++ b/lvdm/models/condition.py
@@ -0,0 +1,477 @@
+import math
+import torch
+import torch.nn as nn
+from torchvision.transforms import functional as F
+import open_clip
+from torch.utils.checkpoint import checkpoint
+from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models import ModelMixin
+from ..common import autocast
+from ..utils import count_params
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+class IdentityEncoder(AbstractEncoder):
+ def encode(self, x):
+ return x
+
+
+class ClassEmbedder(nn.Module):
+ def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
+ super().__init__()
+ self.key = key
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+ self.n_classes = n_classes
+ self.ucg_rate = ucg_rate
+
+ def forward(self, batch, key=None, disable_dropout=False):
+ if key is None:
+ key = self.key
+ # this is for use in crossattn
+ c = batch[key][:, None]
+ if self.ucg_rate > 0. and not disable_dropout:
+ mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
+ c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
+ c = c.long()
+ c = self.embedding(c)
+ return c
+
+ def get_unconditional_conditioning(self, bs, device="cuda"):
+ uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+ uc = torch.ones((bs,), device=device) * uc_class
+ uc = {self.key: uc}
+ return uc
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class FrozenT5Embedder(AbstractEncoder):
+ """Uses the T5 transformer encoder for text"""
+
+ def __init__(self, version="google/t5-v1_1-large", max_length=77,
+ freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.max_length = max_length # TODO: typical value?
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ # self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+ LAYERS = [
+ "last",
+ "pooled",
+ "hidden"
+ ]
+
+ def __init__(self, version="openai/clip-vit-large-patch14", max_length=77,
+ freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
+ super().__init__()
+ assert layer in self.LAYERS
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ self.layer_idx = layer_idx
+ if layer == "hidden":
+ assert layer_idx is not None
+ assert 0 <= abs(layer_idx) <= 12
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ # self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
+ if self.layer == "last":
+ z = outputs.last_hidden_state
+ elif self.layer == "pooled":
+ z = outputs.pooler_output[:, None, :]
+ else:
+ z = outputs.hidden_states[self.layer_idx]
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ LAYERS = [
+ # "pooled",
+ "last",
+ "penultimate"
+ ]
+
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", max_length=77,
+ freeze=True, layer="penultimate"):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
+ del model.visual
+ self.model = model
+
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = open_clip.tokenize(text) ## all clip models use 77 as context length
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP vision transformer encoder for images
+ """
+
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", max_length=77,
+ freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
+ super().__init__()
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
+ pretrained=version, )
+ del model.transformer
+ self.model = model
+ self.preprocess_val = preprocess_val
+ # self.mapper = torch.nn.Linear(1280, 1024)
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "penultimate":
+ raise NotImplementedError()
+ self.layer_idx = 1
+
+ self.antialias = antialias
+
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+ self.ucg_rate = ucg_rate
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = F.resize(x, (224, 224), interpolation=F.InterpolationMode.BICUBIC, antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = F.normalize(x, mean=self.mean, std=self.std)
+ return x
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, image, no_dropout=False):
+ z = self.encode_with_vision_transformer(image)
+ if self.ucg_rate > 0. and not no_dropout:
+ z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
+ return z
+
+ def encode_with_vision_transformer(self, img):
+ img = self.preprocess(img)
+ x = self.model.visual(img)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
+ """
+ Uses the OpenCLIP vision transformer encoder for images
+ """
+
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k",
+ freeze=True, layer="pooled", antialias=True):
+ super().__init__()
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
+ pretrained=version, )
+ del model.transformer
+ self.model = model
+ self.preprocess_val = preprocess_val
+
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "penultimate":
+ raise NotImplementedError()
+ self.layer_idx = 1
+
+ self.antialias = antialias
+
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = F.resize(x, (224, 224), interpolation=F.InterpolationMode.BICUBIC, antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = F.normalize(x, mean=self.mean, std=self.std)
+ return x
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ def forward(self, image, no_dropout=False):
+ ## image: b c h w
+ z = self.encode_with_vision_transformer(image)
+ return z
+
+ def encode_with_vision_transformer(self, x):
+ x = self.preprocess(x)
+
+ # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
+ if self.model.visual.input_patchnorm:
+ # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
+ x = x.reshape(x.shape[0], x.shape[1], self.model.visual.grid_size[0], self.model.visual.patch_size[0], self.model.visual.grid_size[1], self.model.visual.patch_size[1])
+ x = x.permute(0, 2, 4, 1, 3, 5)
+ x = x.reshape(x.shape[0], self.model.visual.grid_size[0] * self.model.visual.grid_size[1], -1)
+ x = self.model.visual.patchnorm_pre_ln(x)
+ x = self.model.visual.conv1(x)
+ else:
+ x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+
+ # class embeddings and positional embeddings
+ x = torch.cat(
+ [self.model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.model.visual.positional_embedding.to(x.dtype)
+
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
+ x = self.model.visual.patch_dropout(x)
+ x = self.model.visual.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.model.visual.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ return x
+
+class FrozenCLIPT5Encoder(AbstractEncoder):
+ def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl",
+ clip_max_length=77, t5_max_length=77):
+ super().__init__()
+ self.clip_encoder = FrozenCLIPEmbedder(clip_version, max_length=clip_max_length)
+ self.t5_encoder = FrozenT5Embedder(t5_version, max_length=t5_max_length)
+ print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
+
+ def encode(self, text):
+ return self(text)
+
+ def forward(self, text):
+ clip_z = self.clip_encoder.encode(text)
+ t5_z = self.t5_encoder.encode(text)
+ return [clip_z, t5_z]
+
+
+# FFN
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+
+def reshape_tensor(x, heads):
+ bs, length, width = x.shape
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, length, heads, -1)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs, heads, length, -1)
+ return x
+
+
+class PerceiverAttention(nn.Module):
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, n2, D)
+ """
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, l, _ = latents.shape
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ q = reshape_tensor(q, self.heads)
+ k = reshape_tensor(k, self.heads)
+ v = reshape_tensor(v, self.heads)
+
+ # attention
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ out = weight @ v
+
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+ return self.to_out(out)
+
+
+class Resampler(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ dim=1024,
+ depth=8,
+ dim_head=64,
+ heads=16,
+ num_queries=8,
+ embedding_dim=768,
+ output_dim=1024,
+ ff_mult=4,
+ video_length=None, # using frame-wise version or not
+ ):
+ super().__init__()
+ ## queries for a single frame / image
+ self.num_queries = num_queries
+ self.video_length = video_length
+
+ ## queries for each frame
+ if video_length is not None:
+ num_queries = num_queries * video_length
+
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
+ self.proj_in = nn.Linear(embedding_dim, dim)
+ self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(output_dim)
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]
+ )
+ )
+
+ def forward(self, x):
+ latents = self.latents.repeat(x.size(0), 1, 1) ## B (T L) C
+ x = self.proj_in(x)
+
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+
+ latents = self.proj_out(latents)
+ latents = self.norm_out(latents) # B L C or B (T L) C
+
+ return latents
\ No newline at end of file
diff --git a/lvdm/models/controlnet.py b/lvdm/models/controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0361c85db2c2f61240f58c6f6b95e489f565ee1
--- /dev/null
+++ b/lvdm/models/controlnet.py
@@ -0,0 +1,500 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+from einops import rearrange, repeat
+import numpy as np
+from functools import partial
+import torch
+from torch import nn
+from torch.nn import functional as F
+from .unet import TimestepEmbedSequential, ResBlock, Downsample, Upsample, TemporalConvBlock
+from ..basics import zero_module, conv_nd
+from ..modules.attention import SpatialTransformer, TemporalTransformer
+from ..common import checkpoint
+
+from diffusers import __version__
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from diffusers.models.model_loading_utils import load_state_dict
+from diffusers.utils import (
+ SAFETENSORS_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+ logging,
+ _get_model_file,
+ _add_variant
+)
+from omegaconf import ListConfig, DictConfig, OmegaConf
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class ResBlock_v2(nn.Module):
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ dims=2,
+ use_checkpoint=False,
+ use_conv=False,
+ up=False,
+ down=False,
+ use_temporal_conv=False,
+ tempspatial_aware=False
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_temporal_conv = use_temporal_conv
+
+ self.in_layers = nn.Sequential(
+ nn.GroupNorm(32, channels),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, channels, self.out_channels, 3, padding=1)),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ if self.use_temporal_conv:
+ self.temopral_conv = TemporalConvBlock(
+ self.out_channels,
+ self.out_channels,
+ dropout=0.1,
+ spatial_aware=tempspatial_aware
+ )
+
+ def forward(self, x, batch_size=None):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ input_tuple = (x, )
+ if batch_size:
+ forward_batchsize = partial(self._forward, batch_size=batch_size)
+ return checkpoint(forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint)
+ return checkpoint(self._forward, input_tuple, self.parameters(), self.use_checkpoint)
+
+ def _forward(self, x, batch_size=None):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ h = self.skip_connection(x) + h
+
+ if self.use_temporal_conv and batch_size:
+ h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
+ h = self.temopral_conv(h)
+ h = rearrange(h, 'b c t h w -> (b t) c h w')
+ return h
+
+
+class TrajectoryEncoder(nn.Module):
+ def __init__(self, cin, time_embed_dim, channels=[320, 640, 1280, 1280], nums_rb=3,
+ dropout=0.0, use_checkpoint=False, tempspatial_aware=False, temporal_conv=False):
+ super(TrajectoryEncoder, self).__init__()
+ # self.unshuffle = nn.PixelUnshuffle(8)
+ self.channels = channels
+ self.nums_rb = nums_rb
+ self.body = []
+ # self.conv_out = []
+ for i in range(len(channels)):
+ for j in range(nums_rb):
+ if (i != 0) and (j == 0):
+ self.body.append(
+ ResBlock_v2(channels[i - 1], time_embed_dim, dropout,
+ out_channels=channels[i], dims=2, use_checkpoint=use_checkpoint,
+ tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv,
+ down=True
+ )
+ )
+ else:
+ self.body.append(
+ ResBlock_v2(channels[i], time_embed_dim, dropout,
+ out_channels=channels[i], dims=2, use_checkpoint=use_checkpoint,
+ tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv,
+ down=False
+ )
+ )
+ self.body.append(
+ ResBlock_v2(channels[-1], time_embed_dim, dropout,
+ out_channels=channels[-1], dims=2, use_checkpoint=use_checkpoint,
+ tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv,
+ down=True
+ )
+ )
+ self.body = nn.ModuleList(self.body)
+ self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
+ self.conv_out = zero_module(conv_nd(2, channels[-1], channels[-1], 3, 1, 1))
+
+ def forward(self, x, batch_size=None):
+ # unshuffle
+ # x = self.unshuffle(x)
+ # extract features
+ # features = []
+ x = self.conv_in(x)
+ for i in range(len(self.channels)):
+ for j in range(self.nums_rb):
+ idx = i * self.nums_rb + j
+ x = self.body[idx](x, batch_size)
+ x = self.body[-1](x, batch_size)
+ out = self.conv_out(x)
+ return out
+
+
+class ControlNet(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0.0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ context_dim=None,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ transformer_depth=1,
+ use_linear=False,
+ use_checkpoint=False,
+ temporal_conv=False,
+ tempspatial_aware=False,
+ temporal_attention=True,
+ use_relative_position=True,
+ use_causal_attention=False,
+ temporal_length=None,
+ addition_attention=False,
+ temporal_selfatt_only=True,
+ image_cross_attention=False,
+ image_cross_attention_scale_learnable=False,
+ default_fps=4,
+ fps_condition=False,
+ ignore_noisy_latents=True,
+ conditioning_channels=4,
+ ):
+ super().__init__()
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.temporal_attention = temporal_attention
+ time_embed_dim = model_channels * 4
+ self.use_checkpoint = use_checkpoint
+ temporal_self_att_only = True
+ self.addition_attention = addition_attention
+ self.temporal_length = temporal_length
+ self.image_cross_attention = image_cross_attention
+ self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable
+ self.default_fps = default_fps
+ self.fps_condition = fps_condition
+ self.ignore_noisy_latents = ignore_noisy_latents
+
+ ## Time embedding blocks
+ self.time_proj = Timesteps(model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.time_embed = TimestepEmbedding(model_channels, time_embed_dim)
+
+ if fps_condition:
+ self.fps_embedding = TimestepEmbedding(model_channels, time_embed_dim)
+ nn.init.zeros_(self.fps_embedding.linear_2.weight)
+ nn.init.zeros_(self.fps_embedding.linear_2.bias)
+
+ # self.cond_embedding = TrajectoryEncoder(
+ # cin=conditioning_channels, time_embed_dim=time_embed_dim, channels=trajectory_channels, nums_rb=3,
+ # dropout=dropout, use_checkpoint=use_checkpoint, tempspatial_aware=tempspatial_aware, temporal_conv=False
+ # )
+ self.cond_embedding = zero_module(conv_nd(dims, conditioning_channels, model_channels, 3, padding=1))
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
+ ]
+ )
+
+ ## Output Block
+ self.downsample_output = nn.ModuleList(
+ [
+ nn.Sequential(
+ nn.GroupNorm(32, model_channels),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, model_channels, 3, padding=1))
+ )
+ ]
+ )
+
+ if self.addition_attention:
+ self.init_attn = TimestepEmbedSequential(
+ TemporalTransformer(
+ model_channels,
+ n_heads=8,
+ d_head=num_head_channels,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only,
+ causal_attention=False, relative_position=use_relative_position,
+ temporal_length=temporal_length
+ )
+ )
+
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(ch, time_embed_dim, dropout,
+ out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ layers.append(
+ SpatialTransformer(ch, num_heads, dim_head,
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
+ use_checkpoint=use_checkpoint, disable_self_attn=False,
+ video_length=temporal_length, image_cross_attention=self.image_cross_attention,
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
+ )
+ )
+ if self.temporal_attention:
+ layers.append(
+ TemporalTransformer(ch, num_heads, dim_head,
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
+ use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
+ causal_attention=use_causal_attention, relative_position=use_relative_position,
+ temporal_length=temporal_length
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self.downsample_output.append(
+ nn.Sequential(
+ nn.GroupNorm(32, ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, ch, ch, 3, padding=1))
+ )
+ )
+ if level < len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(ch, time_embed_dim, dropout,
+ out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True
+ )
+ if resblock_updown
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ )
+ self.downsample_output.append(
+ nn.Sequential(
+ nn.GroupNorm(32, out_ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, out_ch, out_ch, 3, padding=1))
+ )
+ )
+ ch = out_ch
+ ds *= 2
+
+ def forward(
+ self,
+ noisy_latents,
+ timesteps,
+ context_text,
+ context_img=None,
+ fps=None,
+ condition=None, # [b, t, c, h, w]
+ ):
+ if self.ignore_noisy_latents:
+ noisy_latents = torch.zeros_like(noisy_latents)
+
+ b, _, t, height, width = noisy_latents.shape
+ t_emb = self.time_proj(timesteps).type(noisy_latents.dtype)
+ emb = self.time_embed(t_emb)
+
+ ## repeat t times for context [(b t) 77 768] & time embedding
+ ## check if we use per-frame image conditioning
+ if context_img is not None: ## decompose context into text and image
+ context_text = context_text.repeat_interleave(repeats=t, dim=0)
+ context_img = rearrange(context_img, 'b (t l) c -> (b t) l c', t=t)
+ context = torch.cat([context_text, context_img], dim=1)
+ else:
+ context = context_text.repeat_interleave(repeats=t, dim=0)
+ emb = emb.repeat_interleave(repeats=t, dim=0)
+
+ ## always in shape (b n t) c h w, except for temporal layer
+ noisy_latents = rearrange(noisy_latents, 'b c t h w -> (b t) c h w')
+ condition = rearrange(condition, 'b t c h w -> (b t) c h w')
+
+ ## combine emb
+ if self.fps_condition:
+ if fps is None:
+ fps = torch.tensor(
+ [self.default_fs] * b, dtype=torch.long, device=noisy_latents.device)
+ fps_emb = self.time_proj(fps).type(noisy_latents.dtype)
+
+ fps_embed = self.fps_embedding(fps_emb)
+ fps_embed = fps_embed.repeat_interleave(repeats=t, dim=0)
+ emb = emb + fps_embed
+
+ h = noisy_latents.type(self.dtype)
+ hs = []
+ for id, module in enumerate(self.input_blocks):
+ h = module(h, emb, context=context, batch_size=b)
+ if id == 0:
+ h = h + self.cond_embedding(condition)
+ if self.addition_attention:
+ h = self.init_attn(h, emb, context=context, batch_size=b)
+ hs.append(h)
+
+ guidance_feature_list = []
+ for hidden, module in zip(hs, self.downsample_output):
+ h = module(hidden)
+ guidance_feature_list.append(h)
+
+ return guidance_feature_list
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, layer_encoder_additional_kwargs={}, **kwargs):
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ variant = kwargs.pop("variant", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ # Load config if we don't provide a configuration
+ config_path = pretrained_model_name_or_path
+
+ user_agent = {
+ "diffusers": __version__,
+ "file_type": "model",
+ "framework": "pytorch",
+ }
+
+ # load config
+ config, unused_kwargs, commit_hash = cls.load_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ return_commit_hash=True,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ **kwargs,
+ )
+
+ for key, value in layer_encoder_additional_kwargs.items():
+ if isinstance(value, (ListConfig, DictConfig)):
+ config[key] = OmegaConf.to_container(value, resolve=True)
+ else:
+ config[key] = value
+
+ # load model
+ model_file = None
+ if use_safetensors:
+ try:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ commit_hash=commit_hash,
+ )
+
+ except IOError as e:
+ logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
+ if not allow_pickle:
+ raise
+ logger.warning(
+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
+ )
+
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ commit_hash=commit_hash,
+ )
+
+ model = cls.from_config(config, **unused_kwargs)
+ state_dict = load_state_dict(model_file, variant)
+
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
+ print(f"Controlnet loaded from {model_file} with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys.")
+ return model
\ No newline at end of file
diff --git a/lvdm/models/layer_controlnet.py b/lvdm/models/layer_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb920f72da55418c929999188c096b91f6102f52
--- /dev/null
+++ b/lvdm/models/layer_controlnet.py
@@ -0,0 +1,444 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+from einops import rearrange, repeat
+import numpy as np
+from functools import partial
+import torch
+from torch import nn
+from torch.nn import functional as F
+from .unet import TimestepEmbedSequential, ResBlock, Downsample, Upsample, TemporalConvBlock
+from ..basics import zero_module, conv_nd
+from ..modules.attention import SpatialTransformer, TemporalTransformer
+from ..common import checkpoint
+
+from diffusers import __version__
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from diffusers.models.model_loading_utils import load_state_dict
+from diffusers.utils import (
+ SAFETENSORS_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+ logging,
+ _get_model_file,
+ _add_variant
+)
+from omegaconf import ListConfig, DictConfig, OmegaConf
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class ControlNetConditioningEmbedding(nn.Module):
+ """
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
+ model) to encode image-space conditions ... into feature maps ..."
+ """
+
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+
+ self.conv_out = zero_module(
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning):
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ return embedding
+
+
+class LayerControlNet(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0.0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ context_dim=None,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ transformer_depth=1,
+ use_linear=False,
+ use_checkpoint=False,
+ temporal_conv=False,
+ tempspatial_aware=False,
+ temporal_attention=True,
+ use_relative_position=True,
+ use_causal_attention=False,
+ temporal_length=None,
+ addition_attention=False,
+ temporal_selfatt_only=True,
+ image_cross_attention=False,
+ image_cross_attention_scale_learnable=False,
+ default_fps=4,
+ fps_condition=False,
+ ignore_noisy_latents=True,
+ condition_channels={},
+ control_injection_mode='add',
+ use_vae_for_trajectory=False,
+ ):
+ super().__init__()
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.temporal_attention = temporal_attention
+ time_embed_dim = model_channels * 4
+ self.use_checkpoint = use_checkpoint
+ temporal_self_att_only = True
+ self.addition_attention = addition_attention
+ self.temporal_length = temporal_length
+ self.image_cross_attention = image_cross_attention
+ self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable
+ self.default_fps = default_fps
+ self.fps_condition = fps_condition
+ self.ignore_noisy_latents = ignore_noisy_latents
+ assert len(condition_channels) > 0, 'Condition types must be specified'
+ self.condition_channels = condition_channels
+ self.control_injection_mode = control_injection_mode
+ self.use_vae_for_trajectory = use_vae_for_trajectory
+
+ ## Time embedding blocks
+ self.time_proj = Timesteps(model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.time_embed = TimestepEmbedding(model_channels, time_embed_dim)
+
+ if fps_condition:
+ self.fps_embedding = TimestepEmbedding(model_channels, time_embed_dim)
+ nn.init.zeros_(self.fps_embedding.linear_2.weight)
+ nn.init.zeros_(self.fps_embedding.linear_2.bias)
+
+ if "motion_score" in condition_channels:
+ if control_injection_mode == 'add':
+ self.motion_embedding = zero_module(conv_nd(dims, condition_channels["motion_score"], model_channels, 3, padding=1))
+ elif control_injection_mode == 'concat':
+ self.motion_embedding = zero_module(conv_nd(dims, condition_channels["motion_score"], condition_channels["motion_score"], 3, padding=1))
+ else:
+ raise ValueError(f"control_injection_mode {control_injection_mode} is not supported, use 'add' or 'concat'")
+ if "sketch" in condition_channels:
+ if control_injection_mode == 'add':
+ self.sketch_embedding = zero_module(conv_nd(dims, condition_channels["sketch"], model_channels, 3, padding=1))
+ elif control_injection_mode == 'concat':
+ self.sketch_embedding = zero_module(conv_nd(dims, condition_channels["sketch"], condition_channels["sketch"], 3, padding=1))
+ else:
+ raise ValueError(f"control_injection_mode {control_injection_mode} is not supported, use 'add' or 'concat'")
+ if "trajectory" in condition_channels:
+ if control_injection_mode == 'add':
+ if use_vae_for_trajectory:
+ self.trajectory_embedding = zero_module(conv_nd(dims, condition_channels["trajectory"], model_channels, 3, padding=1))
+ else:
+ self.trajectory_embedding = ControlNetConditioningEmbedding(model_channels, condition_channels["trajectory"])
+ elif control_injection_mode == 'concat':
+ if use_vae_for_trajectory:
+ self.trajectory_embedding = zero_module(conv_nd(dims, condition_channels["trajectory"], condition_channels["trajectory"], 3, padding=1))
+ else:
+ self.trajectory_embedding = ControlNetConditioningEmbedding(condition_channels["trajectory"], condition_channels["trajectory"])
+ else:
+ raise ValueError(f"control_injection_mode {control_injection_mode} is not supported, use 'add' or 'concat'")
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
+ ]
+ )
+
+ if self.addition_attention:
+ self.init_attn = TimestepEmbedSequential(
+ TemporalTransformer(
+ model_channels,
+ n_heads=8,
+ d_head=num_head_channels,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only,
+ causal_attention=False, relative_position=use_relative_position,
+ temporal_length=temporal_length
+ )
+ )
+
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(ch, time_embed_dim, dropout,
+ out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ layers.append(
+ SpatialTransformer(ch, num_heads, dim_head,
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
+ use_checkpoint=use_checkpoint, disable_self_attn=False,
+ video_length=temporal_length, image_cross_attention=self.image_cross_attention,
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
+ )
+ )
+ if self.temporal_attention:
+ layers.append(
+ TemporalTransformer(ch, num_heads, dim_head,
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
+ use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
+ causal_attention=use_causal_attention, relative_position=use_relative_position,
+ temporal_length=temporal_length
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+
+ if level < len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(ch, time_embed_dim, dropout,
+ out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True
+ )
+ if resblock_updown
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ )
+ ch = out_ch
+ ds *= 2
+
+ def forward(
+ self,
+ noisy_latents,
+ timesteps,
+ context_text,
+ context_img=None,
+ fps=None,
+ layer_latents=None, # [b, n_layer, t, c, h, w]
+ layer_latent_mask=None, # [b, n_layer, t, 1, h, w]
+ motion_scores=None, # [b, n_layer]
+ sketch=None, # [b, n_layer, t, c, h, w]
+ trajectory=None, # [b, n_layer, t, c, h, w]
+ ):
+ if self.ignore_noisy_latents:
+ noisy_latents_shape = list(noisy_latents.shape)
+ noisy_latents_shape[1] = 0
+ noisy_latents = torch.zeros(noisy_latents_shape, device=noisy_latents.device, dtype=noisy_latents.dtype)
+
+ b, _, t, height, width = noisy_latents.shape
+ n_layer = layer_latents.shape[1]
+ t_emb = self.time_proj(timesteps).type(noisy_latents.dtype)
+ emb = self.time_embed(t_emb)
+
+ ## repeat t times for context [(b t) 77 768] & time embedding
+ ## check if we use per-frame image conditioning
+ if context_img is not None: ## decompose context into text and image
+ context_text = repeat(context_text, 'b l c -> (b n t) l c', n=n_layer, t=t)
+ context_img = repeat(context_img, 'b tl c -> b n tl c', n=n_layer)
+ context_img = rearrange(context_img, 'b n (t l) c -> (b n t) l c', t=t)
+ context = torch.cat([context_text, context_img], dim=1)
+ else:
+ context = repeat(context_text, 'b l c -> (b n t) l c', n=n_layer, t=t)
+ emb = repeat(emb, 'b c -> (b n t) c', n=n_layer, t=t)
+
+ ## always in shape (b n t) c h w, except for temporal layer
+ noisy_latents = repeat(noisy_latents, 'b c t h w -> (b n t) c h w', n=n_layer)
+
+ ## combine emb
+ if self.fps_condition:
+ if fps is None:
+ fps = torch.tensor(
+ [self.default_fs] * b, dtype=torch.long, device=noisy_latents.device)
+ fps_emb = self.time_proj(fps).type(noisy_latents.dtype)
+
+ fps_embed = self.fps_embedding(fps_emb)
+ fps_embed = repeat(fps_embed, 'b c -> (b n t) c', n=n_layer, t=t)
+ emb = emb + fps_embed
+
+ ## process conditions
+ layer_condition = torch.cat([layer_latents, layer_latent_mask], dim=3)
+ layer_condition = rearrange(layer_condition, 'b n t c h w -> (b n t) c h w')
+ h = torch.cat([noisy_latents, layer_condition], dim=1)
+
+ if "motion_score" in self.condition_channels:
+ motion_condition = repeat(motion_scores, 'b n -> b n t 1 h w', t=t, h=height, w=width)
+ motion_condition = torch.cat([motion_condition, layer_latent_mask], dim=3)
+ motion_condition = rearrange(motion_condition, 'b n t c h w -> (b n t) c h w')
+ motion_condition = self.motion_embedding(motion_condition)
+ if self.control_injection_mode == 'concat':
+ h = torch.cat([h, motion_condition], dim=1)
+
+ if "sketch" in self.condition_channels:
+ sketch_condition = rearrange(sketch, 'b n t c h w -> (b n t) c h w')
+ sketch_condition = self.sketch_embedding(sketch_condition)
+ if self.control_injection_mode == 'concat':
+ h = torch.cat([h, sketch_condition], dim=1)
+
+ if "trajectory" in self.condition_channels:
+ traj_condition = rearrange(trajectory, 'b n t c h w -> (b n t) c h w')
+ traj_condition = self.trajectory_embedding(traj_condition)
+ if self.control_injection_mode == 'concat':
+ h = torch.cat([h, traj_condition], dim=1)
+
+ layer_features = []
+ for id, module in enumerate(self.input_blocks):
+ h = module(h, emb, context=context, batch_size=b*n_layer)
+ if id == 0:
+ if self.control_injection_mode == 'add':
+ if "motion_score" in self.condition_channels:
+ h = h + motion_condition
+ if "sketch" in self.condition_channels:
+ h = h + sketch_condition
+ if "trajectory" in self.condition_channels:
+ h = h + traj_condition
+ if self.addition_attention:
+ h = self.init_attn(h, emb, context=context, batch_size=b*n_layer)
+ if SpatialTransformer in [type(m) for m in module]:
+ layer_features.append(rearrange(h, '(b n t) c h w -> b n t c h w', b=b, n=n_layer))
+
+ return layer_features
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, layer_controlnet_additional_kwargs={}, **kwargs):
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ variant = kwargs.pop("variant", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ # Load config if we don't provide a configuration
+ config_path = pretrained_model_name_or_path
+
+ user_agent = {
+ "diffusers": __version__,
+ "file_type": "model",
+ "framework": "pytorch",
+ }
+
+ # load config
+ config, unused_kwargs, commit_hash = cls.load_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ return_commit_hash=True,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ **kwargs,
+ )
+
+ for key, value in layer_controlnet_additional_kwargs.items():
+ if isinstance(value, (ListConfig, DictConfig)):
+ config[key] = OmegaConf.to_container(value, resolve=True)
+ else:
+ config[key] = value
+
+ # load model
+ model_file = None
+ if use_safetensors:
+ try:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ commit_hash=commit_hash,
+ )
+
+ except IOError as e:
+ logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
+ if not allow_pickle:
+ raise
+ logger.warning(
+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
+ )
+
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ commit_hash=commit_hash,
+ )
+
+ model = cls.from_config(config, **unused_kwargs)
+ state_dict = load_state_dict(model_file, variant)
+
+ if state_dict['input_blocks.0.0.weight'].shape[1] != model.input_blocks[0][0].weight.shape[1]:
+ state_dict.pop('input_blocks.0.0.weight')
+
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
+ print(f"LayerControlNet loaded from {model_file} with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys.")
+ return model
\ No newline at end of file
diff --git a/lvdm/models/unet.py b/lvdm/models/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..669fc4a288d601d4fb1c1e59e899cce02f703a92
--- /dev/null
+++ b/lvdm/models/unet.py
@@ -0,0 +1,731 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
+
+from dataclasses import dataclass
+import os
+from os import PathLike
+from typing import List, Mapping, Optional, Tuple, Union
+from functools import partial
+from abc import abstractmethod
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+import torch.nn.functional as F
+from einops import rearrange, repeat
+
+from diffusers import __version__
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models import ModelMixin
+from diffusers.utils import (
+ SAFETENSORS_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+ BaseOutput,
+ logging,
+ _get_model_file,
+ _add_variant
+)
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from diffusers.models.model_loading_utils import load_state_dict
+
+from ..common import checkpoint
+from ..basics import avg_pool_nd, conv_nd, zero_module
+from ..modules.attention import SpatialTransformer, TemporalTransformer, CrossAttention
+from omegaconf import ListConfig, DictConfig, OmegaConf
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None, batch_size=None, **kwargs):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb, batch_size=batch_size)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context, **kwargs)
+ elif isinstance(layer, TemporalTransformer):
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size)
+ x = layer(x, context)
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
+ else:
+ x = layer(x)
+ return x
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest')
+ else:
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ :param use_temporal_conv: if True, use the temporal convolution.
+ :param use_image_dataset: if True, the temporal parameters will not be optimized.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ use_conv=False,
+ up=False,
+ down=False,
+ use_temporal_conv=False,
+ tempspatial_aware=False
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.use_temporal_conv = use_temporal_conv
+
+ self.in_layers = nn.Sequential(
+ nn.GroupNorm(32, channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ nn.GroupNorm(32, self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ if self.use_temporal_conv:
+ self.temopral_conv = TemporalConvBlock(
+ self.out_channels,
+ self.out_channels,
+ dropout=0.1,
+ spatial_aware=tempspatial_aware
+ )
+
+ def forward(self, x, emb, batch_size=None):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ input_tuple = (x, emb)
+ if batch_size:
+ forward_batchsize = partial(self._forward, batch_size=batch_size)
+ return checkpoint(forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint)
+ return checkpoint(self._forward, input_tuple, self.parameters(), self.use_checkpoint)
+
+ def _forward(self, x, emb, batch_size=None):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ h = self.skip_connection(x) + h
+
+ if self.use_temporal_conv and batch_size:
+ h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
+ h = self.temopral_conv(h)
+ h = rearrange(h, 'b c t h w -> (b t) c h w')
+ return h
+
+
+class TemporalConvBlock(nn.Module):
+ """
+ Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py
+ """
+ def __init__(self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False):
+ super(TemporalConvBlock, self).__init__()
+ if out_channels is None:
+ out_channels = in_channels
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
+ th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
+ tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
+ tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
+
+ # conv layers
+ self.conv1 = nn.Sequential(
+ nn.GroupNorm(32, in_channels), nn.SiLU(),
+ nn.Conv3d(in_channels, out_channels, th_kernel_shape, padding=th_padding_shape))
+ self.conv2 = nn.Sequential(
+ nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape))
+ self.conv3 = nn.Sequential(
+ nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv3d(out_channels, in_channels, th_kernel_shape, padding=th_padding_shape))
+ self.conv4 = nn.Sequential(
+ nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape))
+
+ # zero out the last layer params,so the conv block is identity
+ nn.init.zeros_(self.conv4[-1].weight)
+ nn.init.zeros_(self.conv4[-1].bias)
+
+ def forward(self, x):
+ identity = x
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ x = self.conv4(x)
+
+ return identity + x
+
+
+@dataclass
+class UNetModelOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+class UNetModel(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0.0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ context_dim=None,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ transformer_depth=1,
+ use_linear=False,
+ use_checkpoint=False,
+ temporal_conv=False,
+ tempspatial_aware=False,
+ temporal_attention=True,
+ use_relative_position=True,
+ use_causal_attention=False,
+ temporal_length=None,
+ addition_attention=False,
+ temporal_selfatt_only=True,
+ image_cross_attention=False,
+ image_cross_attention_scale_learnable=False,
+ masked_layer_fusion=False,
+ default_fps=4,
+ fps_condition=False,
+ ):
+ super().__init__()
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.temporal_attention = temporal_attention
+ time_embed_dim = model_channels * 4
+ self.use_checkpoint = use_checkpoint
+ temporal_self_att_only = True
+ self.addition_attention = addition_attention
+ self.temporal_length = temporal_length
+ self.image_cross_attention = image_cross_attention
+ self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable
+ self.default_fps = default_fps
+ self.fps_condition = fps_condition
+
+ ## Time embedding blocks
+ self.time_proj = Timesteps(model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.time_embed = TimestepEmbedding(model_channels, time_embed_dim)
+
+ if fps_condition:
+ self.fps_embedding = TimestepEmbedding(model_channels, time_embed_dim)
+ nn.init.zeros_(self.fps_embedding.linear_2.weight)
+ nn.init.zeros_(self.fps_embedding.linear_2.bias)
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
+ ]
+ )
+
+ if self.addition_attention:
+ self.init_attn = TimestepEmbedSequential(
+ TemporalTransformer(
+ model_channels,
+ n_heads=8,
+ d_head=num_head_channels,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only,
+ causal_attention=False, relative_position=use_relative_position,
+ temporal_length=temporal_length
+ )
+ )
+
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(ch, time_embed_dim, dropout,
+ out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ layers.append(
+ SpatialTransformer(ch, num_heads, dim_head,
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
+ use_checkpoint=use_checkpoint, disable_self_attn=False,
+ video_length=temporal_length, image_cross_attention=self.image_cross_attention,
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
+ )
+ )
+ if self.temporal_attention:
+ layers.append(
+ TemporalTransformer(ch, num_heads, dim_head,
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
+ use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
+ causal_attention=use_causal_attention, relative_position=use_relative_position,
+ temporal_length=temporal_length
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(ch, time_embed_dim, dropout,
+ out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True
+ )
+ if resblock_updown
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ layers = [
+ ResBlock(ch, time_embed_dim, dropout,
+ dims=dims, use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv
+ ),
+ SpatialTransformer(ch, num_heads, dim_head,
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
+ use_checkpoint=use_checkpoint, disable_self_attn=False, video_length=temporal_length,
+ image_cross_attention=self.image_cross_attention,image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable
+ )
+ ]
+ if self.temporal_attention:
+ layers.append(
+ TemporalTransformer(ch, num_heads, dim_head,
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
+ use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
+ causal_attention=use_causal_attention, relative_position=use_relative_position,
+ temporal_length=temporal_length
+ )
+ )
+ layers.append(
+ ResBlock(ch, time_embed_dim, dropout,
+ dims=dims, use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv
+ )
+ )
+
+ ## Middle Block
+ self.middle_block = TimestepEmbedSequential(*layers)
+
+ ## Output Block
+ self.output_blocks = nn.ModuleList([])
+
+ self.masked_layer_fusion = masked_layer_fusion
+ if self.masked_layer_fusion:
+ self.masked_layer_fusion_norm_list = nn.ModuleList([])
+ self.masked_layer_fusion_attn_list = nn.ModuleList([])
+ self.masked_layer_fusion_out_list = nn.ModuleList([])
+ self.layer_feature_block_indices = []
+
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ input_channel = ch
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(ch + ich, time_embed_dim, dropout,
+ out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if self.masked_layer_fusion and i < num_res_blocks:
+ self.masked_layer_fusion_norm_list.append(nn.LayerNorm(input_channel))
+ self.masked_layer_fusion_attn_list.append(
+ CrossAttention(
+ query_dim=input_channel,
+ context_dim=ch,
+ dim_head=dim_head,
+ heads=num_heads,
+ use_xformers=False,
+ )
+ )
+ self.masked_layer_fusion_out_list.append(
+ zero_module(conv_nd(dims, input_channel, ch, 3, padding=1))
+ )
+ self.layer_feature_block_indices.append(len(self.output_blocks))
+ layers.append(
+ SpatialTransformer(ch, num_heads, dim_head,
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
+ use_checkpoint=use_checkpoint, disable_self_attn=False, video_length=temporal_length,
+ image_cross_attention=self.image_cross_attention, image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable
+ )
+ )
+ if self.temporal_attention:
+ layers.append(
+ TemporalTransformer(ch, num_heads, dim_head,
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
+ use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
+ causal_attention=use_causal_attention, relative_position=use_relative_position,
+ temporal_length=temporal_length
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(ch, time_embed_dim, dropout,
+ out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+
+ self.out = nn.Sequential(
+ nn.GroupNorm(32, ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+
+ def forward(self, x, timesteps, context_text, context_img=None, controls=None, layer_validity=None, fps=None, **kwargs):
+ b, _, t, _, _ = x.shape
+ t_emb = self.time_proj(timesteps).type(x.dtype)
+ emb = self.time_embed(t_emb)
+
+ ## repeat t times for context [(b t) 77 768] & time embedding
+ ## check if we use per-frame image conditioning
+ if context_img is not None: ## decompose context into text and image
+ context_text = context_text.repeat_interleave(repeats=t, dim=0)
+ context_img = rearrange(context_img, 'b (t l) c -> (b t) l c', t=t)
+ context = torch.cat([context_text, context_img], dim=1)
+ else:
+ context = context_text.repeat_interleave(repeats=t, dim=0)
+ emb = emb.repeat_interleave(repeats=t, dim=0)
+
+ ## always in shape (b t) c h w, except for temporal layer
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
+
+ ## combine emb
+ if self.fps_condition:
+ if fps is None:
+ fps = torch.tensor(
+ [self.default_fs] * b, dtype=torch.long, device=x.device)
+ fps_emb = self.time_proj(fps).type(x.dtype)
+
+ fps_embed = self.fps_embedding(fps_emb)
+ fps_embed = fps_embed.repeat_interleave(repeats=t, dim=0)
+ emb = emb + fps_embed
+
+ h = x.type(self.dtype)
+ hs = []
+ for id, module in enumerate(self.input_blocks):
+ h = module(h, emb, context=context, batch_size=b)
+ if id == 0 and self.addition_attention:
+ h = self.init_attn(h, emb, context=context, batch_size=b)
+ hs.append(h)
+
+ h = self.middle_block(h, emb, context=context, batch_size=b)
+
+ layer_fusion_idx = 0
+ for id, module in enumerate(self.output_blocks):
+ skip = hs.pop()
+ if controls is not None and len(controls) > 0 and id in self.layer_feature_block_indices:
+ layer_features = controls.pop()
+ feature_h, feature_w = layer_features.shape[-2:]
+ layer_features = rearrange(layer_features, 'b n t c h w -> (b t h w) n c')
+ frame_features = rearrange(h, '(b t) c h w -> (b t) (h w) c', b=b)
+ frame_features = self.masked_layer_fusion_norm_list[layer_fusion_idx](frame_features)
+ frame_features = rearrange(frame_features, '(b t) (h w) c -> (b t h w) 1 c', b=b, t=t, h=feature_h, w=feature_w)
+ fused_features = self.masked_layer_fusion_attn_list[layer_fusion_idx](
+ frame_features,
+ layer_features,
+ mask=repeat(layer_validity, "b n -> (b t h w) 1 n", t=t, h=feature_h, w=feature_w)
+ )
+ fused_features = rearrange(fused_features, '(b t h w) 1 c -> (b t) c h w', b=b, t=t, h=feature_h, w=feature_w)
+ fused_features = self.masked_layer_fusion_out_list[layer_fusion_idx](fused_features)
+ skip += fused_features
+ layer_fusion_idx += 1
+ h = torch.cat([h, skip], dim=1)
+ h = module(h, emb, context=context, batch_size=b)
+ h = h.type(x.dtype)
+ y = self.out(h)
+
+ # reshape back to (b c t h w)
+ y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
+ return UNetModelOutput(sample=y)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, unet_additional_kwargs={}, **kwargs):
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ variant = kwargs.pop("variant", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ # Load config if we don't provide a configuration
+ config_path = pretrained_model_name_or_path
+
+ user_agent = {
+ "diffusers": __version__,
+ "file_type": "model",
+ "framework": "pytorch",
+ }
+
+ # load config
+ config, unused_kwargs, commit_hash = cls.load_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ return_commit_hash=True,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ **kwargs,
+ )
+
+ for key, value in unet_additional_kwargs.items():
+ if isinstance(value, (ListConfig, DictConfig)):
+ config[key] = OmegaConf.to_container(value, resolve=True)
+ else:
+ config[key] = value
+
+ # load model
+ model_file = None
+ if use_safetensors:
+ try:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ commit_hash=commit_hash,
+ )
+
+ except IOError as e:
+ logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
+ if not allow_pickle:
+ raise
+ logger.warning(
+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
+ )
+
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ commit_hash=commit_hash,
+ )
+
+ model = cls.from_config(config, **unused_kwargs)
+ state_dict = load_state_dict(model_file, variant)
+
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
+ print(f"UNetModel loaded from {model_file} with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys.")
+ return model
\ No newline at end of file
diff --git a/lvdm/modules/ae_dualref_modules.py b/lvdm/modules/ae_dualref_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..3794142c2ad53beca17a26ed83975262bcea04a7
--- /dev/null
+++ b/lvdm/modules/ae_dualref_modules.py
@@ -0,0 +1,1179 @@
+#### https://github.com/Stability-AI/generative-models
+from einops import rearrange, repeat
+import logging
+from typing import Any, Callable, Optional, Iterable, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from packaging import version
+logpy = logging.getLogger(__name__)
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILABLE = True
+except:
+ XFORMERS_IS_AVAILABLE = False
+ logpy.warning("no module 'xformers'. Processing without...")
+
+from .attention_svd import LinearAttention, MemoryEfficientCrossAttention
+from diffusers.models.autoencoders.vae import DecoderOutput
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ b, c, h, w = q.shape
+ q, k, v = map(
+ lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
+ )
+ h_ = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v
+ ) # scale is dim ** -0.5 per default
+ # compute attention
+
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
+
+ def forward(self, x, **kwargs):
+ h_ = x
+ h_ = self.attention(h_)
+ h_ = self.proj_out(h_)
+ return x + h_
+
+
+class MemoryEfficientAttnBlock(nn.Module):
+ """
+ Uses xformers efficient implementation,
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ Note: this is a single-head self-attention operation
+ """
+
+ #
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.attention_op: Optional[Any] = None
+
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ B, C, H, W = q.shape
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
+
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(B, t.shape[1], 1, C)
+ .permute(0, 2, 1, 3)
+ .reshape(B * 1, t.shape[1], C)
+ .contiguous(),
+ (q, k, v),
+ )
+ out = xformers.ops.memory_efficient_attention(
+ q, k, v, attn_bias=None, op=self.attention_op
+ )
+
+ out = (
+ out.unsqueeze(0)
+ .reshape(B, 1, out.shape[1], C)
+ .permute(0, 2, 1, 3)
+ .reshape(B, out.shape[1], C)
+ )
+ return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
+
+ def forward(self, x, **kwargs):
+ h_ = x
+ h_ = self.attention(h_)
+ h_ = self.proj_out(h_)
+ return x + h_
+
+
+class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+ def forward(self, x, context=None, mask=None, **unused_kwargs):
+ b, c, h, w = x.shape
+ x = rearrange(x, "b c h w -> b (h w) c")
+ out = super().forward(x, context=context, mask=mask)
+ out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
+ return x + out
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+ assert attn_type in [
+ "vanilla",
+ "vanilla-xformers",
+ "memory-efficient-cross-attn",
+ "linear",
+ "none",
+ "memory-efficient-cross-attn-fusion",
+ ], f"attn_type {attn_type} unknown"
+ if (
+ version.parse(torch.__version__) < version.parse("2.0.0")
+ and attn_type != "none"
+ ):
+ assert XFORMERS_IS_AVAILABLE, (
+ f"We do not support vanilla attention in {torch.__version__} anymore, "
+ f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
+ )
+ # attn_type = "vanilla-xformers"
+ logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return AttnBlock(in_channels)
+ elif attn_type == "vanilla-xformers":
+ logpy.info(
+ f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
+ )
+ return MemoryEfficientAttnBlock(in_channels)
+ elif attn_type == "memory-efficient-cross-attn":
+ attn_kwargs["query_dim"] = in_channels
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+ elif attn_type == "memory-efficient-cross-attn-fusion":
+ attn_kwargs["query_dim"] = in_channels
+ return MemoryEfficientCrossAttentionWrapperFusion(**attn_kwargs)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+class MemoryEfficientCrossAttentionWrapperFusion(MemoryEfficientCrossAttention):
+ # print('x.shape: ',x.shape, 'context.shape: ',context.shape) ##torch.Size([8, 128, 256, 256]) torch.Size([1, 128, 2, 256, 256])
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0, **kwargs):
+ super().__init__(query_dim, context_dim, heads, dim_head, dropout, **kwargs)
+ self.norm = Normalize(query_dim)
+ nn.init.zeros_(self.to_out[0].weight)
+ nn.init.zeros_(self.to_out[0].bias)
+
+ def forward(self, x, context=None, mask=None):
+ if self.training:
+ return checkpoint(self._forward, x, context, mask, use_reentrant=False)
+ else:
+ return self._forward(x, context, mask)
+
+ def _forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ ):
+ bt, c, h, w = x.shape
+ h_ = self.norm(x)
+ h_ = rearrange(h_, "b c h w -> b (h w) c")
+ q = self.to_q(h_)
+
+
+ b, c, l, h, w = context.shape
+ context = rearrange(context, "b c l h w -> (b l) (h w) c")
+ k = self.to_k(context)
+ v = self.to_v(context)
+ k = rearrange(k, "(b l) d c -> b l d c", l=l)
+ k = torch.cat([k[:, [0] * (bt//b)], k[:, [1]*(bt//b)]], dim=2)
+ k = rearrange(k, "b l d c -> (b l) d c")
+
+ v = rearrange(v, "(b l) d c -> b l d c", l=l)
+ v = torch.cat([v[:, [0] * (bt//b)], v[:, [1]*(bt//b)]], dim=2)
+ v = rearrange(v, "b l d c -> (b l) d c")
+
+
+ b, _, _ = q.shape ##actually bt
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ if version.parse(xformers.__version__) >= version.parse("0.0.21"):
+ # NOTE: workaround for
+ # https://github.com/facebookresearch/xformers/issues/845
+ max_bs = 32768
+ N = q.shape[0]
+ n_batches = math.ceil(N / max_bs)
+ out = list()
+ for i_batch in range(n_batches):
+ batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
+ out.append(
+ xformers.ops.memory_efficient_attention(
+ q[batch],
+ k[batch],
+ v[batch],
+ attn_bias=None,
+ op=self.attention_op,
+ )
+ )
+ out = torch.cat(out, 0)
+ else:
+ out = xformers.ops.memory_efficient_attention(
+ q, k, v, attn_bias=None, op=self.attention_op
+ )
+
+ # TODO: Use this directly in the attention operation, as a bias
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ out = self.to_out(out)
+ out = rearrange(out, "bt (h w) c -> bt c h w", h=h, w=w, c=c)
+ return x + out
+
+class Combiner(nn.Module):
+ def __init__(self, ch) -> None:
+ super().__init__()
+ self.conv = nn.Conv2d(ch,ch,1,padding=0)
+
+ nn.init.zeros_(self.conv.weight)
+ nn.init.zeros_(self.conv.bias)
+
+ def forward(self, x, context):
+ if self.training:
+ return checkpoint(self._forward, x, context, use_reentrant=False)
+ else:
+ return self._forward(x, context)
+
+ def _forward(self, x, context):
+ ## x: b c h w, context: b c 2 h w
+ b, c, l, h, w = context.shape
+ bt, c, h, w = x.shape
+ context = rearrange(context, "b c l h w -> (b l) c h w")
+ context = self.conv(context)
+ context = rearrange(context, "(b l) c h w -> b c l h w", l=l)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=bt//b)
+ x[:,:,0] = x[:,:,0] + context[:,:,0]
+ x[:,:,-1] = x[:,:,-1] + context[:,:,1]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ return x
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ tanh_out=False,
+ use_linear_attn=False,
+ attn_type="vanilla-xformers",
+ attn_level=[2,3],
+ **ignorekwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+ self.attn_level = attn_level
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ logpy.info(
+ "Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)
+ )
+ )
+
+ make_attn_cls = self._make_attn()
+ make_resblock_cls = self._make_resblock()
+ make_conv_cls = self._make_conv()
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
+ )
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
+ self.mid.block_2 = make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ self.attn_refinement = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn_cls(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ if i_level in self.attn_level:
+ self.attn_refinement.insert(0, make_attn_cls(block_in, attn_type='memory-efficient-cross-attn-fusion', attn_kwargs={}))
+ else:
+ self.attn_refinement.insert(0, Combiner(block_in))
+ # end
+ self.norm_out = Normalize(block_in)
+ self.attn_refinement.append(Combiner(block_in))
+ self.conv_out = make_conv_cls(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def _make_attn(self) -> Callable:
+ return make_attn
+
+ def _make_resblock(self) -> Callable:
+ return ResnetBlock
+
+ def _make_conv(self) -> Callable:
+ return torch.nn.Conv2d
+
+ def get_last_layer(self, **kwargs):
+ return self.conv_out.weight
+
+ def forward(self, z, ref_context=None, **kwargs):
+ ## ref_context: b c 2 h w, 2 means starting and ending frame
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb, **kwargs)
+ h = self.mid.attn_1(h, **kwargs)
+ h = self.mid.block_2(h, temb, **kwargs)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, **kwargs)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, **kwargs)
+ if ref_context:
+ h = self.attn_refinement[i_level](x=h, context=ref_context[i_level])
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ if ref_context:
+ # print(h.shape, ref_context[i_level].shape) #torch.Size([8, 128, 256, 256]) torch.Size([1, 128, 2, 256, 256])
+ h = self.attn_refinement[-1](x=h, context=ref_context[-1])
+ h = self.conv_out(h, **kwargs)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return DecoderOutput(sample=h)
+
+#####
+
+
+from abc import abstractmethod
+from diffusers.models.embeddings import Timesteps
+
+from torch.utils.checkpoint import checkpoint
+from ..basics import (
+ zero_module,
+ conv_nd,
+ linear,
+ normalization,
+)
+from ..models.unet import Upsample, Downsample
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x: torch.Tensor, emb: torch.Tensor):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ emb_channels: int,
+ dropout: float,
+ out_channels: Optional[int] = None,
+ use_conv: bool = False,
+ use_scale_shift_norm: bool = False,
+ dims: int = 2,
+ use_checkpoint: bool = False,
+ up: bool = False,
+ down: bool = False,
+ kernel_size: int = 3,
+ exchange_temb_dims: bool = False,
+ skip_t_emb: bool = False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.exchange_temb_dims = exchange_temb_dims
+
+ if isinstance(kernel_size, Iterable):
+ padding = [k // 2 for k in kernel_size]
+ else:
+ padding = kernel_size // 2
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.skip_t_emb = skip_t_emb
+ self.emb_out_channels = (
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels
+ )
+ if self.skip_t_emb:
+ # print(f"Skipping timestep embedding in {self.__class__.__name__}")
+ assert not self.use_scale_shift_norm
+ self.emb_layers = None
+ self.exchange_temb_dims = False
+ else:
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ self.emb_out_channels,
+ ),
+ )
+
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(
+ dims,
+ self.out_channels,
+ self.out_channels,
+ kernel_size,
+ padding=padding,
+ )
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, kernel_size, padding=padding
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ if self.use_checkpoint:
+ return checkpoint(self._forward, x, emb, use_reentrant=False)
+ else:
+ return self._forward(x, emb)
+
+ def _forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+
+ if self.skip_t_emb:
+ emb_out = torch.zeros_like(h)
+ else:
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ if self.exchange_temb_dims:
+ emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+#####
+
+#####
+from lvdm.modules.attention_svd import *
+class VideoTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention,
+ "softmax-xformers": MemoryEfficientCrossAttention,
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ timesteps=None,
+ ff_in=False,
+ inner_dim=None,
+ attn_mode="softmax",
+ disable_self_attn=False,
+ disable_temporal_crossattention=False,
+ switch_temporal_ca_to_sa=False,
+ ):
+ super().__init__()
+
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+
+ self.ff_in = ff_in or inner_dim is not None
+ if inner_dim is None:
+ inner_dim = dim
+
+ assert int(n_heads * d_head) == inner_dim
+
+ self.is_res = inner_dim == dim
+
+ if self.ff_in:
+ self.norm_in = nn.LayerNorm(dim)
+ self.ff_in = FeedForward(
+ dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff
+ )
+
+ self.timesteps = timesteps
+ self.disable_self_attn = disable_self_attn
+ if self.disable_self_attn:
+ self.attn1 = attn_cls(
+ query_dim=inner_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ context_dim=context_dim,
+ dropout=dropout,
+ ) # is a cross-attention
+ else:
+ self.attn1 = attn_cls(
+ query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is a self-attention
+
+ self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)
+
+ if disable_temporal_crossattention:
+ if switch_temporal_ca_to_sa:
+ raise ValueError
+ else:
+ self.attn2 = None
+ else:
+ self.norm2 = nn.LayerNorm(inner_dim)
+ if switch_temporal_ca_to_sa:
+ self.attn2 = attn_cls(
+ query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is a self-attention
+ else:
+ self.attn2 = attn_cls(
+ query_dim=inner_dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ ) # is self-attn if context is none
+
+ self.norm1 = nn.LayerNorm(inner_dim)
+ self.norm3 = nn.LayerNorm(inner_dim)
+ self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
+
+ self.checkpoint = checkpoint
+ if self.checkpoint:
+ print(f"====>{self.__class__.__name__} is using checkpointing")
+ else:
+ print(f"====>{self.__class__.__name__} is NOT using checkpointing")
+
+ def forward(
+ self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None
+ ) -> torch.Tensor:
+ if self.checkpoint:
+ return checkpoint(self._forward, x, context, timesteps, use_reentrant=False)
+ else:
+ return self._forward(x, context, timesteps=timesteps)
+
+ def _forward(self, x, context=None, timesteps=None):
+ assert self.timesteps or timesteps
+ assert not (self.timesteps and timesteps) or self.timesteps == timesteps
+ timesteps = self.timesteps or timesteps
+ B, S, C = x.shape
+ x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
+
+ if self.ff_in:
+ x_skip = x
+ x = self.ff_in(self.norm_in(x))
+ if self.is_res:
+ x += x_skip
+
+ if self.disable_self_attn:
+ x = self.attn1(self.norm1(x), context=context) + x
+ else:
+ x = self.attn1(self.norm1(x)) + x
+
+ if self.attn2 is not None:
+ if self.switch_temporal_ca_to_sa:
+ x = self.attn2(self.norm2(x)) + x
+ else:
+ x = self.attn2(self.norm2(x), context=context) + x
+ x_skip = x
+ x = self.ff(self.norm3(x))
+ if self.is_res:
+ x += x_skip
+
+ x = rearrange(
+ x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
+ )
+ return x
+
+ def get_last_layer(self):
+ return self.ff.net[-1].weight
+
+#####
+
+#####
+import functools
+def partialclass(cls, *args, **kwargs):
+ class NewCls(cls):
+ __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
+
+ return NewCls
+######
+
+class VideoResBlock(ResnetBlock):
+ def __init__(
+ self,
+ out_channels,
+ *args,
+ dropout=0.0,
+ video_kernel_size=3,
+ alpha=0.0,
+ merge_strategy="learned",
+ **kwargs,
+ ):
+ super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
+ if video_kernel_size is None:
+ video_kernel_size = [3, 1, 1]
+ self.time_stack = ResBlock(
+ channels=out_channels,
+ emb_channels=0,
+ dropout=dropout,
+ dims=3,
+ use_scale_shift_norm=False,
+ use_conv=False,
+ up=False,
+ down=False,
+ kernel_size=video_kernel_size,
+ use_checkpoint=True,
+ skip_t_emb=True,
+ )
+
+ self.merge_strategy = merge_strategy
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif self.merge_strategy == "learned":
+ self.register_parameter(
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
+ )
+ else:
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
+
+ def get_alpha(self, bs):
+ if self.merge_strategy == "fixed":
+ return self.mix_factor
+ elif self.merge_strategy == "learned":
+ return torch.sigmoid(self.mix_factor)
+ else:
+ raise NotImplementedError()
+
+ def forward(self, x, temb, skip_video=False, timesteps=None):
+ if timesteps is None:
+ timesteps = self.timesteps
+
+ b, c, h, w = x.shape
+
+ x = super().forward(x, temb)
+
+ if not skip_video:
+ x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
+
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
+
+ x = self.time_stack(x, temb)
+
+ alpha = self.get_alpha(bs=b // timesteps)
+ x = alpha * x + (1.0 - alpha) * x_mix
+
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ return x
+
+
+class AE3DConv(torch.nn.Conv2d):
+ def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ if isinstance(video_kernel_size, Iterable):
+ padding = [int(k // 2) for k in video_kernel_size]
+ else:
+ padding = int(video_kernel_size // 2)
+
+ self.time_mix_conv = torch.nn.Conv3d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=video_kernel_size,
+ padding=padding,
+ )
+
+ def forward(self, input, timesteps, skip_video=False):
+ x = super().forward(input)
+ if skip_video:
+ return x
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
+ x = self.time_mix_conv(x)
+ return rearrange(x, "b c t h w -> (b t) c h w")
+
+
+class VideoBlock(AttnBlock):
+ def __init__(
+ self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
+ ):
+ super().__init__(in_channels)
+ # no context, single headed, as in base class
+ self.time_mix_block = VideoTransformerBlock(
+ dim=in_channels,
+ n_heads=1,
+ d_head=in_channels,
+ checkpoint=True,
+ ff_in=True,
+ attn_mode="softmax",
+ )
+
+ time_embed_dim = self.in_channels * 4
+ self.timestep_embedding = Timesteps(self.in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.video_time_embed = torch.nn.Sequential(
+ torch.nn.Linear(self.in_channels, time_embed_dim),
+ torch.nn.SiLU(),
+ torch.nn.Linear(time_embed_dim, self.in_channels),
+ )
+
+ self.merge_strategy = merge_strategy
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif self.merge_strategy == "learned":
+ self.register_parameter(
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
+ )
+ else:
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
+
+ def forward(self, x, timesteps, skip_video=False):
+ if skip_video:
+ return super().forward(x)
+
+ x_in = x
+ x = self.attention(x)
+ h, w = x.shape[2:]
+ x = rearrange(x, "b c h w -> b (h w) c")
+
+ x_mix = x
+ num_frames = torch.arange(timesteps, device=x.device)
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
+ num_frames = rearrange(num_frames, "b t -> (b t)")
+ t_emb = self.timestep_embedding(num_frames)
+ emb = self.video_time_embed(t_emb) # b, n_channels
+ emb = emb[:, None, :]
+ x_mix = x_mix + emb
+
+ alpha = self.get_alpha()
+ x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
+ x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
+
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+ x = self.proj_out(x)
+
+ return x_in + x
+
+ def get_alpha(
+ self,
+ ):
+ if self.merge_strategy == "fixed":
+ return self.mix_factor
+ elif self.merge_strategy == "learned":
+ return torch.sigmoid(self.mix_factor)
+ else:
+ raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
+
+
+class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
+ def __init__(
+ self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
+ ):
+ super().__init__(in_channels)
+ # no context, single headed, as in base class
+ self.time_mix_block = VideoTransformerBlock(
+ dim=in_channels,
+ n_heads=1,
+ d_head=in_channels,
+ checkpoint=True,
+ ff_in=True,
+ attn_mode="softmax-xformers",
+ )
+
+ time_embed_dim = self.in_channels * 4
+ self.timestep_embedding = Timesteps(self.in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.video_time_embed = torch.nn.Sequential(
+ torch.nn.Linear(self.in_channels, time_embed_dim),
+ torch.nn.SiLU(),
+ torch.nn.Linear(time_embed_dim, self.in_channels),
+ )
+
+ self.merge_strategy = merge_strategy
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif self.merge_strategy == "learned":
+ self.register_parameter(
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
+ )
+ else:
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
+
+ def forward(self, x, timesteps, skip_time_block=False):
+ if skip_time_block:
+ return super().forward(x)
+
+ x_in = x
+ x = self.attention(x)
+ h, w = x.shape[2:]
+ x = rearrange(x, "b c h w -> b (h w) c")
+
+ x_mix = x
+ num_frames = torch.arange(timesteps, device=x.device)
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
+ num_frames = rearrange(num_frames, "b t -> (b t)")
+ t_emb = self.timestep_embedding(num_frames)
+ emb = self.video_time_embed(t_emb) # b, n_channels
+ emb = emb[:, None, :]
+ x_mix = x_mix + emb
+
+ alpha = self.get_alpha()
+ x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
+ x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
+
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+ x = self.proj_out(x)
+
+ return x_in + x
+
+ def get_alpha(
+ self,
+ ):
+ if self.merge_strategy == "fixed":
+ return self.mix_factor
+ elif self.merge_strategy == "learned":
+ return torch.sigmoid(self.mix_factor)
+ else:
+ raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
+
+
+def make_time_attn(
+ in_channels,
+ attn_type="vanilla",
+ attn_kwargs=None,
+ alpha: float = 0,
+ merge_strategy: str = "learned",
+):
+ assert attn_type in [
+ "vanilla",
+ "vanilla-xformers",
+ ], f"attn_type {attn_type} not supported for spatio-temporal attention"
+ print(
+ f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
+ )
+ if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
+ print(
+ f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
+ )
+ attn_type = "vanilla"
+
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return partialclass(
+ VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
+ )
+ elif attn_type == "vanilla-xformers":
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+ return partialclass(
+ MemoryEfficientVideoBlock,
+ in_channels,
+ alpha=alpha,
+ merge_strategy=merge_strategy,
+ )
+ else:
+ return NotImplementedError()
+
+
+class Conv2DWrapper(torch.nn.Conv2d):
+ def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
+ return super().forward(input)
+
+
+class VideoDecoder(Decoder):
+ available_time_modes = ["all", "conv-only", "attn-only"]
+
+ def __init__(
+ self,
+ *args,
+ video_kernel_size: Union[int, list] = [3,1,1],
+ alpha: float = 0.0,
+ merge_strategy: str = "learned",
+ time_mode: str = "conv-only",
+ **kwargs,
+ ):
+ self.video_kernel_size = video_kernel_size
+ self.alpha = alpha
+ self.merge_strategy = merge_strategy
+ self.time_mode = time_mode
+ assert (
+ self.time_mode in self.available_time_modes
+ ), f"time_mode parameter has to be in {self.available_time_modes}"
+ super().__init__(*args, **kwargs)
+
+ def get_last_layer(self, skip_time_mix=False, **kwargs):
+ if self.time_mode == "attn-only":
+ raise NotImplementedError("TODO")
+ else:
+ return (
+ self.conv_out.time_mix_conv.weight
+ if not skip_time_mix
+ else self.conv_out.weight
+ )
+
+ def _make_attn(self) -> Callable:
+ if self.time_mode not in ["conv-only", "only-last-conv"]:
+ return partialclass(
+ make_time_attn,
+ alpha=self.alpha,
+ merge_strategy=self.merge_strategy,
+ )
+ else:
+ return super()._make_attn()
+
+ def _make_conv(self) -> Callable:
+ if self.time_mode != "attn-only":
+ return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
+ else:
+ return Conv2DWrapper
+
+ def _make_resblock(self) -> Callable:
+ if self.time_mode not in ["attn-only", "only-last-conv"]:
+ return partialclass(
+ VideoResBlock,
+ video_kernel_size=self.video_kernel_size,
+ alpha=self.alpha,
+ merge_strategy=self.merge_strategy,
+ )
+ else:
+ return super()._make_resblock()
\ No newline at end of file
diff --git a/lvdm/modules/ae_modules.py b/lvdm/modules/ae_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..27d5969cdd9ffe676f201e868c66270bab40b6c7
--- /dev/null
+++ b/lvdm/modules/ae_modules.py
@@ -0,0 +1,794 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+
+import torch
+import numpy as np
+import torch.nn as nn
+from einops import rearrange
+
+from ..utils import instantiate_from_config
+from .attention import LinearAttention
+from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
+from diffusers.models.autoencoders.vae import DecoderOutput
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w) # bcl
+ q = q.permute(0,2,1) # bcl -> blc l=hw
+ k = k.reshape(b,c,h*w) # bcl
+
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
+ #print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ self.in_channels = in_channels
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ self.in_channels = in_channels
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t=None, context=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, return_hidden_states=False):
+ # timestep embedding
+ temb = None
+
+ # print(f'encoder-input={x.shape}')
+ # downsampling
+ hs = [self.conv_in(x)]
+
+ ## if we return hidden states for decoder usage, we will store them in a list
+ if return_hidden_states:
+ hidden_states = []
+ # print(f'encoder-conv in feat={hs[0].shape}')
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ # print(f'encoder-down feat={h.shape}')
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if return_hidden_states:
+ hidden_states.append(h)
+ if i_level != self.num_resolutions-1:
+ # print(f'encoder-downsample (input)={hs[-1].shape}')
+ hs.append(self.down[i_level].downsample(hs[-1]))
+ # print(f'encoder-downsample (output)={hs[-1].shape}')
+ if return_hidden_states:
+ hidden_states.append(hs[0])
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ # print(f'encoder-mid1 feat={h.shape}')
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+ # print(f'encoder-mid2 feat={h.shape}')
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ # print(f'end feat={h.shape}')
+ if return_hidden_states:
+ return h, hidden_states
+ else:
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", **ignorekwargs):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ print("AE working on z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # print(f'decoder-input={z.shape}')
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+ # print(f'decoder-conv in feat={h.shape}')
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+ # print(f'decoder-mid feat={h.shape}')
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ # print(f'decoder-up feat={h.shape}')
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+ # print(f'decoder-upsample feat={h.shape}')
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ # print(f'decoder-conv_out feat={h.shape}')
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return DecoderOutput(sample=h)
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+
+ self.conv_out = nn.Conv2d(mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels*ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+ out_channels=tmp_chn, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size//in_size))+1
+ factor_up = 1.+ (out_size % in_size)
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+ attn_resolutions=[], in_channels=None, ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor==1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+ return x
diff --git a/lvdm/modules/attention.py b/lvdm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..76ae58a01dbc926c956a27fe95a93679b9381590
--- /dev/null
+++ b/lvdm/modules/attention.py
@@ -0,0 +1,567 @@
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from functools import partial
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+from ..common import (
+ checkpoint,
+ exists,
+ default,
+)
+from ..basics import zero_module
+
+
+class RelativePosition(nn.Module):
+ """ https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """
+
+ def __init__(self, num_units, max_relative_position):
+ super().__init__()
+ self.num_units = num_units
+ self.max_relative_position = max_relative_position
+ self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
+ nn.init.xavier_uniform_(self.embeddings_table)
+
+ def forward(self, length_q, length_k):
+ device = self.embeddings_table.device
+ range_vec_q = torch.arange(length_q, device=device)
+ range_vec_k = torch.arange(length_k, device=device)
+ distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
+ distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
+ final_mat = distance_mat_clipped + self.max_relative_position
+ final_mat = final_mat.long()
+ embeddings = self.embeddings_table[final_mat]
+ return embeddings
+
+
+class CrossAttention(nn.Module):
+
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., use_xformers=True,
+ relative_position=False, temporal_length=None, video_length=None, image_cross_attention=False, image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False, text_context_len=77,
+ layer_cross_attention=False, layer_cross_attention_scale=1.0, layer_cross_attention_scale_learnable=False):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ self.dim_head = dim_head
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+
+ self.relative_position = relative_position
+ if self.relative_position:
+ assert(temporal_length is not None)
+ self.relative_position_k = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
+ self.relative_position_v = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
+ else:
+ ## only used for spatial attention, while NOT for temporal attention
+ if XFORMERS_IS_AVAILBLE and temporal_length is None and use_xformers:
+ self.forward = self.efficient_forward
+
+ self.video_length = video_length
+ self.image_cross_attention = image_cross_attention
+ self.image_cross_attention_scale = image_cross_attention_scale
+ self.text_context_len = text_context_len
+ self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable
+ self.layer_cross_attention = layer_cross_attention
+ self.layer_cross_attention_scale = layer_cross_attention_scale
+ self.layer_cross_attention_scale_learnable = layer_cross_attention_scale_learnable
+ if self.image_cross_attention:
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
+ if image_cross_attention_scale_learnable:
+ self.register_parameter('alpha', nn.Parameter(torch.tensor(0.)) )
+ if self.layer_cross_attention:
+ self.to_k_layer = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_v_layer = nn.Linear(query_dim, inner_dim, bias=False)
+ if layer_cross_attention_scale_learnable:
+ self.register_parameter('layer_alpha', nn.Parameter(torch.tensor(0.)) )
+
+
+ def forward(self, x, context=None, mask=None, layer_feature=None, layer_mask=None, layer_score=None):
+ spatial_self_attn = (context is None)
+ k_ip, v_ip, out_ip = None, None, None
+
+ h = self.heads
+ q = self.to_q(x)
+ context = default(context, x)
+
+ if self.image_cross_attention and not spatial_self_attn:
+ context, context_image = context[:,:self.text_context_len,:], context[:,self.text_context_len:,:]
+ k = self.to_k(context)
+ v = self.to_v(context)
+ k_ip = self.to_k_ip(context_image)
+ v_ip = self.to_v_ip(context_image)
+ else:
+ if not spatial_self_attn:
+ context = context[:,:self.text_context_len,:]
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if self.layer_cross_attention:
+ k_layer = self.to_k_layer(layer_feature)
+ v_layer = self.to_v_layer(layer_feature)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
+ if self.relative_position:
+ len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
+ k2 = self.relative_position_k(len_q, len_k)
+ sim2 = einsum('b t d, t s d -> b t s', q, k2) * self.scale # TODO check
+ sim += sim2
+ del k
+
+ if exists(mask):
+ ## feasible for causal attention mask only
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b i j -> (b h) i j', h=h)
+ sim.masked_fill_(~(mask>0.5), max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = torch.einsum('b i j, b j d -> b i d', sim, v)
+ if self.relative_position:
+ v2 = self.relative_position_v(len_q, len_v)
+ out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check
+ out += out2
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+
+
+ ## for image cross-attention
+ if k_ip is not None:
+ k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (k_ip, v_ip))
+ sim_ip = torch.einsum('b i d, b j d -> b i j', q, k_ip) * self.scale
+ del k_ip
+ sim_ip = sim_ip.softmax(dim=-1)
+ out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
+ out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
+
+ if self.layer_cross_attention:
+ k_layer, v_layer = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (k_layer, v_layer))
+ sim_layer = torch.einsum('b i d, b j d -> b i j', q, k_layer) * self.scale
+ del k_layer
+ if exists(layer_mask):
+ max_neg_value = -torch.finfo(sim_layer.dtype).max
+ f = x.shape[0] // layer_mask.shape[0] # bf / b = f
+ i = x.shape[1] # hw
+ j_times = layer_feature.shape[1] // layer_mask.shape[1] # nhw / n = hw
+ assert i == j_times
+ layer_mask = repeat(layer_mask, 'b n -> (b f) i (n j)', f=f, i=i, j=j_times)
+ layer_mask = repeat(layer_mask, 'b i j -> (b h) i j', h=h)
+ sim_layer.masked_fill_(~layer_mask, max_neg_value)
+ sim_layer = sim_layer.softmax(dim=-1)
+
+ if exists(layer_score):
+ f = x.shape[0] // layer_score.shape[0] # bf / b = f
+ i = x.shape[1] # hw
+ j_times = layer_feature.shape[1] // layer_score.shape[1] # nhw / n = hw
+ assert i == j_times
+ weight = repeat(layer_score, 'b n -> (b f) i (n j)', f=f, i=i, j=j_times)
+ weight = repeat(weight, 'b i j -> (b h) i j', h=h)
+ sim_layer = sim_layer * weight
+ sim_layer = sim_layer / sim_layer.sum(dim=-1, keepdim=True)
+ out_layer = torch.einsum('b i j, b j d -> b i d', sim_layer, v_layer)
+ out_layer = rearrange(out_layer, '(b h) n d -> b n (h d)', h=h)
+
+ if out_ip is not None:
+ if self.image_cross_attention_scale_learnable:
+ out = out + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha)+1)
+ else:
+ out = out + self.image_cross_attention_scale * out_ip
+
+ if self.layer_cross_attention:
+ if self.layer_cross_attention_scale_learnable:
+ out = out + self.layer_cross_attention_scale * out_layer * (torch.tanh(self.layer_alpha)+1)
+ else:
+ out = out + self.layer_cross_attention_scale * out_layer
+ return self.to_out(out)
+
+ def efficient_forward(self, x, context=None, mask=None, layer_feature=None, layer_mask=None, layer_score=None):
+ assert layer_feature is None, "layer cross-attention is not supported in efficient_forward"
+ spatial_self_attn = (context is None)
+ k_ip, v_ip, out_ip = None, None, None
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ if self.image_cross_attention and not spatial_self_attn:
+ context, context_image = context[:,:self.text_context_len,:], context[:,self.text_context_len:,:]
+ k = self.to_k(context)
+ v = self.to_v(context)
+ k_ip = self.to_k_ip(context_image)
+ v_ip = self.to_v_ip(context_image)
+ else:
+ if not spatial_self_attn:
+ context = context[:,:self.text_context_len,:]
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None)
+
+ ## for image cross-attention
+ if k_ip is not None:
+ k_ip, v_ip = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (k_ip, v_ip),
+ )
+ out_ip = xformers.ops.memory_efficient_attention(q, k_ip, v_ip, attn_bias=None, op=None)
+ out_ip = (
+ out_ip.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ if out_ip is not None:
+ if self.image_cross_attention_scale_learnable:
+ out = out + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha)+1)
+ else:
+ out = out + self.image_cross_attention_scale * out_ip
+
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, use_xformers=True,
+ disable_self_attn=False, attention_cls=None, video_length=None, image_cross_attention=False, image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False, text_context_len=77,
+ layer_cross_attention=False, layer_cross_attention_scale=1.0, layer_cross_attention_scale_learnable=False):
+ super().__init__()
+ attn_cls = CrossAttention if attention_cls is None else attention_cls
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None, use_xformers=use_xformers,
+ layer_cross_attention=layer_cross_attention, layer_cross_attention_scale=layer_cross_attention_scale, layer_cross_attention_scale_learnable=layer_cross_attention_scale_learnable)
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, video_length=video_length,
+ context_dim=context_dim, use_xformers=use_xformers, text_context_len=text_context_len,
+ image_cross_attention=image_cross_attention, image_cross_attention_scale=image_cross_attention_scale, image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,
+ layer_cross_attention=layer_cross_attention, layer_cross_attention_scale=layer_cross_attention_scale, layer_cross_attention_scale_learnable=layer_cross_attention_scale_learnable)
+ self.image_cross_attention = image_cross_attention
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+
+ def forward(self, x, context=None, mask=None, **kwargs):
+ ## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
+ input_tuple = (x,) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments
+ if context is not None:
+ input_tuple = (x, context)
+ if mask is not None:
+ kwargs['mask'] = mask
+ return checkpoint(partial(self._forward, **kwargs), input_tuple, self.parameters(), self.checkpoint)
+
+
+ def _forward(self, x, context=None, mask=None, **kwargs):
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask, **kwargs) + x
+ x = self.attn2(self.norm2(x), context=context, mask=mask, **kwargs) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data in spatial axis.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+
+ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None,
+ use_checkpoint=True, disable_self_attn=False, use_linear=False, video_length=None,
+ image_cross_attention=False, image_cross_attention_scale_learnable=False, use_xformers=True,
+ layer_cross_attention=False, layer_cross_attention_scale_learnable=False):
+ super().__init__()
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ attention_cls = None
+ self.transformer_blocks = nn.ModuleList([
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ disable_self_attn=disable_self_attn,
+ checkpoint=use_checkpoint,
+ attention_cls=attention_cls,
+ video_length=video_length,
+ image_cross_attention=image_cross_attention,
+ image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,
+ use_xformers=use_xformers,
+ layer_cross_attention=layer_cross_attention,
+ layer_cross_attention_scale_learnable=layer_cross_attention_scale_learnable
+ ) for d in range(depth)
+ ])
+ if not use_linear:
+ self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
+ self.use_linear = use_linear
+
+
+ def forward(self, x, context=None, **kwargs):
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context, **kwargs)
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
+
+class TemporalTransformer(nn.Module):
+ """
+ Transformer block for image-like data in temporal axis.
+ First, reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None,
+ use_checkpoint=True, use_linear=False, only_self_att=True, causal_attention=False, causal_block_size=1,
+ relative_position=False, temporal_length=None):
+ super().__init__()
+ self.only_self_att = only_self_att
+ self.relative_position = relative_position
+ self.causal_attention = causal_attention
+ self.causal_block_size = causal_block_size
+
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ if not use_linear:
+ self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ if relative_position:
+ assert(temporal_length is not None)
+ attention_cls = partial(CrossAttention, relative_position=True, temporal_length=temporal_length)
+ else:
+ attention_cls = partial(CrossAttention, temporal_length=temporal_length)
+ if self.causal_attention:
+ assert(temporal_length is not None)
+ self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
+
+ if self.only_self_att:
+ context_dim = None
+ self.transformer_blocks = nn.ModuleList([
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ attention_cls=attention_cls,
+ checkpoint=use_checkpoint) for d in range(depth)
+ ])
+ if not use_linear:
+ self.proj_out = zero_module(nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ b, c, t, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous()
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'bhw c t -> bhw t c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+
+ temp_mask = None
+ if self.causal_attention:
+ # slice the from mask map
+ temp_mask = self.mask[:,:t,:t].to(x.device)
+
+ if temp_mask is not None:
+ mask = temp_mask.to(x.device)
+ mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w)
+ else:
+ mask = None
+
+ if self.only_self_att:
+ ## note: if no context is given, cross-attention defaults to self-attention
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, mask=mask)
+ x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
+ else:
+ x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
+ context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous()
+ for i, block in enumerate(self.transformer_blocks):
+ # calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
+ for j in range(b):
+ context_j = repeat(
+ context[j],
+ 't l con -> (t r) l con', r=(h * w) // t, t=t).contiguous()
+ ## note: causal mask will not applied in cross-attention case
+ x[j] = block(x[j], context=context_j)
+
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous()
+ x = self.proj_out(x)
+ x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous()
+
+ return x + x_in
+
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+
+ return x+h_
diff --git a/lvdm/modules/attention_svd.py b/lvdm/modules/attention_svd.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ceb3c978025c9bb9a640d63558a20a4989d377
--- /dev/null
+++ b/lvdm/modules/attention_svd.py
@@ -0,0 +1,759 @@
+import logging
+import math
+from inspect import isfunction
+from typing import Any, Optional
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from packaging import version
+from torch import nn
+from torch.utils.checkpoint import checkpoint
+
+logpy = logging.getLogger(__name__)
+
+if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ SDP_IS_AVAILABLE = True
+ from torch.backends.cuda import SDPBackend, sdp_kernel
+
+ BACKEND_MAP = {
+ SDPBackend.MATH: {
+ "enable_math": True,
+ "enable_flash": False,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.FLASH_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": True,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.EFFICIENT_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": False,
+ "enable_mem_efficient": True,
+ },
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
+ }
+else:
+ from contextlib import nullcontext
+
+ SDP_IS_AVAILABLE = False
+ sdp_kernel = nullcontext
+ BACKEND_MAP = {}
+ logpy.warn(
+ f"No SDP backend available, likely because you are running in pytorch "
+ f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
+ f"You might want to consider upgrading."
+ )
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILABLE = True
+except:
+ XFORMERS_IS_AVAILABLE = False
+ logpy.warn("no module 'xformers'. Processing without...")
+
+# from .diffusionmodules.util import mixed_checkpoint as checkpoint
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return {el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = (
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+ if not glu
+ else GEGLU(dim, inner_dim)
+ )
+
+ self.net = nn.Sequential(
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
+ )
+ k = k.softmax(dim=-1)
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
+ out = rearrange(
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
+ )
+ return self.to_out(out)
+
+
+class SelfAttention(nn.Module):
+ ATTENTION_MODES = ("xformers", "torch", "math")
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_scale: Optional[float] = None,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ attn_mode: str = "xformers",
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ assert attn_mode in self.ATTENTION_MODES
+ self.attn_mode = attn_mode
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, L, C = x.shape
+
+ qkv = self.qkv(x)
+ if self.attn_mode == "torch":
+ qkv = rearrange(
+ qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
+ ).float()
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
+ x = rearrange(x, "B H L D -> B L (H D)")
+ elif self.attn_mode == "xformers":
+ qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
+ x = xformers.ops.memory_efficient_attention(q, k, v)
+ x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
+ elif self.attn_mode == "math":
+ qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
+ else:
+ raise NotImplemented
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = rearrange(q, "b c h w -> b (h w) c")
+ k = rearrange(k, "b c h w -> b c (h w)")
+ w_ = torch.einsum("bij,bjk->bik", q, k)
+
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, "b c h w -> b c (h w)")
+ w_ = rearrange(w_, "b i j -> b j i")
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.0,
+ backend=None,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+ self.backend = backend
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ ):
+ h = self.heads
+
+ if additional_tokens is not None:
+ # get the number of masked tokens at the beginning of the output sequence
+ n_tokens_to_mask = additional_tokens.shape[1]
+ # add additional token
+ x = torch.cat([additional_tokens, x], dim=1)
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if n_times_crossframe_attn_in_self:
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
+ k = repeat(
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
+ )
+ v = repeat(
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
+ )
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
+
+ ## old
+ """
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ del q, k
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', sim, v)
+ """
+ ## new
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
+ out = F.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask
+ ) # scale is dim_head ** -0.5 per default
+
+ del q, k, v
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
+
+ if additional_tokens is not None:
+ # remove additional token
+ out = out[:, n_tokens_to_mask:]
+ return self.to_out(out)
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
+ ):
+ super().__init__()
+ logpy.debug(
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
+ f"context_dim is {context_dim} and using {heads} heads with a "
+ f"dimension of {dim_head}."
+ )
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+ self.attention_op: Optional[Any] = None
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ ):
+ if additional_tokens is not None:
+ # get the number of masked tokens at the beginning of the output sequence
+ n_tokens_to_mask = additional_tokens.shape[1]
+ # add additional token
+ x = torch.cat([additional_tokens, x], dim=1)
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if n_times_crossframe_attn_in_self:
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
+ k = repeat(
+ k[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+ v = repeat(
+ v[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ if version.parse(xformers.__version__) >= version.parse("0.0.21"):
+ # NOTE: workaround for
+ # https://github.com/facebookresearch/xformers/issues/845
+ max_bs = 32768
+ N = q.shape[0]
+ n_batches = math.ceil(N / max_bs)
+ out = list()
+ for i_batch in range(n_batches):
+ batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
+ out.append(
+ xformers.ops.memory_efficient_attention(
+ q[batch],
+ k[batch],
+ v[batch],
+ attn_bias=None,
+ op=self.attention_op,
+ )
+ )
+ out = torch.cat(out, 0)
+ else:
+ out = xformers.ops.memory_efficient_attention(
+ q, k, v, attn_bias=None, op=self.attention_op
+ )
+
+ # TODO: Use this directly in the attention operation, as a bias
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ if additional_tokens is not None:
+ # remove additional token
+ out = out[:, n_tokens_to_mask:]
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ disable_self_attn=False,
+ attn_mode="softmax",
+ sdp_backend=None,
+ ):
+ super().__init__()
+ assert attn_mode in self.ATTENTION_MODES
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
+ logpy.warn(
+ f"Attention mode '{attn_mode}' is not available. Falling "
+ f"back to native attention. This is not a problem in "
+ f"Pytorch >= 2.0. FYI, you are running with PyTorch "
+ f"version {torch.__version__}."
+ )
+ attn_mode = "softmax"
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
+ logpy.warn(
+ "We do not support vanilla attention anymore, as it is too "
+ "expensive. Sorry."
+ )
+ if not XFORMERS_IS_AVAILABLE:
+ assert (
+ False
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
+ else:
+ logpy.info("Falling back to xformers efficient attention.")
+ attn_mode = "softmax-xformers"
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
+ else:
+ assert sdp_backend is None
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None,
+ backend=sdp_backend,
+ ) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ backend=sdp_backend,
+ ) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+ if self.checkpoint:
+ logpy.debug(f"{self.__class__.__name__} is using checkpointing")
+
+ def forward(
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
+ ):
+ kwargs = {"x": x}
+
+ if context is not None:
+ kwargs.update({"context": context})
+
+ if additional_tokens is not None:
+ kwargs.update({"additional_tokens": additional_tokens})
+
+ if n_times_crossframe_attn_in_self:
+ kwargs.update(
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
+ )
+
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
+ if self.checkpoint:
+ # inputs = {"x": x, "context": context}
+ return checkpoint(self._forward, x, context)
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
+ else:
+ return self._forward(**kwargs)
+
+ def _forward(
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
+ ):
+ x = (
+ self.attn1(
+ self.norm1(x),
+ context=context if self.disable_self_attn else None,
+ additional_tokens=additional_tokens,
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
+ if not self.disable_self_attn
+ else 0,
+ )
+ + x
+ )
+ x = (
+ self.attn2(
+ self.norm2(x), context=context, additional_tokens=additional_tokens
+ )
+ + x
+ )
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class BasicTransformerSingleLayerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ attn_mode="softmax",
+ ):
+ super().__init__()
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ )
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ # inputs = {"x": x, "context": context}
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
+ return checkpoint(self._forward, x, context)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context) + x
+ x = self.ff(self.norm2(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ context_dim=None,
+ disable_self_attn=False,
+ use_linear=False,
+ attn_type="softmax",
+ use_checkpoint=True,
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
+ sdp_backend=None,
+ ):
+ super().__init__()
+ logpy.debug(
+ f"constructing {self.__class__.__name__} of depth {depth} w/ "
+ f"{in_channels} channels and {n_heads} heads."
+ )
+
+ if exists(context_dim) and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ if exists(context_dim) and isinstance(context_dim, list):
+ if depth != len(context_dim):
+ logpy.warn(
+ f"{self.__class__.__name__}: Found context dims "
+ f"{context_dim} of depth {len(context_dim)}, which does not "
+ f"match the specified 'depth' of {depth}. Setting context_dim "
+ f"to {depth * [context_dim[0]]} now."
+ )
+ # depth does not match context dims.
+ assert all(
+ map(lambda x: x == context_dim[0], context_dim)
+ ), "need homogenous context_dim to match depth automatically"
+ context_dim = depth * [context_dim[0]]
+ elif context_dim is None:
+ context_dim = [None] * depth
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn,
+ attn_mode=attn_type,
+ checkpoint=use_checkpoint,
+ sdp_backend=sdp_backend,
+ )
+ for d in range(depth)
+ ]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ )
+ else:
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ if i > 0 and len(context) == 1:
+ i = 0 # use same context for each block
+ x = block(x, context=context[i])
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
+
+class SimpleTransformer(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ depth: int,
+ heads: int,
+ dim_head: int,
+ context_dim: Optional[int] = None,
+ dropout: float = 0.0,
+ checkpoint: bool = True,
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ BasicTransformerBlock(
+ dim,
+ heads,
+ dim_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ attn_mode="softmax-xformers",
+ checkpoint=checkpoint,
+ )
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ context: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ for layer in self.layers:
+ x = layer(x, context)
+ return x
\ No newline at end of file
diff --git a/lvdm/pipelines/pipeline_animation.py b/lvdm/pipelines/pipeline_animation.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc6479d153a4e28b9eb2f0e7670aacaca9ca9886
--- /dev/null
+++ b/lvdm/pipelines/pipeline_animation.py
@@ -0,0 +1,550 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from dataclasses import dataclass
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+from PIL import Image
+import torch
+import torch.nn.functional as F
+
+
+from diffusers.utils import is_accelerate_available
+
+from ..models.unet import UNetModel
+from ..models.autoencoder import AutoencoderKL, AutoencoderKL_Dualref
+from ..models.condition import FrozenOpenCLIPEmbedder, FrozenOpenCLIPImageEmbedderV2, Resampler
+from ..models.layer_controlnet import LayerControlNet
+
+from diffusers.schedulers import DDIMScheduler
+from diffusers.utils import BaseOutput, logging
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
+
+from einops import rearrange
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class AnimationPipelineOutput(BaseOutput):
+ videos: Union[List[Image.Image], np.ndarray]
+
+
+class AnimationPipeline(DiffusionPipeline):
+ model_cpu_offload_seq = "image_encoder->unet->vae"
+ _callback_tensor_inputs = ["latents"]
+ def __init__(
+ self,
+ vae,
+ vae_dualref,
+ text_encoder,
+ image_encoder,
+ image_projector,
+ unet: UNetModel,
+ layer_controlnet: LayerControlNet,
+ scheduler: DDIMScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ vae_dualref=vae_dualref,
+ text_encoder=text_encoder,
+ image_encoder=image_encoder,
+ image_projector=image_projector,
+ unet=unet,
+ layer_controlnet=layer_controlnet,
+ scheduler=scheduler,
+ )
+ if vae is not None:
+ self.vae_scale_factor = 2 ** (len(self.vae.config.ddconfig["ch_mult"]) - 1)
+ else:
+ self.vae_scale_factor = 2 ** (len(self.vae_dualref.config.ddconfig["ch_mult"]) - 1)
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.layer_encoder, self.text_encoder, self.vae, self.vae_dualref]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ @property
+ def _execution_device(self):
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_embeddings = self.text_encoder(prompt)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ uncond_embeddings = self.text_encoder(uncond_tokens)
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance):
+ batch_size = image.shape[0]
+
+ image_embeddings = self.image_encoder(image)
+ image_embeddings = self.image_projector(image_embeddings)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_embeddings = self.image_encoder(torch.zeros_like(image))
+ uncond_embeddings = self.image_projector(uncond_embeddings)
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and image embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
+
+ return image_embeddings
+
+ def _encode_controls(
+ self,
+ layer_masks,
+ layer_regions,
+ layer_validity,
+ motion_scores,
+ layer_static,
+ trajectories,
+ sketches,
+ video_length,
+ mode,
+ device,
+ num_videos_per_prompt,
+ do_classifier_free_guidance
+ ):
+ vae = self.vae if self.vae is not None else self.vae_dualref
+
+ batch_size, n_layers = layer_masks.shape[:2]
+ # Frame decomposition
+ layer_regions = rearrange(layer_regions, "b n f c h w -> (b n f) c h w")
+ keyframe_layer_latents = vae.encode(layer_regions)[0].sample() * 0.18215
+ keyframe_layer_latents = rearrange(keyframe_layer_latents, "(b n f) c h w -> b n f c h w", b=batch_size, n=n_layers)
+ layer_latents_shape = list(keyframe_layer_latents.shape)
+ layer_latents_shape[2] = video_length
+ layer_latents = torch.zeros(layer_latents_shape, device=device, dtype=keyframe_layer_latents.dtype)
+ resized_layer_masks = rearrange(layer_masks, "b n f c h w -> (b n f) c h w")
+ resized_layer_masks = F.interpolate(resized_layer_masks.float(), size=layer_latents.shape[-2:], mode="bilinear")
+ resized_layer_masks = rearrange(resized_layer_masks, "(b n f) c h w -> b n f c h w", b=batch_size, n=n_layers).to(dtype=layer_latents.dtype)
+ layer_latent_mask_shape = list(resized_layer_masks.shape)
+ layer_latent_mask_shape[2] = video_length
+ layer_latent_mask = torch.zeros(layer_latent_mask_shape, device=device, dtype=resized_layer_masks.dtype)
+
+ for batch_idx in range(batch_size):
+ if mode != "interpolate":
+ layer_latents[batch_idx, :, 0] = keyframe_layer_latents[batch_idx, :, 0]
+ layer_latent_mask[batch_idx, :, 0] = resized_layer_masks[batch_idx, :, 0]
+ if layer_static[batch_idx].any():
+ static_indices = torch.nonzero(layer_static[batch_idx]).squeeze(1)
+ layer_latents[batch_idx, static_indices, :] = keyframe_layer_latents[batch_idx, static_indices, 0:1].repeat(1, video_length, 1, 1, 1)
+ layer_latent_mask[batch_idx, static_indices, :] = resized_layer_masks[batch_idx, static_indices, 0:1].repeat(1, video_length, 1, 1, 1)
+ else:
+ layer_latents[batch_idx, :, 0] = keyframe_layer_latents[batch_idx, :, 0]
+ layer_latents[batch_idx, :, -1] = keyframe_layer_latents[batch_idx, :, -1]
+ layer_latent_mask[batch_idx, :, 0] = resized_layer_masks[batch_idx, :, 0]
+ layer_latent_mask[batch_idx, :, -1] = resized_layer_masks[batch_idx, :, -1]
+ if layer_static[batch_idx].any():
+ static_indices = torch.nonzero(layer_static[batch_idx]).squeeze(1)
+ layer_latents[batch_idx, static_indices, :video_length//2] = keyframe_layer_latents[batch_idx, static_indices, 0:1].repeat(1, video_length//2, 1, 1, 1)
+ layer_latents[batch_idx, static_indices, video_length//2:] = keyframe_layer_latents[batch_idx, static_indices, -1:].repeat(1, video_length//2, 1, 1, 1)
+ layer_latent_mask[batch_idx, static_indices, :video_length//2] = resized_layer_masks[batch_idx, static_indices, 0:1].repeat(1, video_length//2, 1, 1, 1)
+ layer_latent_mask[batch_idx, static_indices, video_length//2:] = resized_layer_masks[batch_idx, static_indices, -1:].repeat(1, video_length//2, 1, 1, 1)
+ layer_latents = torch.repeat_interleave(layer_latents, num_videos_per_prompt, dim=0)
+ layer_latent_mask = torch.repeat_interleave(layer_latent_mask, num_videos_per_prompt, dim=0)
+ layer_validity = torch.repeat_interleave(layer_validity, num_videos_per_prompt, dim=0)
+
+ sketches = rearrange(sketches, 'b n f c h w -> (b n f) c h w')
+ layer_sketch_latents = vae.encode(sketches)[0].sample() * 0.18215
+ layer_sketch_latents = rearrange(layer_sketch_latents, '(b n f) c h w -> b n f c h w', b=batch_size, n=n_layers)
+ layer_sketch_latents = torch.repeat_interleave(layer_sketch_latents, num_videos_per_prompt, dim=0)
+
+ trajectories = torch.repeat_interleave(trajectories, num_videos_per_prompt, dim=0)
+
+ motion_scores = torch.repeat_interleave(motion_scores, num_videos_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ layer_latents = torch.cat([layer_latents, layer_latents], dim=0)
+ layer_latent_mask = torch.cat([layer_latent_mask, layer_latent_mask], dim=0)
+ motion_scores = torch.cat([motion_scores, motion_scores], dim=0)
+ layer_sketch_latents = torch.cat([layer_sketch_latents, layer_sketch_latents], dim=0)
+ trajectories = torch.cat([trajectories, trajectories], dim=0)
+ layer_validity = torch.cat([layer_validity, layer_validity], dim=0)
+ return dict(
+ layer_latents=layer_latents,
+ layer_latent_mask=layer_latent_mask,
+ motion_scores=motion_scores,
+ sketch=layer_sketch_latents,
+ trajectory=trajectories,
+ layer_validity=layer_validity,
+ )
+
+ def get_latent_z_with_hidden_states(self, videos):
+ b, f, c, h, w = videos.shape
+ x = rearrange(videos, 'b f c h w -> (b f) c h w')
+ encoder_posterior, hidden_states = self.vae_dualref.encode(x, return_hidden_states=True)
+ hidden_states_first_last = []
+ ### use only the first and last hidden states
+ for hid in hidden_states:
+ hid = rearrange(hid, '(b f) c h w -> b c f h w', f=f)
+ hid_new = torch.cat([hid[:, :, 0:1], hid[:, :, -1:]], dim=2)
+ hidden_states_first_last.append(hid_new.float())
+
+ z = encoder_posterior[0].sample() * 0.18215
+ z = rearrange(z, '(b f) c h w -> b c f h w', b=b, f=f).detach()
+ return z, hidden_states_first_last
+
+ def get_latent_z(self, videos):
+ b, f, c, h, w = videos.shape
+ x = rearrange(videos, 'b f c h w -> (b f) c h w')
+ z = self.vae.encode(x)[0].sample() * 0.18215
+ z = rearrange(z, '(b f) c h w -> b c f h w', b=b, f=f).detach()
+ return z
+
+ def decode_latents(self, latents):
+ batch_size = latents.shape[0]
+ video_length = latents.shape[2]
+ latents = 1 / 0.18215 * latents
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
+ video = []
+ for batch_idx in range(batch_size):
+ video.append(self.vae.decode(latents[batch_idx * video_length:(batch_idx + 1) * video_length]).sample)
+ video = torch.cat(video, dim=0)
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
+ video = (video / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ video = video.cpu().float().numpy()
+ return video
+
+ def decode_latents_with_hidden_states(self, latents, hidden_states):
+ batch_size = latents.shape[0]
+ video_length = latents.shape[2]
+ latents = 1 / 0.18215 * latents
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
+ video = []
+ for batch_idx in range(batch_size):
+ video.append(self.vae_dualref.decode(latents[batch_idx * video_length:(batch_idx + 1) * video_length].float(), ref_context=hidden_states, timesteps=video_length).sample)
+ video = torch.cat(video, dim=0)
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
+ video = (video / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ video = video.cpu().float().numpy()
+ return video
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ if latents is None:
+ rand_device = device
+
+ if isinstance(generator, list):
+ latents = [
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
+ for i in range(batch_size)
+ ]
+ latents = torch.cat(latents, dim=0).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ video_length: int,
+ height: int,
+ width: int,
+ frame_tensor: torch.FloatTensor,
+ layer_masks: torch.FloatTensor, # [b, n_layers, 1 (2), c, h, w]
+ layer_regions: torch.FloatTensor, # [b, n_layers, 1 (2), c, h, w]
+ layer_static: torch.Tensor, # [b, n_layers]
+ motion_scores: torch.Tensor, # [b, n_layers]
+ sketch: torch.FloatTensor, # [b, n_layers, f, c, h, w]
+ trajectory: torch.FloatTensor, # [b, n_layers, f, c, h, w]
+ layer_validity: torch.Tensor, # [b, n_layers]
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ guidance_rescale: float=0.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: Optional[str] = "tensor",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+
+ fps: Optional[int] = 24,
+ mode: str = "interpolate",
+ weight_dtype: torch.dtype = torch.float32,
+
+ **kwargs,
+ ):
+ # Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # Define call parameters
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ batch_size = len(frame_tensor)
+ if isinstance(prompt, list):
+ batch_size = len(prompt)
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ fps = torch.tensor([fps] * batch_size * num_videos_per_prompt, device=device, dtype=weight_dtype)
+ frame_tensor = frame_tensor.to(dtype=weight_dtype)
+ layer_regions = layer_regions.to(dtype=weight_dtype)
+ motion_scores = motion_scores.to(dtype=weight_dtype)
+ sketch = sketch.to(dtype=weight_dtype)
+ trajectory = trajectory.to(dtype=weight_dtype)
+
+ # Encode layer-level controls
+ encoded_layer_controls = self._encode_controls(
+ layer_masks,
+ layer_regions,
+ layer_validity,
+ motion_scores,
+ layer_static,
+ trajectory,
+ sketch,
+ video_length,
+ mode,
+ device,
+ num_videos_per_prompt,
+ do_classifier_free_guidance
+ )
+ layer_validity = encoded_layer_controls.pop("layer_validity")
+
+ # Encode input prompt
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
+ if negative_prompt is not None:
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ cond_frame = frame_tensor[:, 0] # [b, f, c, h, w] -> [b, c, h, w]
+ image_embeddings = self._encode_image(
+ cond_frame, device, num_videos_per_prompt, do_classifier_free_guidance
+ )
+
+ if mode == "interpolate":
+ z, hidden_states = self.get_latent_z_with_hidden_states(frame_tensor)
+ else:
+ z = self.get_latent_z(frame_tensor)
+ z = z.to(dtype=weight_dtype)
+ if mode != "interpolate":
+ img_cat_cond = z[:, :, :1]
+ img_cat_cond = img_cat_cond.repeat(1, 1, video_length, 1, 1)
+ else:
+ img_cat_cond = torch.zeros_like(z[:, :, :1].repeat(1, 1, video_length, 1, 1))
+ img_cat_cond[:, :, 0] = z[:, :, 0]
+ img_cat_cond[:, :, -1] = z[:, :, -1]
+ img_cat_cond = torch.repeat_interleave(img_cat_cond, num_videos_per_prompt, dim=0)
+ if do_classifier_free_guidance:
+ img_cat_cond = torch.cat([img_cat_cond, img_cat_cond], dim=0)
+ fps = torch.cat([fps, fps], dim=0)
+
+ # Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # Prepare latent variables
+ num_channels_latents = self.unet.out_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ video_length,
+ height,
+ width,
+ weight_dtype,
+ device,
+ generator,
+ )
+
+ # Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ noise_with_img = torch.cat([latent_model_input, img_cat_cond], dim=1)
+
+ if do_classifier_free_guidance:
+ ts = torch.full((batch_size * num_videos_per_prompt * 2,), t, device=device, dtype=torch.long)
+ else:
+ ts = torch.full((batch_size * num_videos_per_prompt,), t, device=device, dtype=torch.long)
+
+ layer_features = self.layer_controlnet(
+ noise_with_img, ts,
+ context_text=text_embeddings,
+ context_img=image_embeddings,
+ fps=fps,
+ **encoded_layer_controls
+ )
+ noise_pred = self.unet(
+ noise_with_img, ts,
+ context_text=text_embeddings,
+ context_img=image_embeddings,
+ fps=fps,
+ controls=layer_features,
+ layer_validity=layer_validity,
+ ).sample.to(dtype=weight_dtype)
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # Post-processing
+ if mode == "interpolate":
+ video = self.decode_latents_with_hidden_states(latents, hidden_states)
+ else:
+ video = self.decode_latents(latents)
+
+ # Convert to tensor
+ if output_type == "tensor":
+ video = torch.from_numpy(video)
+
+ if not return_dict:
+ return video
+
+ return AnimationPipelineOutput(videos=video)
diff --git a/lvdm/utils.py b/lvdm/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb9cce23e63dedf112324e15d8baa7cba37da7ee
--- /dev/null
+++ b/lvdm/utils.py
@@ -0,0 +1,394 @@
+import importlib
+import numpy as np
+import cv2
+import torch
+import torch.distributed as dist
+import os
+from einops import rearrange
+import imageio
+import torchvision
+from PIL import Image
+import io
+from matplotlib import pyplot as plt
+
+
+RY = 15
+YG = 6
+GC = 4
+CB = 11
+BM = 13
+MR = 6
+
+COLORWHEEL = torch.zeros((RY + YG + GC + CB + BM + MR, 3))
+col = 0
+
+# RY
+COLORWHEEL[0:RY, 0] = 255
+COLORWHEEL[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY)
+col = col + RY
+# YG
+COLORWHEEL[col:col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
+COLORWHEEL[col:col + YG, 1] = 255
+col = col + YG
+# GC
+COLORWHEEL[col:col + GC, 1] = 255
+COLORWHEEL[col:col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC)
+col = col + GC
+# CB
+COLORWHEEL[col:col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB)
+COLORWHEEL[col:col + CB, 2] = 255
+col = col + CB
+# BM
+COLORWHEEL[col:col + BM, 2] = 255
+COLORWHEEL[col:col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM)
+col = col + BM
+# MR
+COLORWHEEL[col:col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR)
+COLORWHEEL[col:col + MR, 0] = 255
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+ return total_params
+
+
+def check_istarget(name, para_list):
+ """
+ name: full name of source para
+ para_list: partial name of target para
+ """
+ istarget=False
+ for para in para_list:
+ if para in name:
+ return True
+ return istarget
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def load_npz_from_dir(data_dir):
+ data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)]
+ data = np.concatenate(data, axis=0)
+ return data
+
+
+def load_npz_from_paths(data_paths):
+ data = [np.load(data_path)['arr_0'] for data_path in data_paths]
+ data = np.concatenate(data, axis=0)
+ return data
+
+
+def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None):
+ h, w = image.shape[:2]
+ if resize_short_edge is not None:
+ k = resize_short_edge / min(h, w)
+ else:
+ k = max_resolution / (h * w)
+ k = k**0.5
+ h = int(np.round(h * k / 64)) * 64
+ w = int(np.round(w * k / 64)) * 64
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
+ return image
+
+
+def setup_dist(args):
+ if dist.is_initialized():
+ return
+ torch.cuda.set_device(args.local_rank)
+ torch.distributed.init_process_group(
+ 'nccl',
+ init_method='env://'
+ )
+
+
+def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ outputs = []
+ for x in videos:
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = (x * 255).numpy().astype(np.uint8)
+ outputs.append(x)
+
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ imageio.mimsave(path, outputs, fps=fps)
+
+def save_images_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6):
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ os.makedirs(path, exist_ok=True)
+ for time_idx, x in enumerate(videos):
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = (x * 255).numpy().astype(np.uint8)
+ image = Image.fromarray(x)
+ image.save(os.path.join(path, f"{time_idx:04d}.png"))
+
+def save_image_with_mask(image: torch.Tensor, masks: torch.Tensor, path: str, rescale=False, alpha=0.6):
+ # image: [C, H, W], mask: [N, H, W]
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ image = rearrange(image, "c h w -> h w c")
+ if rescale:
+ image = (image + 1.0) / 2.0 # -1,1 -> 0,1
+ image = (image * 255).numpy().astype(np.uint8)
+ final_image = Image.fromarray(image).convert("RGBA")
+ cmap = plt.get_cmap("tab20c")
+ masks = masks.cpu().numpy().astype(np.float32)
+ for i, img in enumerate(masks):
+ mask_color = np.array([*cmap(i * 4 + 2)[:3], alpha])
+ mask = img[:,:,None] * mask_color[None,None,:] * 255
+ mask = mask.astype(np.uint8)
+ mask = Image.fromarray(mask).convert("RGBA")
+ final_image = Image.alpha_composite(final_image, mask)
+ final_image.save(path)
+
+def save_videos_with_heatmap(videos: torch.Tensor, trajectory: torch.Tensor, path: str, n_rows=6, fps=8):
+ # use Image RGBA and alpha_composite to combine video and trajectory
+ # use imageio to save video
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ trajectory = rearrange(trajectory, "b c t h w -> t b c h w")
+ outputs = []
+ for x, y in zip(videos, trajectory):
+ x = torchvision.utils.make_grid(x, nrow=6)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ x = (x * 255).numpy().astype(np.uint8)
+ y = torchvision.utils.make_grid(y, nrow=6)
+ y = y.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ y = torch.cat([y, torch.mean(y, dim=-1, keepdim=True)], dim=-1)
+ y = (y * 255).numpy().astype(np.uint8)
+ x = Image.fromarray(x).convert("RGBA")
+ y = Image.fromarray(y)
+ x = Image.alpha_composite(x, y)
+ outputs.append(x)
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ imageio.mimsave(path, outputs, fps=fps)
+
+def save_videos_with_traj(videos: torch.Tensor, trajectory: torch.Tensor, path: str, rescale=False, fps=8, line_width=3, circle_radius=5):
+ # videos: [C, F, H, W]
+ # trajectory: [F, N, 2]
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ videos = rearrange(videos, "c f h w -> f h w c")
+ if rescale:
+ videos = (videos + 1) / 2
+ videos = (videos * 255).numpy().astype(np.uint8)
+ outputs = []
+ for frame_idx, img in enumerate(videos):
+ # img: [H, W, C], traj: [N, 2]
+ # draw trajectory use cv2.line
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ for traj_idx in range(trajectory.shape[1]):
+ for history_idx in range(frame_idx):
+ cv2.line(img, tuple(trajectory[history_idx, traj_idx].int().tolist()), tuple(trajectory[history_idx+1, traj_idx].int().tolist()), (0, 0, 255), line_width)
+ cv2.circle(img, tuple(trajectory[frame_idx, traj_idx].int().tolist()), circle_radius, (100, 230, 160), -1)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ outputs.append(img)
+ imageio.mimsave(path, outputs, fps=fps)
+
+def save_layer_prompts_video(videos, layer_masks, motion_scores, flow_maps, path, alpha=0.6, fps=8, flow_step=10, flow_scale=1.0):
+ # videos: [F, C, H, W]
+ # layer_masks: [N, F, H, W]
+ # motion_scores: [N, ]
+ # flow_maps: [F, 2, H, W]
+ frame_length = videos.shape[0]
+ h, w = videos.shape[-2:]
+ n_keyframes = layer_masks.shape[1]
+ if n_keyframes == 1:
+ keyframe_indices = [0]
+ elif n_keyframes == 2:
+ keyframe_indices = [0, frame_length - 1]
+ else:
+ keyframe_indices = list(range(n_keyframes))
+ videos = rearrange(videos, "t c h w -> t h w c")
+ videos = ((videos + 1) / 2 * 255).clamp(0, 255).numpy().astype(np.uint8)
+ layer_masks = layer_masks.numpy()
+ flow_maps = flow_maps.float().numpy()
+ frame_list = []
+ cmap = plt.get_cmap("tab10")
+ for frame_idx in range(frame_length):
+ output_frame = Image.new("RGBA", (w * 2, h * 2))
+ frame = Image.fromarray(videos[frame_idx]).convert("RGBA")
+ frame_mask = None
+ output_frame.paste(frame, (0, 0))
+ for layer_idx, layer_mask in enumerate(layer_masks):
+ if frame_idx in keyframe_indices:
+ layer_color = (np.array([*cmap(layer_idx)[:3], alpha]) * 255).astype(np.uint8)
+ if frame_idx == frame_length - 1:
+ mask_with_color = Image.fromarray(layer_mask[-1, :, :, np.newaxis] * layer_color[np.newaxis, np.newaxis, :])
+ else:
+ mask_with_color = Image.fromarray(layer_mask[frame_idx, :, :, np.newaxis] * layer_color[np.newaxis, np.newaxis, :])
+ else:
+ mask_with_color = Image.fromarray(np.zeros((h, w, 4), dtype=np.uint8))
+ frame = Image.alpha_composite(frame, mask_with_color)
+ frame_mask = Image.alpha_composite(frame_mask, mask_with_color) if frame_mask is not None else mask_with_color
+ output_frame.paste(frame, (w, 0))
+ output_frame.paste(frame_mask, (0, h))
+ flow_x = flow_maps[frame_idx, 0] * flow_scale
+ flow_y = flow_maps[frame_idx, 1] * flow_scale
+ x, y = np.arange(0, w, step=flow_step), np.arange(0, h, step=flow_step)
+ X, Y = np.meshgrid(x, y)
+ U, V = flow_x[::flow_step, ::flow_step], flow_y[::flow_step, ::flow_step]
+ plt.figure()
+ plt.gca().set_facecolor('white')
+ plt.quiver(X, Y, U, V, color='black', angles='xy', scale_units='xy', scale=1)
+ plt.xlim(0, w)
+ plt.ylim(h, 0)
+ plt.gca().set_xticks([])
+ plt.gca().set_yticks([])
+ buf = io.BytesIO()
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
+ buf.seek(0)
+ flow = Image.open(buf).convert("RGBA")
+ output_frame.paste(flow, (w, h))
+ plt.close()
+ frame_list.append(output_frame)
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ imageio.mimsave(path, frame_list, fps=fps)
+
+def flow_uv_to_colors(u, v, rad, convert_to_bgr=False):
+ """
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
+
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+
+ Args:
+ u (torch.tensor): Input horizontal flow of shape [N,H,W]
+ v (torch.tensor): Input vertical flow of shape [N,H,W]
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
+
+ Returns:
+ torch.tensor: Flow visualization image of shape [N,3,H,W]
+ """
+ flow_image = torch.zeros((u.shape[0], 3, u.shape[1], u.shape[2]), dtype=torch.uint8, device=u.device)
+ colorwheel = COLORWHEEL.to(u.device)
+ ncols = colorwheel.shape[0]
+ a = torch.arctan2(-v, -u) / np.pi
+ fk = (a + 1) / 2 * (ncols - 1)
+ k0 = torch.floor(fk).int()
+ k1 = k0 + 1
+ k1[k1 == ncols] = 0
+ f = fk - k0
+ for i in range(colorwheel.shape[1]):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0] / 255.0
+ col1 = tmp[k1] / 255.0
+ col = (1 - f) * col0 + f * col1
+ idx = rad <= 1
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
+ col[~idx] = col[~idx] * 0.75 # out of range
+ # Note the 2-i => BGR instead of RGB
+ ch_idx = 2 - i if convert_to_bgr else i
+ flow_image[:, ch_idx, :, :] = torch.floor(255 * col)
+ return flow_image
+
+def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
+ """
+ Adapted from Tora: https://github.com/alibaba/Tora/blob/14db1b0a074284a6c265564eef07f5320911dc00/sat/utils/flow_utils.py#L120
+ Expects a two dimensional flow image of shape.
+
+ Args:
+ flow_uv (torch.Tensor): Flow UV image of shape [N,2,H,W]
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
+
+ Returns:
+ torch.Tensor: Flow visualization image of shape [N,3,H,W]
+ """
+ if clip_flow is not None:
+ flow_uv = torch.clamp(flow_uv, 0, clip_flow)
+ u = flow_uv[:, 0]
+ v = flow_uv[:, 1]
+ rad = torch.sqrt(u**2 + v**2)
+ rad_max = torch.max(rad)
+ epsilon = 1e-5
+ u = u / (rad_max + epsilon)
+ v = v / (rad_max + epsilon)
+ flow_image = flow_uv_to_colors(u, v, rad, convert_to_bgr)
+ return flow_image
+
+def generate_gaussian_template(imgSize=200):
+ """ Adapted from DragAnything: https://github.com/showlab/DragAnything/blob/79355363218a7eb9b3437a31b8604b6d436d9337/dataset/dataset.py#L110"""
+ circle_img = np.zeros((imgSize, imgSize), np.float32)
+ circle_mask = cv2.circle(circle_img, (imgSize//2, imgSize//2), imgSize//2, 1, -1)
+
+ isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32)
+
+ # Guass Map
+ for i in range(imgSize):
+ for j in range(imgSize):
+ isotropicGrayscaleImage[i, j] = 1 / 2 / np.pi / (40 ** 2) * np.exp(
+ -1 / 2 * ((i - imgSize / 2) ** 2 / (40 ** 2) + (j - imgSize / 2) ** 2 / (40 ** 2)))
+
+ isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask
+ isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32)
+ isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)*255).astype(np.uint8)
+
+ # isotropicGrayscaleImage = cv2.resize(isotropicGrayscaleImage, (40, 40))
+ return isotropicGrayscaleImage
+
+def generate_gaussian_heatmap(tracks, width, height, layer_index, layer_capacity, side=20, offset=True):
+ heatmap_template = generate_gaussian_template()
+ num_frames, num_points = tracks.shape[:2]
+ if isinstance(tracks, torch.Tensor):
+ tracks = tracks.cpu().numpy()
+ if offset:
+ offset_kernel = cv2.resize(heatmap_template / 255, (2 * side + 1, 2 * side + 1))
+ offset_kernel /= np.sum(offset_kernel)
+ offset_kernel /= offset_kernel[side, side]
+ heatmaps = []
+ for frame_idx in range(num_frames):
+ if offset:
+ layer_imgs = np.zeros((layer_capacity, height, width, 3), dtype=np.float32)
+ else:
+ layer_imgs = np.zeros((layer_capacity, height, width, 1), dtype=np.float32)
+ layer_heatmaps = []
+ for point_idx in range(num_points):
+ x, y = tracks[frame_idx, point_idx]
+ layer_id = layer_index[point_idx]
+ if x < 0 or y < 0 or x >= width or y >= height:
+ continue
+ x1 = int(max(x - side, 0))
+ x2 = int(min(x + side, width - 1))
+ y1 = int(max(y - side, 0))
+ y2 = int(min(y + side, height - 1))
+ if (x2 - x1) < 1 or (y2 - y1) < 1:
+ continue
+ temp_map = cv2.resize(heatmap_template, (x2-x1, y2-y1))
+ layer_imgs[layer_id, y1:y2,x1:x2, 0] = np.maximum(layer_imgs[layer_id, y1:y2,x1:x2, 0], temp_map)
+ if offset:
+ if frame_idx < num_frames - 1:
+ next_x, next_y = tracks[frame_idx + 1, point_idx]
+ else:
+ next_x, next_y = x, y
+ layer_imgs[layer_id, int(y), int(x), 1] = next_x - x
+ layer_imgs[layer_id, int(y), int(x), 2] = next_y - y
+ for img in layer_imgs:
+ if offset:
+ img[:, :, 1:] = cv2.filter2D(img[:, :, 1:], -1, offset_kernel)
+ else:
+ img = cv2.cvtColor(img[:, :, 0].astype(np.uint8), cv2.COLOR_GRAY2RGB)
+ layer_heatmaps.append(img)
+ heatmaps.append(np.stack(layer_heatmaps, axis=0))
+ heatmaps = np.stack(heatmaps, axis=0)
+ return torch.from_numpy(heatmaps).permute(0, 1, 4, 2, 3).contiguous().float() # [F, N_layer, C, H, W]
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..dd0d213724be00a3aaf736470a23437b13a7d257
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,21 @@
+accelerate==1.5.2
+decord==0.6.0
+diffusers==0.30.2
+einops==0.8.1
+gradio==5.23.1
+huggingface_hub==0.29.3
+imageio==2.27.0
+imageio-ffmpeg==0.6.0
+matplotlib==3.10.1
+numpy==2.2.4
+omegaconf==2.3.0
+open_clip_torch==2.22.0
+opencv_python==4.11.0.86
+packaging==24.2
+Pillow==11.1.0
+scipy==1.15.2
+spaces==0.34.0
+torch==2.4.0
+torchvision==0.19.0
+transformers==4.46.2
+xformers==0.0.27.post2