Spaces:
Runtime error
Runtime error
Merge pull request #10 from LightricksResearch/pre-commit
Browse files- .github/workflows/pylint.yml +27 -0
- .gitignore +1 -1
- .pre-commit-config.yaml +16 -0
- scripts/to_safetensors.py +64 -30
- setup.py +11 -7
- xora/__init__.py +0 -1
- xora/examples/image_to_video.py +94 -44
- xora/examples/text_to_video.py +29 -13
- xora/models/autoencoders/causal_conv3d.py +9 -3
- xora/models/autoencoders/causal_video_autoencoder.py +133 -33
- xora/models/autoencoders/conv_nd_factory.py +6 -2
- xora/models/autoencoders/dual_conv3d.py +36 -6
- xora/models/autoencoders/vae.py +74 -24
- xora/models/autoencoders/vae_encode.py +62 -17
- xora/models/autoencoders/video_autoencoder.py +170 -46
- xora/models/transformers/attention.py +174 -53
- xora/models/transformers/embeddings.py +6 -2
- xora/models/transformers/symmetric_patchifier.py +19 -4
- xora/models/transformers/transformer3d.py +86 -23
- xora/pipelines/pipeline_video_pixart_alpha.py +205 -63
- xora/schedulers/rf.py +43 -13
- xora/utils/conditioning_method.py +2 -1
- xora/utils/torch_utils.py +5 -1
.github/workflows/pylint.yml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Ruff
|
| 2 |
+
|
| 3 |
+
on: [push]
|
| 4 |
+
|
| 5 |
+
jobs:
|
| 6 |
+
build:
|
| 7 |
+
runs-on: ubuntu-latest
|
| 8 |
+
strategy:
|
| 9 |
+
matrix:
|
| 10 |
+
python-version: ["3.10"]
|
| 11 |
+
steps:
|
| 12 |
+
- name: Checkout repository and submodules
|
| 13 |
+
uses: actions/checkout@v3
|
| 14 |
+
- name: Set up Python ${{ matrix.python-version }}
|
| 15 |
+
uses: actions/setup-python@v3
|
| 16 |
+
with:
|
| 17 |
+
python-version: ${{ matrix.python-version }}
|
| 18 |
+
- name: Install dependencies
|
| 19 |
+
run: |
|
| 20 |
+
python -m pip install --upgrade pip
|
| 21 |
+
pip install ruff==0.2.2 black==24.2.0
|
| 22 |
+
- name: Analyzing the code with ruff
|
| 23 |
+
run: |
|
| 24 |
+
ruff $(git ls-files '*.py')
|
| 25 |
+
- name: Verify that no Black changes are required
|
| 26 |
+
run: |
|
| 27 |
+
black --check $(git ls-files '*.py')
|
.gitignore
CHANGED
|
@@ -159,4 +159,4 @@ cython_debug/
|
|
| 159 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
-
|
|
|
|
| 159 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
+
.idea/
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 3 |
+
# Ruff version.
|
| 4 |
+
rev: v0.2.2
|
| 5 |
+
hooks:
|
| 6 |
+
# Run the linter.
|
| 7 |
+
- id: ruff
|
| 8 |
+
args: [--fix] # Automatically fix issues if possible.
|
| 9 |
+
types: [python] # Ensure it only runs on .py files.
|
| 10 |
+
|
| 11 |
+
- repo: https://github.com/psf/black
|
| 12 |
+
rev: 24.2.0 # Specify the version of Black you want
|
| 13 |
+
hooks:
|
| 14 |
+
- id: black
|
| 15 |
+
name: Black code formatter
|
| 16 |
+
language_version: python3 # Use the Python version you're targeting (e.g., 3.10)
|
scripts/to_safetensors.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import argparse
|
| 2 |
from pathlib import Path
|
| 3 |
-
from typing import
|
| 4 |
import safetensors.torch
|
| 5 |
import torch
|
| 6 |
import json
|
|
@@ -8,12 +8,14 @@ import shutil
|
|
| 8 |
|
| 9 |
|
| 10 |
def load_text_encoder(index_path: Path) -> Dict:
|
| 11 |
-
with open(index_path,
|
| 12 |
index: Dict = json.load(f)
|
| 13 |
|
| 14 |
loaded_tensors = {}
|
| 15 |
for part_file in set(index.get("weight_map", {}).values()):
|
| 16 |
-
tensors = safetensors.torch.load_file(
|
|
|
|
|
|
|
| 17 |
for tensor_name in tensors:
|
| 18 |
loaded_tensors[tensor_name] = tensors[tensor_name]
|
| 19 |
|
|
@@ -30,23 +32,30 @@ def convert_vae(vae_path: Path, add_prefix=True) -> Dict:
|
|
| 30 |
state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True)
|
| 31 |
stats_path = vae_path / "per_channel_statistics.json"
|
| 32 |
if stats_path.exists():
|
| 33 |
-
with open(stats_path,
|
| 34 |
data = json.load(f)
|
| 35 |
transposed_data = list(zip(*data["data"]))
|
| 36 |
data_dict = {
|
| 37 |
-
f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor(
|
|
|
|
|
|
|
| 38 |
for col, vals in zip(data["columns"], transposed_data)
|
| 39 |
}
|
| 40 |
else:
|
| 41 |
data_dict = {}
|
| 42 |
|
| 43 |
-
result = {
|
|
|
|
|
|
|
| 44 |
result.update(data_dict)
|
| 45 |
return result
|
| 46 |
|
| 47 |
|
| 48 |
def convert_encoder(encoder: Dict) -> Dict:
|
| 49 |
-
return {
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
def save_config(config_src: str, config_dst: str):
|
|
@@ -60,50 +69,75 @@ def load_vae_config(vae_path: Path) -> str:
|
|
| 60 |
return str(config_path)
|
| 61 |
|
| 62 |
|
| 63 |
-
def main(
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
# Load VAE from directory and config
|
| 68 |
-
vae = convert_vae(Path(vae_path), add_prefix=(mode ==
|
| 69 |
vae_config_path = load_vae_config(Path(vae_path))
|
| 70 |
|
| 71 |
-
if mode ==
|
| 72 |
result = {**unet, **vae}
|
| 73 |
safetensors.torch.save_file(result, out_path)
|
| 74 |
-
elif mode ==
|
| 75 |
# Create directories for unet, vae, and scheduler
|
| 76 |
-
unet_dir = Path(out_path) /
|
| 77 |
-
vae_dir = Path(out_path) /
|
| 78 |
-
scheduler_dir = Path(out_path) /
|
| 79 |
|
| 80 |
unet_dir.mkdir(parents=True, exist_ok=True)
|
| 81 |
vae_dir.mkdir(parents=True, exist_ok=True)
|
| 82 |
scheduler_dir.mkdir(parents=True, exist_ok=True)
|
| 83 |
|
| 84 |
# Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
|
| 85 |
-
safetensors.torch.save_file(
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
# Save config files for unet, vae, and scheduler
|
| 89 |
if unet_config_path:
|
| 90 |
-
save_config(unet_config_path, unet_dir /
|
| 91 |
if vae_config_path:
|
| 92 |
-
save_config(vae_config_path, vae_dir /
|
| 93 |
if scheduler_config_path:
|
| 94 |
-
save_config(scheduler_config_path, scheduler_dir /
|
| 95 |
|
| 96 |
|
| 97 |
-
if __name__ ==
|
| 98 |
parser = argparse.ArgumentParser()
|
| 99 |
-
parser.add_argument(
|
| 100 |
-
parser.add_argument(
|
| 101 |
-
parser.add_argument(
|
| 102 |
-
parser.add_argument(
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
args = parser.parse_args()
|
| 109 |
main(**args.__dict__)
|
|
|
|
| 1 |
import argparse
|
| 2 |
from pathlib import Path
|
| 3 |
+
from typing import Dict
|
| 4 |
import safetensors.torch
|
| 5 |
import torch
|
| 6 |
import json
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
def load_text_encoder(index_path: Path) -> Dict:
|
| 11 |
+
with open(index_path, "r") as f:
|
| 12 |
index: Dict = json.load(f)
|
| 13 |
|
| 14 |
loaded_tensors = {}
|
| 15 |
for part_file in set(index.get("weight_map", {}).values()):
|
| 16 |
+
tensors = safetensors.torch.load_file(
|
| 17 |
+
index_path.parent / part_file, device="cpu"
|
| 18 |
+
)
|
| 19 |
for tensor_name in tensors:
|
| 20 |
loaded_tensors[tensor_name] = tensors[tensor_name]
|
| 21 |
|
|
|
|
| 32 |
state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True)
|
| 33 |
stats_path = vae_path / "per_channel_statistics.json"
|
| 34 |
if stats_path.exists():
|
| 35 |
+
with open(stats_path, "r") as f:
|
| 36 |
data = json.load(f)
|
| 37 |
transposed_data = list(zip(*data["data"]))
|
| 38 |
data_dict = {
|
| 39 |
+
f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor(
|
| 40 |
+
vals
|
| 41 |
+
)
|
| 42 |
for col, vals in zip(data["columns"], transposed_data)
|
| 43 |
}
|
| 44 |
else:
|
| 45 |
data_dict = {}
|
| 46 |
|
| 47 |
+
result = {
|
| 48 |
+
("vae." if add_prefix else "") + key: value for key, value in state_dict.items()
|
| 49 |
+
}
|
| 50 |
result.update(data_dict)
|
| 51 |
return result
|
| 52 |
|
| 53 |
|
| 54 |
def convert_encoder(encoder: Dict) -> Dict:
|
| 55 |
+
return {
|
| 56 |
+
"text_encoders.t5xxl.transformer." + key: value
|
| 57 |
+
for key, value in encoder.items()
|
| 58 |
+
}
|
| 59 |
|
| 60 |
|
| 61 |
def save_config(config_src: str, config_dst: str):
|
|
|
|
| 69 |
return str(config_path)
|
| 70 |
|
| 71 |
|
| 72 |
+
def main(
|
| 73 |
+
unet_path: str,
|
| 74 |
+
vae_path: str,
|
| 75 |
+
out_path: str,
|
| 76 |
+
mode: str,
|
| 77 |
+
unet_config_path: str = None,
|
| 78 |
+
scheduler_config_path: str = None,
|
| 79 |
+
) -> None:
|
| 80 |
+
unet = convert_unet(
|
| 81 |
+
torch.load(unet_path, weights_only=True), add_prefix=(mode == "single")
|
| 82 |
+
)
|
| 83 |
|
| 84 |
# Load VAE from directory and config
|
| 85 |
+
vae = convert_vae(Path(vae_path), add_prefix=(mode == "single"))
|
| 86 |
vae_config_path = load_vae_config(Path(vae_path))
|
| 87 |
|
| 88 |
+
if mode == "single":
|
| 89 |
result = {**unet, **vae}
|
| 90 |
safetensors.torch.save_file(result, out_path)
|
| 91 |
+
elif mode == "separate":
|
| 92 |
# Create directories for unet, vae, and scheduler
|
| 93 |
+
unet_dir = Path(out_path) / "unet"
|
| 94 |
+
vae_dir = Path(out_path) / "vae"
|
| 95 |
+
scheduler_dir = Path(out_path) / "scheduler"
|
| 96 |
|
| 97 |
unet_dir.mkdir(parents=True, exist_ok=True)
|
| 98 |
vae_dir.mkdir(parents=True, exist_ok=True)
|
| 99 |
scheduler_dir.mkdir(parents=True, exist_ok=True)
|
| 100 |
|
| 101 |
# Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
|
| 102 |
+
safetensors.torch.save_file(
|
| 103 |
+
unet, unet_dir / "diffusion_pytorch_model.safetensors"
|
| 104 |
+
)
|
| 105 |
+
safetensors.torch.save_file(
|
| 106 |
+
vae, vae_dir / "diffusion_pytorch_model.safetensors"
|
| 107 |
+
)
|
| 108 |
|
| 109 |
# Save config files for unet, vae, and scheduler
|
| 110 |
if unet_config_path:
|
| 111 |
+
save_config(unet_config_path, unet_dir / "config.json")
|
| 112 |
if vae_config_path:
|
| 113 |
+
save_config(vae_config_path, vae_dir / "config.json")
|
| 114 |
if scheduler_config_path:
|
| 115 |
+
save_config(scheduler_config_path, scheduler_dir / "scheduler_config.json")
|
| 116 |
|
| 117 |
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
parser = argparse.ArgumentParser()
|
| 120 |
+
parser.add_argument("--unet_path", "-u", type=str, default="unet/ema-002.pt")
|
| 121 |
+
parser.add_argument("--vae_path", "-v", type=str, default="vae/")
|
| 122 |
+
parser.add_argument("--out_path", "-o", type=str, default="xora.safetensors")
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--mode",
|
| 125 |
+
"-m",
|
| 126 |
+
type=str,
|
| 127 |
+
choices=["single", "separate"],
|
| 128 |
+
default="single",
|
| 129 |
+
help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.",
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--unet_config_path",
|
| 133 |
+
type=str,
|
| 134 |
+
help="Path to the UNet config file (for separate mode)",
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--scheduler_config_path",
|
| 138 |
+
type=str,
|
| 139 |
+
help="Path to the Scheduler config file (for separate mode)",
|
| 140 |
+
)
|
| 141 |
|
| 142 |
args = parser.parse_args()
|
| 143 |
main(**args.__dict__)
|
setup.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
| 1 |
from setuptools import setup, find_packages
|
|
|
|
|
|
|
| 2 |
def parse_requirements(filename):
|
| 3 |
"""Load requirements from a pip requirements file."""
|
| 4 |
-
with open(filename,
|
| 5 |
return file.read().splitlines()
|
| 6 |
|
| 7 |
|
|
@@ -13,11 +15,13 @@ setup(
|
|
| 13 |
author_email="[email protected]", # Your email
|
| 14 |
url="https://github.com/LightricksResearch/xora-core", # URL for the project (GitHub, etc.)
|
| 15 |
packages=find_packages(), # Automatically find all packages inside `xora`
|
| 16 |
-
install_requires=parse_requirements(
|
|
|
|
|
|
|
| 17 |
classifiers=[
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
],
|
| 22 |
-
python_requires=
|
| 23 |
-
)
|
|
|
|
| 1 |
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
|
| 4 |
def parse_requirements(filename):
|
| 5 |
"""Load requirements from a pip requirements file."""
|
| 6 |
+
with open(filename, "r") as file:
|
| 7 |
return file.read().splitlines()
|
| 8 |
|
| 9 |
|
|
|
|
| 15 |
author_email="[email protected]", # Your email
|
| 16 |
url="https://github.com/LightricksResearch/xora-core", # URL for the project (GitHub, etc.)
|
| 17 |
packages=find_packages(), # Automatically find all packages inside `xora`
|
| 18 |
+
install_requires=parse_requirements(
|
| 19 |
+
"requirements.txt"
|
| 20 |
+
), # Install dependencies from requirements.txt
|
| 21 |
classifiers=[
|
| 22 |
+
"Programming Language :: Python :: 3",
|
| 23 |
+
"License :: OSI Approved :: MIT License",
|
| 24 |
+
"Operating System :: OS Independent",
|
| 25 |
],
|
| 26 |
+
python_requires=">=3.10", # Specify Python version compatibility
|
| 27 |
+
)
|
xora/__init__.py
CHANGED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
from .pipelines import *
|
|
|
|
|
|
xora/examples/image_to_video.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import time
|
| 2 |
import torch
|
| 3 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 4 |
from xora.models.transformers.transformer3d import Transformer3DModel
|
|
@@ -15,19 +14,20 @@ import os
|
|
| 15 |
import numpy as np
|
| 16 |
import cv2
|
| 17 |
from PIL import Image
|
| 18 |
-
from tqdm import tqdm
|
| 19 |
import random
|
| 20 |
|
|
|
|
| 21 |
def load_vae(vae_dir):
|
| 22 |
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
| 23 |
vae_config_path = vae_dir / "config.json"
|
| 24 |
-
with open(vae_config_path,
|
| 25 |
vae_config = json.load(f)
|
| 26 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
| 27 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
| 28 |
vae.load_state_dict(vae_state_dict)
|
| 29 |
return vae.cuda().to(torch.bfloat16)
|
| 30 |
|
|
|
|
| 31 |
def load_unet(unet_dir):
|
| 32 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
| 33 |
unet_config_path = unet_dir / "config.json"
|
|
@@ -37,11 +37,13 @@ def load_unet(unet_dir):
|
|
| 37 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
| 38 |
return transformer.cuda()
|
| 39 |
|
|
|
|
| 40 |
def load_scheduler(scheduler_dir):
|
| 41 |
scheduler_config_path = scheduler_dir / "scheduler_config.json"
|
| 42 |
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
|
| 43 |
return RectifiedFlowScheduler.from_config(scheduler_config)
|
| 44 |
|
|
|
|
| 45 |
def center_crop_and_resize(frame, target_height, target_width):
|
| 46 |
h, w, _ = frame.shape
|
| 47 |
aspect_ratio_target = target_width / target_height
|
|
@@ -49,14 +51,15 @@ def center_crop_and_resize(frame, target_height, target_width):
|
|
| 49 |
if aspect_ratio_frame > aspect_ratio_target:
|
| 50 |
new_width = int(h * aspect_ratio_target)
|
| 51 |
x_start = (w - new_width) // 2
|
| 52 |
-
frame_cropped = frame[:, x_start:x_start + new_width]
|
| 53 |
else:
|
| 54 |
new_height = int(w / aspect_ratio_target)
|
| 55 |
y_start = (h - new_height) // 2
|
| 56 |
-
frame_cropped = frame[y_start:y_start + new_height, :]
|
| 57 |
frame_resized = cv2.resize(frame_cropped, (target_width, target_height))
|
| 58 |
return frame_resized
|
| 59 |
|
|
|
|
| 60 |
def load_video_to_tensor_with_resize(video_path, target_height=512, target_width=768):
|
| 61 |
cap = cv2.VideoCapture(video_path)
|
| 62 |
frames = []
|
|
@@ -72,6 +75,7 @@ def load_video_to_tensor_with_resize(video_path, target_height=512, target_width
|
|
| 72 |
video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
|
| 73 |
return video_tensor
|
| 74 |
|
|
|
|
| 75 |
def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
|
| 76 |
image = Image.open(image_path).convert("RGB")
|
| 77 |
image_np = np.array(image)
|
|
@@ -81,51 +85,90 @@ def load_image_to_tensor_with_resize(image_path, target_height=512, target_width
|
|
| 81 |
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
|
| 82 |
return frame_tensor.unsqueeze(0).unsqueeze(2)
|
| 83 |
|
|
|
|
| 84 |
def main():
|
| 85 |
-
parser = argparse.ArgumentParser(
|
|
|
|
|
|
|
| 86 |
|
| 87 |
# Directories
|
| 88 |
-
parser.add_argument(
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
# Pipeline parameters
|
| 97 |
-
parser.add_argument(
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
parser.add_argument(
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
# Prompts
|
| 106 |
-
parser.add_argument(
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
args = parser.parse_args()
|
| 114 |
|
| 115 |
# Paths for the separate mode directories
|
| 116 |
ckpt_dir = Path(args.ckpt_dir)
|
| 117 |
-
unet_dir = ckpt_dir /
|
| 118 |
-
vae_dir = ckpt_dir /
|
| 119 |
-
scheduler_dir = ckpt_dir /
|
| 120 |
|
| 121 |
# Load models
|
| 122 |
vae = load_vae(vae_dir)
|
| 123 |
unet = load_unet(unet_dir)
|
| 124 |
scheduler = load_scheduler(scheduler_dir)
|
| 125 |
patchifier = SymmetricPatchifier(patch_size=1)
|
| 126 |
-
text_encoder = T5EncoderModel.from_pretrained(
|
| 127 |
-
"
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
# Use submodels for the pipeline
|
| 131 |
submodel_dict = {
|
|
@@ -141,22 +184,25 @@ def main():
|
|
| 141 |
|
| 142 |
# Load media (video or image)
|
| 143 |
if args.video_path:
|
| 144 |
-
media_items = load_video_to_tensor_with_resize(
|
|
|
|
|
|
|
| 145 |
elif args.image_path:
|
| 146 |
-
media_items = load_image_to_tensor_with_resize(
|
|
|
|
|
|
|
| 147 |
else:
|
| 148 |
raise ValueError("Either --video_path or --image_path must be provided.")
|
| 149 |
|
| 150 |
# Prepare input for the pipeline
|
| 151 |
sample = {
|
| 152 |
"prompt": args.prompt,
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
}
|
| 158 |
|
| 159 |
-
start_time = time.time()
|
| 160 |
random.seed(args.seed)
|
| 161 |
np.random.seed(args.seed)
|
| 162 |
torch.manual_seed(args.seed)
|
|
@@ -177,16 +223,18 @@ def main():
|
|
| 177 |
**sample,
|
| 178 |
is_video=True,
|
| 179 |
vae_per_channel_normalize=True,
|
| 180 |
-
conditioning_method=ConditioningMethod.FIRST_FRAME
|
| 181 |
).images
|
|
|
|
| 182 |
# Save output video
|
| 183 |
-
def get_unique_filename(base, ext, dir=
|
| 184 |
for i in range(index_range):
|
| 185 |
filename = os.path.join(dir, f"{base}_{i}{ext}")
|
| 186 |
if not os.path.exists(filename):
|
| 187 |
return filename
|
| 188 |
-
raise FileExistsError(
|
| 189 |
-
|
|
|
|
| 190 |
|
| 191 |
for i in range(images.shape[0]):
|
| 192 |
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
|
|
@@ -195,7 +243,9 @@ def main():
|
|
| 195 |
height, width = video_np.shape[1:3]
|
| 196 |
output_filename = get_unique_filename(f"video_output_{i}", ".mp4", ".")
|
| 197 |
|
| 198 |
-
out = cv2.VideoWriter(
|
|
|
|
|
|
|
| 199 |
|
| 200 |
for frame in video_np[..., ::-1]:
|
| 201 |
out.write(frame)
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 3 |
from xora.models.transformers.transformer3d import Transformer3DModel
|
|
|
|
| 14 |
import numpy as np
|
| 15 |
import cv2
|
| 16 |
from PIL import Image
|
|
|
|
| 17 |
import random
|
| 18 |
|
| 19 |
+
|
| 20 |
def load_vae(vae_dir):
|
| 21 |
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
| 22 |
vae_config_path = vae_dir / "config.json"
|
| 23 |
+
with open(vae_config_path, "r") as f:
|
| 24 |
vae_config = json.load(f)
|
| 25 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
| 26 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
| 27 |
vae.load_state_dict(vae_state_dict)
|
| 28 |
return vae.cuda().to(torch.bfloat16)
|
| 29 |
|
| 30 |
+
|
| 31 |
def load_unet(unet_dir):
|
| 32 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
| 33 |
unet_config_path = unet_dir / "config.json"
|
|
|
|
| 37 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
| 38 |
return transformer.cuda()
|
| 39 |
|
| 40 |
+
|
| 41 |
def load_scheduler(scheduler_dir):
|
| 42 |
scheduler_config_path = scheduler_dir / "scheduler_config.json"
|
| 43 |
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
|
| 44 |
return RectifiedFlowScheduler.from_config(scheduler_config)
|
| 45 |
|
| 46 |
+
|
| 47 |
def center_crop_and_resize(frame, target_height, target_width):
|
| 48 |
h, w, _ = frame.shape
|
| 49 |
aspect_ratio_target = target_width / target_height
|
|
|
|
| 51 |
if aspect_ratio_frame > aspect_ratio_target:
|
| 52 |
new_width = int(h * aspect_ratio_target)
|
| 53 |
x_start = (w - new_width) // 2
|
| 54 |
+
frame_cropped = frame[:, x_start : x_start + new_width]
|
| 55 |
else:
|
| 56 |
new_height = int(w / aspect_ratio_target)
|
| 57 |
y_start = (h - new_height) // 2
|
| 58 |
+
frame_cropped = frame[y_start : y_start + new_height, :]
|
| 59 |
frame_resized = cv2.resize(frame_cropped, (target_width, target_height))
|
| 60 |
return frame_resized
|
| 61 |
|
| 62 |
+
|
| 63 |
def load_video_to_tensor_with_resize(video_path, target_height=512, target_width=768):
|
| 64 |
cap = cv2.VideoCapture(video_path)
|
| 65 |
frames = []
|
|
|
|
| 75 |
video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
|
| 76 |
return video_tensor
|
| 77 |
|
| 78 |
+
|
| 79 |
def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
|
| 80 |
image = Image.open(image_path).convert("RGB")
|
| 81 |
image_np = np.array(image)
|
|
|
|
| 85 |
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
|
| 86 |
return frame_tensor.unsqueeze(0).unsqueeze(2)
|
| 87 |
|
| 88 |
+
|
| 89 |
def main():
|
| 90 |
+
parser = argparse.ArgumentParser(
|
| 91 |
+
description="Load models from separate directories and run the pipeline."
|
| 92 |
+
)
|
| 93 |
|
| 94 |
# Directories
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--ckpt_dir",
|
| 97 |
+
type=str,
|
| 98 |
+
required=True,
|
| 99 |
+
help="Path to the directory containing unet, vae, and scheduler subdirectories",
|
| 100 |
+
)
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--video_path", type=str, help="Path to the input video file (first frame used)"
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument("--image_path", type=str, help="Path to the input image file")
|
| 105 |
+
parser.add_argument("--seed", type=int, default="171198")
|
| 106 |
|
| 107 |
# Pipeline parameters
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--num_inference_steps", type=int, default=40, help="Number of inference steps"
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--num_images_per_prompt",
|
| 113 |
+
type=int,
|
| 114 |
+
default=1,
|
| 115 |
+
help="Number of images per prompt",
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--guidance_scale",
|
| 119 |
+
type=float,
|
| 120 |
+
default=3,
|
| 121 |
+
help="Guidance scale for the pipeline",
|
| 122 |
+
)
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--height", type=int, default=512, help="Height of the output video frames"
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--width", type=int, default=768, help="Width of the output video frames"
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--num_frames",
|
| 131 |
+
type=int,
|
| 132 |
+
default=121,
|
| 133 |
+
help="Number of frames to generate in the output video",
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--frame_rate", type=int, default=25, help="Frame rate for the output video"
|
| 137 |
+
)
|
| 138 |
|
| 139 |
# Prompts
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--prompt",
|
| 142 |
+
type=str,
|
| 143 |
+
default='A man wearing a black leather jacket and blue jeans is riding a Harley Davidson motorcycle down a paved road. The man has short brown hair and is wearing a black helmet. The motorcycle is a dark red color with a large front fairing. The road is surrounded by green grass and trees. There is a gas station on the left side of the road with a red and white sign that says "Oil" and "Diner".',
|
| 144 |
+
help="Text prompt to guide generation",
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--negative_prompt",
|
| 148 |
+
type=str,
|
| 149 |
+
default="worst quality, inconsistent motion, blurry, jittery, distorted",
|
| 150 |
+
help="Negative prompt for undesired features",
|
| 151 |
+
)
|
| 152 |
|
| 153 |
args = parser.parse_args()
|
| 154 |
|
| 155 |
# Paths for the separate mode directories
|
| 156 |
ckpt_dir = Path(args.ckpt_dir)
|
| 157 |
+
unet_dir = ckpt_dir / "unet"
|
| 158 |
+
vae_dir = ckpt_dir / "vae"
|
| 159 |
+
scheduler_dir = ckpt_dir / "scheduler"
|
| 160 |
|
| 161 |
# Load models
|
| 162 |
vae = load_vae(vae_dir)
|
| 163 |
unet = load_unet(unet_dir)
|
| 164 |
scheduler = load_scheduler(scheduler_dir)
|
| 165 |
patchifier = SymmetricPatchifier(patch_size=1)
|
| 166 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
| 167 |
+
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
|
| 168 |
+
).to("cuda")
|
| 169 |
+
tokenizer = T5Tokenizer.from_pretrained(
|
| 170 |
+
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
|
| 171 |
+
)
|
| 172 |
|
| 173 |
# Use submodels for the pipeline
|
| 174 |
submodel_dict = {
|
|
|
|
| 184 |
|
| 185 |
# Load media (video or image)
|
| 186 |
if args.video_path:
|
| 187 |
+
media_items = load_video_to_tensor_with_resize(
|
| 188 |
+
args.video_path, args.height, args.width
|
| 189 |
+
).unsqueeze(0)
|
| 190 |
elif args.image_path:
|
| 191 |
+
media_items = load_image_to_tensor_with_resize(
|
| 192 |
+
args.image_path, args.height, args.width
|
| 193 |
+
)
|
| 194 |
else:
|
| 195 |
raise ValueError("Either --video_path or --image_path must be provided.")
|
| 196 |
|
| 197 |
# Prepare input for the pipeline
|
| 198 |
sample = {
|
| 199 |
"prompt": args.prompt,
|
| 200 |
+
"prompt_attention_mask": None,
|
| 201 |
+
"negative_prompt": args.negative_prompt,
|
| 202 |
+
"negative_prompt_attention_mask": None,
|
| 203 |
+
"media_items": media_items,
|
| 204 |
}
|
| 205 |
|
|
|
|
| 206 |
random.seed(args.seed)
|
| 207 |
np.random.seed(args.seed)
|
| 208 |
torch.manual_seed(args.seed)
|
|
|
|
| 223 |
**sample,
|
| 224 |
is_video=True,
|
| 225 |
vae_per_channel_normalize=True,
|
| 226 |
+
conditioning_method=ConditioningMethod.FIRST_FRAME,
|
| 227 |
).images
|
| 228 |
+
|
| 229 |
# Save output video
|
| 230 |
+
def get_unique_filename(base, ext, dir=".", index_range=1000):
|
| 231 |
for i in range(index_range):
|
| 232 |
filename = os.path.join(dir, f"{base}_{i}{ext}")
|
| 233 |
if not os.path.exists(filename):
|
| 234 |
return filename
|
| 235 |
+
raise FileExistsError(
|
| 236 |
+
f"Could not find a unique filename after {index_range} attempts."
|
| 237 |
+
)
|
| 238 |
|
| 239 |
for i in range(images.shape[0]):
|
| 240 |
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
|
|
|
|
| 243 |
height, width = video_np.shape[1:3]
|
| 244 |
output_filename = get_unique_filename(f"video_output_{i}", ".mp4", ".")
|
| 245 |
|
| 246 |
+
out = cv2.VideoWriter(
|
| 247 |
+
output_filename, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
|
| 248 |
+
)
|
| 249 |
|
| 250 |
for frame in video_np[..., ::-1]:
|
| 251 |
out.write(frame)
|
xora/examples/text_to_video.py
CHANGED
|
@@ -10,16 +10,18 @@ import safetensors.torch
|
|
| 10 |
import json
|
| 11 |
import argparse
|
| 12 |
|
|
|
|
| 13 |
def load_vae(vae_dir):
|
| 14 |
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
| 15 |
vae_config_path = vae_dir / "config.json"
|
| 16 |
-
with open(vae_config_path,
|
| 17 |
vae_config = json.load(f)
|
| 18 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
| 19 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
| 20 |
vae.load_state_dict(vae_state_dict)
|
| 21 |
return vae.cuda().to(torch.bfloat16)
|
| 22 |
|
|
|
|
| 23 |
def load_unet(unet_dir):
|
| 24 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
| 25 |
unet_config_path = unet_dir / "config.json"
|
|
@@ -29,22 +31,31 @@ def load_unet(unet_dir):
|
|
| 29 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
| 30 |
return transformer.cuda()
|
| 31 |
|
|
|
|
| 32 |
def load_scheduler(scheduler_dir):
|
| 33 |
scheduler_config_path = scheduler_dir / "scheduler_config.json"
|
| 34 |
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
|
| 35 |
return RectifiedFlowScheduler.from_config(scheduler_config)
|
| 36 |
|
|
|
|
| 37 |
def main():
|
| 38 |
# Parse command line arguments
|
| 39 |
-
parser = argparse.ArgumentParser(
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
args = parser.parse_args()
|
| 42 |
|
| 43 |
# Paths for the separate mode directories
|
| 44 |
separate_dir = Path(args.separate_dir)
|
| 45 |
-
unet_dir = separate_dir /
|
| 46 |
-
vae_dir = separate_dir /
|
| 47 |
-
scheduler_dir = separate_dir /
|
| 48 |
|
| 49 |
# Load models
|
| 50 |
vae = load_vae(vae_dir)
|
|
@@ -54,8 +65,12 @@ def main():
|
|
| 54 |
# Patchifier (remains the same)
|
| 55 |
patchifier = SymmetricPatchifier(patch_size=1)
|
| 56 |
|
| 57 |
-
text_encoder = T5EncoderModel.from_pretrained(
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
# Use submodels for the pipeline
|
| 61 |
submodel_dict = {
|
|
@@ -79,14 +94,14 @@ def main():
|
|
| 79 |
frame_rate = 25
|
| 80 |
sample = {
|
| 81 |
"prompt": "A middle-aged man with glasses and a salt-and-pepper beard is driving a car and talking, gesturing with his right hand. "
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
}
|
| 87 |
|
| 88 |
# Generate images (video frames)
|
| 89 |
-
|
| 90 |
num_inference_steps=num_inference_steps,
|
| 91 |
num_images_per_prompt=num_images_per_prompt,
|
| 92 |
guidance_scale=guidance_scale,
|
|
@@ -104,5 +119,6 @@ def main():
|
|
| 104 |
|
| 105 |
print("Generated images (video frames).")
|
| 106 |
|
|
|
|
| 107 |
if __name__ == "__main__":
|
| 108 |
main()
|
|
|
|
| 10 |
import json
|
| 11 |
import argparse
|
| 12 |
|
| 13 |
+
|
| 14 |
def load_vae(vae_dir):
|
| 15 |
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
| 16 |
vae_config_path = vae_dir / "config.json"
|
| 17 |
+
with open(vae_config_path, "r") as f:
|
| 18 |
vae_config = json.load(f)
|
| 19 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
| 20 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
| 21 |
vae.load_state_dict(vae_state_dict)
|
| 22 |
return vae.cuda().to(torch.bfloat16)
|
| 23 |
|
| 24 |
+
|
| 25 |
def load_unet(unet_dir):
|
| 26 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
| 27 |
unet_config_path = unet_dir / "config.json"
|
|
|
|
| 31 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
| 32 |
return transformer.cuda()
|
| 33 |
|
| 34 |
+
|
| 35 |
def load_scheduler(scheduler_dir):
|
| 36 |
scheduler_config_path = scheduler_dir / "scheduler_config.json"
|
| 37 |
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
|
| 38 |
return RectifiedFlowScheduler.from_config(scheduler_config)
|
| 39 |
|
| 40 |
+
|
| 41 |
def main():
|
| 42 |
# Parse command line arguments
|
| 43 |
+
parser = argparse.ArgumentParser(
|
| 44 |
+
description="Load models from separate directories"
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--separate_dir",
|
| 48 |
+
type=str,
|
| 49 |
+
required=True,
|
| 50 |
+
help="Path to the directory containing unet, vae, and scheduler subdirectories",
|
| 51 |
+
)
|
| 52 |
args = parser.parse_args()
|
| 53 |
|
| 54 |
# Paths for the separate mode directories
|
| 55 |
separate_dir = Path(args.separate_dir)
|
| 56 |
+
unet_dir = separate_dir / "unet"
|
| 57 |
+
vae_dir = separate_dir / "vae"
|
| 58 |
+
scheduler_dir = separate_dir / "scheduler"
|
| 59 |
|
| 60 |
# Load models
|
| 61 |
vae = load_vae(vae_dir)
|
|
|
|
| 65 |
# Patchifier (remains the same)
|
| 66 |
patchifier = SymmetricPatchifier(patch_size=1)
|
| 67 |
|
| 68 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
| 69 |
+
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
|
| 70 |
+
).to("cuda")
|
| 71 |
+
tokenizer = T5Tokenizer.from_pretrained(
|
| 72 |
+
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
|
| 73 |
+
)
|
| 74 |
|
| 75 |
# Use submodels for the pipeline
|
| 76 |
submodel_dict = {
|
|
|
|
| 94 |
frame_rate = 25
|
| 95 |
sample = {
|
| 96 |
"prompt": "A middle-aged man with glasses and a salt-and-pepper beard is driving a car and talking, gesturing with his right hand. "
|
| 97 |
+
"The man is wearing a dark blue zip-up jacket and a light blue collared shirt. He is sitting in the driver's seat of a car with a black interior. The car is moving on a road with trees and bushes on either side. The man has a serious expression on his face and is looking straight ahead.",
|
| 98 |
+
"prompt_attention_mask": None, # Adjust attention masks as needed
|
| 99 |
+
"negative_prompt": "Ugly deformed",
|
| 100 |
+
"negative_prompt_attention_mask": None,
|
| 101 |
}
|
| 102 |
|
| 103 |
# Generate images (video frames)
|
| 104 |
+
_ = pipeline(
|
| 105 |
num_inference_steps=num_inference_steps,
|
| 106 |
num_images_per_prompt=num_images_per_prompt,
|
| 107 |
guidance_scale=guidance_scale,
|
|
|
|
| 119 |
|
| 120 |
print("Generated images (video frames).")
|
| 121 |
|
| 122 |
+
|
| 123 |
if __name__ == "__main__":
|
| 124 |
main()
|
xora/models/autoencoders/causal_conv3d.py
CHANGED
|
@@ -40,11 +40,17 @@ class CausalConv3d(nn.Module):
|
|
| 40 |
|
| 41 |
def forward(self, x, causal: bool = True):
|
| 42 |
if causal:
|
| 43 |
-
first_frame_pad = x[:, :, :1, :, :].repeat(
|
|
|
|
|
|
|
| 44 |
x = torch.concatenate((first_frame_pad, x), dim=2)
|
| 45 |
else:
|
| 46 |
-
first_frame_pad = x[:, :, :1, :, :].repeat(
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
| 49 |
x = self.conv(x)
|
| 50 |
return x
|
|
|
|
| 40 |
|
| 41 |
def forward(self, x, causal: bool = True):
|
| 42 |
if causal:
|
| 43 |
+
first_frame_pad = x[:, :, :1, :, :].repeat(
|
| 44 |
+
(1, 1, self.time_kernel_size - 1, 1, 1)
|
| 45 |
+
)
|
| 46 |
x = torch.concatenate((first_frame_pad, x), dim=2)
|
| 47 |
else:
|
| 48 |
+
first_frame_pad = x[:, :, :1, :, :].repeat(
|
| 49 |
+
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
| 50 |
+
)
|
| 51 |
+
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
| 52 |
+
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
| 53 |
+
)
|
| 54 |
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
| 55 |
x = self.conv(x)
|
| 56 |
return x
|
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
|
@@ -16,9 +16,15 @@ from xora.models.autoencoders.vae import AutoencoderKLWrapper
|
|
| 16 |
|
| 17 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 18 |
|
|
|
|
| 19 |
class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
| 20 |
@classmethod
|
| 21 |
-
def from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
config_local_path = pretrained_model_name_or_path / "config.json"
|
| 23 |
config = cls.load_config(config_local_path, **kwargs)
|
| 24 |
video_vae = cls.from_config(config)
|
|
@@ -28,29 +34,41 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
| 28 |
ckpt_state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))
|
| 29 |
video_vae.load_state_dict(ckpt_state_dict)
|
| 30 |
|
| 31 |
-
statistics_local_path =
|
|
|
|
|
|
|
| 32 |
if statistics_local_path.exists():
|
| 33 |
with open(statistics_local_path, "r") as file:
|
| 34 |
data = json.load(file)
|
| 35 |
transposed_data = list(zip(*data["data"]))
|
| 36 |
-
data_dict = {
|
|
|
|
|
|
|
|
|
|
| 37 |
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
|
| 38 |
video_vae.register_buffer(
|
| 39 |
-
"mean_of_means",
|
|
|
|
|
|
|
|
|
|
| 40 |
)
|
| 41 |
|
| 42 |
return video_vae
|
| 43 |
|
| 44 |
@staticmethod
|
| 45 |
def from_config(config):
|
| 46 |
-
assert
|
|
|
|
|
|
|
| 47 |
if isinstance(config["dims"], list):
|
| 48 |
config["dims"] = tuple(config["dims"])
|
| 49 |
|
| 50 |
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
|
| 51 |
|
| 52 |
double_z = config.get("double_z", True)
|
| 53 |
-
latent_log_var = config.get(
|
|
|
|
|
|
|
| 54 |
use_quant_conv = config.get("use_quant_conv", True)
|
| 55 |
|
| 56 |
if use_quant_conv and latent_log_var == "uniform":
|
|
@@ -91,7 +109,8 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
| 91 |
_class_name="CausalVideoAutoencoder",
|
| 92 |
dims=self.dims,
|
| 93 |
in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
|
| 94 |
-
out_channels=self.decoder.conv_out.out_channels
|
|
|
|
| 95 |
latent_channels=self.decoder.conv_in.in_channels,
|
| 96 |
blocks=self.encoder.blocks_desc,
|
| 97 |
scaling_factor=1.0,
|
|
@@ -112,13 +131,26 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
| 112 |
@property
|
| 113 |
def spatial_downscale_factor(self):
|
| 114 |
return (
|
| 115 |
-
2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
* self.encoder.patch_size
|
| 117 |
)
|
| 118 |
|
| 119 |
@property
|
| 120 |
def temporal_downscale_factor(self):
|
| 121 |
-
return 2 ** len(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
def to_json_string(self) -> str:
|
| 124 |
import json
|
|
@@ -146,7 +178,9 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
| 146 |
key = key.replace(k, v)
|
| 147 |
|
| 148 |
if "norm" in key and key not in model_keys:
|
| 149 |
-
logger.info(
|
|
|
|
|
|
|
| 150 |
continue
|
| 151 |
|
| 152 |
converted_state_dict[key] = value
|
|
@@ -293,7 +327,9 @@ class Encoder(nn.Module):
|
|
| 293 |
|
| 294 |
# out
|
| 295 |
if norm_layer == "group_norm":
|
| 296 |
-
self.conv_norm_out = nn.GroupNorm(
|
|
|
|
|
|
|
| 297 |
elif norm_layer == "pixel_norm":
|
| 298 |
self.conv_norm_out = PixelNorm()
|
| 299 |
elif norm_layer == "layer_norm":
|
|
@@ -308,7 +344,9 @@ class Encoder(nn.Module):
|
|
| 308 |
conv_out_channels += 1
|
| 309 |
elif latent_log_var != "none":
|
| 310 |
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
| 311 |
-
self.conv_out = make_conv_nd(
|
|
|
|
|
|
|
| 312 |
|
| 313 |
self.gradient_checkpointing = False
|
| 314 |
|
|
@@ -337,11 +375,15 @@ class Encoder(nn.Module):
|
|
| 337 |
|
| 338 |
if num_dims == 4:
|
| 339 |
# For shape (B, C, H, W)
|
| 340 |
-
repeated_last_channel = last_channel.repeat(
|
|
|
|
|
|
|
| 341 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 342 |
elif num_dims == 5:
|
| 343 |
# For shape (B, C, F, H, W)
|
| 344 |
-
repeated_last_channel = last_channel.repeat(
|
|
|
|
|
|
|
| 345 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 346 |
else:
|
| 347 |
raise ValueError(f"Invalid input shape: {sample.shape}")
|
|
@@ -430,25 +472,35 @@ class Decoder(nn.Module):
|
|
| 430 |
norm_layer=norm_layer,
|
| 431 |
)
|
| 432 |
elif block_name == "compress_time":
|
| 433 |
-
block = DepthToSpaceUpsample(
|
|
|
|
|
|
|
| 434 |
elif block_name == "compress_space":
|
| 435 |
-
block = DepthToSpaceUpsample(
|
|
|
|
|
|
|
| 436 |
elif block_name == "compress_all":
|
| 437 |
-
block = DepthToSpaceUpsample(
|
|
|
|
|
|
|
| 438 |
else:
|
| 439 |
raise ValueError(f"unknown layer: {block_name}")
|
| 440 |
|
| 441 |
self.up_blocks.append(block)
|
| 442 |
|
| 443 |
if norm_layer == "group_norm":
|
| 444 |
-
self.conv_norm_out = nn.GroupNorm(
|
|
|
|
|
|
|
| 445 |
elif norm_layer == "pixel_norm":
|
| 446 |
self.conv_norm_out = PixelNorm()
|
| 447 |
elif norm_layer == "layer_norm":
|
| 448 |
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
| 449 |
|
| 450 |
self.conv_act = nn.SiLU()
|
| 451 |
-
self.conv_out = make_conv_nd(
|
|
|
|
|
|
|
| 452 |
|
| 453 |
self.gradient_checkpointing = False
|
| 454 |
|
|
@@ -509,7 +561,9 @@ class UNetMidBlock3D(nn.Module):
|
|
| 509 |
norm_layer: str = "group_norm",
|
| 510 |
):
|
| 511 |
super().__init__()
|
| 512 |
-
resnet_groups =
|
|
|
|
|
|
|
| 513 |
|
| 514 |
self.res_blocks = nn.ModuleList(
|
| 515 |
[
|
|
@@ -526,7 +580,9 @@ class UNetMidBlock3D(nn.Module):
|
|
| 526 |
]
|
| 527 |
)
|
| 528 |
|
| 529 |
-
def forward(
|
|
|
|
|
|
|
| 530 |
for resnet in self.res_blocks:
|
| 531 |
hidden_states = resnet(hidden_states, causal=causal)
|
| 532 |
|
|
@@ -604,7 +660,9 @@ class ResnetBlock3D(nn.Module):
|
|
| 604 |
self.use_conv_shortcut = conv_shortcut
|
| 605 |
|
| 606 |
if norm_layer == "group_norm":
|
| 607 |
-
self.norm1 = nn.GroupNorm(
|
|
|
|
|
|
|
| 608 |
elif norm_layer == "pixel_norm":
|
| 609 |
self.norm1 = PixelNorm()
|
| 610 |
elif norm_layer == "layer_norm":
|
|
@@ -612,10 +670,20 @@ class ResnetBlock3D(nn.Module):
|
|
| 612 |
|
| 613 |
self.non_linearity = nn.SiLU()
|
| 614 |
|
| 615 |
-
self.conv1 = make_conv_nd(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
|
| 617 |
if norm_layer == "group_norm":
|
| 618 |
-
self.norm2 = nn.GroupNorm(
|
|
|
|
|
|
|
| 619 |
elif norm_layer == "pixel_norm":
|
| 620 |
self.norm2 = PixelNorm()
|
| 621 |
elif norm_layer == "layer_norm":
|
|
@@ -623,16 +691,28 @@ class ResnetBlock3D(nn.Module):
|
|
| 623 |
|
| 624 |
self.dropout = torch.nn.Dropout(dropout)
|
| 625 |
|
| 626 |
-
self.conv2 = make_conv_nd(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
|
| 628 |
self.conv_shortcut = (
|
| 629 |
-
make_linear_nd(
|
|
|
|
|
|
|
| 630 |
if in_channels != out_channels
|
| 631 |
else nn.Identity()
|
| 632 |
)
|
| 633 |
|
| 634 |
self.norm3 = (
|
| 635 |
-
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
|
|
|
|
|
|
| 636 |
)
|
| 637 |
|
| 638 |
def forward(
|
|
@@ -669,9 +749,17 @@ def patchify(x, patch_size_hw, patch_size_t=1):
|
|
| 669 |
if patch_size_hw == 1 and patch_size_t == 1:
|
| 670 |
return x
|
| 671 |
if x.dim() == 4:
|
| 672 |
-
x = rearrange(
|
|
|
|
|
|
|
| 673 |
elif x.dim() == 5:
|
| 674 |
-
x = rearrange(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
else:
|
| 676 |
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 677 |
|
|
@@ -683,9 +771,17 @@ def unpatchify(x, patch_size_hw, patch_size_t=1):
|
|
| 683 |
return x
|
| 684 |
|
| 685 |
if x.dim() == 4:
|
| 686 |
-
x = rearrange(
|
|
|
|
|
|
|
| 687 |
elif x.dim() == 5:
|
| 688 |
-
x = rearrange(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
|
| 690 |
return x
|
| 691 |
|
|
@@ -755,14 +851,18 @@ def demo_video_autoencoder_forward_backward():
|
|
| 755 |
print(f"input shape={input_videos.shape}")
|
| 756 |
print(f"latent shape={latent.shape}")
|
| 757 |
|
| 758 |
-
reconstructed_videos = video_autoencoder.decode(
|
|
|
|
|
|
|
| 759 |
|
| 760 |
print(f"reconstructed shape={reconstructed_videos.shape}")
|
| 761 |
|
| 762 |
# Validate that single image gets treated the same way as first frame
|
| 763 |
input_image = input_videos[:, :, :1, :, :]
|
| 764 |
image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
|
| 765 |
-
reconstructed_image = video_autoencoder.decode(
|
|
|
|
|
|
|
| 766 |
|
| 767 |
first_frame_latent = latent[:, :, :1, :, :]
|
| 768 |
|
|
|
|
| 16 |
|
| 17 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 18 |
|
| 19 |
+
|
| 20 |
class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
| 21 |
@classmethod
|
| 22 |
+
def from_pretrained(
|
| 23 |
+
cls,
|
| 24 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
| 25 |
+
*args,
|
| 26 |
+
**kwargs,
|
| 27 |
+
):
|
| 28 |
config_local_path = pretrained_model_name_or_path / "config.json"
|
| 29 |
config = cls.load_config(config_local_path, **kwargs)
|
| 30 |
video_vae = cls.from_config(config)
|
|
|
|
| 34 |
ckpt_state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))
|
| 35 |
video_vae.load_state_dict(ckpt_state_dict)
|
| 36 |
|
| 37 |
+
statistics_local_path = (
|
| 38 |
+
pretrained_model_name_or_path / "per_channel_statistics.json"
|
| 39 |
+
)
|
| 40 |
if statistics_local_path.exists():
|
| 41 |
with open(statistics_local_path, "r") as file:
|
| 42 |
data = json.load(file)
|
| 43 |
transposed_data = list(zip(*data["data"]))
|
| 44 |
+
data_dict = {
|
| 45 |
+
col: torch.tensor(vals)
|
| 46 |
+
for col, vals in zip(data["columns"], transposed_data)
|
| 47 |
+
}
|
| 48 |
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
|
| 49 |
video_vae.register_buffer(
|
| 50 |
+
"mean_of_means",
|
| 51 |
+
data_dict.get(
|
| 52 |
+
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
| 53 |
+
),
|
| 54 |
)
|
| 55 |
|
| 56 |
return video_vae
|
| 57 |
|
| 58 |
@staticmethod
|
| 59 |
def from_config(config):
|
| 60 |
+
assert (
|
| 61 |
+
config["_class_name"] == "CausalVideoAutoencoder"
|
| 62 |
+
), "config must have _class_name=CausalVideoAutoencoder"
|
| 63 |
if isinstance(config["dims"], list):
|
| 64 |
config["dims"] = tuple(config["dims"])
|
| 65 |
|
| 66 |
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
|
| 67 |
|
| 68 |
double_z = config.get("double_z", True)
|
| 69 |
+
latent_log_var = config.get(
|
| 70 |
+
"latent_log_var", "per_channel" if double_z else "none"
|
| 71 |
+
)
|
| 72 |
use_quant_conv = config.get("use_quant_conv", True)
|
| 73 |
|
| 74 |
if use_quant_conv and latent_log_var == "uniform":
|
|
|
|
| 109 |
_class_name="CausalVideoAutoencoder",
|
| 110 |
dims=self.dims,
|
| 111 |
in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
|
| 112 |
+
out_channels=self.decoder.conv_out.out_channels
|
| 113 |
+
// self.decoder.patch_size**2,
|
| 114 |
latent_channels=self.decoder.conv_in.in_channels,
|
| 115 |
blocks=self.encoder.blocks_desc,
|
| 116 |
scaling_factor=1.0,
|
|
|
|
| 131 |
@property
|
| 132 |
def spatial_downscale_factor(self):
|
| 133 |
return (
|
| 134 |
+
2
|
| 135 |
+
** len(
|
| 136 |
+
[
|
| 137 |
+
block
|
| 138 |
+
for block in self.encoder.blocks_desc
|
| 139 |
+
if block[0] in ["compress_space", "compress_all"]
|
| 140 |
+
]
|
| 141 |
+
)
|
| 142 |
* self.encoder.patch_size
|
| 143 |
)
|
| 144 |
|
| 145 |
@property
|
| 146 |
def temporal_downscale_factor(self):
|
| 147 |
+
return 2 ** len(
|
| 148 |
+
[
|
| 149 |
+
block
|
| 150 |
+
for block in self.encoder.blocks_desc
|
| 151 |
+
if block[0] in ["compress_time", "compress_all"]
|
| 152 |
+
]
|
| 153 |
+
)
|
| 154 |
|
| 155 |
def to_json_string(self) -> str:
|
| 156 |
import json
|
|
|
|
| 178 |
key = key.replace(k, v)
|
| 179 |
|
| 180 |
if "norm" in key and key not in model_keys:
|
| 181 |
+
logger.info(
|
| 182 |
+
f"Removing key {key} from state_dict as it is not present in the model"
|
| 183 |
+
)
|
| 184 |
continue
|
| 185 |
|
| 186 |
converted_state_dict[key] = value
|
|
|
|
| 327 |
|
| 328 |
# out
|
| 329 |
if norm_layer == "group_norm":
|
| 330 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 331 |
+
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
| 332 |
+
)
|
| 333 |
elif norm_layer == "pixel_norm":
|
| 334 |
self.conv_norm_out = PixelNorm()
|
| 335 |
elif norm_layer == "layer_norm":
|
|
|
|
| 344 |
conv_out_channels += 1
|
| 345 |
elif latent_log_var != "none":
|
| 346 |
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
| 347 |
+
self.conv_out = make_conv_nd(
|
| 348 |
+
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
|
| 349 |
+
)
|
| 350 |
|
| 351 |
self.gradient_checkpointing = False
|
| 352 |
|
|
|
|
| 375 |
|
| 376 |
if num_dims == 4:
|
| 377 |
# For shape (B, C, H, W)
|
| 378 |
+
repeated_last_channel = last_channel.repeat(
|
| 379 |
+
1, sample.shape[1] - 2, 1, 1
|
| 380 |
+
)
|
| 381 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 382 |
elif num_dims == 5:
|
| 383 |
# For shape (B, C, F, H, W)
|
| 384 |
+
repeated_last_channel = last_channel.repeat(
|
| 385 |
+
1, sample.shape[1] - 2, 1, 1, 1
|
| 386 |
+
)
|
| 387 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 388 |
else:
|
| 389 |
raise ValueError(f"Invalid input shape: {sample.shape}")
|
|
|
|
| 472 |
norm_layer=norm_layer,
|
| 473 |
)
|
| 474 |
elif block_name == "compress_time":
|
| 475 |
+
block = DepthToSpaceUpsample(
|
| 476 |
+
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
|
| 477 |
+
)
|
| 478 |
elif block_name == "compress_space":
|
| 479 |
+
block = DepthToSpaceUpsample(
|
| 480 |
+
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
| 481 |
+
)
|
| 482 |
elif block_name == "compress_all":
|
| 483 |
+
block = DepthToSpaceUpsample(
|
| 484 |
+
dims=dims, in_channels=input_channel, stride=(2, 2, 2)
|
| 485 |
+
)
|
| 486 |
else:
|
| 487 |
raise ValueError(f"unknown layer: {block_name}")
|
| 488 |
|
| 489 |
self.up_blocks.append(block)
|
| 490 |
|
| 491 |
if norm_layer == "group_norm":
|
| 492 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 493 |
+
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
| 494 |
+
)
|
| 495 |
elif norm_layer == "pixel_norm":
|
| 496 |
self.conv_norm_out = PixelNorm()
|
| 497 |
elif norm_layer == "layer_norm":
|
| 498 |
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
| 499 |
|
| 500 |
self.conv_act = nn.SiLU()
|
| 501 |
+
self.conv_out = make_conv_nd(
|
| 502 |
+
dims, output_channel, out_channels, 3, padding=1, causal=True
|
| 503 |
+
)
|
| 504 |
|
| 505 |
self.gradient_checkpointing = False
|
| 506 |
|
|
|
|
| 561 |
norm_layer: str = "group_norm",
|
| 562 |
):
|
| 563 |
super().__init__()
|
| 564 |
+
resnet_groups = (
|
| 565 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 566 |
+
)
|
| 567 |
|
| 568 |
self.res_blocks = nn.ModuleList(
|
| 569 |
[
|
|
|
|
| 580 |
]
|
| 581 |
)
|
| 582 |
|
| 583 |
+
def forward(
|
| 584 |
+
self, hidden_states: torch.FloatTensor, causal: bool = True
|
| 585 |
+
) -> torch.FloatTensor:
|
| 586 |
for resnet in self.res_blocks:
|
| 587 |
hidden_states = resnet(hidden_states, causal=causal)
|
| 588 |
|
|
|
|
| 660 |
self.use_conv_shortcut = conv_shortcut
|
| 661 |
|
| 662 |
if norm_layer == "group_norm":
|
| 663 |
+
self.norm1 = nn.GroupNorm(
|
| 664 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
| 665 |
+
)
|
| 666 |
elif norm_layer == "pixel_norm":
|
| 667 |
self.norm1 = PixelNorm()
|
| 668 |
elif norm_layer == "layer_norm":
|
|
|
|
| 670 |
|
| 671 |
self.non_linearity = nn.SiLU()
|
| 672 |
|
| 673 |
+
self.conv1 = make_conv_nd(
|
| 674 |
+
dims,
|
| 675 |
+
in_channels,
|
| 676 |
+
out_channels,
|
| 677 |
+
kernel_size=3,
|
| 678 |
+
stride=1,
|
| 679 |
+
padding=1,
|
| 680 |
+
causal=True,
|
| 681 |
+
)
|
| 682 |
|
| 683 |
if norm_layer == "group_norm":
|
| 684 |
+
self.norm2 = nn.GroupNorm(
|
| 685 |
+
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
| 686 |
+
)
|
| 687 |
elif norm_layer == "pixel_norm":
|
| 688 |
self.norm2 = PixelNorm()
|
| 689 |
elif norm_layer == "layer_norm":
|
|
|
|
| 691 |
|
| 692 |
self.dropout = torch.nn.Dropout(dropout)
|
| 693 |
|
| 694 |
+
self.conv2 = make_conv_nd(
|
| 695 |
+
dims,
|
| 696 |
+
out_channels,
|
| 697 |
+
out_channels,
|
| 698 |
+
kernel_size=3,
|
| 699 |
+
stride=1,
|
| 700 |
+
padding=1,
|
| 701 |
+
causal=True,
|
| 702 |
+
)
|
| 703 |
|
| 704 |
self.conv_shortcut = (
|
| 705 |
+
make_linear_nd(
|
| 706 |
+
dims=dims, in_channels=in_channels, out_channels=out_channels
|
| 707 |
+
)
|
| 708 |
if in_channels != out_channels
|
| 709 |
else nn.Identity()
|
| 710 |
)
|
| 711 |
|
| 712 |
self.norm3 = (
|
| 713 |
+
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
| 714 |
+
if in_channels != out_channels
|
| 715 |
+
else nn.Identity()
|
| 716 |
)
|
| 717 |
|
| 718 |
def forward(
|
|
|
|
| 749 |
if patch_size_hw == 1 and patch_size_t == 1:
|
| 750 |
return x
|
| 751 |
if x.dim() == 4:
|
| 752 |
+
x = rearrange(
|
| 753 |
+
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
|
| 754 |
+
)
|
| 755 |
elif x.dim() == 5:
|
| 756 |
+
x = rearrange(
|
| 757 |
+
x,
|
| 758 |
+
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
|
| 759 |
+
p=patch_size_t,
|
| 760 |
+
q=patch_size_hw,
|
| 761 |
+
r=patch_size_hw,
|
| 762 |
+
)
|
| 763 |
else:
|
| 764 |
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 765 |
|
|
|
|
| 771 |
return x
|
| 772 |
|
| 773 |
if x.dim() == 4:
|
| 774 |
+
x = rearrange(
|
| 775 |
+
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
|
| 776 |
+
)
|
| 777 |
elif x.dim() == 5:
|
| 778 |
+
x = rearrange(
|
| 779 |
+
x,
|
| 780 |
+
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
|
| 781 |
+
p=patch_size_t,
|
| 782 |
+
q=patch_size_hw,
|
| 783 |
+
r=patch_size_hw,
|
| 784 |
+
)
|
| 785 |
|
| 786 |
return x
|
| 787 |
|
|
|
|
| 851 |
print(f"input shape={input_videos.shape}")
|
| 852 |
print(f"latent shape={latent.shape}")
|
| 853 |
|
| 854 |
+
reconstructed_videos = video_autoencoder.decode(
|
| 855 |
+
latent, target_shape=input_videos.shape
|
| 856 |
+
).sample
|
| 857 |
|
| 858 |
print(f"reconstructed shape={reconstructed_videos.shape}")
|
| 859 |
|
| 860 |
# Validate that single image gets treated the same way as first frame
|
| 861 |
input_image = input_videos[:, :, :1, :, :]
|
| 862 |
image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
|
| 863 |
+
reconstructed_image = video_autoencoder.decode(
|
| 864 |
+
image_latent, target_shape=image_latent.shape
|
| 865 |
+
).sample
|
| 866 |
|
| 867 |
first_frame_latent = latent[:, :, :1, :, :]
|
| 868 |
|
xora/models/autoencoders/conv_nd_factory.py
CHANGED
|
@@ -71,8 +71,12 @@ def make_linear_nd(
|
|
| 71 |
bias=True,
|
| 72 |
):
|
| 73 |
if dims == 2:
|
| 74 |
-
return torch.nn.Conv2d(
|
|
|
|
|
|
|
| 75 |
elif dims == 3 or dims == (2, 1):
|
| 76 |
-
return torch.nn.Conv3d(
|
|
|
|
|
|
|
| 77 |
else:
|
| 78 |
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
|
| 71 |
bias=True,
|
| 72 |
):
|
| 73 |
if dims == 2:
|
| 74 |
+
return torch.nn.Conv2d(
|
| 75 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
| 76 |
+
)
|
| 77 |
elif dims == 3 or dims == (2, 1):
|
| 78 |
+
return torch.nn.Conv3d(
|
| 79 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
| 80 |
+
)
|
| 81 |
else:
|
| 82 |
raise ValueError(f"unsupported dimensions: {dims}")
|
xora/models/autoencoders/dual_conv3d.py
CHANGED
|
@@ -27,7 +27,9 @@ class DualConv3d(nn.Module):
|
|
| 27 |
if isinstance(kernel_size, int):
|
| 28 |
kernel_size = (kernel_size, kernel_size, kernel_size)
|
| 29 |
if kernel_size == (1, 1, 1):
|
| 30 |
-
raise ValueError(
|
|
|
|
|
|
|
| 31 |
if isinstance(stride, int):
|
| 32 |
stride = (stride, stride, stride)
|
| 33 |
if isinstance(padding, int):
|
|
@@ -40,11 +42,19 @@ class DualConv3d(nn.Module):
|
|
| 40 |
self.bias = bias
|
| 41 |
|
| 42 |
# Define the size of the channels after the first convolution
|
| 43 |
-
intermediate_channels =
|
|
|
|
|
|
|
| 44 |
|
| 45 |
# Define parameters for the first convolution
|
| 46 |
self.weight1 = nn.Parameter(
|
| 47 |
-
torch.Tensor(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
)
|
| 49 |
self.stride1 = (1, stride[1], stride[2])
|
| 50 |
self.padding1 = (0, padding[1], padding[2])
|
|
@@ -55,7 +65,11 @@ class DualConv3d(nn.Module):
|
|
| 55 |
self.register_parameter("bias1", None)
|
| 56 |
|
| 57 |
# Define parameters for the second convolution
|
| 58 |
-
self.weight2 = nn.Parameter(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
self.stride2 = (stride[0], 1, 1)
|
| 60 |
self.padding2 = (padding[0], 0, 0)
|
| 61 |
self.dilation2 = (dilation[0], 1, 1)
|
|
@@ -86,13 +100,29 @@ class DualConv3d(nn.Module):
|
|
| 86 |
|
| 87 |
def forward_with_3d(self, x, skip_time_conv):
|
| 88 |
# First convolution
|
| 89 |
-
x = F.conv3d(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
if skip_time_conv:
|
| 92 |
return x
|
| 93 |
|
| 94 |
# Second convolution
|
| 95 |
-
x = F.conv3d(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
return x
|
| 98 |
|
|
|
|
| 27 |
if isinstance(kernel_size, int):
|
| 28 |
kernel_size = (kernel_size, kernel_size, kernel_size)
|
| 29 |
if kernel_size == (1, 1, 1):
|
| 30 |
+
raise ValueError(
|
| 31 |
+
"kernel_size must be greater than 1. Use make_linear_nd instead."
|
| 32 |
+
)
|
| 33 |
if isinstance(stride, int):
|
| 34 |
stride = (stride, stride, stride)
|
| 35 |
if isinstance(padding, int):
|
|
|
|
| 42 |
self.bias = bias
|
| 43 |
|
| 44 |
# Define the size of the channels after the first convolution
|
| 45 |
+
intermediate_channels = (
|
| 46 |
+
out_channels if in_channels < out_channels else in_channels
|
| 47 |
+
)
|
| 48 |
|
| 49 |
# Define parameters for the first convolution
|
| 50 |
self.weight1 = nn.Parameter(
|
| 51 |
+
torch.Tensor(
|
| 52 |
+
intermediate_channels,
|
| 53 |
+
in_channels // groups,
|
| 54 |
+
1,
|
| 55 |
+
kernel_size[1],
|
| 56 |
+
kernel_size[2],
|
| 57 |
+
)
|
| 58 |
)
|
| 59 |
self.stride1 = (1, stride[1], stride[2])
|
| 60 |
self.padding1 = (0, padding[1], padding[2])
|
|
|
|
| 65 |
self.register_parameter("bias1", None)
|
| 66 |
|
| 67 |
# Define parameters for the second convolution
|
| 68 |
+
self.weight2 = nn.Parameter(
|
| 69 |
+
torch.Tensor(
|
| 70 |
+
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
|
| 71 |
+
)
|
| 72 |
+
)
|
| 73 |
self.stride2 = (stride[0], 1, 1)
|
| 74 |
self.padding2 = (padding[0], 0, 0)
|
| 75 |
self.dilation2 = (dilation[0], 1, 1)
|
|
|
|
| 100 |
|
| 101 |
def forward_with_3d(self, x, skip_time_conv):
|
| 102 |
# First convolution
|
| 103 |
+
x = F.conv3d(
|
| 104 |
+
x,
|
| 105 |
+
self.weight1,
|
| 106 |
+
self.bias1,
|
| 107 |
+
self.stride1,
|
| 108 |
+
self.padding1,
|
| 109 |
+
self.dilation1,
|
| 110 |
+
self.groups,
|
| 111 |
+
)
|
| 112 |
|
| 113 |
if skip_time_conv:
|
| 114 |
return x
|
| 115 |
|
| 116 |
# Second convolution
|
| 117 |
+
x = F.conv3d(
|
| 118 |
+
x,
|
| 119 |
+
self.weight2,
|
| 120 |
+
self.bias2,
|
| 121 |
+
self.stride2,
|
| 122 |
+
self.padding2,
|
| 123 |
+
self.dilation2,
|
| 124 |
+
self.groups,
|
| 125 |
+
)
|
| 126 |
|
| 127 |
return x
|
| 128 |
|
xora/models/autoencoders/vae.py
CHANGED
|
@@ -4,7 +4,10 @@ import torch
|
|
| 4 |
import math
|
| 5 |
import torch.nn as nn
|
| 6 |
from diffusers import ConfigMixin, ModelMixin
|
| 7 |
-
from diffusers.models.autoencoders.vae import
|
|
|
|
|
|
|
|
|
|
| 8 |
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 9 |
from xora.models.autoencoders.conv_nd_factory import make_conv_nd
|
| 10 |
|
|
@@ -43,8 +46,12 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
| 43 |
quant_dims = 2 if dims == 2 else 3
|
| 44 |
self.decoder = decoder
|
| 45 |
if use_quant_conv:
|
| 46 |
-
self.quant_conv = make_conv_nd(
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
else:
|
| 49 |
self.quant_conv = nn.Identity()
|
| 50 |
self.post_quant_conv = nn.Identity()
|
|
@@ -104,7 +111,13 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
| 104 |
for i in range(0, x.shape[3], overlap_size):
|
| 105 |
row = []
|
| 106 |
for j in range(0, x.shape[4], overlap_size):
|
| 107 |
-
tile = x[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
tile = self.encoder(tile)
|
| 109 |
tile = self.quant_conv(tile)
|
| 110 |
row.append(tile)
|
|
@@ -125,42 +138,58 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
| 125 |
moments = torch.cat(result_rows, dim=3)
|
| 126 |
return moments
|
| 127 |
|
| 128 |
-
def blend_z(
|
|
|
|
|
|
|
| 129 |
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
| 130 |
for z in range(blend_extent):
|
| 131 |
-
b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (
|
| 132 |
-
z / blend_extent
|
| 133 |
-
)
|
| 134 |
return b
|
| 135 |
|
| 136 |
-
def blend_v(
|
|
|
|
|
|
|
| 137 |
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
| 138 |
for y in range(blend_extent):
|
| 139 |
-
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
|
| 140 |
-
y / blend_extent
|
| 141 |
-
)
|
| 142 |
return b
|
| 143 |
|
| 144 |
-
def blend_h(
|
|
|
|
|
|
|
| 145 |
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
| 146 |
for x in range(blend_extent):
|
| 147 |
-
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
|
| 148 |
-
x / blend_extent
|
| 149 |
-
)
|
| 150 |
return b
|
| 151 |
|
| 152 |
def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
|
| 153 |
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
| 154 |
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
| 155 |
row_limit = self.tile_sample_min_size - blend_extent
|
| 156 |
-
tile_target_shape = (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
# Split z into overlapping 64x64 tiles and decode them separately.
|
| 158 |
# The tiles have an overlap to avoid seams between tiles.
|
| 159 |
rows = []
|
| 160 |
for i in range(0, z.shape[3], overlap_size):
|
| 161 |
row = []
|
| 162 |
for j in range(0, z.shape[4], overlap_size):
|
| 163 |
-
tile = z[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
tile = self.post_quant_conv(tile)
|
| 165 |
decoded = self.decoder(tile, target_shape=tile_target_shape)
|
| 166 |
row.append(decoded)
|
|
@@ -181,20 +210,34 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
| 181 |
dec = torch.cat(result_rows, dim=3)
|
| 182 |
return dec
|
| 183 |
|
| 184 |
-
def encode(
|
|
|
|
|
|
|
| 185 |
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
|
| 186 |
num_splits = z.shape[2] // self.z_sample_size
|
| 187 |
sizes = [self.z_sample_size] * num_splits
|
| 188 |
-
sizes =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
tiles = z.split(sizes, dim=2)
|
| 190 |
moments_tiles = [
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
for z_tile in tiles
|
| 193 |
]
|
| 194 |
moments = torch.cat(moments_tiles, dim=2)
|
| 195 |
|
| 196 |
else:
|
| 197 |
-
moments =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
posterior = DiagonalGaussianDistribution(moments)
|
| 200 |
if not return_dict:
|
|
@@ -207,7 +250,9 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
| 207 |
moments = self.quant_conv(h)
|
| 208 |
return moments
|
| 209 |
|
| 210 |
-
def _decode(
|
|
|
|
|
|
|
| 211 |
z = self.post_quant_conv(z)
|
| 212 |
dec = self.decoder(z, target_shape=target_shape)
|
| 213 |
return dec
|
|
@@ -219,7 +264,12 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
| 219 |
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
|
| 220 |
reduction_factor = int(
|
| 221 |
self.encoder.patch_size_t
|
| 222 |
-
* 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
)
|
| 224 |
split_size = self.z_sample_size // reduction_factor
|
| 225 |
num_splits = z.shape[2] // split_size
|
|
|
|
| 4 |
import math
|
| 5 |
import torch.nn as nn
|
| 6 |
from diffusers import ConfigMixin, ModelMixin
|
| 7 |
+
from diffusers.models.autoencoders.vae import (
|
| 8 |
+
DecoderOutput,
|
| 9 |
+
DiagonalGaussianDistribution,
|
| 10 |
+
)
|
| 11 |
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 12 |
from xora.models.autoencoders.conv_nd_factory import make_conv_nd
|
| 13 |
|
|
|
|
| 46 |
quant_dims = 2 if dims == 2 else 3
|
| 47 |
self.decoder = decoder
|
| 48 |
if use_quant_conv:
|
| 49 |
+
self.quant_conv = make_conv_nd(
|
| 50 |
+
quant_dims, 2 * latent_channels, 2 * latent_channels, 1
|
| 51 |
+
)
|
| 52 |
+
self.post_quant_conv = make_conv_nd(
|
| 53 |
+
quant_dims, latent_channels, latent_channels, 1
|
| 54 |
+
)
|
| 55 |
else:
|
| 56 |
self.quant_conv = nn.Identity()
|
| 57 |
self.post_quant_conv = nn.Identity()
|
|
|
|
| 111 |
for i in range(0, x.shape[3], overlap_size):
|
| 112 |
row = []
|
| 113 |
for j in range(0, x.shape[4], overlap_size):
|
| 114 |
+
tile = x[
|
| 115 |
+
:,
|
| 116 |
+
:,
|
| 117 |
+
:,
|
| 118 |
+
i : i + self.tile_sample_min_size,
|
| 119 |
+
j : j + self.tile_sample_min_size,
|
| 120 |
+
]
|
| 121 |
tile = self.encoder(tile)
|
| 122 |
tile = self.quant_conv(tile)
|
| 123 |
row.append(tile)
|
|
|
|
| 138 |
moments = torch.cat(result_rows, dim=3)
|
| 139 |
return moments
|
| 140 |
|
| 141 |
+
def blend_z(
|
| 142 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
| 143 |
+
) -> torch.Tensor:
|
| 144 |
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
| 145 |
for z in range(blend_extent):
|
| 146 |
+
b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (
|
| 147 |
+
1 - z / blend_extent
|
| 148 |
+
) + b[:, :, z, :, :] * (z / blend_extent)
|
| 149 |
return b
|
| 150 |
|
| 151 |
+
def blend_v(
|
| 152 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
| 153 |
+
) -> torch.Tensor:
|
| 154 |
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
| 155 |
for y in range(blend_extent):
|
| 156 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
|
| 157 |
+
1 - y / blend_extent
|
| 158 |
+
) + b[:, :, :, y, :] * (y / blend_extent)
|
| 159 |
return b
|
| 160 |
|
| 161 |
+
def blend_h(
|
| 162 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
| 163 |
+
) -> torch.Tensor:
|
| 164 |
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
| 165 |
for x in range(blend_extent):
|
| 166 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
|
| 167 |
+
1 - x / blend_extent
|
| 168 |
+
) + b[:, :, :, :, x] * (x / blend_extent)
|
| 169 |
return b
|
| 170 |
|
| 171 |
def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
|
| 172 |
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
| 173 |
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
| 174 |
row_limit = self.tile_sample_min_size - blend_extent
|
| 175 |
+
tile_target_shape = (
|
| 176 |
+
*target_shape[:3],
|
| 177 |
+
self.tile_sample_min_size,
|
| 178 |
+
self.tile_sample_min_size,
|
| 179 |
+
)
|
| 180 |
# Split z into overlapping 64x64 tiles and decode them separately.
|
| 181 |
# The tiles have an overlap to avoid seams between tiles.
|
| 182 |
rows = []
|
| 183 |
for i in range(0, z.shape[3], overlap_size):
|
| 184 |
row = []
|
| 185 |
for j in range(0, z.shape[4], overlap_size):
|
| 186 |
+
tile = z[
|
| 187 |
+
:,
|
| 188 |
+
:,
|
| 189 |
+
:,
|
| 190 |
+
i : i + self.tile_latent_min_size,
|
| 191 |
+
j : j + self.tile_latent_min_size,
|
| 192 |
+
]
|
| 193 |
tile = self.post_quant_conv(tile)
|
| 194 |
decoded = self.decoder(tile, target_shape=tile_target_shape)
|
| 195 |
row.append(decoded)
|
|
|
|
| 210 |
dec = torch.cat(result_rows, dim=3)
|
| 211 |
return dec
|
| 212 |
|
| 213 |
+
def encode(
|
| 214 |
+
self, z: torch.FloatTensor, return_dict: bool = True
|
| 215 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 216 |
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
|
| 217 |
num_splits = z.shape[2] // self.z_sample_size
|
| 218 |
sizes = [self.z_sample_size] * num_splits
|
| 219 |
+
sizes = (
|
| 220 |
+
sizes + [z.shape[2] - sum(sizes)]
|
| 221 |
+
if z.shape[2] - sum(sizes) > 0
|
| 222 |
+
else sizes
|
| 223 |
+
)
|
| 224 |
tiles = z.split(sizes, dim=2)
|
| 225 |
moments_tiles = [
|
| 226 |
+
(
|
| 227 |
+
self._hw_tiled_encode(z_tile, return_dict)
|
| 228 |
+
if self.use_hw_tiling
|
| 229 |
+
else self._encode(z_tile)
|
| 230 |
+
)
|
| 231 |
for z_tile in tiles
|
| 232 |
]
|
| 233 |
moments = torch.cat(moments_tiles, dim=2)
|
| 234 |
|
| 235 |
else:
|
| 236 |
+
moments = (
|
| 237 |
+
self._hw_tiled_encode(z, return_dict)
|
| 238 |
+
if self.use_hw_tiling
|
| 239 |
+
else self._encode(z)
|
| 240 |
+
)
|
| 241 |
|
| 242 |
posterior = DiagonalGaussianDistribution(moments)
|
| 243 |
if not return_dict:
|
|
|
|
| 250 |
moments = self.quant_conv(h)
|
| 251 |
return moments
|
| 252 |
|
| 253 |
+
def _decode(
|
| 254 |
+
self, z: torch.FloatTensor, target_shape=None
|
| 255 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 256 |
z = self.post_quant_conv(z)
|
| 257 |
dec = self.decoder(z, target_shape=target_shape)
|
| 258 |
return dec
|
|
|
|
| 264 |
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
|
| 265 |
reduction_factor = int(
|
| 266 |
self.encoder.patch_size_t
|
| 267 |
+
* 2
|
| 268 |
+
** (
|
| 269 |
+
len(self.encoder.down_blocks)
|
| 270 |
+
- 1
|
| 271 |
+
- math.sqrt(self.encoder.patch_size)
|
| 272 |
+
)
|
| 273 |
)
|
| 274 |
split_size = self.z_sample_size // reduction_factor
|
| 275 |
num_splits = z.shape[2] // split_size
|
xora/models/autoencoders/vae_encode.py
CHANGED
|
@@ -6,12 +6,19 @@ from torch import Tensor
|
|
| 6 |
|
| 7 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 8 |
from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
|
|
|
|
| 9 |
try:
|
| 10 |
import torch_xla.core.xla_model as xm
|
| 11 |
-
except:
|
| 12 |
-
|
|
|
|
| 13 |
|
| 14 |
-
def vae_encode(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
"""
|
| 16 |
Encodes media items (images or videos) into latent representations using a specified VAE model.
|
| 17 |
The function supports processing batches of images or video frames and can handle the processing
|
|
@@ -48,11 +55,15 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae
|
|
| 48 |
if channels != 3:
|
| 49 |
raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
|
| 50 |
|
| 51 |
-
if is_video_shaped and not isinstance(
|
|
|
|
|
|
|
| 52 |
media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
|
| 53 |
if split_size > 1:
|
| 54 |
if len(media_items) % split_size != 0:
|
| 55 |
-
raise ValueError(
|
|
|
|
|
|
|
| 56 |
encode_bs = len(media_items) // split_size
|
| 57 |
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
|
| 58 |
latents = []
|
|
@@ -67,22 +78,32 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae
|
|
| 67 |
latents = vae.encode(media_items).latent_dist.sample()
|
| 68 |
|
| 69 |
latents = normalize_latents(latents, vae, vae_per_channel_normalize)
|
| 70 |
-
if is_video_shaped and not isinstance(
|
|
|
|
|
|
|
| 71 |
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
|
| 72 |
return latents
|
| 73 |
|
| 74 |
|
| 75 |
def vae_decode(
|
| 76 |
-
latents: Tensor,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
) -> Tensor:
|
| 78 |
is_video_shaped = latents.dim() == 5
|
| 79 |
batch_size = latents.shape[0]
|
| 80 |
|
| 81 |
-
if is_video_shaped and not isinstance(
|
|
|
|
|
|
|
| 82 |
latents = rearrange(latents, "b c n h w -> (b n) c h w")
|
| 83 |
if split_size > 1:
|
| 84 |
if len(latents) % split_size != 0:
|
| 85 |
-
raise ValueError(
|
|
|
|
|
|
|
| 86 |
encode_bs = len(latents) // split_size
|
| 87 |
image_batch = [
|
| 88 |
_run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize)
|
|
@@ -92,12 +113,16 @@ def vae_decode(
|
|
| 92 |
else:
|
| 93 |
images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
|
| 94 |
|
| 95 |
-
if is_video_shaped and not isinstance(
|
|
|
|
|
|
|
| 96 |
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
|
| 97 |
return images
|
| 98 |
|
| 99 |
|
| 100 |
-
def _run_decoder(
|
|
|
|
|
|
|
| 101 |
if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
| 102 |
*_, fl, hl, wl = latents.shape
|
| 103 |
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
|
|
@@ -105,7 +130,13 @@ def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_ch
|
|
| 105 |
image = vae.decode(
|
| 106 |
un_normalize_latents(latents, vae, vae_per_channel_normalize),
|
| 107 |
return_dict=False,
|
| 108 |
-
target_shape=(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
)[0]
|
| 110 |
else:
|
| 111 |
image = vae.decode(
|
|
@@ -120,14 +151,26 @@ def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
|
|
| 120 |
spatial = vae.spatial_downscale_factor
|
| 121 |
temporal = vae.temporal_downscale_factor
|
| 122 |
else:
|
| 123 |
-
down_blocks = len(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
spatial = vae.config.patch_size * 2**down_blocks
|
| 125 |
-
temporal =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
return (temporal, spatial, spatial)
|
| 128 |
|
| 129 |
|
| 130 |
-
def normalize_latents(
|
|
|
|
|
|
|
| 131 |
return (
|
| 132 |
(latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
|
| 133 |
/ vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
|
@@ -136,10 +179,12 @@ def normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_norma
|
|
| 136 |
)
|
| 137 |
|
| 138 |
|
| 139 |
-
def un_normalize_latents(
|
|
|
|
|
|
|
| 140 |
return (
|
| 141 |
latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
| 142 |
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
| 143 |
if vae_per_channel_normalize
|
| 144 |
else latents / vae.config.scaling_factor
|
| 145 |
-
)
|
|
|
|
| 6 |
|
| 7 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 8 |
from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
|
| 9 |
+
|
| 10 |
try:
|
| 11 |
import torch_xla.core.xla_model as xm
|
| 12 |
+
except ImportError:
|
| 13 |
+
xm = None
|
| 14 |
+
|
| 15 |
|
| 16 |
+
def vae_encode(
|
| 17 |
+
media_items: Tensor,
|
| 18 |
+
vae: AutoencoderKL,
|
| 19 |
+
split_size: int = 1,
|
| 20 |
+
vae_per_channel_normalize=False,
|
| 21 |
+
) -> Tensor:
|
| 22 |
"""
|
| 23 |
Encodes media items (images or videos) into latent representations using a specified VAE model.
|
| 24 |
The function supports processing batches of images or video frames and can handle the processing
|
|
|
|
| 55 |
if channels != 3:
|
| 56 |
raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
|
| 57 |
|
| 58 |
+
if is_video_shaped and not isinstance(
|
| 59 |
+
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
| 60 |
+
):
|
| 61 |
media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
|
| 62 |
if split_size > 1:
|
| 63 |
if len(media_items) % split_size != 0:
|
| 64 |
+
raise ValueError(
|
| 65 |
+
"Error: The batch size must be divisible by 'train.vae_bs_split"
|
| 66 |
+
)
|
| 67 |
encode_bs = len(media_items) // split_size
|
| 68 |
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
|
| 69 |
latents = []
|
|
|
|
| 78 |
latents = vae.encode(media_items).latent_dist.sample()
|
| 79 |
|
| 80 |
latents = normalize_latents(latents, vae, vae_per_channel_normalize)
|
| 81 |
+
if is_video_shaped and not isinstance(
|
| 82 |
+
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
| 83 |
+
):
|
| 84 |
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
|
| 85 |
return latents
|
| 86 |
|
| 87 |
|
| 88 |
def vae_decode(
|
| 89 |
+
latents: Tensor,
|
| 90 |
+
vae: AutoencoderKL,
|
| 91 |
+
is_video: bool = True,
|
| 92 |
+
split_size: int = 1,
|
| 93 |
+
vae_per_channel_normalize=False,
|
| 94 |
) -> Tensor:
|
| 95 |
is_video_shaped = latents.dim() == 5
|
| 96 |
batch_size = latents.shape[0]
|
| 97 |
|
| 98 |
+
if is_video_shaped and not isinstance(
|
| 99 |
+
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
| 100 |
+
):
|
| 101 |
latents = rearrange(latents, "b c n h w -> (b n) c h w")
|
| 102 |
if split_size > 1:
|
| 103 |
if len(latents) % split_size != 0:
|
| 104 |
+
raise ValueError(
|
| 105 |
+
"Error: The batch size must be divisible by 'train.vae_bs_split"
|
| 106 |
+
)
|
| 107 |
encode_bs = len(latents) // split_size
|
| 108 |
image_batch = [
|
| 109 |
_run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize)
|
|
|
|
| 113 |
else:
|
| 114 |
images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
|
| 115 |
|
| 116 |
+
if is_video_shaped and not isinstance(
|
| 117 |
+
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
| 118 |
+
):
|
| 119 |
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
|
| 120 |
return images
|
| 121 |
|
| 122 |
|
| 123 |
+
def _run_decoder(
|
| 124 |
+
latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False
|
| 125 |
+
) -> Tensor:
|
| 126 |
if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
| 127 |
*_, fl, hl, wl = latents.shape
|
| 128 |
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
|
|
|
|
| 130 |
image = vae.decode(
|
| 131 |
un_normalize_latents(latents, vae, vae_per_channel_normalize),
|
| 132 |
return_dict=False,
|
| 133 |
+
target_shape=(
|
| 134 |
+
1,
|
| 135 |
+
3,
|
| 136 |
+
fl * temporal_scale if is_video else 1,
|
| 137 |
+
hl * spatial_scale,
|
| 138 |
+
wl * spatial_scale,
|
| 139 |
+
),
|
| 140 |
)[0]
|
| 141 |
else:
|
| 142 |
image = vae.decode(
|
|
|
|
| 151 |
spatial = vae.spatial_downscale_factor
|
| 152 |
temporal = vae.temporal_downscale_factor
|
| 153 |
else:
|
| 154 |
+
down_blocks = len(
|
| 155 |
+
[
|
| 156 |
+
block
|
| 157 |
+
for block in vae.encoder.down_blocks
|
| 158 |
+
if isinstance(block.downsample, Downsample3D)
|
| 159 |
+
]
|
| 160 |
+
)
|
| 161 |
spatial = vae.config.patch_size * 2**down_blocks
|
| 162 |
+
temporal = (
|
| 163 |
+
vae.config.patch_size_t * 2**down_blocks
|
| 164 |
+
if isinstance(vae, VideoAutoencoder)
|
| 165 |
+
else 1
|
| 166 |
+
)
|
| 167 |
|
| 168 |
return (temporal, spatial, spatial)
|
| 169 |
|
| 170 |
|
| 171 |
+
def normalize_latents(
|
| 172 |
+
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
|
| 173 |
+
) -> Tensor:
|
| 174 |
return (
|
| 175 |
(latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
|
| 176 |
/ vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
|
|
|
| 179 |
)
|
| 180 |
|
| 181 |
|
| 182 |
+
def un_normalize_latents(
|
| 183 |
+
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
|
| 184 |
+
) -> Tensor:
|
| 185 |
return (
|
| 186 |
latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
| 187 |
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
| 188 |
if vae_per_channel_normalize
|
| 189 |
else latents / vae.config.scaling_factor
|
| 190 |
+
)
|
xora/models/autoencoders/video_autoencoder.py
CHANGED
|
@@ -21,7 +21,12 @@ logger = logging.get_logger(__name__)
|
|
| 21 |
|
| 22 |
class VideoAutoencoder(AutoencoderKLWrapper):
|
| 23 |
@classmethod
|
| 24 |
-
def from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
config_local_path = pretrained_model_name_or_path / "config.json"
|
| 26 |
config = cls.load_config(config_local_path, **kwargs)
|
| 27 |
video_vae = cls.from_config(config)
|
|
@@ -31,29 +36,41 @@ class VideoAutoencoder(AutoencoderKLWrapper):
|
|
| 31 |
ckpt_state_dict = torch.load(model_local_path)
|
| 32 |
video_vae.load_state_dict(ckpt_state_dict)
|
| 33 |
|
| 34 |
-
statistics_local_path =
|
|
|
|
|
|
|
| 35 |
if statistics_local_path.exists():
|
| 36 |
with open(statistics_local_path, "r") as file:
|
| 37 |
data = json.load(file)
|
| 38 |
transposed_data = list(zip(*data["data"]))
|
| 39 |
-
data_dict = {
|
|
|
|
|
|
|
|
|
|
| 40 |
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
|
| 41 |
video_vae.register_buffer(
|
| 42 |
-
"mean_of_means",
|
|
|
|
|
|
|
|
|
|
| 43 |
)
|
| 44 |
|
| 45 |
return video_vae
|
| 46 |
|
| 47 |
@staticmethod
|
| 48 |
def from_config(config):
|
| 49 |
-
assert
|
|
|
|
|
|
|
| 50 |
if isinstance(config["dims"], list):
|
| 51 |
config["dims"] = tuple(config["dims"])
|
| 52 |
|
| 53 |
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
|
| 54 |
|
| 55 |
double_z = config.get("double_z", True)
|
| 56 |
-
latent_log_var = config.get(
|
|
|
|
|
|
|
| 57 |
use_quant_conv = config.get("use_quant_conv", True)
|
| 58 |
|
| 59 |
if use_quant_conv and latent_log_var == "uniform":
|
|
@@ -96,8 +113,10 @@ class VideoAutoencoder(AutoencoderKLWrapper):
|
|
| 96 |
return SimpleNamespace(
|
| 97 |
_class_name="VideoAutoencoder",
|
| 98 |
dims=self.dims,
|
| 99 |
-
in_channels=self.encoder.conv_in.in_channels
|
| 100 |
-
|
|
|
|
|
|
|
| 101 |
latent_channels=self.decoder.conv_in.in_channels,
|
| 102 |
block_out_channels=[
|
| 103 |
self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
|
|
@@ -143,7 +162,9 @@ class VideoAutoencoder(AutoencoderKLWrapper):
|
|
| 143 |
key = key.replace(k, v)
|
| 144 |
|
| 145 |
if "norm" in key and key not in model_keys:
|
| 146 |
-
logger.info(
|
|
|
|
|
|
|
| 147 |
continue
|
| 148 |
|
| 149 |
converted_state_dict[key] = value
|
|
@@ -253,7 +274,11 @@ class Encoder(nn.Module):
|
|
| 253 |
|
| 254 |
# out
|
| 255 |
if norm_layer == "group_norm":
|
| 256 |
-
self.conv_norm_out = nn.GroupNorm(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
elif norm_layer == "pixel_norm":
|
| 258 |
self.conv_norm_out = PixelNorm()
|
| 259 |
self.conv_act = nn.SiLU()
|
|
@@ -265,14 +290,23 @@ class Encoder(nn.Module):
|
|
| 265 |
conv_out_channels += 1
|
| 266 |
elif latent_log_var != "none":
|
| 267 |
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
| 268 |
-
self.conv_out = make_conv_nd(
|
|
|
|
|
|
|
| 269 |
|
| 270 |
self.gradient_checkpointing = False
|
| 271 |
|
| 272 |
@property
|
| 273 |
def downscale_factor(self):
|
| 274 |
return (
|
| 275 |
-
2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
* self.patch_size
|
| 277 |
)
|
| 278 |
|
|
@@ -299,7 +333,9 @@ class Encoder(nn.Module):
|
|
| 299 |
)
|
| 300 |
|
| 301 |
for down_block in self.down_blocks:
|
| 302 |
-
sample = checkpoint_fn(down_block)(
|
|
|
|
|
|
|
| 303 |
|
| 304 |
sample = checkpoint_fn(self.mid_block)(sample)
|
| 305 |
|
|
@@ -314,11 +350,15 @@ class Encoder(nn.Module):
|
|
| 314 |
|
| 315 |
if num_dims == 4:
|
| 316 |
# For shape (B, C, H, W)
|
| 317 |
-
repeated_last_channel = last_channel.repeat(
|
|
|
|
|
|
|
| 318 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 319 |
elif num_dims == 5:
|
| 320 |
# For shape (B, C, F, H, W)
|
| 321 |
-
repeated_last_channel = last_channel.repeat(
|
|
|
|
|
|
|
| 322 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 323 |
else:
|
| 324 |
raise ValueError(f"Invalid input shape: {sample.shape}")
|
|
@@ -405,7 +445,8 @@ class Decoder(nn.Module):
|
|
| 405 |
num_layers=self.layers_per_block + 1,
|
| 406 |
in_channels=prev_output_channel,
|
| 407 |
out_channels=output_channel,
|
| 408 |
-
add_upsample=not is_final_block
|
|
|
|
| 409 |
resnet_eps=1e-6,
|
| 410 |
resnet_groups=norm_num_groups,
|
| 411 |
norm_layer=norm_layer,
|
|
@@ -413,12 +454,16 @@ class Decoder(nn.Module):
|
|
| 413 |
self.up_blocks.append(up_block)
|
| 414 |
|
| 415 |
if norm_layer == "group_norm":
|
| 416 |
-
self.conv_norm_out = nn.GroupNorm(
|
|
|
|
|
|
|
| 417 |
elif norm_layer == "pixel_norm":
|
| 418 |
self.conv_norm_out = PixelNorm()
|
| 419 |
|
| 420 |
self.conv_act = nn.SiLU()
|
| 421 |
-
self.conv_out = make_conv_nd(
|
|
|
|
|
|
|
| 422 |
|
| 423 |
self.gradient_checkpointing = False
|
| 424 |
|
|
@@ -494,15 +539,24 @@ class DownEncoderBlock3D(nn.Module):
|
|
| 494 |
self.res_blocks = nn.ModuleList(res_blocks)
|
| 495 |
|
| 496 |
if add_downsample:
|
| 497 |
-
self.downsample = Downsample3D(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
else:
|
| 499 |
self.downsample = Identity()
|
| 500 |
|
| 501 |
-
def forward(
|
|
|
|
|
|
|
| 502 |
for resnet in self.res_blocks:
|
| 503 |
hidden_states = resnet(hidden_states)
|
| 504 |
|
| 505 |
-
hidden_states = self.downsample(
|
|
|
|
|
|
|
| 506 |
|
| 507 |
return hidden_states
|
| 508 |
|
|
@@ -536,7 +590,9 @@ class UNetMidBlock3D(nn.Module):
|
|
| 536 |
norm_layer: str = "group_norm",
|
| 537 |
):
|
| 538 |
super().__init__()
|
| 539 |
-
resnet_groups =
|
|
|
|
|
|
|
| 540 |
|
| 541 |
self.res_blocks = nn.ModuleList(
|
| 542 |
[
|
|
@@ -595,13 +651,17 @@ class UpDecoderBlock3D(nn.Module):
|
|
| 595 |
self.res_blocks = nn.ModuleList(res_blocks)
|
| 596 |
|
| 597 |
if add_upsample:
|
| 598 |
-
self.upsample = Upsample3D(
|
|
|
|
|
|
|
| 599 |
else:
|
| 600 |
self.upsample = Identity()
|
| 601 |
|
| 602 |
self.resolution_idx = resolution_idx
|
| 603 |
|
| 604 |
-
def forward(
|
|
|
|
|
|
|
| 605 |
for resnet in self.res_blocks:
|
| 606 |
hidden_states = resnet(hidden_states)
|
| 607 |
|
|
@@ -641,25 +701,35 @@ class ResnetBlock3D(nn.Module):
|
|
| 641 |
self.use_conv_shortcut = conv_shortcut
|
| 642 |
|
| 643 |
if norm_layer == "group_norm":
|
| 644 |
-
self.norm1 = torch.nn.GroupNorm(
|
|
|
|
|
|
|
| 645 |
elif norm_layer == "pixel_norm":
|
| 646 |
self.norm1 = PixelNorm()
|
| 647 |
|
| 648 |
self.non_linearity = nn.SiLU()
|
| 649 |
|
| 650 |
-
self.conv1 = make_conv_nd(
|
|
|
|
|
|
|
| 651 |
|
| 652 |
if norm_layer == "group_norm":
|
| 653 |
-
self.norm2 = torch.nn.GroupNorm(
|
|
|
|
|
|
|
| 654 |
elif norm_layer == "pixel_norm":
|
| 655 |
self.norm2 = PixelNorm()
|
| 656 |
|
| 657 |
self.dropout = torch.nn.Dropout(dropout)
|
| 658 |
|
| 659 |
-
self.conv2 = make_conv_nd(
|
|
|
|
|
|
|
| 660 |
|
| 661 |
self.conv_shortcut = (
|
| 662 |
-
make_linear_nd(
|
|
|
|
|
|
|
| 663 |
if in_channels != out_channels
|
| 664 |
else nn.Identity()
|
| 665 |
)
|
|
@@ -692,7 +762,14 @@ class ResnetBlock3D(nn.Module):
|
|
| 692 |
|
| 693 |
|
| 694 |
class Downsample3D(nn.Module):
|
| 695 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
super().__init__()
|
| 697 |
stride: int = 2
|
| 698 |
self.padding = padding
|
|
@@ -735,18 +812,24 @@ class Upsample3D(nn.Module):
|
|
| 735 |
self.dims = dims
|
| 736 |
self.channels = channels
|
| 737 |
self.out_channels = out_channels or channels
|
| 738 |
-
self.conv = make_conv_nd(
|
|
|
|
|
|
|
| 739 |
|
| 740 |
def forward(self, x, upsample_in_time):
|
| 741 |
if self.dims == 2:
|
| 742 |
-
x = functional.interpolate(
|
|
|
|
|
|
|
| 743 |
else:
|
| 744 |
time_scale_factor = 2 if upsample_in_time else 1
|
| 745 |
# print("before:", x.shape)
|
| 746 |
b, c, d, h, w = x.shape
|
| 747 |
x = rearrange(x, "b c d h w -> (b d) c h w")
|
| 748 |
# height and width interpolate
|
| 749 |
-
x = functional.interpolate(
|
|
|
|
|
|
|
| 750 |
_, _, h, w = x.shape
|
| 751 |
|
| 752 |
if not upsample_in_time and self.dims == (2, 1):
|
|
@@ -760,7 +843,9 @@ class Upsample3D(nn.Module):
|
|
| 760 |
new_d = x.shape[-1] * time_scale_factor
|
| 761 |
x = functional.interpolate(x, (1, new_d), mode="nearest")
|
| 762 |
# (b h w) c 1 new_d
|
| 763 |
-
x = rearrange(
|
|
|
|
|
|
|
| 764 |
# b c d h w
|
| 765 |
|
| 766 |
# x = functional.interpolate(
|
|
@@ -775,13 +860,25 @@ def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
|
|
| 775 |
if patch_size_hw == 1 and patch_size_t == 1:
|
| 776 |
return x
|
| 777 |
if x.dim() == 4:
|
| 778 |
-
x = rearrange(
|
|
|
|
|
|
|
| 779 |
elif x.dim() == 5:
|
| 780 |
-
x = rearrange(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
else:
|
| 782 |
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 783 |
|
| 784 |
-
if (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 785 |
channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
|
| 786 |
padding_zeros = torch.zeros(
|
| 787 |
x.shape[0],
|
|
@@ -801,14 +898,26 @@ def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
|
|
| 801 |
if patch_size_hw == 1 and patch_size_t == 1:
|
| 802 |
return x
|
| 803 |
|
| 804 |
-
if (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 805 |
channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
|
| 806 |
x = x[:, :channels_to_keep, :, :, :]
|
| 807 |
|
| 808 |
if x.dim() == 4:
|
| 809 |
-
x = rearrange(
|
|
|
|
|
|
|
| 810 |
elif x.dim() == 5:
|
| 811 |
-
x = rearrange(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
|
| 813 |
return x
|
| 814 |
|
|
@@ -818,11 +927,19 @@ def create_video_autoencoder_config(
|
|
| 818 |
):
|
| 819 |
config = {
|
| 820 |
"_class_name": "VideoAutoencoder",
|
| 821 |
-
"dims": (
|
|
|
|
|
|
|
|
|
|
| 822 |
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
| 823 |
"out_channels": 3, # Number of output color channels
|
| 824 |
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
| 825 |
-
"block_out_channels": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 826 |
"patch_size": 1,
|
| 827 |
}
|
| 828 |
|
|
@@ -834,11 +951,15 @@ def create_video_autoencoder_pathify4x4x4_config(
|
|
| 834 |
):
|
| 835 |
config = {
|
| 836 |
"_class_name": "VideoAutoencoder",
|
| 837 |
-
"dims": (
|
|
|
|
|
|
|
|
|
|
| 838 |
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
| 839 |
"out_channels": 3, # Number of output color channels
|
| 840 |
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
| 841 |
-
"block_out_channels": [512]
|
|
|
|
| 842 |
"patch_size": 4,
|
| 843 |
"latent_log_var": "uniform",
|
| 844 |
}
|
|
@@ -855,7 +976,8 @@ def create_video_autoencoder_pathify4x4_config(
|
|
| 855 |
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
| 856 |
"out_channels": 3, # Number of output color channels
|
| 857 |
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
| 858 |
-
"block_out_channels": [512]
|
|
|
|
| 859 |
"patch_size": 4,
|
| 860 |
"norm_layer": "pixel_norm",
|
| 861 |
}
|
|
@@ -894,7 +1016,9 @@ def demo_video_autoencoder_forward_backward():
|
|
| 894 |
latent = video_autoencoder.encode(input_videos).latent_dist.mode()
|
| 895 |
print(f"input shape={input_videos.shape}")
|
| 896 |
print(f"latent shape={latent.shape}")
|
| 897 |
-
reconstructed_videos = video_autoencoder.decode(
|
|
|
|
|
|
|
| 898 |
|
| 899 |
print(f"reconstructed shape={reconstructed_videos.shape}")
|
| 900 |
|
|
|
|
| 21 |
|
| 22 |
class VideoAutoencoder(AutoencoderKLWrapper):
|
| 23 |
@classmethod
|
| 24 |
+
def from_pretrained(
|
| 25 |
+
cls,
|
| 26 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
| 27 |
+
*args,
|
| 28 |
+
**kwargs,
|
| 29 |
+
):
|
| 30 |
config_local_path = pretrained_model_name_or_path / "config.json"
|
| 31 |
config = cls.load_config(config_local_path, **kwargs)
|
| 32 |
video_vae = cls.from_config(config)
|
|
|
|
| 36 |
ckpt_state_dict = torch.load(model_local_path)
|
| 37 |
video_vae.load_state_dict(ckpt_state_dict)
|
| 38 |
|
| 39 |
+
statistics_local_path = (
|
| 40 |
+
pretrained_model_name_or_path / "per_channel_statistics.json"
|
| 41 |
+
)
|
| 42 |
if statistics_local_path.exists():
|
| 43 |
with open(statistics_local_path, "r") as file:
|
| 44 |
data = json.load(file)
|
| 45 |
transposed_data = list(zip(*data["data"]))
|
| 46 |
+
data_dict = {
|
| 47 |
+
col: torch.tensor(vals)
|
| 48 |
+
for col, vals in zip(data["columns"], transposed_data)
|
| 49 |
+
}
|
| 50 |
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
|
| 51 |
video_vae.register_buffer(
|
| 52 |
+
"mean_of_means",
|
| 53 |
+
data_dict.get(
|
| 54 |
+
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
| 55 |
+
),
|
| 56 |
)
|
| 57 |
|
| 58 |
return video_vae
|
| 59 |
|
| 60 |
@staticmethod
|
| 61 |
def from_config(config):
|
| 62 |
+
assert (
|
| 63 |
+
config["_class_name"] == "VideoAutoencoder"
|
| 64 |
+
), "config must have _class_name=VideoAutoencoder"
|
| 65 |
if isinstance(config["dims"], list):
|
| 66 |
config["dims"] = tuple(config["dims"])
|
| 67 |
|
| 68 |
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
|
| 69 |
|
| 70 |
double_z = config.get("double_z", True)
|
| 71 |
+
latent_log_var = config.get(
|
| 72 |
+
"latent_log_var", "per_channel" if double_z else "none"
|
| 73 |
+
)
|
| 74 |
use_quant_conv = config.get("use_quant_conv", True)
|
| 75 |
|
| 76 |
if use_quant_conv and latent_log_var == "uniform":
|
|
|
|
| 113 |
return SimpleNamespace(
|
| 114 |
_class_name="VideoAutoencoder",
|
| 115 |
dims=self.dims,
|
| 116 |
+
in_channels=self.encoder.conv_in.in_channels
|
| 117 |
+
// (self.encoder.patch_size_t * self.encoder.patch_size**2),
|
| 118 |
+
out_channels=self.decoder.conv_out.out_channels
|
| 119 |
+
// (self.decoder.patch_size_t * self.decoder.patch_size**2),
|
| 120 |
latent_channels=self.decoder.conv_in.in_channels,
|
| 121 |
block_out_channels=[
|
| 122 |
self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
|
|
|
|
| 162 |
key = key.replace(k, v)
|
| 163 |
|
| 164 |
if "norm" in key and key not in model_keys:
|
| 165 |
+
logger.info(
|
| 166 |
+
f"Removing key {key} from state_dict as it is not present in the model"
|
| 167 |
+
)
|
| 168 |
continue
|
| 169 |
|
| 170 |
converted_state_dict[key] = value
|
|
|
|
| 274 |
|
| 275 |
# out
|
| 276 |
if norm_layer == "group_norm":
|
| 277 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 278 |
+
num_channels=block_out_channels[-1],
|
| 279 |
+
num_groups=norm_num_groups,
|
| 280 |
+
eps=1e-6,
|
| 281 |
+
)
|
| 282 |
elif norm_layer == "pixel_norm":
|
| 283 |
self.conv_norm_out = PixelNorm()
|
| 284 |
self.conv_act = nn.SiLU()
|
|
|
|
| 290 |
conv_out_channels += 1
|
| 291 |
elif latent_log_var != "none":
|
| 292 |
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
| 293 |
+
self.conv_out = make_conv_nd(
|
| 294 |
+
dims, block_out_channels[-1], conv_out_channels, 3, padding=1
|
| 295 |
+
)
|
| 296 |
|
| 297 |
self.gradient_checkpointing = False
|
| 298 |
|
| 299 |
@property
|
| 300 |
def downscale_factor(self):
|
| 301 |
return (
|
| 302 |
+
2
|
| 303 |
+
** len(
|
| 304 |
+
[
|
| 305 |
+
block
|
| 306 |
+
for block in self.down_blocks
|
| 307 |
+
if isinstance(block.downsample, Downsample3D)
|
| 308 |
+
]
|
| 309 |
+
)
|
| 310 |
* self.patch_size
|
| 311 |
)
|
| 312 |
|
|
|
|
| 333 |
)
|
| 334 |
|
| 335 |
for down_block in self.down_blocks:
|
| 336 |
+
sample = checkpoint_fn(down_block)(
|
| 337 |
+
sample, downsample_in_time=downsample_in_time
|
| 338 |
+
)
|
| 339 |
|
| 340 |
sample = checkpoint_fn(self.mid_block)(sample)
|
| 341 |
|
|
|
|
| 350 |
|
| 351 |
if num_dims == 4:
|
| 352 |
# For shape (B, C, H, W)
|
| 353 |
+
repeated_last_channel = last_channel.repeat(
|
| 354 |
+
1, sample.shape[1] - 2, 1, 1
|
| 355 |
+
)
|
| 356 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 357 |
elif num_dims == 5:
|
| 358 |
# For shape (B, C, F, H, W)
|
| 359 |
+
repeated_last_channel = last_channel.repeat(
|
| 360 |
+
1, sample.shape[1] - 2, 1, 1, 1
|
| 361 |
+
)
|
| 362 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 363 |
else:
|
| 364 |
raise ValueError(f"Invalid input shape: {sample.shape}")
|
|
|
|
| 445 |
num_layers=self.layers_per_block + 1,
|
| 446 |
in_channels=prev_output_channel,
|
| 447 |
out_channels=output_channel,
|
| 448 |
+
add_upsample=not is_final_block
|
| 449 |
+
and 2 ** (len(block_out_channels) - i - 1) > patch_size,
|
| 450 |
resnet_eps=1e-6,
|
| 451 |
resnet_groups=norm_num_groups,
|
| 452 |
norm_layer=norm_layer,
|
|
|
|
| 454 |
self.up_blocks.append(up_block)
|
| 455 |
|
| 456 |
if norm_layer == "group_norm":
|
| 457 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 458 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
|
| 459 |
+
)
|
| 460 |
elif norm_layer == "pixel_norm":
|
| 461 |
self.conv_norm_out = PixelNorm()
|
| 462 |
|
| 463 |
self.conv_act = nn.SiLU()
|
| 464 |
+
self.conv_out = make_conv_nd(
|
| 465 |
+
dims, block_out_channels[0], out_channels, 3, padding=1
|
| 466 |
+
)
|
| 467 |
|
| 468 |
self.gradient_checkpointing = False
|
| 469 |
|
|
|
|
| 539 |
self.res_blocks = nn.ModuleList(res_blocks)
|
| 540 |
|
| 541 |
if add_downsample:
|
| 542 |
+
self.downsample = Downsample3D(
|
| 543 |
+
dims,
|
| 544 |
+
out_channels,
|
| 545 |
+
out_channels=out_channels,
|
| 546 |
+
padding=downsample_padding,
|
| 547 |
+
)
|
| 548 |
else:
|
| 549 |
self.downsample = Identity()
|
| 550 |
|
| 551 |
+
def forward(
|
| 552 |
+
self, hidden_states: torch.FloatTensor, downsample_in_time
|
| 553 |
+
) -> torch.FloatTensor:
|
| 554 |
for resnet in self.res_blocks:
|
| 555 |
hidden_states = resnet(hidden_states)
|
| 556 |
|
| 557 |
+
hidden_states = self.downsample(
|
| 558 |
+
hidden_states, downsample_in_time=downsample_in_time
|
| 559 |
+
)
|
| 560 |
|
| 561 |
return hidden_states
|
| 562 |
|
|
|
|
| 590 |
norm_layer: str = "group_norm",
|
| 591 |
):
|
| 592 |
super().__init__()
|
| 593 |
+
resnet_groups = (
|
| 594 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 595 |
+
)
|
| 596 |
|
| 597 |
self.res_blocks = nn.ModuleList(
|
| 598 |
[
|
|
|
|
| 651 |
self.res_blocks = nn.ModuleList(res_blocks)
|
| 652 |
|
| 653 |
if add_upsample:
|
| 654 |
+
self.upsample = Upsample3D(
|
| 655 |
+
dims=dims, channels=out_channels, out_channels=out_channels
|
| 656 |
+
)
|
| 657 |
else:
|
| 658 |
self.upsample = Identity()
|
| 659 |
|
| 660 |
self.resolution_idx = resolution_idx
|
| 661 |
|
| 662 |
+
def forward(
|
| 663 |
+
self, hidden_states: torch.FloatTensor, upsample_in_time=True
|
| 664 |
+
) -> torch.FloatTensor:
|
| 665 |
for resnet in self.res_blocks:
|
| 666 |
hidden_states = resnet(hidden_states)
|
| 667 |
|
|
|
|
| 701 |
self.use_conv_shortcut = conv_shortcut
|
| 702 |
|
| 703 |
if norm_layer == "group_norm":
|
| 704 |
+
self.norm1 = torch.nn.GroupNorm(
|
| 705 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
| 706 |
+
)
|
| 707 |
elif norm_layer == "pixel_norm":
|
| 708 |
self.norm1 = PixelNorm()
|
| 709 |
|
| 710 |
self.non_linearity = nn.SiLU()
|
| 711 |
|
| 712 |
+
self.conv1 = make_conv_nd(
|
| 713 |
+
dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 714 |
+
)
|
| 715 |
|
| 716 |
if norm_layer == "group_norm":
|
| 717 |
+
self.norm2 = torch.nn.GroupNorm(
|
| 718 |
+
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
| 719 |
+
)
|
| 720 |
elif norm_layer == "pixel_norm":
|
| 721 |
self.norm2 = PixelNorm()
|
| 722 |
|
| 723 |
self.dropout = torch.nn.Dropout(dropout)
|
| 724 |
|
| 725 |
+
self.conv2 = make_conv_nd(
|
| 726 |
+
dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 727 |
+
)
|
| 728 |
|
| 729 |
self.conv_shortcut = (
|
| 730 |
+
make_linear_nd(
|
| 731 |
+
dims=dims, in_channels=in_channels, out_channels=out_channels
|
| 732 |
+
)
|
| 733 |
if in_channels != out_channels
|
| 734 |
else nn.Identity()
|
| 735 |
)
|
|
|
|
| 762 |
|
| 763 |
|
| 764 |
class Downsample3D(nn.Module):
|
| 765 |
+
def __init__(
|
| 766 |
+
self,
|
| 767 |
+
dims,
|
| 768 |
+
in_channels: int,
|
| 769 |
+
out_channels: int,
|
| 770 |
+
kernel_size: int = 3,
|
| 771 |
+
padding: int = 1,
|
| 772 |
+
):
|
| 773 |
super().__init__()
|
| 774 |
stride: int = 2
|
| 775 |
self.padding = padding
|
|
|
|
| 812 |
self.dims = dims
|
| 813 |
self.channels = channels
|
| 814 |
self.out_channels = out_channels or channels
|
| 815 |
+
self.conv = make_conv_nd(
|
| 816 |
+
dims, channels, out_channels, kernel_size=3, padding=1, bias=True
|
| 817 |
+
)
|
| 818 |
|
| 819 |
def forward(self, x, upsample_in_time):
|
| 820 |
if self.dims == 2:
|
| 821 |
+
x = functional.interpolate(
|
| 822 |
+
x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
|
| 823 |
+
)
|
| 824 |
else:
|
| 825 |
time_scale_factor = 2 if upsample_in_time else 1
|
| 826 |
# print("before:", x.shape)
|
| 827 |
b, c, d, h, w = x.shape
|
| 828 |
x = rearrange(x, "b c d h w -> (b d) c h w")
|
| 829 |
# height and width interpolate
|
| 830 |
+
x = functional.interpolate(
|
| 831 |
+
x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
|
| 832 |
+
)
|
| 833 |
_, _, h, w = x.shape
|
| 834 |
|
| 835 |
if not upsample_in_time and self.dims == (2, 1):
|
|
|
|
| 843 |
new_d = x.shape[-1] * time_scale_factor
|
| 844 |
x = functional.interpolate(x, (1, new_d), mode="nearest")
|
| 845 |
# (b h w) c 1 new_d
|
| 846 |
+
x = rearrange(
|
| 847 |
+
x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d
|
| 848 |
+
)
|
| 849 |
# b c d h w
|
| 850 |
|
| 851 |
# x = functional.interpolate(
|
|
|
|
| 860 |
if patch_size_hw == 1 and patch_size_t == 1:
|
| 861 |
return x
|
| 862 |
if x.dim() == 4:
|
| 863 |
+
x = rearrange(
|
| 864 |
+
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
|
| 865 |
+
)
|
| 866 |
elif x.dim() == 5:
|
| 867 |
+
x = rearrange(
|
| 868 |
+
x,
|
| 869 |
+
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
|
| 870 |
+
p=patch_size_t,
|
| 871 |
+
q=patch_size_hw,
|
| 872 |
+
r=patch_size_hw,
|
| 873 |
+
)
|
| 874 |
else:
|
| 875 |
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 876 |
|
| 877 |
+
if (
|
| 878 |
+
(x.dim() == 5)
|
| 879 |
+
and (patch_size_hw > patch_size_t)
|
| 880 |
+
and (patch_size_t > 1 or add_channel_padding)
|
| 881 |
+
):
|
| 882 |
channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
|
| 883 |
padding_zeros = torch.zeros(
|
| 884 |
x.shape[0],
|
|
|
|
| 898 |
if patch_size_hw == 1 and patch_size_t == 1:
|
| 899 |
return x
|
| 900 |
|
| 901 |
+
if (
|
| 902 |
+
(x.dim() == 5)
|
| 903 |
+
and (patch_size_hw > patch_size_t)
|
| 904 |
+
and (patch_size_t > 1 or add_channel_padding)
|
| 905 |
+
):
|
| 906 |
channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
|
| 907 |
x = x[:, :channels_to_keep, :, :, :]
|
| 908 |
|
| 909 |
if x.dim() == 4:
|
| 910 |
+
x = rearrange(
|
| 911 |
+
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
|
| 912 |
+
)
|
| 913 |
elif x.dim() == 5:
|
| 914 |
+
x = rearrange(
|
| 915 |
+
x,
|
| 916 |
+
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
|
| 917 |
+
p=patch_size_t,
|
| 918 |
+
q=patch_size_hw,
|
| 919 |
+
r=patch_size_hw,
|
| 920 |
+
)
|
| 921 |
|
| 922 |
return x
|
| 923 |
|
|
|
|
| 927 |
):
|
| 928 |
config = {
|
| 929 |
"_class_name": "VideoAutoencoder",
|
| 930 |
+
"dims": (
|
| 931 |
+
2,
|
| 932 |
+
1,
|
| 933 |
+
), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
| 934 |
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
| 935 |
"out_channels": 3, # Number of output color channels
|
| 936 |
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
| 937 |
+
"block_out_channels": [
|
| 938 |
+
128,
|
| 939 |
+
256,
|
| 940 |
+
512,
|
| 941 |
+
512,
|
| 942 |
+
], # Number of output channels of each encoder / decoder inner block
|
| 943 |
"patch_size": 1,
|
| 944 |
}
|
| 945 |
|
|
|
|
| 951 |
):
|
| 952 |
config = {
|
| 953 |
"_class_name": "VideoAutoencoder",
|
| 954 |
+
"dims": (
|
| 955 |
+
2,
|
| 956 |
+
1,
|
| 957 |
+
), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
| 958 |
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
| 959 |
"out_channels": 3, # Number of output color channels
|
| 960 |
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
| 961 |
+
"block_out_channels": [512]
|
| 962 |
+
* 4, # Number of output channels of each encoder / decoder inner block
|
| 963 |
"patch_size": 4,
|
| 964 |
"latent_log_var": "uniform",
|
| 965 |
}
|
|
|
|
| 976 |
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
| 977 |
"out_channels": 3, # Number of output color channels
|
| 978 |
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
| 979 |
+
"block_out_channels": [512]
|
| 980 |
+
* 4, # Number of output channels of each encoder / decoder inner block
|
| 981 |
"patch_size": 4,
|
| 982 |
"norm_layer": "pixel_norm",
|
| 983 |
}
|
|
|
|
| 1016 |
latent = video_autoencoder.encode(input_videos).latent_dist.mode()
|
| 1017 |
print(f"input shape={input_videos.shape}")
|
| 1018 |
print(f"latent shape={latent.shape}")
|
| 1019 |
+
reconstructed_videos = video_autoencoder.decode(
|
| 1020 |
+
latent, target_shape=input_videos.shape
|
| 1021 |
+
).sample
|
| 1022 |
|
| 1023 |
print(f"reconstructed shape={reconstructed_videos.shape}")
|
| 1024 |
|
xora/models/transformers/attention.py
CHANGED
|
@@ -106,11 +106,15 @@ class BasicTransformerBlock(nn.Module):
|
|
| 106 |
assert standardization_norm in ["layer_norm", "rms_norm"]
|
| 107 |
assert adaptive_norm in ["single_scale_shift", "single_scale", "none"]
|
| 108 |
|
| 109 |
-
make_norm_layer =
|
|
|
|
|
|
|
| 110 |
|
| 111 |
# Define 3 blocks. Each block has its own normalization layer.
|
| 112 |
# 1. Self-Attn
|
| 113 |
-
self.norm1 = make_norm_layer(
|
|
|
|
|
|
|
| 114 |
|
| 115 |
self.attn1 = Attention(
|
| 116 |
query_dim=dim,
|
|
@@ -130,7 +134,9 @@ class BasicTransformerBlock(nn.Module):
|
|
| 130 |
if cross_attention_dim is not None or double_self_attention:
|
| 131 |
self.attn2 = Attention(
|
| 132 |
query_dim=dim,
|
| 133 |
-
cross_attention_dim=
|
|
|
|
|
|
|
| 134 |
heads=num_attention_heads,
|
| 135 |
dim_head=attention_head_dim,
|
| 136 |
dropout=dropout,
|
|
@@ -143,7 +149,9 @@ class BasicTransformerBlock(nn.Module):
|
|
| 143 |
) # is self-attn if encoder_hidden_states is none
|
| 144 |
|
| 145 |
if adaptive_norm == "none":
|
| 146 |
-
self.attn2_norm = make_norm_layer(
|
|
|
|
|
|
|
| 147 |
else:
|
| 148 |
self.attn2 = None
|
| 149 |
self.attn2_norm = None
|
|
@@ -163,7 +171,9 @@ class BasicTransformerBlock(nn.Module):
|
|
| 163 |
# 5. Scale-shift for PixArt-Alpha.
|
| 164 |
if adaptive_norm != "none":
|
| 165 |
num_ada_params = 4 if adaptive_norm == "single_scale" else 6
|
| 166 |
-
self.scale_shift_table = nn.Parameter(
|
|
|
|
|
|
|
| 167 |
|
| 168 |
# let chunk size default to None
|
| 169 |
self._chunk_size = None
|
|
@@ -198,7 +208,9 @@ class BasicTransformerBlock(nn.Module):
|
|
| 198 |
) -> torch.FloatTensor:
|
| 199 |
if cross_attention_kwargs is not None:
|
| 200 |
if cross_attention_kwargs.get("scale", None) is not None:
|
| 201 |
-
logger.warning(
|
|
|
|
|
|
|
| 202 |
|
| 203 |
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 204 |
# 0. Self-Attention
|
|
@@ -214,7 +226,9 @@ class BasicTransformerBlock(nn.Module):
|
|
| 214 |
batch_size, timestep.shape[1], num_ada_params, -1
|
| 215 |
)
|
| 216 |
if self.adaptive_norm == "single_scale_shift":
|
| 217 |
-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp =
|
|
|
|
|
|
|
| 218 |
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
| 219 |
else:
|
| 220 |
scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
|
@@ -224,15 +238,21 @@ class BasicTransformerBlock(nn.Module):
|
|
| 224 |
else:
|
| 225 |
raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
|
| 226 |
|
| 227 |
-
norm_hidden_states = norm_hidden_states.squeeze(
|
|
|
|
|
|
|
| 228 |
|
| 229 |
# 1. Prepare GLIGEN inputs
|
| 230 |
-
cross_attention_kwargs =
|
|
|
|
|
|
|
| 231 |
|
| 232 |
attn_output = self.attn1(
|
| 233 |
norm_hidden_states,
|
| 234 |
freqs_cis=freqs_cis,
|
| 235 |
-
encoder_hidden_states=
|
|
|
|
|
|
|
| 236 |
attention_mask=attention_mask,
|
| 237 |
**cross_attention_kwargs,
|
| 238 |
)
|
|
@@ -271,7 +291,9 @@ class BasicTransformerBlock(nn.Module):
|
|
| 271 |
|
| 272 |
if self._chunk_size is not None:
|
| 273 |
# "feed_forward_chunk_size" can be used to save memory
|
| 274 |
-
ff_output = _chunked_feed_forward(
|
|
|
|
|
|
|
| 275 |
else:
|
| 276 |
ff_output = self.ff(norm_hidden_states)
|
| 277 |
if gate_mlp is not None:
|
|
@@ -371,7 +393,9 @@ class Attention(nn.Module):
|
|
| 371 |
self.query_dim = query_dim
|
| 372 |
self.use_bias = bias
|
| 373 |
self.is_cross_attention = cross_attention_dim is not None
|
| 374 |
-
self.cross_attention_dim =
|
|
|
|
|
|
|
| 375 |
self.upcast_attention = upcast_attention
|
| 376 |
self.upcast_softmax = upcast_softmax
|
| 377 |
self.rescale_output_factor = rescale_output_factor
|
|
@@ -416,12 +440,16 @@ class Attention(nn.Module):
|
|
| 416 |
)
|
| 417 |
|
| 418 |
if norm_num_groups is not None:
|
| 419 |
-
self.group_norm = nn.GroupNorm(
|
|
|
|
|
|
|
| 420 |
else:
|
| 421 |
self.group_norm = None
|
| 422 |
|
| 423 |
if spatial_norm_dim is not None:
|
| 424 |
-
self.spatial_norm = SpatialNorm(
|
|
|
|
|
|
|
| 425 |
else:
|
| 426 |
self.spatial_norm = None
|
| 427 |
|
|
@@ -441,7 +469,10 @@ class Attention(nn.Module):
|
|
| 441 |
norm_cross_num_channels = self.cross_attention_dim
|
| 442 |
|
| 443 |
self.norm_cross = nn.GroupNorm(
|
| 444 |
-
num_channels=norm_cross_num_channels,
|
|
|
|
|
|
|
|
|
|
| 445 |
)
|
| 446 |
else:
|
| 447 |
raise ValueError(
|
|
@@ -499,12 +530,16 @@ class Attention(nn.Module):
|
|
| 499 |
and isinstance(self.processor, torch.nn.Module)
|
| 500 |
and not isinstance(processor, torch.nn.Module)
|
| 501 |
):
|
| 502 |
-
logger.info(
|
|
|
|
|
|
|
| 503 |
self._modules.pop("processor")
|
| 504 |
|
| 505 |
self.processor = processor
|
| 506 |
|
| 507 |
-
def get_processor(
|
|
|
|
|
|
|
| 508 |
r"""
|
| 509 |
Get the attention processor in use.
|
| 510 |
|
|
@@ -542,12 +577,18 @@ class Attention(nn.Module):
|
|
| 542 |
|
| 543 |
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
|
| 544 |
non_lora_processor_cls_name = self.processor.__class__.__name__
|
| 545 |
-
lora_processor_cls = getattr(
|
|
|
|
|
|
|
| 546 |
|
| 547 |
hidden_size = self.inner_dim
|
| 548 |
|
| 549 |
# now create a LoRA attention processor from the LoRA layers
|
| 550 |
-
if lora_processor_cls in [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
kwargs = {
|
| 552 |
"cross_attention_dim": self.cross_attention_dim,
|
| 553 |
"rank": self.to_q.lora_layer.rank,
|
|
@@ -569,7 +610,9 @@ class Attention(nn.Module):
|
|
| 569 |
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
| 570 |
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
| 571 |
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
| 572 |
-
lora_processor.to_out_lora.load_state_dict(
|
|
|
|
|
|
|
| 573 |
elif lora_processor_cls == LoRAAttnAddedKVProcessor:
|
| 574 |
lora_processor = lora_processor_cls(
|
| 575 |
hidden_size,
|
|
@@ -580,12 +623,18 @@ class Attention(nn.Module):
|
|
| 580 |
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
| 581 |
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
| 582 |
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
| 583 |
-
lora_processor.to_out_lora.load_state_dict(
|
|
|
|
|
|
|
| 584 |
|
| 585 |
# only save if used
|
| 586 |
if self.add_k_proj.lora_layer is not None:
|
| 587 |
-
lora_processor.add_k_proj_lora.load_state_dict(
|
| 588 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
else:
|
| 590 |
lora_processor.add_k_proj_lora = None
|
| 591 |
lora_processor.add_v_proj_lora = None
|
|
@@ -622,14 +671,20 @@ class Attention(nn.Module):
|
|
| 622 |
# here we simply pass along all tensors to the selected processor class
|
| 623 |
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
| 624 |
|
| 625 |
-
attn_parameters = set(
|
| 626 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
if len(unused_kwargs) > 0:
|
| 628 |
logger.warning(
|
| 629 |
f"cross_attention_kwargs {unused_kwargs} are not expected by"
|
| 630 |
f" {self.processor.__class__.__name__} and will be ignored."
|
| 631 |
)
|
| 632 |
-
cross_attention_kwargs = {
|
|
|
|
|
|
|
| 633 |
|
| 634 |
return self.processor(
|
| 635 |
self,
|
|
@@ -654,7 +709,9 @@ class Attention(nn.Module):
|
|
| 654 |
head_size = self.heads
|
| 655 |
batch_size, seq_len, dim = tensor.shape
|
| 656 |
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
| 657 |
-
tensor = tensor.permute(0, 2, 1, 3).reshape(
|
|
|
|
|
|
|
| 658 |
return tensor
|
| 659 |
|
| 660 |
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
|
@@ -677,16 +734,23 @@ class Attention(nn.Module):
|
|
| 677 |
extra_dim = 1
|
| 678 |
else:
|
| 679 |
batch_size, extra_dim, seq_len, dim = tensor.shape
|
| 680 |
-
tensor = tensor.reshape(
|
|
|
|
|
|
|
| 681 |
tensor = tensor.permute(0, 2, 1, 3)
|
| 682 |
|
| 683 |
if out_dim == 3:
|
| 684 |
-
tensor = tensor.reshape(
|
|
|
|
|
|
|
| 685 |
|
| 686 |
return tensor
|
| 687 |
|
| 688 |
def get_attention_scores(
|
| 689 |
-
self,
|
|
|
|
|
|
|
|
|
|
| 690 |
) -> torch.Tensor:
|
| 691 |
r"""
|
| 692 |
Compute the attention scores.
|
|
@@ -706,7 +770,11 @@ class Attention(nn.Module):
|
|
| 706 |
|
| 707 |
if attention_mask is None:
|
| 708 |
baddbmm_input = torch.empty(
|
| 709 |
-
query.shape[0],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
)
|
| 711 |
beta = 0
|
| 712 |
else:
|
|
@@ -733,7 +801,11 @@ class Attention(nn.Module):
|
|
| 733 |
return attention_probs
|
| 734 |
|
| 735 |
def prepare_attention_mask(
|
| 736 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 737 |
) -> torch.Tensor:
|
| 738 |
r"""
|
| 739 |
Prepare the attention mask for the attention computation.
|
|
@@ -760,8 +832,16 @@ class Attention(nn.Module):
|
|
| 760 |
if attention_mask.device.type == "mps":
|
| 761 |
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
| 762 |
# Instead, we can manually construct the padding tensor.
|
| 763 |
-
padding_shape = (
|
| 764 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 765 |
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
| 766 |
else:
|
| 767 |
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
|
@@ -779,7 +859,9 @@ class Attention(nn.Module):
|
|
| 779 |
|
| 780 |
return attention_mask
|
| 781 |
|
| 782 |
-
def norm_encoder_hidden_states(
|
|
|
|
|
|
|
| 783 |
r"""
|
| 784 |
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
|
| 785 |
`Attention` class.
|
|
@@ -790,7 +872,9 @@ class Attention(nn.Module):
|
|
| 790 |
Returns:
|
| 791 |
`torch.Tensor`: The normalized encoder hidden states.
|
| 792 |
"""
|
| 793 |
-
assert
|
|
|
|
|
|
|
| 794 |
|
| 795 |
if isinstance(self.norm_cross, nn.LayerNorm):
|
| 796 |
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
|
@@ -857,27 +941,39 @@ class AttnProcessor2_0:
|
|
| 857 |
|
| 858 |
if input_ndim == 4:
|
| 859 |
batch_size, channel, height, width = hidden_states.shape
|
| 860 |
-
hidden_states = hidden_states.view(
|
|
|
|
|
|
|
| 861 |
|
| 862 |
batch_size, sequence_length, _ = (
|
| 863 |
-
hidden_states.shape
|
|
|
|
|
|
|
| 864 |
)
|
| 865 |
|
| 866 |
if (attention_mask is not None) and (not attn.use_tpu_flash_attention):
|
| 867 |
-
attention_mask = attn.prepare_attention_mask(
|
|
|
|
|
|
|
| 868 |
# scaled_dot_product_attention expects attention_mask shape to be
|
| 869 |
# (batch, heads, source_length, target_length)
|
| 870 |
-
attention_mask = attention_mask.view(
|
|
|
|
|
|
|
| 871 |
|
| 872 |
if attn.group_norm is not None:
|
| 873 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
|
|
|
|
|
|
| 874 |
|
| 875 |
query = attn.to_q(hidden_states)
|
| 876 |
query = attn.q_norm(query)
|
| 877 |
|
| 878 |
if encoder_hidden_states is not None:
|
| 879 |
if attn.norm_cross:
|
| 880 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
|
|
|
|
|
|
| 881 |
key = attn.to_k(encoder_hidden_states)
|
| 882 |
key = attn.k_norm(key)
|
| 883 |
else: # if no context provided do self-attention
|
|
@@ -901,10 +997,14 @@ class AttnProcessor2_0:
|
|
| 901 |
|
| 902 |
if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention'
|
| 903 |
q_segment_indexes = None
|
| 904 |
-
if
|
|
|
|
|
|
|
| 905 |
# attention_mask = torch.squeeze(attention_mask).to(torch.float32)
|
| 906 |
attention_mask = attention_mask.to(torch.float32)
|
| 907 |
-
q_segment_indexes = torch.ones(
|
|
|
|
|
|
|
| 908 |
assert (
|
| 909 |
attention_mask.shape[1] == key.shape[2]
|
| 910 |
), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]"
|
|
@@ -927,10 +1027,17 @@ class AttnProcessor2_0:
|
|
| 927 |
)
|
| 928 |
else:
|
| 929 |
hidden_states = F.scaled_dot_product_attention(
|
| 930 |
-
query,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 931 |
)
|
| 932 |
|
| 933 |
-
hidden_states = hidden_states.transpose(1, 2).reshape(
|
|
|
|
|
|
|
| 934 |
hidden_states = hidden_states.to(query.dtype)
|
| 935 |
|
| 936 |
# linear proj
|
|
@@ -939,7 +1046,9 @@ class AttnProcessor2_0:
|
|
| 939 |
hidden_states = attn.to_out[1](hidden_states)
|
| 940 |
|
| 941 |
if input_ndim == 4:
|
| 942 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
|
|
|
|
|
|
| 943 |
|
| 944 |
if attn.residual_connection:
|
| 945 |
hidden_states = hidden_states + residual
|
|
@@ -977,22 +1086,32 @@ class AttnProcessor:
|
|
| 977 |
|
| 978 |
if input_ndim == 4:
|
| 979 |
batch_size, channel, height, width = hidden_states.shape
|
| 980 |
-
hidden_states = hidden_states.view(
|
|
|
|
|
|
|
| 981 |
|
| 982 |
batch_size, sequence_length, _ = (
|
| 983 |
-
hidden_states.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 984 |
)
|
| 985 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 986 |
|
| 987 |
if attn.group_norm is not None:
|
| 988 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
|
|
|
|
|
|
| 989 |
|
| 990 |
query = attn.to_q(hidden_states)
|
| 991 |
|
| 992 |
if encoder_hidden_states is None:
|
| 993 |
encoder_hidden_states = hidden_states
|
| 994 |
elif attn.norm_cross:
|
| 995 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
|
|
|
|
|
|
| 996 |
|
| 997 |
key = attn.to_k(encoder_hidden_states)
|
| 998 |
value = attn.to_v(encoder_hidden_states)
|
|
@@ -1014,7 +1133,9 @@ class AttnProcessor:
|
|
| 1014 |
hidden_states = attn.to_out[1](hidden_states)
|
| 1015 |
|
| 1016 |
if input_ndim == 4:
|
| 1017 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
|
|
|
|
|
|
| 1018 |
|
| 1019 |
if attn.residual_connection:
|
| 1020 |
hidden_states = hidden_states + residual
|
|
|
|
| 106 |
assert standardization_norm in ["layer_norm", "rms_norm"]
|
| 107 |
assert adaptive_norm in ["single_scale_shift", "single_scale", "none"]
|
| 108 |
|
| 109 |
+
make_norm_layer = (
|
| 110 |
+
nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm
|
| 111 |
+
)
|
| 112 |
|
| 113 |
# Define 3 blocks. Each block has its own normalization layer.
|
| 114 |
# 1. Self-Attn
|
| 115 |
+
self.norm1 = make_norm_layer(
|
| 116 |
+
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
|
| 117 |
+
)
|
| 118 |
|
| 119 |
self.attn1 = Attention(
|
| 120 |
query_dim=dim,
|
|
|
|
| 134 |
if cross_attention_dim is not None or double_self_attention:
|
| 135 |
self.attn2 = Attention(
|
| 136 |
query_dim=dim,
|
| 137 |
+
cross_attention_dim=(
|
| 138 |
+
cross_attention_dim if not double_self_attention else None
|
| 139 |
+
),
|
| 140 |
heads=num_attention_heads,
|
| 141 |
dim_head=attention_head_dim,
|
| 142 |
dropout=dropout,
|
|
|
|
| 149 |
) # is self-attn if encoder_hidden_states is none
|
| 150 |
|
| 151 |
if adaptive_norm == "none":
|
| 152 |
+
self.attn2_norm = make_norm_layer(
|
| 153 |
+
dim, norm_eps, norm_elementwise_affine
|
| 154 |
+
)
|
| 155 |
else:
|
| 156 |
self.attn2 = None
|
| 157 |
self.attn2_norm = None
|
|
|
|
| 171 |
# 5. Scale-shift for PixArt-Alpha.
|
| 172 |
if adaptive_norm != "none":
|
| 173 |
num_ada_params = 4 if adaptive_norm == "single_scale" else 6
|
| 174 |
+
self.scale_shift_table = nn.Parameter(
|
| 175 |
+
torch.randn(num_ada_params, dim) / dim**0.5
|
| 176 |
+
)
|
| 177 |
|
| 178 |
# let chunk size default to None
|
| 179 |
self._chunk_size = None
|
|
|
|
| 208 |
) -> torch.FloatTensor:
|
| 209 |
if cross_attention_kwargs is not None:
|
| 210 |
if cross_attention_kwargs.get("scale", None) is not None:
|
| 211 |
+
logger.warning(
|
| 212 |
+
"Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored."
|
| 213 |
+
)
|
| 214 |
|
| 215 |
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 216 |
# 0. Self-Attention
|
|
|
|
| 226 |
batch_size, timestep.shape[1], num_ada_params, -1
|
| 227 |
)
|
| 228 |
if self.adaptive_norm == "single_scale_shift":
|
| 229 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
| 230 |
+
ada_values.unbind(dim=2)
|
| 231 |
+
)
|
| 232 |
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
| 233 |
else:
|
| 234 |
scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
|
|
|
| 238 |
else:
|
| 239 |
raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
|
| 240 |
|
| 241 |
+
norm_hidden_states = norm_hidden_states.squeeze(
|
| 242 |
+
1
|
| 243 |
+
) # TODO: Check if this is needed
|
| 244 |
|
| 245 |
# 1. Prepare GLIGEN inputs
|
| 246 |
+
cross_attention_kwargs = (
|
| 247 |
+
cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
| 248 |
+
)
|
| 249 |
|
| 250 |
attn_output = self.attn1(
|
| 251 |
norm_hidden_states,
|
| 252 |
freqs_cis=freqs_cis,
|
| 253 |
+
encoder_hidden_states=(
|
| 254 |
+
encoder_hidden_states if self.only_cross_attention else None
|
| 255 |
+
),
|
| 256 |
attention_mask=attention_mask,
|
| 257 |
**cross_attention_kwargs,
|
| 258 |
)
|
|
|
|
| 291 |
|
| 292 |
if self._chunk_size is not None:
|
| 293 |
# "feed_forward_chunk_size" can be used to save memory
|
| 294 |
+
ff_output = _chunked_feed_forward(
|
| 295 |
+
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size
|
| 296 |
+
)
|
| 297 |
else:
|
| 298 |
ff_output = self.ff(norm_hidden_states)
|
| 299 |
if gate_mlp is not None:
|
|
|
|
| 393 |
self.query_dim = query_dim
|
| 394 |
self.use_bias = bias
|
| 395 |
self.is_cross_attention = cross_attention_dim is not None
|
| 396 |
+
self.cross_attention_dim = (
|
| 397 |
+
cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 398 |
+
)
|
| 399 |
self.upcast_attention = upcast_attention
|
| 400 |
self.upcast_softmax = upcast_softmax
|
| 401 |
self.rescale_output_factor = rescale_output_factor
|
|
|
|
| 440 |
)
|
| 441 |
|
| 442 |
if norm_num_groups is not None:
|
| 443 |
+
self.group_norm = nn.GroupNorm(
|
| 444 |
+
num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
|
| 445 |
+
)
|
| 446 |
else:
|
| 447 |
self.group_norm = None
|
| 448 |
|
| 449 |
if spatial_norm_dim is not None:
|
| 450 |
+
self.spatial_norm = SpatialNorm(
|
| 451 |
+
f_channels=query_dim, zq_channels=spatial_norm_dim
|
| 452 |
+
)
|
| 453 |
else:
|
| 454 |
self.spatial_norm = None
|
| 455 |
|
|
|
|
| 469 |
norm_cross_num_channels = self.cross_attention_dim
|
| 470 |
|
| 471 |
self.norm_cross = nn.GroupNorm(
|
| 472 |
+
num_channels=norm_cross_num_channels,
|
| 473 |
+
num_groups=cross_attention_norm_num_groups,
|
| 474 |
+
eps=1e-5,
|
| 475 |
+
affine=True,
|
| 476 |
)
|
| 477 |
else:
|
| 478 |
raise ValueError(
|
|
|
|
| 530 |
and isinstance(self.processor, torch.nn.Module)
|
| 531 |
and not isinstance(processor, torch.nn.Module)
|
| 532 |
):
|
| 533 |
+
logger.info(
|
| 534 |
+
f"You are removing possibly trained weights of {self.processor} with {processor}"
|
| 535 |
+
)
|
| 536 |
self._modules.pop("processor")
|
| 537 |
|
| 538 |
self.processor = processor
|
| 539 |
|
| 540 |
+
def get_processor(
|
| 541 |
+
self, return_deprecated_lora: bool = False
|
| 542 |
+
) -> "AttentionProcessor": # noqa: F821
|
| 543 |
r"""
|
| 544 |
Get the attention processor in use.
|
| 545 |
|
|
|
|
| 577 |
|
| 578 |
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
|
| 579 |
non_lora_processor_cls_name = self.processor.__class__.__name__
|
| 580 |
+
lora_processor_cls = getattr(
|
| 581 |
+
import_module(__name__), "LoRA" + non_lora_processor_cls_name
|
| 582 |
+
)
|
| 583 |
|
| 584 |
hidden_size = self.inner_dim
|
| 585 |
|
| 586 |
# now create a LoRA attention processor from the LoRA layers
|
| 587 |
+
if lora_processor_cls in [
|
| 588 |
+
LoRAAttnProcessor,
|
| 589 |
+
LoRAAttnProcessor2_0,
|
| 590 |
+
LoRAXFormersAttnProcessor,
|
| 591 |
+
]:
|
| 592 |
kwargs = {
|
| 593 |
"cross_attention_dim": self.cross_attention_dim,
|
| 594 |
"rank": self.to_q.lora_layer.rank,
|
|
|
|
| 610 |
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
| 611 |
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
| 612 |
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
| 613 |
+
lora_processor.to_out_lora.load_state_dict(
|
| 614 |
+
self.to_out[0].lora_layer.state_dict()
|
| 615 |
+
)
|
| 616 |
elif lora_processor_cls == LoRAAttnAddedKVProcessor:
|
| 617 |
lora_processor = lora_processor_cls(
|
| 618 |
hidden_size,
|
|
|
|
| 623 |
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
| 624 |
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
| 625 |
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
| 626 |
+
lora_processor.to_out_lora.load_state_dict(
|
| 627 |
+
self.to_out[0].lora_layer.state_dict()
|
| 628 |
+
)
|
| 629 |
|
| 630 |
# only save if used
|
| 631 |
if self.add_k_proj.lora_layer is not None:
|
| 632 |
+
lora_processor.add_k_proj_lora.load_state_dict(
|
| 633 |
+
self.add_k_proj.lora_layer.state_dict()
|
| 634 |
+
)
|
| 635 |
+
lora_processor.add_v_proj_lora.load_state_dict(
|
| 636 |
+
self.add_v_proj.lora_layer.state_dict()
|
| 637 |
+
)
|
| 638 |
else:
|
| 639 |
lora_processor.add_k_proj_lora = None
|
| 640 |
lora_processor.add_v_proj_lora = None
|
|
|
|
| 671 |
# here we simply pass along all tensors to the selected processor class
|
| 672 |
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
| 673 |
|
| 674 |
+
attn_parameters = set(
|
| 675 |
+
inspect.signature(self.processor.__call__).parameters.keys()
|
| 676 |
+
)
|
| 677 |
+
unused_kwargs = [
|
| 678 |
+
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters
|
| 679 |
+
]
|
| 680 |
if len(unused_kwargs) > 0:
|
| 681 |
logger.warning(
|
| 682 |
f"cross_attention_kwargs {unused_kwargs} are not expected by"
|
| 683 |
f" {self.processor.__class__.__name__} and will be ignored."
|
| 684 |
)
|
| 685 |
+
cross_attention_kwargs = {
|
| 686 |
+
k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters
|
| 687 |
+
}
|
| 688 |
|
| 689 |
return self.processor(
|
| 690 |
self,
|
|
|
|
| 709 |
head_size = self.heads
|
| 710 |
batch_size, seq_len, dim = tensor.shape
|
| 711 |
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
| 712 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(
|
| 713 |
+
batch_size // head_size, seq_len, dim * head_size
|
| 714 |
+
)
|
| 715 |
return tensor
|
| 716 |
|
| 717 |
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
|
|
|
| 734 |
extra_dim = 1
|
| 735 |
else:
|
| 736 |
batch_size, extra_dim, seq_len, dim = tensor.shape
|
| 737 |
+
tensor = tensor.reshape(
|
| 738 |
+
batch_size, seq_len * extra_dim, head_size, dim // head_size
|
| 739 |
+
)
|
| 740 |
tensor = tensor.permute(0, 2, 1, 3)
|
| 741 |
|
| 742 |
if out_dim == 3:
|
| 743 |
+
tensor = tensor.reshape(
|
| 744 |
+
batch_size * head_size, seq_len * extra_dim, dim // head_size
|
| 745 |
+
)
|
| 746 |
|
| 747 |
return tensor
|
| 748 |
|
| 749 |
def get_attention_scores(
|
| 750 |
+
self,
|
| 751 |
+
query: torch.Tensor,
|
| 752 |
+
key: torch.Tensor,
|
| 753 |
+
attention_mask: torch.Tensor = None,
|
| 754 |
) -> torch.Tensor:
|
| 755 |
r"""
|
| 756 |
Compute the attention scores.
|
|
|
|
| 770 |
|
| 771 |
if attention_mask is None:
|
| 772 |
baddbmm_input = torch.empty(
|
| 773 |
+
query.shape[0],
|
| 774 |
+
query.shape[1],
|
| 775 |
+
key.shape[1],
|
| 776 |
+
dtype=query.dtype,
|
| 777 |
+
device=query.device,
|
| 778 |
)
|
| 779 |
beta = 0
|
| 780 |
else:
|
|
|
|
| 801 |
return attention_probs
|
| 802 |
|
| 803 |
def prepare_attention_mask(
|
| 804 |
+
self,
|
| 805 |
+
attention_mask: torch.Tensor,
|
| 806 |
+
target_length: int,
|
| 807 |
+
batch_size: int,
|
| 808 |
+
out_dim: int = 3,
|
| 809 |
) -> torch.Tensor:
|
| 810 |
r"""
|
| 811 |
Prepare the attention mask for the attention computation.
|
|
|
|
| 832 |
if attention_mask.device.type == "mps":
|
| 833 |
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
| 834 |
# Instead, we can manually construct the padding tensor.
|
| 835 |
+
padding_shape = (
|
| 836 |
+
attention_mask.shape[0],
|
| 837 |
+
attention_mask.shape[1],
|
| 838 |
+
target_length,
|
| 839 |
+
)
|
| 840 |
+
padding = torch.zeros(
|
| 841 |
+
padding_shape,
|
| 842 |
+
dtype=attention_mask.dtype,
|
| 843 |
+
device=attention_mask.device,
|
| 844 |
+
)
|
| 845 |
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
| 846 |
else:
|
| 847 |
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
|
|
|
| 859 |
|
| 860 |
return attention_mask
|
| 861 |
|
| 862 |
+
def norm_encoder_hidden_states(
|
| 863 |
+
self, encoder_hidden_states: torch.Tensor
|
| 864 |
+
) -> torch.Tensor:
|
| 865 |
r"""
|
| 866 |
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
|
| 867 |
`Attention` class.
|
|
|
|
| 872 |
Returns:
|
| 873 |
`torch.Tensor`: The normalized encoder hidden states.
|
| 874 |
"""
|
| 875 |
+
assert (
|
| 876 |
+
self.norm_cross is not None
|
| 877 |
+
), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
| 878 |
|
| 879 |
if isinstance(self.norm_cross, nn.LayerNorm):
|
| 880 |
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
|
|
|
| 941 |
|
| 942 |
if input_ndim == 4:
|
| 943 |
batch_size, channel, height, width = hidden_states.shape
|
| 944 |
+
hidden_states = hidden_states.view(
|
| 945 |
+
batch_size, channel, height * width
|
| 946 |
+
).transpose(1, 2)
|
| 947 |
|
| 948 |
batch_size, sequence_length, _ = (
|
| 949 |
+
hidden_states.shape
|
| 950 |
+
if encoder_hidden_states is None
|
| 951 |
+
else encoder_hidden_states.shape
|
| 952 |
)
|
| 953 |
|
| 954 |
if (attention_mask is not None) and (not attn.use_tpu_flash_attention):
|
| 955 |
+
attention_mask = attn.prepare_attention_mask(
|
| 956 |
+
attention_mask, sequence_length, batch_size
|
| 957 |
+
)
|
| 958 |
# scaled_dot_product_attention expects attention_mask shape to be
|
| 959 |
# (batch, heads, source_length, target_length)
|
| 960 |
+
attention_mask = attention_mask.view(
|
| 961 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
| 962 |
+
)
|
| 963 |
|
| 964 |
if attn.group_norm is not None:
|
| 965 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
| 966 |
+
1, 2
|
| 967 |
+
)
|
| 968 |
|
| 969 |
query = attn.to_q(hidden_states)
|
| 970 |
query = attn.q_norm(query)
|
| 971 |
|
| 972 |
if encoder_hidden_states is not None:
|
| 973 |
if attn.norm_cross:
|
| 974 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
| 975 |
+
encoder_hidden_states
|
| 976 |
+
)
|
| 977 |
key = attn.to_k(encoder_hidden_states)
|
| 978 |
key = attn.k_norm(key)
|
| 979 |
else: # if no context provided do self-attention
|
|
|
|
| 997 |
|
| 998 |
if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention'
|
| 999 |
q_segment_indexes = None
|
| 1000 |
+
if (
|
| 1001 |
+
attention_mask is not None
|
| 1002 |
+
): # if mask is required need to tune both segmenIds fields
|
| 1003 |
# attention_mask = torch.squeeze(attention_mask).to(torch.float32)
|
| 1004 |
attention_mask = attention_mask.to(torch.float32)
|
| 1005 |
+
q_segment_indexes = torch.ones(
|
| 1006 |
+
batch_size, query.shape[2], device=query.device, dtype=torch.float32
|
| 1007 |
+
)
|
| 1008 |
assert (
|
| 1009 |
attention_mask.shape[1] == key.shape[2]
|
| 1010 |
), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]"
|
|
|
|
| 1027 |
)
|
| 1028 |
else:
|
| 1029 |
hidden_states = F.scaled_dot_product_attention(
|
| 1030 |
+
query,
|
| 1031 |
+
key,
|
| 1032 |
+
value,
|
| 1033 |
+
attn_mask=attention_mask,
|
| 1034 |
+
dropout_p=0.0,
|
| 1035 |
+
is_causal=False,
|
| 1036 |
)
|
| 1037 |
|
| 1038 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
| 1039 |
+
batch_size, -1, attn.heads * head_dim
|
| 1040 |
+
)
|
| 1041 |
hidden_states = hidden_states.to(query.dtype)
|
| 1042 |
|
| 1043 |
# linear proj
|
|
|
|
| 1046 |
hidden_states = attn.to_out[1](hidden_states)
|
| 1047 |
|
| 1048 |
if input_ndim == 4:
|
| 1049 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 1050 |
+
batch_size, channel, height, width
|
| 1051 |
+
)
|
| 1052 |
|
| 1053 |
if attn.residual_connection:
|
| 1054 |
hidden_states = hidden_states + residual
|
|
|
|
| 1086 |
|
| 1087 |
if input_ndim == 4:
|
| 1088 |
batch_size, channel, height, width = hidden_states.shape
|
| 1089 |
+
hidden_states = hidden_states.view(
|
| 1090 |
+
batch_size, channel, height * width
|
| 1091 |
+
).transpose(1, 2)
|
| 1092 |
|
| 1093 |
batch_size, sequence_length, _ = (
|
| 1094 |
+
hidden_states.shape
|
| 1095 |
+
if encoder_hidden_states is None
|
| 1096 |
+
else encoder_hidden_states.shape
|
| 1097 |
+
)
|
| 1098 |
+
attention_mask = attn.prepare_attention_mask(
|
| 1099 |
+
attention_mask, sequence_length, batch_size
|
| 1100 |
)
|
|
|
|
| 1101 |
|
| 1102 |
if attn.group_norm is not None:
|
| 1103 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
| 1104 |
+
1, 2
|
| 1105 |
+
)
|
| 1106 |
|
| 1107 |
query = attn.to_q(hidden_states)
|
| 1108 |
|
| 1109 |
if encoder_hidden_states is None:
|
| 1110 |
encoder_hidden_states = hidden_states
|
| 1111 |
elif attn.norm_cross:
|
| 1112 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
| 1113 |
+
encoder_hidden_states
|
| 1114 |
+
)
|
| 1115 |
|
| 1116 |
key = attn.to_k(encoder_hidden_states)
|
| 1117 |
value = attn.to_v(encoder_hidden_states)
|
|
|
|
| 1133 |
hidden_states = attn.to_out[1](hidden_states)
|
| 1134 |
|
| 1135 |
if input_ndim == 4:
|
| 1136 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 1137 |
+
batch_size, channel, height, width
|
| 1138 |
+
)
|
| 1139 |
|
| 1140 |
if attn.residual_connection:
|
| 1141 |
hidden_states = hidden_states + residual
|
xora/models/transformers/embeddings.py
CHANGED
|
@@ -26,7 +26,9 @@ def get_timestep_embedding(
|
|
| 26 |
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 27 |
|
| 28 |
half_dim = embedding_dim // 2
|
| 29 |
-
exponent = -math.log(max_period) * torch.arange(
|
|
|
|
|
|
|
| 30 |
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 31 |
|
| 32 |
emb = torch.exp(exponent)
|
|
@@ -113,7 +115,9 @@ class SinusoidalPositionalEmbedding(nn.Module):
|
|
| 113 |
def __init__(self, embed_dim: int, max_seq_length: int = 32):
|
| 114 |
super().__init__()
|
| 115 |
position = torch.arange(max_seq_length).unsqueeze(1)
|
| 116 |
-
div_term = torch.exp(
|
|
|
|
|
|
|
| 117 |
pe = torch.zeros(1, max_seq_length, embed_dim)
|
| 118 |
pe[0, :, 0::2] = torch.sin(position * div_term)
|
| 119 |
pe[0, :, 1::2] = torch.cos(position * div_term)
|
|
|
|
| 26 |
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 27 |
|
| 28 |
half_dim = embedding_dim // 2
|
| 29 |
+
exponent = -math.log(max_period) * torch.arange(
|
| 30 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
| 31 |
+
)
|
| 32 |
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 33 |
|
| 34 |
emb = torch.exp(exponent)
|
|
|
|
| 115 |
def __init__(self, embed_dim: int, max_seq_length: int = 32):
|
| 116 |
super().__init__()
|
| 117 |
position = torch.arange(max_seq_length).unsqueeze(1)
|
| 118 |
+
div_term = torch.exp(
|
| 119 |
+
torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
|
| 120 |
+
)
|
| 121 |
pe = torch.zeros(1, max_seq_length, embed_dim)
|
| 122 |
pe[0, :, 0::2] = torch.sin(position * div_term)
|
| 123 |
pe[0, :, 1::2] = torch.cos(position * div_term)
|
xora/models/transformers/symmetric_patchifier.py
CHANGED
|
@@ -15,12 +15,19 @@ class Patchifier(ConfigMixin, ABC):
|
|
| 15 |
self._patch_size = (1, patch_size, patch_size)
|
| 16 |
|
| 17 |
@abstractmethod
|
| 18 |
-
def patchify(
|
|
|
|
|
|
|
| 19 |
pass
|
| 20 |
|
| 21 |
@abstractmethod
|
| 22 |
def unpatchify(
|
| 23 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
) -> Tuple[Tensor, Tensor]:
|
| 25 |
pass
|
| 26 |
|
|
@@ -28,7 +35,9 @@ class Patchifier(ConfigMixin, ABC):
|
|
| 28 |
def patch_size(self):
|
| 29 |
return self._patch_size
|
| 30 |
|
| 31 |
-
def get_grid(
|
|
|
|
|
|
|
| 32 |
f = orig_num_frames // self._patch_size[0]
|
| 33 |
h = orig_height // self._patch_size[1]
|
| 34 |
w = orig_width // self._patch_size[2]
|
|
@@ -64,6 +73,7 @@ def pixart_alpha_patchify(
|
|
| 64 |
)
|
| 65 |
return latents
|
| 66 |
|
|
|
|
| 67 |
class SymmetricPatchifier(Patchifier):
|
| 68 |
def patchify(
|
| 69 |
self,
|
|
@@ -72,7 +82,12 @@ class SymmetricPatchifier(Patchifier):
|
|
| 72 |
return pixart_alpha_patchify(latents, self._patch_size)
|
| 73 |
|
| 74 |
def unpatchify(
|
| 75 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
) -> Tuple[Tensor, Tensor]:
|
| 77 |
output_height = output_height // self._patch_size[1]
|
| 78 |
output_width = output_width // self._patch_size[2]
|
|
|
|
| 15 |
self._patch_size = (1, patch_size, patch_size)
|
| 16 |
|
| 17 |
@abstractmethod
|
| 18 |
+
def patchify(
|
| 19 |
+
self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
|
| 20 |
+
) -> Tuple[Tensor, Tensor]:
|
| 21 |
pass
|
| 22 |
|
| 23 |
@abstractmethod
|
| 24 |
def unpatchify(
|
| 25 |
+
self,
|
| 26 |
+
latents: Tensor,
|
| 27 |
+
output_height: int,
|
| 28 |
+
output_width: int,
|
| 29 |
+
output_num_frames: int,
|
| 30 |
+
out_channels: int,
|
| 31 |
) -> Tuple[Tensor, Tensor]:
|
| 32 |
pass
|
| 33 |
|
|
|
|
| 35 |
def patch_size(self):
|
| 36 |
return self._patch_size
|
| 37 |
|
| 38 |
+
def get_grid(
|
| 39 |
+
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
|
| 40 |
+
):
|
| 41 |
f = orig_num_frames // self._patch_size[0]
|
| 42 |
h = orig_height // self._patch_size[1]
|
| 43 |
w = orig_width // self._patch_size[2]
|
|
|
|
| 73 |
)
|
| 74 |
return latents
|
| 75 |
|
| 76 |
+
|
| 77 |
class SymmetricPatchifier(Patchifier):
|
| 78 |
def patchify(
|
| 79 |
self,
|
|
|
|
| 82 |
return pixart_alpha_patchify(latents, self._patch_size)
|
| 83 |
|
| 84 |
def unpatchify(
|
| 85 |
+
self,
|
| 86 |
+
latents: Tensor,
|
| 87 |
+
output_height: int,
|
| 88 |
+
output_width: int,
|
| 89 |
+
output_num_frames: int,
|
| 90 |
+
out_channels: int,
|
| 91 |
) -> Tuple[Tensor, Tensor]:
|
| 92 |
output_height = output_height // self._patch_size[1]
|
| 93 |
output_width = output_width // self._patch_size[2]
|
xora/models/transformers/transformer3d.py
CHANGED
|
@@ -17,6 +17,7 @@ from xora.models.transformers.embeddings import get_3d_sincos_pos_embed
|
|
| 17 |
|
| 18 |
logger = logging.get_logger(__name__)
|
| 19 |
|
|
|
|
| 20 |
@dataclass
|
| 21 |
class Transformer3DModelOutput(BaseOutput):
|
| 22 |
"""
|
|
@@ -68,7 +69,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 68 |
timestep_scale_multiplier: Optional[float] = None,
|
| 69 |
):
|
| 70 |
super().__init__()
|
| 71 |
-
self.use_tpu_flash_attention =
|
|
|
|
|
|
|
| 72 |
self.use_linear_projection = use_linear_projection
|
| 73 |
self.num_attention_heads = num_attention_heads
|
| 74 |
self.attention_head_dim = attention_head_dim
|
|
@@ -86,7 +89,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 86 |
self.timestep_scale_multiplier = timestep_scale_multiplier
|
| 87 |
|
| 88 |
if self.positional_embedding_type == "absolute":
|
| 89 |
-
embed_dim_3d =
|
|
|
|
|
|
|
| 90 |
if self.project_to_2d_pos:
|
| 91 |
self.to_2d_proj = torch.nn.Linear(embed_dim_3d, inner_dim, bias=False)
|
| 92 |
self._init_to_2d_proj_weights(self.to_2d_proj)
|
|
@@ -131,18 +136,24 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 131 |
# 4. Define output layers
|
| 132 |
self.out_channels = in_channels if out_channels is None else out_channels
|
| 133 |
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
| 134 |
-
self.scale_shift_table = nn.Parameter(
|
|
|
|
|
|
|
| 135 |
self.proj_out = nn.Linear(inner_dim, self.out_channels)
|
| 136 |
|
| 137 |
# 5. PixArt-Alpha blocks.
|
| 138 |
-
self.adaln_single = AdaLayerNormSingle(
|
|
|
|
|
|
|
| 139 |
if adaptive_norm == "single_scale":
|
| 140 |
# Use 4 channels instead of the 6 for the PixArt-Alpha scale + shift ada norm.
|
| 141 |
self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
|
| 142 |
|
| 143 |
self.caption_projection = None
|
| 144 |
if caption_channels is not None:
|
| 145 |
-
self.caption_projection = PixArtAlphaTextProjection(
|
|
|
|
|
|
|
| 146 |
|
| 147 |
self.gradient_checkpointing = False
|
| 148 |
|
|
@@ -169,16 +180,32 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 169 |
self.apply(_basic_init)
|
| 170 |
|
| 171 |
# Initialize timestep embedding MLP:
|
| 172 |
-
nn.init.normal_(
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
nn.init.normal_(self.adaln_single.linear.weight, std=embedding_std)
|
| 175 |
|
| 176 |
if hasattr(self.adaln_single.emb, "resolution_embedder"):
|
| 177 |
-
nn.init.normal_(
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
if hasattr(self.adaln_single.emb, "aspect_ratio_embedder"):
|
| 180 |
-
nn.init.normal_(
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
# Initialize caption embedding MLP:
|
| 184 |
nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
|
|
@@ -220,7 +247,11 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 220 |
|
| 221 |
def get_fractional_positions(self, indices_grid):
|
| 222 |
fractional_positions = torch.stack(
|
| 223 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
)
|
| 225 |
return fractional_positions
|
| 226 |
|
|
@@ -236,7 +267,13 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 236 |
device = fractional_positions.device
|
| 237 |
if spacing == "exp":
|
| 238 |
indices = theta ** (
|
| 239 |
-
torch.linspace(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
)
|
| 241 |
indices = indices.to(dtype=dtype)
|
| 242 |
elif spacing == "exp_2":
|
|
@@ -245,14 +282,24 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 245 |
elif spacing == "linear":
|
| 246 |
indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
|
| 247 |
elif spacing == "sqrt":
|
| 248 |
-
indices = torch.linspace(
|
|
|
|
|
|
|
| 249 |
|
| 250 |
indices = indices * math.pi / 2
|
| 251 |
|
| 252 |
if spacing == "exp_2":
|
| 253 |
-
freqs = (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
else:
|
| 255 |
-
freqs = (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
| 258 |
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
|
@@ -336,7 +383,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 336 |
|
| 337 |
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
| 338 |
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
| 339 |
-
encoder_attention_mask = (
|
|
|
|
|
|
|
| 340 |
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 341 |
|
| 342 |
# 1. Input
|
|
@@ -346,7 +395,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 346 |
timestep = self.timestep_scale_multiplier * timestep
|
| 347 |
|
| 348 |
if self.positional_embedding_type == "absolute":
|
| 349 |
-
pos_embed_3d = self.get_absolute_pos_embed(indices_grid).to(
|
|
|
|
|
|
|
| 350 |
if self.project_to_2d_pos:
|
| 351 |
pos_embed = self.to_2d_proj(pos_embed_3d)
|
| 352 |
hidden_states = (hidden_states + pos_embed).to(hidden_states.dtype)
|
|
@@ -363,13 +414,17 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 363 |
)
|
| 364 |
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
| 365 |
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
| 366 |
-
embedded_timestep = embedded_timestep.view(
|
|
|
|
|
|
|
| 367 |
|
| 368 |
# 2. Blocks
|
| 369 |
if self.caption_projection is not None:
|
| 370 |
batch_size = hidden_states.shape[0]
|
| 371 |
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
| 372 |
-
encoder_hidden_states = encoder_hidden_states.view(
|
|
|
|
|
|
|
| 373 |
|
| 374 |
for block in self.transformer_blocks:
|
| 375 |
if self.training and self.gradient_checkpointing:
|
|
@@ -383,7 +438,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 383 |
|
| 384 |
return custom_forward
|
| 385 |
|
| 386 |
-
ckpt_kwargs: Dict[str, Any] =
|
|
|
|
|
|
|
| 387 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 388 |
create_custom_forward(block),
|
| 389 |
hidden_states,
|
|
@@ -409,7 +466,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 409 |
)
|
| 410 |
|
| 411 |
# 3. Output
|
| 412 |
-
scale_shift_values =
|
|
|
|
|
|
|
| 413 |
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
| 414 |
hidden_states = self.norm_out(hidden_states)
|
| 415 |
# Modulation
|
|
@@ -422,7 +481,11 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 422 |
|
| 423 |
def get_absolute_pos_embed(self, grid):
|
| 424 |
grid_np = grid[0].cpu().numpy()
|
| 425 |
-
embed_dim_3d =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
pos_embed = get_3d_sincos_pos_embed( # (f h w)
|
| 427 |
embed_dim_3d,
|
| 428 |
grid_np,
|
|
|
|
| 17 |
|
| 18 |
logger = logging.get_logger(__name__)
|
| 19 |
|
| 20 |
+
|
| 21 |
@dataclass
|
| 22 |
class Transformer3DModelOutput(BaseOutput):
|
| 23 |
"""
|
|
|
|
| 69 |
timestep_scale_multiplier: Optional[float] = None,
|
| 70 |
):
|
| 71 |
super().__init__()
|
| 72 |
+
self.use_tpu_flash_attention = (
|
| 73 |
+
use_tpu_flash_attention # FIXME: push config down to the attention modules
|
| 74 |
+
)
|
| 75 |
self.use_linear_projection = use_linear_projection
|
| 76 |
self.num_attention_heads = num_attention_heads
|
| 77 |
self.attention_head_dim = attention_head_dim
|
|
|
|
| 89 |
self.timestep_scale_multiplier = timestep_scale_multiplier
|
| 90 |
|
| 91 |
if self.positional_embedding_type == "absolute":
|
| 92 |
+
embed_dim_3d = (
|
| 93 |
+
math.ceil((inner_dim / 2) * 3) if project_to_2d_pos else inner_dim
|
| 94 |
+
)
|
| 95 |
if self.project_to_2d_pos:
|
| 96 |
self.to_2d_proj = torch.nn.Linear(embed_dim_3d, inner_dim, bias=False)
|
| 97 |
self._init_to_2d_proj_weights(self.to_2d_proj)
|
|
|
|
| 136 |
# 4. Define output layers
|
| 137 |
self.out_channels = in_channels if out_channels is None else out_channels
|
| 138 |
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
| 139 |
+
self.scale_shift_table = nn.Parameter(
|
| 140 |
+
torch.randn(2, inner_dim) / inner_dim**0.5
|
| 141 |
+
)
|
| 142 |
self.proj_out = nn.Linear(inner_dim, self.out_channels)
|
| 143 |
|
| 144 |
# 5. PixArt-Alpha blocks.
|
| 145 |
+
self.adaln_single = AdaLayerNormSingle(
|
| 146 |
+
inner_dim, use_additional_conditions=False
|
| 147 |
+
)
|
| 148 |
if adaptive_norm == "single_scale":
|
| 149 |
# Use 4 channels instead of the 6 for the PixArt-Alpha scale + shift ada norm.
|
| 150 |
self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
|
| 151 |
|
| 152 |
self.caption_projection = None
|
| 153 |
if caption_channels is not None:
|
| 154 |
+
self.caption_projection = PixArtAlphaTextProjection(
|
| 155 |
+
in_features=caption_channels, hidden_size=inner_dim
|
| 156 |
+
)
|
| 157 |
|
| 158 |
self.gradient_checkpointing = False
|
| 159 |
|
|
|
|
| 180 |
self.apply(_basic_init)
|
| 181 |
|
| 182 |
# Initialize timestep embedding MLP:
|
| 183 |
+
nn.init.normal_(
|
| 184 |
+
self.adaln_single.emb.timestep_embedder.linear_1.weight, std=embedding_std
|
| 185 |
+
)
|
| 186 |
+
nn.init.normal_(
|
| 187 |
+
self.adaln_single.emb.timestep_embedder.linear_2.weight, std=embedding_std
|
| 188 |
+
)
|
| 189 |
nn.init.normal_(self.adaln_single.linear.weight, std=embedding_std)
|
| 190 |
|
| 191 |
if hasattr(self.adaln_single.emb, "resolution_embedder"):
|
| 192 |
+
nn.init.normal_(
|
| 193 |
+
self.adaln_single.emb.resolution_embedder.linear_1.weight,
|
| 194 |
+
std=embedding_std,
|
| 195 |
+
)
|
| 196 |
+
nn.init.normal_(
|
| 197 |
+
self.adaln_single.emb.resolution_embedder.linear_2.weight,
|
| 198 |
+
std=embedding_std,
|
| 199 |
+
)
|
| 200 |
if hasattr(self.adaln_single.emb, "aspect_ratio_embedder"):
|
| 201 |
+
nn.init.normal_(
|
| 202 |
+
self.adaln_single.emb.aspect_ratio_embedder.linear_1.weight,
|
| 203 |
+
std=embedding_std,
|
| 204 |
+
)
|
| 205 |
+
nn.init.normal_(
|
| 206 |
+
self.adaln_single.emb.aspect_ratio_embedder.linear_2.weight,
|
| 207 |
+
std=embedding_std,
|
| 208 |
+
)
|
| 209 |
|
| 210 |
# Initialize caption embedding MLP:
|
| 211 |
nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
|
|
|
|
| 247 |
|
| 248 |
def get_fractional_positions(self, indices_grid):
|
| 249 |
fractional_positions = torch.stack(
|
| 250 |
+
[
|
| 251 |
+
indices_grid[:, i] / self.positional_embedding_max_pos[i]
|
| 252 |
+
for i in range(3)
|
| 253 |
+
],
|
| 254 |
+
dim=-1,
|
| 255 |
)
|
| 256 |
return fractional_positions
|
| 257 |
|
|
|
|
| 267 |
device = fractional_positions.device
|
| 268 |
if spacing == "exp":
|
| 269 |
indices = theta ** (
|
| 270 |
+
torch.linspace(
|
| 271 |
+
math.log(start, theta),
|
| 272 |
+
math.log(end, theta),
|
| 273 |
+
dim // 6,
|
| 274 |
+
device=device,
|
| 275 |
+
dtype=dtype,
|
| 276 |
+
)
|
| 277 |
)
|
| 278 |
indices = indices.to(dtype=dtype)
|
| 279 |
elif spacing == "exp_2":
|
|
|
|
| 282 |
elif spacing == "linear":
|
| 283 |
indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
|
| 284 |
elif spacing == "sqrt":
|
| 285 |
+
indices = torch.linspace(
|
| 286 |
+
start**2, end**2, dim // 6, device=device, dtype=dtype
|
| 287 |
+
).sqrt()
|
| 288 |
|
| 289 |
indices = indices * math.pi / 2
|
| 290 |
|
| 291 |
if spacing == "exp_2":
|
| 292 |
+
freqs = (
|
| 293 |
+
(indices * fractional_positions.unsqueeze(-1))
|
| 294 |
+
.transpose(-1, -2)
|
| 295 |
+
.flatten(2)
|
| 296 |
+
)
|
| 297 |
else:
|
| 298 |
+
freqs = (
|
| 299 |
+
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
| 300 |
+
.transpose(-1, -2)
|
| 301 |
+
.flatten(2)
|
| 302 |
+
)
|
| 303 |
|
| 304 |
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
| 305 |
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
|
|
|
| 383 |
|
| 384 |
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
| 385 |
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
| 386 |
+
encoder_attention_mask = (
|
| 387 |
+
1 - encoder_attention_mask.to(hidden_states.dtype)
|
| 388 |
+
) * -10000.0
|
| 389 |
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 390 |
|
| 391 |
# 1. Input
|
|
|
|
| 395 |
timestep = self.timestep_scale_multiplier * timestep
|
| 396 |
|
| 397 |
if self.positional_embedding_type == "absolute":
|
| 398 |
+
pos_embed_3d = self.get_absolute_pos_embed(indices_grid).to(
|
| 399 |
+
hidden_states.device
|
| 400 |
+
)
|
| 401 |
if self.project_to_2d_pos:
|
| 402 |
pos_embed = self.to_2d_proj(pos_embed_3d)
|
| 403 |
hidden_states = (hidden_states + pos_embed).to(hidden_states.dtype)
|
|
|
|
| 414 |
)
|
| 415 |
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
| 416 |
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
| 417 |
+
embedded_timestep = embedded_timestep.view(
|
| 418 |
+
batch_size, -1, embedded_timestep.shape[-1]
|
| 419 |
+
)
|
| 420 |
|
| 421 |
# 2. Blocks
|
| 422 |
if self.caption_projection is not None:
|
| 423 |
batch_size = hidden_states.shape[0]
|
| 424 |
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
| 425 |
+
encoder_hidden_states = encoder_hidden_states.view(
|
| 426 |
+
batch_size, -1, hidden_states.shape[-1]
|
| 427 |
+
)
|
| 428 |
|
| 429 |
for block in self.transformer_blocks:
|
| 430 |
if self.training and self.gradient_checkpointing:
|
|
|
|
| 438 |
|
| 439 |
return custom_forward
|
| 440 |
|
| 441 |
+
ckpt_kwargs: Dict[str, Any] = (
|
| 442 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 443 |
+
)
|
| 444 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 445 |
create_custom_forward(block),
|
| 446 |
hidden_states,
|
|
|
|
| 466 |
)
|
| 467 |
|
| 468 |
# 3. Output
|
| 469 |
+
scale_shift_values = (
|
| 470 |
+
self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
|
| 471 |
+
)
|
| 472 |
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
| 473 |
hidden_states = self.norm_out(hidden_states)
|
| 474 |
# Modulation
|
|
|
|
| 481 |
|
| 482 |
def get_absolute_pos_embed(self, grid):
|
| 483 |
grid_np = grid[0].cpu().numpy()
|
| 484 |
+
embed_dim_3d = (
|
| 485 |
+
math.ceil((self.inner_dim / 2) * 3)
|
| 486 |
+
if self.project_to_2d_pos
|
| 487 |
+
else self.inner_dim
|
| 488 |
+
)
|
| 489 |
pos_embed = get_3d_sincos_pos_embed( # (f h w)
|
| 490 |
embed_dim_3d,
|
| 491 |
grid_np,
|
xora/pipelines/pipeline_video_pixart_alpha.py
CHANGED
|
@@ -5,12 +5,10 @@ import math
|
|
| 5 |
import re
|
| 6 |
import urllib.parse as ul
|
| 7 |
from typing import Callable, Dict, List, Optional, Tuple, Union
|
| 8 |
-
from abc import ABC, abstractmethod
|
| 9 |
|
| 10 |
|
| 11 |
import torch
|
| 12 |
import torch.nn.functional as F
|
| 13 |
-
from torch import Tensor
|
| 14 |
from diffusers.image_processor import VaeImageProcessor
|
| 15 |
from diffusers.models import AutoencoderKL
|
| 16 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
@@ -29,7 +27,11 @@ from transformers import T5EncoderModel, T5Tokenizer
|
|
| 29 |
|
| 30 |
from xora.models.transformers.transformer3d import Transformer3DModel
|
| 31 |
from xora.models.transformers.symmetric_patchifier import Patchifier
|
| 32 |
-
from xora.models.autoencoders.vae_encode import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 34 |
from xora.schedulers.rf import TimestepShifter
|
| 35 |
from xora.utils.conditioning_method import ConditioningMethod
|
|
@@ -161,7 +163,9 @@ def retrieve_timesteps(
|
|
| 161 |
second element is the number of inference steps.
|
| 162 |
"""
|
| 163 |
if timesteps is not None:
|
| 164 |
-
accepts_timesteps = "timesteps" in set(
|
|
|
|
|
|
|
| 165 |
if not accepts_timesteps:
|
| 166 |
raise ValueError(
|
| 167 |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
|
@@ -238,7 +242,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 238 |
patchifier=patchifier,
|
| 239 |
)
|
| 240 |
|
| 241 |
-
self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(
|
|
|
|
|
|
|
| 242 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 243 |
|
| 244 |
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
|
@@ -320,12 +326,16 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 320 |
return_tensors="pt",
|
| 321 |
)
|
| 322 |
text_input_ids = text_inputs.input_ids
|
| 323 |
-
untruncated_ids = self.tokenizer(
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
logger.warning(
|
| 330 |
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 331 |
f" {max_length} tokens: {removed_text}"
|
|
@@ -334,7 +344,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 334 |
prompt_attention_mask = text_inputs.attention_mask
|
| 335 |
prompt_attention_mask = prompt_attention_mask.to(device)
|
| 336 |
|
| 337 |
-
prompt_embeds = self.text_encoder(
|
|
|
|
|
|
|
| 338 |
prompt_embeds = prompt_embeds[0]
|
| 339 |
|
| 340 |
if self.text_encoder is not None:
|
|
@@ -349,14 +361,20 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 349 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 350 |
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 351 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 352 |
-
prompt_embeds = prompt_embeds.view(
|
|
|
|
|
|
|
| 353 |
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
|
| 354 |
-
prompt_attention_mask = prompt_attention_mask.view(
|
|
|
|
|
|
|
| 355 |
|
| 356 |
# get unconditional embeddings for classifier free guidance
|
| 357 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 358 |
uncond_tokens = [negative_prompt] * batch_size
|
| 359 |
-
uncond_tokens = self._text_preprocessing(
|
|
|
|
|
|
|
| 360 |
max_length = prompt_embeds.shape[1]
|
| 361 |
uncond_input = self.tokenizer(
|
| 362 |
uncond_tokens,
|
|
@@ -371,7 +389,8 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 371 |
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
| 372 |
|
| 373 |
negative_prompt_embeds = self.text_encoder(
|
| 374 |
-
uncond_input.input_ids.to(device),
|
|
|
|
| 375 |
)
|
| 376 |
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 377 |
|
|
@@ -379,18 +398,33 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 379 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 380 |
seq_len = negative_prompt_embeds.shape[1]
|
| 381 |
|
| 382 |
-
negative_prompt_embeds = negative_prompt_embeds.to(
|
|
|
|
|
|
|
| 383 |
|
| 384 |
-
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(
|
| 388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
else:
|
| 390 |
negative_prompt_embeds = None
|
| 391 |
negative_prompt_attention_mask = None
|
| 392 |
|
| 393 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 396 |
def prepare_extra_step_kwargs(self, generator, eta):
|
|
@@ -399,13 +433,17 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 399 |
# eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 400 |
# and should be between [0, 1]
|
| 401 |
|
| 402 |
-
accepts_eta = "eta" in set(
|
|
|
|
|
|
|
| 403 |
extra_step_kwargs = {}
|
| 404 |
if accepts_eta:
|
| 405 |
extra_step_kwargs["eta"] = eta
|
| 406 |
|
| 407 |
# check if the scheduler accepts generator
|
| 408 |
-
accepts_generator = "generator" in set(
|
|
|
|
|
|
|
| 409 |
if accepts_generator:
|
| 410 |
extra_step_kwargs["generator"] = generator
|
| 411 |
return extra_step_kwargs
|
|
@@ -422,7 +460,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 422 |
negative_prompt_attention_mask=None,
|
| 423 |
):
|
| 424 |
if height % 8 != 0 or width % 8 != 0:
|
| 425 |
-
raise ValueError(
|
|
|
|
|
|
|
| 426 |
|
| 427 |
if prompt is not None and prompt_embeds is not None:
|
| 428 |
raise ValueError(
|
|
@@ -433,8 +473,12 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 433 |
raise ValueError(
|
| 434 |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 435 |
)
|
| 436 |
-
elif prompt is not None and (
|
| 437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
if prompt is not None and negative_prompt_embeds is not None:
|
| 440 |
raise ValueError(
|
|
@@ -449,10 +493,17 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 449 |
)
|
| 450 |
|
| 451 |
if prompt_embeds is not None and prompt_attention_mask is None:
|
| 452 |
-
raise ValueError(
|
|
|
|
|
|
|
| 453 |
|
| 454 |
-
if
|
| 455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
|
| 457 |
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 458 |
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
|
@@ -471,12 +522,16 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 471 |
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
| 472 |
def _text_preprocessing(self, text, clean_caption=False):
|
| 473 |
if clean_caption and not is_bs4_available():
|
| 474 |
-
logger.warn(
|
|
|
|
|
|
|
| 475 |
logger.warn("Setting `clean_caption` to False...")
|
| 476 |
clean_caption = False
|
| 477 |
|
| 478 |
if clean_caption and not is_ftfy_available():
|
| 479 |
-
logger.warn(
|
|
|
|
|
|
|
| 480 |
logger.warn("Setting `clean_caption` to False...")
|
| 481 |
clean_caption = False
|
| 482 |
|
|
@@ -564,13 +619,17 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 564 |
# "123456.."
|
| 565 |
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
| 566 |
# filenames:
|
| 567 |
-
caption = re.sub(
|
|
|
|
|
|
|
| 568 |
|
| 569 |
#
|
| 570 |
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
| 571 |
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
| 572 |
|
| 573 |
-
caption = re.sub(
|
|
|
|
|
|
|
| 574 |
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
| 575 |
|
| 576 |
# this-is-my-cute-cat / this_is_my_cute_cat
|
|
@@ -588,10 +647,14 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 588 |
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
| 589 |
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
| 590 |
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
| 591 |
-
caption = re.sub(
|
|
|
|
|
|
|
| 592 |
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
| 593 |
|
| 594 |
-
caption = re.sub(
|
|
|
|
|
|
|
| 595 |
|
| 596 |
caption = re.sub(r"\b\d+\.?\d*[xΡ
Γ]\d+\.?\d*\b", "", caption)
|
| 597 |
|
|
@@ -610,7 +673,15 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 610 |
|
| 611 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 612 |
def prepare_latents(
|
| 613 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
):
|
| 615 |
shape = (
|
| 616 |
batch_size,
|
|
@@ -625,10 +696,14 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 625 |
)
|
| 626 |
|
| 627 |
if latents is None:
|
| 628 |
-
latents = randn_tensor(
|
|
|
|
|
|
|
| 629 |
elif latents_mask is not None:
|
| 630 |
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 631 |
-
latents = latents * latents_mask[..., None] + noise * (
|
|
|
|
|
|
|
| 632 |
else:
|
| 633 |
latents = latents.to(device)
|
| 634 |
|
|
@@ -637,7 +712,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 637 |
return latents
|
| 638 |
|
| 639 |
@staticmethod
|
| 640 |
-
def classify_height_width_bin(
|
|
|
|
|
|
|
| 641 |
"""Returns binned height and width."""
|
| 642 |
ar = float(height / width)
|
| 643 |
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
|
@@ -645,7 +722,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 645 |
return int(default_hw[0]), int(default_hw[1])
|
| 646 |
|
| 647 |
@staticmethod
|
| 648 |
-
def resize_and_crop_tensor(
|
|
|
|
|
|
|
| 649 |
n_frames, orig_height, orig_width = samples.shape[-3:]
|
| 650 |
|
| 651 |
# Check if resizing is needed
|
|
@@ -656,7 +735,12 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 656 |
|
| 657 |
# Resize
|
| 658 |
samples = rearrange(samples, "b c n h w -> (b n) c h w")
|
| 659 |
-
samples = F.interpolate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
samples = rearrange(samples, "(b n) c h w -> b c n h w", n=n_frames)
|
| 661 |
|
| 662 |
# Center Crop
|
|
@@ -821,14 +905,21 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 821 |
)
|
| 822 |
if do_classifier_free_guidance:
|
| 823 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 824 |
-
prompt_attention_mask = torch.cat(
|
|
|
|
|
|
|
| 825 |
|
| 826 |
# 3b. Encode and prepare conditioning data
|
| 827 |
self.video_scale_factor = self.video_scale_factor if is_video else 1
|
| 828 |
conditioning_method = kwargs.get("conditioning_method", None)
|
| 829 |
vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", False)
|
| 830 |
init_latents, conditioning_mask = self.prepare_conditioning(
|
| 831 |
-
media_items,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 832 |
)
|
| 833 |
|
| 834 |
# 4. Prepare latents.
|
|
@@ -851,29 +942,46 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 851 |
)
|
| 852 |
if conditioning_mask is not None and is_video:
|
| 853 |
assert num_images_per_prompt == 1
|
| 854 |
-
conditioning_mask =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 855 |
|
| 856 |
# 5. Prepare timesteps
|
| 857 |
retrieve_timesteps_kwargs = {}
|
| 858 |
if isinstance(self.scheduler, TimestepShifter):
|
| 859 |
retrieve_timesteps_kwargs["samples"] = latents
|
| 860 |
timesteps, num_inference_steps = retrieve_timesteps(
|
| 861 |
-
self.scheduler,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 862 |
)
|
| 863 |
|
| 864 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 865 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 866 |
|
| 867 |
# 7. Denoising loop
|
| 868 |
-
num_warmup_steps = max(
|
|
|
|
|
|
|
| 869 |
|
| 870 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 871 |
for i, t in enumerate(timesteps):
|
| 872 |
-
latent_model_input =
|
| 873 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 874 |
|
| 875 |
latent_frame_rates = (
|
| 876 |
-
torch.ones(
|
|
|
|
|
|
|
|
|
|
| 877 |
)
|
| 878 |
|
| 879 |
current_timestep = t
|
|
@@ -885,13 +993,25 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 885 |
dtype = torch.float32 if is_mps else torch.float64
|
| 886 |
else:
|
| 887 |
dtype = torch.int32 if is_mps else torch.int64
|
| 888 |
-
current_timestep = torch.tensor(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 889 |
elif len(current_timestep.shape) == 0:
|
| 890 |
-
current_timestep = current_timestep[None].to(
|
|
|
|
|
|
|
| 891 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 892 |
-
current_timestep = current_timestep.expand(
|
|
|
|
|
|
|
| 893 |
scale_grid = (
|
| 894 |
-
(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 895 |
if self.transformer.use_rope
|
| 896 |
else None
|
| 897 |
)
|
|
@@ -920,11 +1040,16 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 920 |
# perform guidance
|
| 921 |
if do_classifier_free_guidance:
|
| 922 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 923 |
-
noise_pred = noise_pred_uncond + guidance_scale * (
|
|
|
|
|
|
|
| 924 |
current_timestep, _ = current_timestep.chunk(2)
|
| 925 |
|
| 926 |
# learned sigma
|
| 927 |
-
if
|
|
|
|
|
|
|
|
|
|
| 928 |
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
| 929 |
|
| 930 |
# compute previous image: x_t -> x_t-1
|
|
@@ -937,7 +1062,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 937 |
)[0]
|
| 938 |
|
| 939 |
# call the callback, if provided
|
| 940 |
-
if i == len(timesteps) - 1 or (
|
|
|
|
|
|
|
| 941 |
progress_bar.update()
|
| 942 |
|
| 943 |
if callback_on_step_end is not None:
|
|
@@ -948,11 +1075,15 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 948 |
output_height=latent_height,
|
| 949 |
output_width=latent_width,
|
| 950 |
output_num_frames=latent_num_frames,
|
| 951 |
-
out_channels=self.transformer.in_channels
|
|
|
|
| 952 |
)
|
| 953 |
if output_type != "latent":
|
| 954 |
image = vae_decode(
|
| 955 |
-
latents,
|
|
|
|
|
|
|
|
|
|
| 956 |
)
|
| 957 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 958 |
|
|
@@ -1005,20 +1136,31 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
| 1005 |
vae_per_channel_normalize=vae_per_channel_normalize,
|
| 1006 |
).float()
|
| 1007 |
|
| 1008 |
-
init_len, target_len =
|
|
|
|
|
|
|
|
|
|
| 1009 |
if isinstance(self.vae, CausalVideoAutoencoder):
|
| 1010 |
target_len += 1
|
| 1011 |
init_latents = init_latents[:, :, :target_len]
|
| 1012 |
if target_len > init_len:
|
| 1013 |
repeat_factor = (target_len + init_len - 1) // init_len # Ceiling division
|
| 1014 |
-
init_latents = init_latents.repeat(1, 1, repeat_factor, 1, 1)[
|
|
|
|
|
|
|
| 1015 |
|
| 1016 |
# Prepare the conditioning mask (1.0 = condition on this token)
|
| 1017 |
b, n, f, h, w = init_latents.shape
|
| 1018 |
conditioning_mask = torch.zeros([b, 1, f, h, w], device=init_latents.device)
|
| 1019 |
-
if method in [
|
|
|
|
|
|
|
|
|
|
| 1020 |
conditioning_mask[:, :, 0] = 1.0
|
| 1021 |
-
if method in [
|
|
|
|
|
|
|
|
|
|
| 1022 |
conditioning_mask[:, :, -1] = 1.0
|
| 1023 |
|
| 1024 |
# Patchify the init latents and the mask
|
|
|
|
| 5 |
import re
|
| 6 |
import urllib.parse as ul
|
| 7 |
from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
import torch
|
| 11 |
import torch.nn.functional as F
|
|
|
|
| 12 |
from diffusers.image_processor import VaeImageProcessor
|
| 13 |
from diffusers.models import AutoencoderKL
|
| 14 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
|
|
| 27 |
|
| 28 |
from xora.models.transformers.transformer3d import Transformer3DModel
|
| 29 |
from xora.models.transformers.symmetric_patchifier import Patchifier
|
| 30 |
+
from xora.models.autoencoders.vae_encode import (
|
| 31 |
+
get_vae_size_scale_factor,
|
| 32 |
+
vae_decode,
|
| 33 |
+
vae_encode,
|
| 34 |
+
)
|
| 35 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 36 |
from xora.schedulers.rf import TimestepShifter
|
| 37 |
from xora.utils.conditioning_method import ConditioningMethod
|
|
|
|
| 163 |
second element is the number of inference steps.
|
| 164 |
"""
|
| 165 |
if timesteps is not None:
|
| 166 |
+
accepts_timesteps = "timesteps" in set(
|
| 167 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
| 168 |
+
)
|
| 169 |
if not accepts_timesteps:
|
| 170 |
raise ValueError(
|
| 171 |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
|
|
|
| 242 |
patchifier=patchifier,
|
| 243 |
)
|
| 244 |
|
| 245 |
+
self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(
|
| 246 |
+
self.vae
|
| 247 |
+
)
|
| 248 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 249 |
|
| 250 |
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
|
|
|
| 326 |
return_tensors="pt",
|
| 327 |
)
|
| 328 |
text_input_ids = text_inputs.input_ids
|
| 329 |
+
untruncated_ids = self.tokenizer(
|
| 330 |
+
prompt, padding="longest", return_tensors="pt"
|
| 331 |
+
).input_ids
|
| 332 |
+
|
| 333 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
| 334 |
+
-1
|
| 335 |
+
] and not torch.equal(text_input_ids, untruncated_ids):
|
| 336 |
+
removed_text = self.tokenizer.batch_decode(
|
| 337 |
+
untruncated_ids[:, max_length - 1 : -1]
|
| 338 |
+
)
|
| 339 |
logger.warning(
|
| 340 |
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 341 |
f" {max_length} tokens: {removed_text}"
|
|
|
|
| 344 |
prompt_attention_mask = text_inputs.attention_mask
|
| 345 |
prompt_attention_mask = prompt_attention_mask.to(device)
|
| 346 |
|
| 347 |
+
prompt_embeds = self.text_encoder(
|
| 348 |
+
text_input_ids.to(device), attention_mask=prompt_attention_mask
|
| 349 |
+
)
|
| 350 |
prompt_embeds = prompt_embeds[0]
|
| 351 |
|
| 352 |
if self.text_encoder is not None:
|
|
|
|
| 361 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 362 |
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 363 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 364 |
+
prompt_embeds = prompt_embeds.view(
|
| 365 |
+
bs_embed * num_images_per_prompt, seq_len, -1
|
| 366 |
+
)
|
| 367 |
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
|
| 368 |
+
prompt_attention_mask = prompt_attention_mask.view(
|
| 369 |
+
bs_embed * num_images_per_prompt, -1
|
| 370 |
+
)
|
| 371 |
|
| 372 |
# get unconditional embeddings for classifier free guidance
|
| 373 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 374 |
uncond_tokens = [negative_prompt] * batch_size
|
| 375 |
+
uncond_tokens = self._text_preprocessing(
|
| 376 |
+
uncond_tokens, clean_caption=clean_caption
|
| 377 |
+
)
|
| 378 |
max_length = prompt_embeds.shape[1]
|
| 379 |
uncond_input = self.tokenizer(
|
| 380 |
uncond_tokens,
|
|
|
|
| 389 |
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
| 390 |
|
| 391 |
negative_prompt_embeds = self.text_encoder(
|
| 392 |
+
uncond_input.input_ids.to(device),
|
| 393 |
+
attention_mask=negative_prompt_attention_mask,
|
| 394 |
)
|
| 395 |
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 396 |
|
|
|
|
| 398 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 399 |
seq_len = negative_prompt_embeds.shape[1]
|
| 400 |
|
| 401 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
| 402 |
+
dtype=dtype, device=device
|
| 403 |
+
)
|
| 404 |
|
| 405 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
| 406 |
+
1, num_images_per_prompt, 1
|
| 407 |
+
)
|
| 408 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
| 409 |
+
batch_size * num_images_per_prompt, seq_len, -1
|
| 410 |
+
)
|
| 411 |
|
| 412 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(
|
| 413 |
+
1, num_images_per_prompt
|
| 414 |
+
)
|
| 415 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(
|
| 416 |
+
bs_embed * num_images_per_prompt, -1
|
| 417 |
+
)
|
| 418 |
else:
|
| 419 |
negative_prompt_embeds = None
|
| 420 |
negative_prompt_attention_mask = None
|
| 421 |
|
| 422 |
+
return (
|
| 423 |
+
prompt_embeds,
|
| 424 |
+
prompt_attention_mask,
|
| 425 |
+
negative_prompt_embeds,
|
| 426 |
+
negative_prompt_attention_mask,
|
| 427 |
+
)
|
| 428 |
|
| 429 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 430 |
def prepare_extra_step_kwargs(self, generator, eta):
|
|
|
|
| 433 |
# eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 434 |
# and should be between [0, 1]
|
| 435 |
|
| 436 |
+
accepts_eta = "eta" in set(
|
| 437 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
| 438 |
+
)
|
| 439 |
extra_step_kwargs = {}
|
| 440 |
if accepts_eta:
|
| 441 |
extra_step_kwargs["eta"] = eta
|
| 442 |
|
| 443 |
# check if the scheduler accepts generator
|
| 444 |
+
accepts_generator = "generator" in set(
|
| 445 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
| 446 |
+
)
|
| 447 |
if accepts_generator:
|
| 448 |
extra_step_kwargs["generator"] = generator
|
| 449 |
return extra_step_kwargs
|
|
|
|
| 460 |
negative_prompt_attention_mask=None,
|
| 461 |
):
|
| 462 |
if height % 8 != 0 or width % 8 != 0:
|
| 463 |
+
raise ValueError(
|
| 464 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
| 465 |
+
)
|
| 466 |
|
| 467 |
if prompt is not None and prompt_embeds is not None:
|
| 468 |
raise ValueError(
|
|
|
|
| 473 |
raise ValueError(
|
| 474 |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 475 |
)
|
| 476 |
+
elif prompt is not None and (
|
| 477 |
+
not isinstance(prompt, str) and not isinstance(prompt, list)
|
| 478 |
+
):
|
| 479 |
+
raise ValueError(
|
| 480 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
| 481 |
+
)
|
| 482 |
|
| 483 |
if prompt is not None and negative_prompt_embeds is not None:
|
| 484 |
raise ValueError(
|
|
|
|
| 493 |
)
|
| 494 |
|
| 495 |
if prompt_embeds is not None and prompt_attention_mask is None:
|
| 496 |
+
raise ValueError(
|
| 497 |
+
"Must provide `prompt_attention_mask` when specifying `prompt_embeds`."
|
| 498 |
+
)
|
| 499 |
|
| 500 |
+
if (
|
| 501 |
+
negative_prompt_embeds is not None
|
| 502 |
+
and negative_prompt_attention_mask is None
|
| 503 |
+
):
|
| 504 |
+
raise ValueError(
|
| 505 |
+
"Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`."
|
| 506 |
+
)
|
| 507 |
|
| 508 |
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 509 |
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
|
|
|
| 522 |
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
| 523 |
def _text_preprocessing(self, text, clean_caption=False):
|
| 524 |
if clean_caption and not is_bs4_available():
|
| 525 |
+
logger.warn(
|
| 526 |
+
BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")
|
| 527 |
+
)
|
| 528 |
logger.warn("Setting `clean_caption` to False...")
|
| 529 |
clean_caption = False
|
| 530 |
|
| 531 |
if clean_caption and not is_ftfy_available():
|
| 532 |
+
logger.warn(
|
| 533 |
+
BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")
|
| 534 |
+
)
|
| 535 |
logger.warn("Setting `clean_caption` to False...")
|
| 536 |
clean_caption = False
|
| 537 |
|
|
|
|
| 619 |
# "123456.."
|
| 620 |
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
| 621 |
# filenames:
|
| 622 |
+
caption = re.sub(
|
| 623 |
+
r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption
|
| 624 |
+
)
|
| 625 |
|
| 626 |
#
|
| 627 |
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
| 628 |
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
| 629 |
|
| 630 |
+
caption = re.sub(
|
| 631 |
+
self.bad_punct_regex, r" ", caption
|
| 632 |
+
) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
| 633 |
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
| 634 |
|
| 635 |
# this-is-my-cute-cat / this_is_my_cute_cat
|
|
|
|
| 647 |
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
| 648 |
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
| 649 |
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
| 650 |
+
caption = re.sub(
|
| 651 |
+
r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption
|
| 652 |
+
)
|
| 653 |
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
| 654 |
|
| 655 |
+
caption = re.sub(
|
| 656 |
+
r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption
|
| 657 |
+
) # j2d1a2a...
|
| 658 |
|
| 659 |
caption = re.sub(r"\b\d+\.?\d*[xΡ
Γ]\d+\.?\d*\b", "", caption)
|
| 660 |
|
|
|
|
| 673 |
|
| 674 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 675 |
def prepare_latents(
|
| 676 |
+
self,
|
| 677 |
+
batch_size,
|
| 678 |
+
num_latent_channels,
|
| 679 |
+
num_patches,
|
| 680 |
+
dtype,
|
| 681 |
+
device,
|
| 682 |
+
generator,
|
| 683 |
+
latents=None,
|
| 684 |
+
latents_mask=None,
|
| 685 |
):
|
| 686 |
shape = (
|
| 687 |
batch_size,
|
|
|
|
| 696 |
)
|
| 697 |
|
| 698 |
if latents is None:
|
| 699 |
+
latents = randn_tensor(
|
| 700 |
+
shape, generator=generator, device=device, dtype=dtype
|
| 701 |
+
)
|
| 702 |
elif latents_mask is not None:
|
| 703 |
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 704 |
+
latents = latents * latents_mask[..., None] + noise * (
|
| 705 |
+
1 - latents_mask[..., None]
|
| 706 |
+
)
|
| 707 |
else:
|
| 708 |
latents = latents.to(device)
|
| 709 |
|
|
|
|
| 712 |
return latents
|
| 713 |
|
| 714 |
@staticmethod
|
| 715 |
+
def classify_height_width_bin(
|
| 716 |
+
height: int, width: int, ratios: dict
|
| 717 |
+
) -> Tuple[int, int]:
|
| 718 |
"""Returns binned height and width."""
|
| 719 |
ar = float(height / width)
|
| 720 |
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
|
|
|
| 722 |
return int(default_hw[0]), int(default_hw[1])
|
| 723 |
|
| 724 |
@staticmethod
|
| 725 |
+
def resize_and_crop_tensor(
|
| 726 |
+
samples: torch.Tensor, new_width: int, new_height: int
|
| 727 |
+
) -> torch.Tensor:
|
| 728 |
n_frames, orig_height, orig_width = samples.shape[-3:]
|
| 729 |
|
| 730 |
# Check if resizing is needed
|
|
|
|
| 735 |
|
| 736 |
# Resize
|
| 737 |
samples = rearrange(samples, "b c n h w -> (b n) c h w")
|
| 738 |
+
samples = F.interpolate(
|
| 739 |
+
samples,
|
| 740 |
+
size=(resized_height, resized_width),
|
| 741 |
+
mode="bilinear",
|
| 742 |
+
align_corners=False,
|
| 743 |
+
)
|
| 744 |
samples = rearrange(samples, "(b n) c h w -> b c n h w", n=n_frames)
|
| 745 |
|
| 746 |
# Center Crop
|
|
|
|
| 905 |
)
|
| 906 |
if do_classifier_free_guidance:
|
| 907 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 908 |
+
prompt_attention_mask = torch.cat(
|
| 909 |
+
[negative_prompt_attention_mask, prompt_attention_mask], dim=0
|
| 910 |
+
)
|
| 911 |
|
| 912 |
# 3b. Encode and prepare conditioning data
|
| 913 |
self.video_scale_factor = self.video_scale_factor if is_video else 1
|
| 914 |
conditioning_method = kwargs.get("conditioning_method", None)
|
| 915 |
vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", False)
|
| 916 |
init_latents, conditioning_mask = self.prepare_conditioning(
|
| 917 |
+
media_items,
|
| 918 |
+
num_frames,
|
| 919 |
+
height,
|
| 920 |
+
width,
|
| 921 |
+
conditioning_method,
|
| 922 |
+
vae_per_channel_normalize,
|
| 923 |
)
|
| 924 |
|
| 925 |
# 4. Prepare latents.
|
|
|
|
| 942 |
)
|
| 943 |
if conditioning_mask is not None and is_video:
|
| 944 |
assert num_images_per_prompt == 1
|
| 945 |
+
conditioning_mask = (
|
| 946 |
+
torch.cat([conditioning_mask] * 2)
|
| 947 |
+
if do_classifier_free_guidance
|
| 948 |
+
else conditioning_mask
|
| 949 |
+
)
|
| 950 |
|
| 951 |
# 5. Prepare timesteps
|
| 952 |
retrieve_timesteps_kwargs = {}
|
| 953 |
if isinstance(self.scheduler, TimestepShifter):
|
| 954 |
retrieve_timesteps_kwargs["samples"] = latents
|
| 955 |
timesteps, num_inference_steps = retrieve_timesteps(
|
| 956 |
+
self.scheduler,
|
| 957 |
+
num_inference_steps,
|
| 958 |
+
device,
|
| 959 |
+
timesteps,
|
| 960 |
+
**retrieve_timesteps_kwargs,
|
| 961 |
)
|
| 962 |
|
| 963 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 964 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 965 |
|
| 966 |
# 7. Denoising loop
|
| 967 |
+
num_warmup_steps = max(
|
| 968 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
| 969 |
+
)
|
| 970 |
|
| 971 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 972 |
for i, t in enumerate(timesteps):
|
| 973 |
+
latent_model_input = (
|
| 974 |
+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 975 |
+
)
|
| 976 |
+
latent_model_input = self.scheduler.scale_model_input(
|
| 977 |
+
latent_model_input, t
|
| 978 |
+
)
|
| 979 |
|
| 980 |
latent_frame_rates = (
|
| 981 |
+
torch.ones(
|
| 982 |
+
latent_model_input.shape[0], 1, device=latent_model_input.device
|
| 983 |
+
)
|
| 984 |
+
* latent_frame_rate
|
| 985 |
)
|
| 986 |
|
| 987 |
current_timestep = t
|
|
|
|
| 993 |
dtype = torch.float32 if is_mps else torch.float64
|
| 994 |
else:
|
| 995 |
dtype = torch.int32 if is_mps else torch.int64
|
| 996 |
+
current_timestep = torch.tensor(
|
| 997 |
+
[current_timestep],
|
| 998 |
+
dtype=dtype,
|
| 999 |
+
device=latent_model_input.device,
|
| 1000 |
+
)
|
| 1001 |
elif len(current_timestep.shape) == 0:
|
| 1002 |
+
current_timestep = current_timestep[None].to(
|
| 1003 |
+
latent_model_input.device
|
| 1004 |
+
)
|
| 1005 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 1006 |
+
current_timestep = current_timestep.expand(
|
| 1007 |
+
latent_model_input.shape[0]
|
| 1008 |
+
).unsqueeze(-1)
|
| 1009 |
scale_grid = (
|
| 1010 |
+
(
|
| 1011 |
+
1 / latent_frame_rates,
|
| 1012 |
+
self.vae_scale_factor,
|
| 1013 |
+
self.vae_scale_factor,
|
| 1014 |
+
)
|
| 1015 |
if self.transformer.use_rope
|
| 1016 |
else None
|
| 1017 |
)
|
|
|
|
| 1040 |
# perform guidance
|
| 1041 |
if do_classifier_free_guidance:
|
| 1042 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1043 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
| 1044 |
+
noise_pred_text - noise_pred_uncond
|
| 1045 |
+
)
|
| 1046 |
current_timestep, _ = current_timestep.chunk(2)
|
| 1047 |
|
| 1048 |
# learned sigma
|
| 1049 |
+
if (
|
| 1050 |
+
self.transformer.config.out_channels // 2
|
| 1051 |
+
== self.transformer.config.in_channels
|
| 1052 |
+
):
|
| 1053 |
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
| 1054 |
|
| 1055 |
# compute previous image: x_t -> x_t-1
|
|
|
|
| 1062 |
)[0]
|
| 1063 |
|
| 1064 |
# call the callback, if provided
|
| 1065 |
+
if i == len(timesteps) - 1 or (
|
| 1066 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
| 1067 |
+
):
|
| 1068 |
progress_bar.update()
|
| 1069 |
|
| 1070 |
if callback_on_step_end is not None:
|
|
|
|
| 1075 |
output_height=latent_height,
|
| 1076 |
output_width=latent_width,
|
| 1077 |
output_num_frames=latent_num_frames,
|
| 1078 |
+
out_channels=self.transformer.in_channels
|
| 1079 |
+
// math.prod(self.patchifier.patch_size),
|
| 1080 |
)
|
| 1081 |
if output_type != "latent":
|
| 1082 |
image = vae_decode(
|
| 1083 |
+
latents,
|
| 1084 |
+
self.vae,
|
| 1085 |
+
is_video,
|
| 1086 |
+
vae_per_channel_normalize=kwargs["vae_per_channel_normalize"],
|
| 1087 |
)
|
| 1088 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1089 |
|
|
|
|
| 1136 |
vae_per_channel_normalize=vae_per_channel_normalize,
|
| 1137 |
).float()
|
| 1138 |
|
| 1139 |
+
init_len, target_len = (
|
| 1140 |
+
init_latents.shape[2],
|
| 1141 |
+
num_frames // self.video_scale_factor,
|
| 1142 |
+
)
|
| 1143 |
if isinstance(self.vae, CausalVideoAutoencoder):
|
| 1144 |
target_len += 1
|
| 1145 |
init_latents = init_latents[:, :, :target_len]
|
| 1146 |
if target_len > init_len:
|
| 1147 |
repeat_factor = (target_len + init_len - 1) // init_len # Ceiling division
|
| 1148 |
+
init_latents = init_latents.repeat(1, 1, repeat_factor, 1, 1)[
|
| 1149 |
+
:, :, :target_len
|
| 1150 |
+
]
|
| 1151 |
|
| 1152 |
# Prepare the conditioning mask (1.0 = condition on this token)
|
| 1153 |
b, n, f, h, w = init_latents.shape
|
| 1154 |
conditioning_mask = torch.zeros([b, 1, f, h, w], device=init_latents.device)
|
| 1155 |
+
if method in [
|
| 1156 |
+
ConditioningMethod.FIRST_FRAME,
|
| 1157 |
+
ConditioningMethod.FIRST_AND_LAST_FRAME,
|
| 1158 |
+
]:
|
| 1159 |
conditioning_mask[:, :, 0] = 1.0
|
| 1160 |
+
if method in [
|
| 1161 |
+
ConditioningMethod.LAST_FRAME,
|
| 1162 |
+
ConditioningMethod.FIRST_AND_LAST_FRAME,
|
| 1163 |
+
]:
|
| 1164 |
conditioning_mask[:, :, -1] = 1.0
|
| 1165 |
|
| 1166 |
# Patchify the init latents and the mask
|
xora/schedulers/rf.py
CHANGED
|
@@ -22,7 +22,9 @@ def simple_diffusion_resolution_dependent_timestep_shift(
|
|
| 22 |
elif len(samples.shape) in [4, 5]:
|
| 23 |
m = math.prod(samples.shape[2:])
|
| 24 |
else:
|
| 25 |
-
raise ValueError(
|
|
|
|
|
|
|
| 26 |
snr = (timesteps / (1 - timesteps)) ** 2
|
| 27 |
shift_snr = torch.log(snr) + 2 * math.log(m / n)
|
| 28 |
shifted_timesteps = torch.sigmoid(0.5 * shift_snr)
|
|
@@ -46,7 +48,9 @@ def get_normal_shift(
|
|
| 46 |
return m * n_tokens + b
|
| 47 |
|
| 48 |
|
| 49 |
-
def sd3_resolution_dependent_timestep_shift(
|
|
|
|
|
|
|
| 50 |
"""
|
| 51 |
Shifts the timestep schedule as a function of the generated resolution.
|
| 52 |
|
|
@@ -70,7 +74,9 @@ def sd3_resolution_dependent_timestep_shift(samples: Tensor, timesteps: Tensor)
|
|
| 70 |
elif len(samples.shape) in [4, 5]:
|
| 71 |
m = math.prod(samples.shape[2:])
|
| 72 |
else:
|
| 73 |
-
raise ValueError(
|
|
|
|
|
|
|
| 74 |
|
| 75 |
shift = get_normal_shift(m)
|
| 76 |
return time_shift(shift, 1, timesteps)
|
|
@@ -104,12 +110,21 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
|
|
| 104 |
order = 1
|
| 105 |
|
| 106 |
@register_to_config
|
| 107 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
super().__init__()
|
| 109 |
self.init_noise_sigma = 1.0
|
| 110 |
self.num_inference_steps = None
|
| 111 |
-
self.timesteps = self.sigmas = torch.linspace(
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
self.shifting = shifting
|
| 114 |
self.base_resolution = base_resolution
|
| 115 |
|
|
@@ -117,10 +132,17 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
|
|
| 117 |
if self.shifting == "SD3":
|
| 118 |
return sd3_resolution_dependent_timestep_shift(samples, timesteps)
|
| 119 |
elif self.shifting == "SimpleDiffusion":
|
| 120 |
-
return simple_diffusion_resolution_dependent_timestep_shift(
|
|
|
|
|
|
|
| 121 |
return timesteps
|
| 122 |
|
| 123 |
-
def set_timesteps(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
"""
|
| 125 |
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
| 126 |
|
|
@@ -130,13 +152,19 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
|
|
| 130 |
device (`Union[str, torch.device]`, *optional*): The device to which the timesteps tensor will be moved.
|
| 131 |
"""
|
| 132 |
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
|
| 133 |
-
timesteps = torch.linspace(1, 1 / num_inference_steps, num_inference_steps).to(
|
|
|
|
|
|
|
| 134 |
self.timesteps = self.shift_timesteps(samples, timesteps)
|
| 135 |
-
self.delta_timesteps = self.timesteps - torch.cat(
|
|
|
|
|
|
|
| 136 |
self.num_inference_steps = num_inference_steps
|
| 137 |
self.sigmas = self.timesteps
|
| 138 |
|
| 139 |
-
def scale_model_input(
|
|
|
|
|
|
|
| 140 |
# pylint: disable=unused-argument
|
| 141 |
"""
|
| 142 |
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
|
@@ -206,7 +234,9 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
|
|
| 206 |
else:
|
| 207 |
# Timestep per token
|
| 208 |
assert timestep.ndim == 2
|
| 209 |
-
current_index = (
|
|
|
|
|
|
|
| 210 |
dt = self.delta_timesteps[current_index]
|
| 211 |
# Special treatment for zero timestep tokens - set dt to 0 so prev_sample = sample
|
| 212 |
dt = torch.where(timestep == 0.0, torch.zeros_like(dt), dt)[..., None]
|
|
@@ -228,4 +258,4 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
|
|
| 228 |
sigmas = append_dims(sigmas, original_samples.ndim)
|
| 229 |
alphas = 1 - sigmas
|
| 230 |
noisy_samples = alphas * original_samples + sigmas * noise
|
| 231 |
-
return noisy_samples
|
|
|
|
| 22 |
elif len(samples.shape) in [4, 5]:
|
| 23 |
m = math.prod(samples.shape[2:])
|
| 24 |
else:
|
| 25 |
+
raise ValueError(
|
| 26 |
+
"Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
|
| 27 |
+
)
|
| 28 |
snr = (timesteps / (1 - timesteps)) ** 2
|
| 29 |
shift_snr = torch.log(snr) + 2 * math.log(m / n)
|
| 30 |
shifted_timesteps = torch.sigmoid(0.5 * shift_snr)
|
|
|
|
| 48 |
return m * n_tokens + b
|
| 49 |
|
| 50 |
|
| 51 |
+
def sd3_resolution_dependent_timestep_shift(
|
| 52 |
+
samples: Tensor, timesteps: Tensor
|
| 53 |
+
) -> Tensor:
|
| 54 |
"""
|
| 55 |
Shifts the timestep schedule as a function of the generated resolution.
|
| 56 |
|
|
|
|
| 74 |
elif len(samples.shape) in [4, 5]:
|
| 75 |
m = math.prod(samples.shape[2:])
|
| 76 |
else:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
"Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
|
| 79 |
+
)
|
| 80 |
|
| 81 |
shift = get_normal_shift(m)
|
| 82 |
return time_shift(shift, 1, timesteps)
|
|
|
|
| 110 |
order = 1
|
| 111 |
|
| 112 |
@register_to_config
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
num_train_timesteps=1000,
|
| 116 |
+
shifting: Optional[str] = None,
|
| 117 |
+
base_resolution: int = 32**2,
|
| 118 |
+
):
|
| 119 |
super().__init__()
|
| 120 |
self.init_noise_sigma = 1.0
|
| 121 |
self.num_inference_steps = None
|
| 122 |
+
self.timesteps = self.sigmas = torch.linspace(
|
| 123 |
+
1, 1 / num_train_timesteps, num_train_timesteps
|
| 124 |
+
)
|
| 125 |
+
self.delta_timesteps = self.timesteps - torch.cat(
|
| 126 |
+
[self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])]
|
| 127 |
+
)
|
| 128 |
self.shifting = shifting
|
| 129 |
self.base_resolution = base_resolution
|
| 130 |
|
|
|
|
| 132 |
if self.shifting == "SD3":
|
| 133 |
return sd3_resolution_dependent_timestep_shift(samples, timesteps)
|
| 134 |
elif self.shifting == "SimpleDiffusion":
|
| 135 |
+
return simple_diffusion_resolution_dependent_timestep_shift(
|
| 136 |
+
samples, timesteps, self.base_resolution
|
| 137 |
+
)
|
| 138 |
return timesteps
|
| 139 |
|
| 140 |
+
def set_timesteps(
|
| 141 |
+
self,
|
| 142 |
+
num_inference_steps: int,
|
| 143 |
+
samples: Tensor,
|
| 144 |
+
device: Union[str, torch.device] = None,
|
| 145 |
+
):
|
| 146 |
"""
|
| 147 |
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
| 148 |
|
|
|
|
| 152 |
device (`Union[str, torch.device]`, *optional*): The device to which the timesteps tensor will be moved.
|
| 153 |
"""
|
| 154 |
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
|
| 155 |
+
timesteps = torch.linspace(1, 1 / num_inference_steps, num_inference_steps).to(
|
| 156 |
+
device
|
| 157 |
+
)
|
| 158 |
self.timesteps = self.shift_timesteps(samples, timesteps)
|
| 159 |
+
self.delta_timesteps = self.timesteps - torch.cat(
|
| 160 |
+
[self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])]
|
| 161 |
+
)
|
| 162 |
self.num_inference_steps = num_inference_steps
|
| 163 |
self.sigmas = self.timesteps
|
| 164 |
|
| 165 |
+
def scale_model_input(
|
| 166 |
+
self, sample: torch.FloatTensor, timestep: Optional[int] = None
|
| 167 |
+
) -> torch.FloatTensor:
|
| 168 |
# pylint: disable=unused-argument
|
| 169 |
"""
|
| 170 |
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
|
|
|
| 234 |
else:
|
| 235 |
# Timestep per token
|
| 236 |
assert timestep.ndim == 2
|
| 237 |
+
current_index = (
|
| 238 |
+
(self.timesteps[:, None, None] - timestep[None]).abs().argmin(dim=0)
|
| 239 |
+
)
|
| 240 |
dt = self.delta_timesteps[current_index]
|
| 241 |
# Special treatment for zero timestep tokens - set dt to 0 so prev_sample = sample
|
| 242 |
dt = torch.where(timestep == 0.0, torch.zeros_like(dt), dt)[..., None]
|
|
|
|
| 258 |
sigmas = append_dims(sigmas, original_samples.ndim)
|
| 259 |
alphas = 1 - sigmas
|
| 260 |
noisy_samples = alphas * original_samples + sigmas * noise
|
| 261 |
+
return noisy_samples
|
xora/utils/conditioning_method.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
from enum import Enum
|
| 2 |
|
|
|
|
| 3 |
class ConditioningMethod(Enum):
|
| 4 |
UNCONDITIONAL = "unconditional"
|
| 5 |
FIRST_FRAME = "first_frame"
|
| 6 |
LAST_FRAME = "last_frame"
|
| 7 |
-
FIRST_AND_LAST_FRAME = "first_and_last_frame"
|
|
|
|
| 1 |
from enum import Enum
|
| 2 |
|
| 3 |
+
|
| 4 |
class ConditioningMethod(Enum):
|
| 5 |
UNCONDITIONAL = "unconditional"
|
| 6 |
FIRST_FRAME = "first_frame"
|
| 7 |
LAST_FRAME = "last_frame"
|
| 8 |
+
FIRST_AND_LAST_FRAME = "first_and_last_frame"
|
xora/utils/torch_utils.py
CHANGED
|
@@ -1,15 +1,19 @@
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
|
|
|
|
| 4 |
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
| 5 |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
| 6 |
dims_to_append = target_dims - x.ndim
|
| 7 |
if dims_to_append < 0:
|
| 8 |
-
raise ValueError(
|
|
|
|
|
|
|
| 9 |
elif dims_to_append == 0:
|
| 10 |
return x
|
| 11 |
return x[(...,) + (None,) * dims_to_append]
|
| 12 |
|
|
|
|
| 13 |
class Identity(nn.Module):
|
| 14 |
"""A placeholder identity operator that is argument-insensitive."""
|
| 15 |
|
|
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
|
| 4 |
+
|
| 5 |
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
| 6 |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
| 7 |
dims_to_append = target_dims - x.ndim
|
| 8 |
if dims_to_append < 0:
|
| 9 |
+
raise ValueError(
|
| 10 |
+
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
| 11 |
+
)
|
| 12 |
elif dims_to_append == 0:
|
| 13 |
return x
|
| 14 |
return x[(...,) + (None,) * dims_to_append]
|
| 15 |
|
| 16 |
+
|
| 17 |
class Identity(nn.Module):
|
| 18 |
"""A placeholder identity operator that is argument-insensitive."""
|
| 19 |
|