Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	working.
Browse files- xora/examples/image_to_video.py +87 -0
- xora/models/autoencoders/causal_video_autoencoder.py +3 -1
- xora/models/autoencoders/vae_encode.py +11 -41
- xora/models/autoencoders/video_autoencoder.py +912 -0
- xora/models/transformers/embeddings.py +125 -0
- xora/models/transformers/transformer3d.py +77 -4
- xora/pipelines/pipeline_video_pixart_alpha.py +181 -13
- xora/schedulers/rf.py +13 -4
- xora/utils/conditioning_method.py +7 -0
- xora/utils/dist_util.py +11 -0
    	
        xora/examples/image_to_video.py
    ADDED
    
    | @@ -0,0 +1,87 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
         | 
| 3 | 
            +
            from xora.models.transformers.transformer3d import Transformer3DModel
         | 
| 4 | 
            +
            from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
         | 
| 5 | 
            +
            from xora.schedulers.rf import RectifiedFlowScheduler
         | 
| 6 | 
            +
            from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
         | 
| 7 | 
            +
            from pathlib import Path
         | 
| 8 | 
            +
            from transformers import T5EncoderModel
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
         | 
| 12 | 
            +
            vae_local_path = Path("/opt/models/checkpoints/vae_training/causal_vvae_32x32x8_420m_cont_32/step_2296000")
         | 
| 13 | 
            +
            dtype = torch.float32
         | 
| 14 | 
            +
            vae = CausalVideoAutoencoder.from_pretrained(
         | 
| 15 | 
            +
                        pretrained_model_name_or_path=vae_local_path,
         | 
| 16 | 
            +
                        revision=False,
         | 
| 17 | 
            +
                        torch_dtype=torch.bfloat16,
         | 
| 18 | 
            +
                        load_in_8bit=False,
         | 
| 19 | 
            +
            ).cuda()
         | 
| 20 | 
            +
            transformer_config_path = Path("/opt/txt2img/txt2img/config/transformer3d/xora_v1.2-L.json")
         | 
| 21 | 
            +
            transformer_config = Transformer3DModel.load_config(transformer_config_path)
         | 
| 22 | 
            +
            transformer = Transformer3DModel.from_config(transformer_config)
         | 
| 23 | 
            +
            transformer_local_path = Path("/opt/models/logs/v1.2-vae-mf-medHR-mr-cvae-first-frame-cond-4k-seq/ckpt/01822000/model.pt")
         | 
| 24 | 
            +
            transformer_ckpt_state_dict = torch.load(transformer_local_path)
         | 
| 25 | 
            +
            transformer.load_state_dict(transformer_ckpt_state_dict, True)
         | 
| 26 | 
            +
            transformer = transformer.cuda()
         | 
| 27 | 
            +
            unet = transformer
         | 
| 28 | 
            +
            scheduler_config_path = Path("/opt/txt2img/txt2img/config/scheduler/RF_SD3_shifted.json")
         | 
| 29 | 
            +
            scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
         | 
| 30 | 
            +
            scheduler = RectifiedFlowScheduler.from_config(scheduler_config)
         | 
| 31 | 
            +
            patchifier = SymmetricPatchifier(patch_size=1)
         | 
| 32 | 
            +
            # text_encoder = T5EncoderModel.from_pretrained("t5-v1_1-xxl")
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            submodel_dict = {
         | 
| 35 | 
            +
                "unet": unet,
         | 
| 36 | 
            +
                "transformer": transformer,
         | 
| 37 | 
            +
                "patchifier": patchifier,
         | 
| 38 | 
            +
                "text_encoder": None,
         | 
| 39 | 
            +
                "scheduler": scheduler,
         | 
| 40 | 
            +
                "vae": vae,
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            }
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path,
         | 
| 45 | 
            +
                                                                safety_checker=None,
         | 
| 46 | 
            +
                        revision=None,
         | 
| 47 | 
            +
                        torch_dtype=dtype,
         | 
| 48 | 
            +
                        **submodel_dict,
         | 
| 49 | 
            +
                    )
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            num_inference_steps=20
         | 
| 52 | 
            +
            num_images_per_prompt=2
         | 
| 53 | 
            +
            guidance_scale=3
         | 
| 54 | 
            +
            height=512
         | 
| 55 | 
            +
            width=768
         | 
| 56 | 
            +
            num_frames=57
         | 
| 57 | 
            +
            frame_rate=25
         | 
| 58 | 
            +
            # sample = {
         | 
| 59 | 
            +
            #     "prompt": "A cat", # (B, L, E)
         | 
| 60 | 
            +
            #     'prompt_attention_mask': None, # (B , L)
         | 
| 61 | 
            +
            #     'negative_prompt': "Ugly deformed",
         | 
| 62 | 
            +
            #     'negative_prompt_attention_mask': None # (B , L)
         | 
| 63 | 
            +
            # }
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            sample = torch.load("/opt/sample.pt")
         | 
| 66 | 
            +
            for _, item in sample.items():
         | 
| 67 | 
            +
                if item is not None:
         | 
| 68 | 
            +
                    item = item.cuda()
         | 
| 69 | 
            +
            media_items = torch.load("/opt/sample_media.pt")
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            images = pipeline(
         | 
| 72 | 
            +
                num_inference_steps=num_inference_steps,
         | 
| 73 | 
            +
                num_images_per_prompt=num_images_per_prompt,
         | 
| 74 | 
            +
                guidance_scale=guidance_scale,
         | 
| 75 | 
            +
                generator=None,
         | 
| 76 | 
            +
                output_type="pt",
         | 
| 77 | 
            +
                callback_on_step_end=None,
         | 
| 78 | 
            +
                height=height,
         | 
| 79 | 
            +
                width=width,
         | 
| 80 | 
            +
                num_frames=num_frames,
         | 
| 81 | 
            +
                frame_rate=frame_rate,
         | 
| 82 | 
            +
                **sample,
         | 
| 83 | 
            +
                is_video=True,
         | 
| 84 | 
            +
                vae_per_channel_normalize=True,
         | 
| 85 | 
            +
            ).images
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            print()
         | 
    	
        xora/models/autoencoders/causal_video_autoencoder.py
    CHANGED
    
    | @@ -8,11 +8,13 @@ import torch | |
| 8 | 
             
            import numpy as np
         | 
| 9 | 
             
            from einops import rearrange
         | 
| 10 | 
             
            from torch import nn
         | 
|  | |
| 11 |  | 
| 12 | 
             
            from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
         | 
| 13 | 
             
            from xora.models.autoencoders.pixel_norm import PixelNorm
         | 
| 14 | 
             
            from xora.models.autoencoders.vae import AutoencoderKLWrapper
         | 
| 15 |  | 
|  | |
| 16 |  | 
| 17 | 
             
            class CausalVideoAutoencoder(AutoencoderKLWrapper):
         | 
| 18 | 
             
                @classmethod
         | 
| @@ -138,7 +140,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper): | |
| 138 | 
             
                            key = key.replace(k, v)
         | 
| 139 |  | 
| 140 | 
             
                        if "norm" in key and key not in model_keys:
         | 
| 141 | 
            -
                             | 
| 142 | 
             
                            continue
         | 
| 143 |  | 
| 144 | 
             
                        converted_state_dict[key] = value
         | 
|  | |
| 8 | 
             
            import numpy as np
         | 
| 9 | 
             
            from einops import rearrange
         | 
| 10 | 
             
            from torch import nn
         | 
| 11 | 
            +
            from diffusers.utils import logging
         | 
| 12 |  | 
| 13 | 
             
            from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
         | 
| 14 | 
             
            from xora.models.autoencoders.pixel_norm import PixelNorm
         | 
| 15 | 
             
            from xora.models.autoencoders.vae import AutoencoderKLWrapper
         | 
| 16 |  | 
| 17 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 18 |  | 
| 19 | 
             
            class CausalVideoAutoencoder(AutoencoderKLWrapper):
         | 
| 20 | 
             
                @classmethod
         | 
|  | |
| 140 | 
             
                            key = key.replace(k, v)
         | 
| 141 |  | 
| 142 | 
             
                        if "norm" in key and key not in model_keys:
         | 
| 143 | 
            +
                            logger.info(f"Removing key {key} from state_dict as it is not present in the model")
         | 
| 144 | 
             
                            continue
         | 
| 145 |  | 
| 146 | 
             
                        converted_state_dict[key] = value
         | 
    	
        xora/models/autoencoders/vae_encode.py
    CHANGED
    
    | @@ -1,44 +1,12 @@ | |
| 1 | 
             
            import torch
         | 
| 2 | 
            -
            from torch import nn
         | 
| 3 | 
             
            from diffusers import AutoencoderKL
         | 
| 4 | 
             
            from einops import rearrange
         | 
| 5 | 
             
            from torch import Tensor
         | 
| 6 | 
            -
            from torch.nn import functional
         | 
| 7 |  | 
| 8 |  | 
| 9 | 
             
            from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
         | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
                def __init__(self, dims, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1):
         | 
| 13 | 
            -
                    super().__init__()
         | 
| 14 | 
            -
                    stride: int = 2
         | 
| 15 | 
            -
                    self.padding = padding
         | 
| 16 | 
            -
                    self.in_channels = in_channels
         | 
| 17 | 
            -
                    self.dims = dims
         | 
| 18 | 
            -
                    self.conv = make_conv_nd(
         | 
| 19 | 
            -
                        dims=dims,
         | 
| 20 | 
            -
                        in_channels=in_channels,
         | 
| 21 | 
            -
                        out_channels=out_channels,
         | 
| 22 | 
            -
                        kernel_size=kernel_size,
         | 
| 23 | 
            -
                        stride=stride,
         | 
| 24 | 
            -
                        padding=padding,
         | 
| 25 | 
            -
                    )
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                def forward(self, x, downsample_in_time=True):
         | 
| 28 | 
            -
                    conv = self.conv
         | 
| 29 | 
            -
                    if self.padding == 0:
         | 
| 30 | 
            -
                        if self.dims == 2:
         | 
| 31 | 
            -
                            padding = (0, 1, 0, 1)
         | 
| 32 | 
            -
                        else:
         | 
| 33 | 
            -
                            padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)
         | 
| 34 | 
            -
             | 
| 35 | 
            -
                        x = functional.pad(x, padding, mode="constant", value=0)
         | 
| 36 | 
            -
             | 
| 37 | 
            -
                        if self.dims == (2, 1) and not downsample_in_time:
         | 
| 38 | 
            -
                            return conv(x, skip_time_conv=True)
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                    return conv(x)
         | 
| 41 | 
            -
             | 
| 42 |  | 
| 43 |  | 
| 44 | 
             
            def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
         | 
| @@ -78,7 +46,7 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae | |
| 78 | 
             
                if channels != 3:
         | 
| 79 | 
             
                    raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
         | 
| 80 |  | 
| 81 | 
            -
                if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
         | 
| 82 | 
             
                    media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
         | 
| 83 | 
             
                if split_size > 1:
         | 
| 84 | 
             
                    if len(media_items) % split_size != 0:
         | 
| @@ -86,14 +54,16 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae | |
| 86 | 
             
                    encode_bs = len(media_items) // split_size
         | 
| 87 | 
             
                    # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
         | 
| 88 | 
             
                    latents = []
         | 
|  | |
| 89 | 
             
                    for image_batch in media_items.split(encode_bs):
         | 
| 90 | 
             
                        latents.append(vae.encode(image_batch).latent_dist.sample())
         | 
|  | |
| 91 | 
             
                    latents = torch.cat(latents, dim=0)
         | 
| 92 | 
             
                else:
         | 
| 93 | 
             
                    latents = vae.encode(media_items).latent_dist.sample()
         | 
| 94 |  | 
| 95 | 
             
                latents = normalize_latents(latents, vae, vae_per_channel_normalize)
         | 
| 96 | 
            -
                if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
         | 
| 97 | 
             
                    latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
         | 
| 98 | 
             
                return latents
         | 
| 99 |  | 
| @@ -104,7 +74,7 @@ def vae_decode( | |
| 104 | 
             
                is_video_shaped = latents.dim() == 5
         | 
| 105 | 
             
                batch_size = latents.shape[0]
         | 
| 106 |  | 
| 107 | 
            -
                if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
         | 
| 108 | 
             
                    latents = rearrange(latents, "b c n h w -> (b n) c h w")
         | 
| 109 | 
             
                if split_size > 1:
         | 
| 110 | 
             
                    if len(latents) % split_size != 0:
         | 
| @@ -118,13 +88,13 @@ def vae_decode( | |
| 118 | 
             
                else:
         | 
| 119 | 
             
                    images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
         | 
| 120 |  | 
| 121 | 
            -
                if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
         | 
| 122 | 
             
                    images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
         | 
| 123 | 
             
                return images
         | 
| 124 |  | 
| 125 |  | 
| 126 | 
             
            def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False) -> Tensor:
         | 
| 127 | 
            -
                if isinstance(vae, (CausalVideoAutoencoder)):
         | 
| 128 | 
             
                    *_, fl, hl, wl = latents.shape
         | 
| 129 | 
             
                    temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
         | 
| 130 | 
             
                    latents = latents.to(vae.dtype)
         | 
| @@ -148,7 +118,7 @@ def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: | |
| 148 | 
             
                else:
         | 
| 149 | 
             
                    down_blocks = len([block for block in vae.encoder.down_blocks if isinstance(block.downsample, Downsample3D)])
         | 
| 150 | 
             
                    spatial = vae.config.patch_size * 2**down_blocks
         | 
| 151 | 
            -
                    temporal = vae.config.patch_size_t * 2 ** down_blocks if isinstance(vae) else 1
         | 
| 152 |  | 
| 153 | 
             
                return (temporal, spatial, spatial)
         | 
| 154 |  | 
| @@ -168,4 +138,4 @@ def un_normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_no | |
| 168 | 
             
                    + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
         | 
| 169 | 
             
                    if vae_per_channel_normalize
         | 
| 170 | 
             
                    else latents / vae.config.scaling_factor
         | 
| 171 | 
            -
                )
         | 
|  | |
| 1 | 
             
            import torch
         | 
|  | |
| 2 | 
             
            from diffusers import AutoencoderKL
         | 
| 3 | 
             
            from einops import rearrange
         | 
| 4 | 
             
            from torch import Tensor
         | 
|  | |
| 5 |  | 
| 6 |  | 
| 7 | 
             
            from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
         | 
| 8 | 
            +
            from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
         | 
| 9 | 
            +
            import xora.utils.dist_util
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 10 |  | 
| 11 |  | 
| 12 | 
             
            def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
         | 
|  | |
| 46 | 
             
                if channels != 3:
         | 
| 47 | 
             
                    raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
         | 
| 48 |  | 
| 49 | 
            +
                if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
         | 
| 50 | 
             
                    media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
         | 
| 51 | 
             
                if split_size > 1:
         | 
| 52 | 
             
                    if len(media_items) % split_size != 0:
         | 
|  | |
| 54 | 
             
                    encode_bs = len(media_items) // split_size
         | 
| 55 | 
             
                    # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
         | 
| 56 | 
             
                    latents = []
         | 
| 57 | 
            +
                    dist_util.execute_graph()
         | 
| 58 | 
             
                    for image_batch in media_items.split(encode_bs):
         | 
| 59 | 
             
                        latents.append(vae.encode(image_batch).latent_dist.sample())
         | 
| 60 | 
            +
                        dist_util.execute_graph()
         | 
| 61 | 
             
                    latents = torch.cat(latents, dim=0)
         | 
| 62 | 
             
                else:
         | 
| 63 | 
             
                    latents = vae.encode(media_items).latent_dist.sample()
         | 
| 64 |  | 
| 65 | 
             
                latents = normalize_latents(latents, vae, vae_per_channel_normalize)
         | 
| 66 | 
            +
                if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
         | 
| 67 | 
             
                    latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
         | 
| 68 | 
             
                return latents
         | 
| 69 |  | 
|  | |
| 74 | 
             
                is_video_shaped = latents.dim() == 5
         | 
| 75 | 
             
                batch_size = latents.shape[0]
         | 
| 76 |  | 
| 77 | 
            +
                if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
         | 
| 78 | 
             
                    latents = rearrange(latents, "b c n h w -> (b n) c h w")
         | 
| 79 | 
             
                if split_size > 1:
         | 
| 80 | 
             
                    if len(latents) % split_size != 0:
         | 
|  | |
| 88 | 
             
                else:
         | 
| 89 | 
             
                    images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
         | 
| 90 |  | 
| 91 | 
            +
                if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
         | 
| 92 | 
             
                    images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
         | 
| 93 | 
             
                return images
         | 
| 94 |  | 
| 95 |  | 
| 96 | 
             
            def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False) -> Tensor:
         | 
| 97 | 
            +
                if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
         | 
| 98 | 
             
                    *_, fl, hl, wl = latents.shape
         | 
| 99 | 
             
                    temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
         | 
| 100 | 
             
                    latents = latents.to(vae.dtype)
         | 
|  | |
| 118 | 
             
                else:
         | 
| 119 | 
             
                    down_blocks = len([block for block in vae.encoder.down_blocks if isinstance(block.downsample, Downsample3D)])
         | 
| 120 | 
             
                    spatial = vae.config.patch_size * 2**down_blocks
         | 
| 121 | 
            +
                    temporal = vae.config.patch_size_t * 2 ** down_blocks if isinstance(vae, VideoAutoencoder) else 1
         | 
| 122 |  | 
| 123 | 
             
                return (temporal, spatial, spatial)
         | 
| 124 |  | 
|  | |
| 138 | 
             
                    + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
         | 
| 139 | 
             
                    if vae_per_channel_normalize
         | 
| 140 | 
             
                    else latents / vae.config.scaling_factor
         | 
| 141 | 
            +
                )
         | 
    	
        xora/models/autoencoders/video_autoencoder.py
    ADDED
    
    | @@ -0,0 +1,912 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            from functools import partial
         | 
| 4 | 
            +
            from types import SimpleNamespace
         | 
| 5 | 
            +
            from typing import Any, Mapping, Optional, Tuple, Union
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from einops import rearrange
         | 
| 9 | 
            +
            from torch import nn
         | 
| 10 | 
            +
            from torch.nn import functional
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from diffusers.utils import logging
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from txt2img.models.layers.nn import Identity
         | 
| 15 | 
            +
            from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
         | 
| 16 | 
            +
            from xora.models.autoencoders.pixel_norm import PixelNorm
         | 
| 17 | 
            +
            from xora.models.autoencoders.vae import AutoencoderKLWrapper
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            logger = logging.get_logger(__name__)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class VideoAutoencoder(AutoencoderKLWrapper):
         | 
| 23 | 
            +
                @classmethod
         | 
| 24 | 
            +
                def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *args, **kwargs):
         | 
| 25 | 
            +
                    config_local_path = pretrained_model_name_or_path / "config.json"
         | 
| 26 | 
            +
                    config = cls.load_config(config_local_path, **kwargs)
         | 
| 27 | 
            +
                    video_vae = cls.from_config(config)
         | 
| 28 | 
            +
                    video_vae.to(kwargs["torch_dtype"])
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
         | 
| 31 | 
            +
                    ckpt_state_dict = torch.load(model_local_path)
         | 
| 32 | 
            +
                    video_vae.load_state_dict(ckpt_state_dict)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    statistics_local_path = pretrained_model_name_or_path / "per_channel_statistics.json"
         | 
| 35 | 
            +
                    if statistics_local_path.exists():
         | 
| 36 | 
            +
                        with open(statistics_local_path, "r") as file:
         | 
| 37 | 
            +
                            data = json.load(file)
         | 
| 38 | 
            +
                        transposed_data = list(zip(*data["data"]))
         | 
| 39 | 
            +
                        data_dict = {col: torch.tensor(vals) for col, vals in zip(data["columns"], transposed_data)}
         | 
| 40 | 
            +
                        video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
         | 
| 41 | 
            +
                        video_vae.register_buffer(
         | 
| 42 | 
            +
                            "mean_of_means", data_dict.get("mean-of-means", torch.zeros_like(data_dict["std-of-means"]))
         | 
| 43 | 
            +
                        )
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    return video_vae
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                @staticmethod
         | 
| 48 | 
            +
                def from_config(config):
         | 
| 49 | 
            +
                    assert config["_class_name"] == "VideoAutoencoder", "config must have _class_name=VideoAutoencoder"
         | 
| 50 | 
            +
                    if isinstance(config["dims"], list):
         | 
| 51 | 
            +
                        config["dims"] = tuple(config["dims"])
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    double_z = config.get("double_z", True)
         | 
| 56 | 
            +
                    latent_log_var = config.get("latent_log_var", "per_channel" if double_z else "none")
         | 
| 57 | 
            +
                    use_quant_conv = config.get("use_quant_conv", True)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    if use_quant_conv and latent_log_var == "uniform":
         | 
| 60 | 
            +
                        raise ValueError("uniform latent_log_var requires use_quant_conv=False")
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    encoder = Encoder(
         | 
| 63 | 
            +
                        dims=config["dims"],
         | 
| 64 | 
            +
                        in_channels=config.get("in_channels", 3),
         | 
| 65 | 
            +
                        out_channels=config["latent_channels"],
         | 
| 66 | 
            +
                        block_out_channels=config["block_out_channels"],
         | 
| 67 | 
            +
                        patch_size=config.get("patch_size", 1),
         | 
| 68 | 
            +
                        latent_log_var=latent_log_var,
         | 
| 69 | 
            +
                        norm_layer=config.get("norm_layer", "group_norm"),
         | 
| 70 | 
            +
                        patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
         | 
| 71 | 
            +
                        add_channel_padding=config.get("add_channel_padding", False),
         | 
| 72 | 
            +
                    )
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    decoder = Decoder(
         | 
| 75 | 
            +
                        dims=config["dims"],
         | 
| 76 | 
            +
                        in_channels=config["latent_channels"],
         | 
| 77 | 
            +
                        out_channels=config.get("out_channels", 3),
         | 
| 78 | 
            +
                        block_out_channels=config["block_out_channels"],
         | 
| 79 | 
            +
                        patch_size=config.get("patch_size", 1),
         | 
| 80 | 
            +
                        norm_layer=config.get("norm_layer", "group_norm"),
         | 
| 81 | 
            +
                        patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
         | 
| 82 | 
            +
                        add_channel_padding=config.get("add_channel_padding", False),
         | 
| 83 | 
            +
                    )
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    dims = config["dims"]
         | 
| 86 | 
            +
                    return VideoAutoencoder(
         | 
| 87 | 
            +
                        encoder=encoder,
         | 
| 88 | 
            +
                        decoder=decoder,
         | 
| 89 | 
            +
                        latent_channels=config["latent_channels"],
         | 
| 90 | 
            +
                        dims=dims,
         | 
| 91 | 
            +
                        use_quant_conv=use_quant_conv,
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                @property
         | 
| 95 | 
            +
                def config(self):
         | 
| 96 | 
            +
                    return SimpleNamespace(
         | 
| 97 | 
            +
                        _class_name="VideoAutoencoder",
         | 
| 98 | 
            +
                        dims=self.dims,
         | 
| 99 | 
            +
                        in_channels=self.encoder.conv_in.in_channels // (self.encoder.patch_size_t * self.encoder.patch_size**2),
         | 
| 100 | 
            +
                        out_channels=self.decoder.conv_out.out_channels // (self.decoder.patch_size_t * self.decoder.patch_size**2),
         | 
| 101 | 
            +
                        latent_channels=self.decoder.conv_in.in_channels,
         | 
| 102 | 
            +
                        block_out_channels=[
         | 
| 103 | 
            +
                            self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
         | 
| 104 | 
            +
                            for i in range(len(self.encoder.down_blocks))
         | 
| 105 | 
            +
                        ],
         | 
| 106 | 
            +
                        scaling_factor=1.0,
         | 
| 107 | 
            +
                        norm_layer=self.encoder.norm_layer,
         | 
| 108 | 
            +
                        patch_size=self.encoder.patch_size,
         | 
| 109 | 
            +
                        latent_log_var=self.encoder.latent_log_var,
         | 
| 110 | 
            +
                        use_quant_conv=self.use_quant_conv,
         | 
| 111 | 
            +
                        patch_size_t=self.encoder.patch_size_t,
         | 
| 112 | 
            +
                        add_channel_padding=self.encoder.add_channel_padding,
         | 
| 113 | 
            +
                    )
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                @property
         | 
| 116 | 
            +
                def is_video_supported(self):
         | 
| 117 | 
            +
                    """
         | 
| 118 | 
            +
                    Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
         | 
| 119 | 
            +
                    """
         | 
| 120 | 
            +
                    return self.dims != 2
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                @property
         | 
| 123 | 
            +
                def downscale_factor(self):
         | 
| 124 | 
            +
                    return self.encoder.downsample_factor
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                def to_json_string(self) -> str:
         | 
| 127 | 
            +
                    import json
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    return json.dumps(self.config.__dict__)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
         | 
| 132 | 
            +
                    model_keys = set(name for name, _ in self.named_parameters())
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    key_mapping = {
         | 
| 135 | 
            +
                        ".resnets.": ".res_blocks.",
         | 
| 136 | 
            +
                        "downsamplers.0": "downsample",
         | 
| 137 | 
            +
                        "upsamplers.0": "upsample",
         | 
| 138 | 
            +
                    }
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    converted_state_dict = {}
         | 
| 141 | 
            +
                    for key, value in state_dict.items():
         | 
| 142 | 
            +
                        for k, v in key_mapping.items():
         | 
| 143 | 
            +
                            key = key.replace(k, v)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                        if "norm" in key and key not in model_keys:
         | 
| 146 | 
            +
                            logger.info(f"Removing key {key} from state_dict as it is not present in the model")
         | 
| 147 | 
            +
                            continue
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                        converted_state_dict[key] = value
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    super().load_state_dict(converted_state_dict, strict=strict)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                def last_layer(self):
         | 
| 154 | 
            +
                    if hasattr(self.decoder, "conv_out"):
         | 
| 155 | 
            +
                        if isinstance(self.decoder.conv_out, nn.Sequential):
         | 
| 156 | 
            +
                            last_layer = self.decoder.conv_out[-1]
         | 
| 157 | 
            +
                        else:
         | 
| 158 | 
            +
                            last_layer = self.decoder.conv_out
         | 
| 159 | 
            +
                    else:
         | 
| 160 | 
            +
                        last_layer = self.decoder.layers[-1]
         | 
| 161 | 
            +
                    return last_layer
         | 
| 162 | 
            +
             | 
| 163 | 
            +
             | 
| 164 | 
            +
            class Encoder(nn.Module):
         | 
| 165 | 
            +
                r"""
         | 
| 166 | 
            +
                The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                Args:
         | 
| 169 | 
            +
                    in_channels (`int`, *optional*, defaults to 3):
         | 
| 170 | 
            +
                        The number of input channels.
         | 
| 171 | 
            +
                    out_channels (`int`, *optional*, defaults to 3):
         | 
| 172 | 
            +
                        The number of output channels.
         | 
| 173 | 
            +
                    block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
         | 
| 174 | 
            +
                        The number of output channels for each block.
         | 
| 175 | 
            +
                    layers_per_block (`int`, *optional*, defaults to 2):
         | 
| 176 | 
            +
                        The number of layers per block.
         | 
| 177 | 
            +
                    norm_num_groups (`int`, *optional*, defaults to 32):
         | 
| 178 | 
            +
                        The number of groups for normalization.
         | 
| 179 | 
            +
                    patch_size (`int`, *optional*, defaults to 1):
         | 
| 180 | 
            +
                        The patch size to use. Should be a power of 2.
         | 
| 181 | 
            +
                    norm_layer (`str`, *optional*, defaults to `group_norm`):
         | 
| 182 | 
            +
                        The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
         | 
| 183 | 
            +
                    latent_log_var (`str`, *optional*, defaults to `per_channel`):
         | 
| 184 | 
            +
                        The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
         | 
| 185 | 
            +
                """
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                def __init__(
         | 
| 188 | 
            +
                    self,
         | 
| 189 | 
            +
                    dims: Union[int, Tuple[int, int]] = 3,
         | 
| 190 | 
            +
                    in_channels: int = 3,
         | 
| 191 | 
            +
                    out_channels: int = 3,
         | 
| 192 | 
            +
                    block_out_channels: Tuple[int, ...] = (64,),
         | 
| 193 | 
            +
                    layers_per_block: int = 2,
         | 
| 194 | 
            +
                    norm_num_groups: int = 32,
         | 
| 195 | 
            +
                    patch_size: Union[int, Tuple[int]] = 1,
         | 
| 196 | 
            +
                    norm_layer: str = "group_norm",  # group_norm, pixel_norm
         | 
| 197 | 
            +
                    latent_log_var: str = "per_channel",
         | 
| 198 | 
            +
                    patch_size_t: Optional[int] = None,
         | 
| 199 | 
            +
                    add_channel_padding: Optional[bool] = False,
         | 
| 200 | 
            +
                ):
         | 
| 201 | 
            +
                    super().__init__()
         | 
| 202 | 
            +
                    self.patch_size = patch_size
         | 
| 203 | 
            +
                    self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
         | 
| 204 | 
            +
                    self.add_channel_padding = add_channel_padding
         | 
| 205 | 
            +
                    self.layers_per_block = layers_per_block
         | 
| 206 | 
            +
                    self.norm_layer = norm_layer
         | 
| 207 | 
            +
                    self.latent_channels = out_channels
         | 
| 208 | 
            +
                    self.latent_log_var = latent_log_var
         | 
| 209 | 
            +
                    if add_channel_padding:
         | 
| 210 | 
            +
                        in_channels = in_channels * self.patch_size**3
         | 
| 211 | 
            +
                    else:
         | 
| 212 | 
            +
                        in_channels = in_channels * self.patch_size_t * self.patch_size**2
         | 
| 213 | 
            +
                    self.in_channels = in_channels
         | 
| 214 | 
            +
                    output_channel = block_out_channels[0]
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    self.conv_in = make_conv_nd(
         | 
| 217 | 
            +
                        dims=dims,
         | 
| 218 | 
            +
                        in_channels=in_channels,
         | 
| 219 | 
            +
                        out_channels=output_channel,
         | 
| 220 | 
            +
                        kernel_size=3,
         | 
| 221 | 
            +
                        stride=1,
         | 
| 222 | 
            +
                        padding=1,
         | 
| 223 | 
            +
                    )
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    self.down_blocks = nn.ModuleList([])
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    for i in range(len(block_out_channels)):
         | 
| 228 | 
            +
                        input_channel = output_channel
         | 
| 229 | 
            +
                        output_channel = block_out_channels[i]
         | 
| 230 | 
            +
                        is_final_block = i == len(block_out_channels) - 1
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                        down_block = DownEncoderBlock3D(
         | 
| 233 | 
            +
                            dims=dims,
         | 
| 234 | 
            +
                            in_channels=input_channel,
         | 
| 235 | 
            +
                            out_channels=output_channel,
         | 
| 236 | 
            +
                            num_layers=self.layers_per_block,
         | 
| 237 | 
            +
                            add_downsample=not is_final_block and 2**i >= patch_size,
         | 
| 238 | 
            +
                            resnet_eps=1e-6,
         | 
| 239 | 
            +
                            downsample_padding=0,
         | 
| 240 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 241 | 
            +
                            norm_layer=norm_layer,
         | 
| 242 | 
            +
                        )
         | 
| 243 | 
            +
                        self.down_blocks.append(down_block)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    self.mid_block = UNetMidBlock3D(
         | 
| 246 | 
            +
                        dims=dims,
         | 
| 247 | 
            +
                        in_channels=block_out_channels[-1],
         | 
| 248 | 
            +
                        num_layers=self.layers_per_block,
         | 
| 249 | 
            +
                        resnet_eps=1e-6,
         | 
| 250 | 
            +
                        resnet_groups=norm_num_groups,
         | 
| 251 | 
            +
                        norm_layer=norm_layer,
         | 
| 252 | 
            +
                    )
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    # out
         | 
| 255 | 
            +
                    if norm_layer == "group_norm":
         | 
| 256 | 
            +
                        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
         | 
| 257 | 
            +
                    elif norm_layer == "pixel_norm":
         | 
| 258 | 
            +
                        self.conv_norm_out = PixelNorm()
         | 
| 259 | 
            +
                    self.conv_act = nn.SiLU()
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    conv_out_channels = out_channels
         | 
| 262 | 
            +
                    if latent_log_var == "per_channel":
         | 
| 263 | 
            +
                        conv_out_channels *= 2
         | 
| 264 | 
            +
                    elif latent_log_var == "uniform":
         | 
| 265 | 
            +
                        conv_out_channels += 1
         | 
| 266 | 
            +
                    elif latent_log_var != "none":
         | 
| 267 | 
            +
                        raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
         | 
| 268 | 
            +
                    self.conv_out = make_conv_nd(dims, block_out_channels[-1], conv_out_channels, 3, padding=1)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    self.gradient_checkpointing = False
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                @property
         | 
| 273 | 
            +
                def downscale_factor(self):
         | 
| 274 | 
            +
                    return (
         | 
| 275 | 
            +
                        2 ** len([block for block in self.down_blocks if isinstance(block.downsample, Downsample3D)])
         | 
| 276 | 
            +
                        * self.patch_size
         | 
| 277 | 
            +
                    )
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
         | 
| 280 | 
            +
                    r"""The forward method of the `Encoder` class."""
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    downsample_in_time = sample.shape[2] != 1
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    # patchify
         | 
| 285 | 
            +
                    patch_size_t = self.patch_size_t if downsample_in_time else 1
         | 
| 286 | 
            +
                    sample = patchify(
         | 
| 287 | 
            +
                        sample,
         | 
| 288 | 
            +
                        patch_size_hw=self.patch_size,
         | 
| 289 | 
            +
                        patch_size_t=patch_size_t,
         | 
| 290 | 
            +
                        add_channel_padding=self.add_channel_padding,
         | 
| 291 | 
            +
                    )
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    sample = self.conv_in(sample)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    checkpoint_fn = (
         | 
| 296 | 
            +
                        partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
         | 
| 297 | 
            +
                        if self.gradient_checkpointing and self.training
         | 
| 298 | 
            +
                        else lambda x: x
         | 
| 299 | 
            +
                    )
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    for down_block in self.down_blocks:
         | 
| 302 | 
            +
                        sample = checkpoint_fn(down_block)(sample, downsample_in_time=downsample_in_time)
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    sample = checkpoint_fn(self.mid_block)(sample)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    # post-process
         | 
| 307 | 
            +
                    sample = self.conv_norm_out(sample)
         | 
| 308 | 
            +
                    sample = self.conv_act(sample)
         | 
| 309 | 
            +
                    sample = self.conv_out(sample)
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    if self.latent_log_var == "uniform":
         | 
| 312 | 
            +
                        last_channel = sample[:, -1:, ...]
         | 
| 313 | 
            +
                        num_dims = sample.dim()
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                        if num_dims == 4:
         | 
| 316 | 
            +
                            # For shape (B, C, H, W)
         | 
| 317 | 
            +
                            repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1)
         | 
| 318 | 
            +
                            sample = torch.cat([sample, repeated_last_channel], dim=1)
         | 
| 319 | 
            +
                        elif num_dims == 5:
         | 
| 320 | 
            +
                            # For shape (B, C, F, H, W)
         | 
| 321 | 
            +
                            repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1, 1)
         | 
| 322 | 
            +
                            sample = torch.cat([sample, repeated_last_channel], dim=1)
         | 
| 323 | 
            +
                        else:
         | 
| 324 | 
            +
                            raise ValueError(f"Invalid input shape: {sample.shape}")
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    return sample
         | 
| 327 | 
            +
             | 
| 328 | 
            +
             | 
| 329 | 
            +
            class Decoder(nn.Module):
         | 
| 330 | 
            +
                r"""
         | 
| 331 | 
            +
                The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                Args:
         | 
| 334 | 
            +
                    in_channels (`int`, *optional*, defaults to 3):
         | 
| 335 | 
            +
                        The number of input channels.
         | 
| 336 | 
            +
                    out_channels (`int`, *optional*, defaults to 3):
         | 
| 337 | 
            +
                        The number of output channels.
         | 
| 338 | 
            +
                    block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
         | 
| 339 | 
            +
                        The number of output channels for each block.
         | 
| 340 | 
            +
                    layers_per_block (`int`, *optional*, defaults to 2):
         | 
| 341 | 
            +
                        The number of layers per block.
         | 
| 342 | 
            +
                    norm_num_groups (`int`, *optional*, defaults to 32):
         | 
| 343 | 
            +
                        The number of groups for normalization.
         | 
| 344 | 
            +
                    patch_size (`int`, *optional*, defaults to 1):
         | 
| 345 | 
            +
                        The patch size to use. Should be a power of 2.
         | 
| 346 | 
            +
                    norm_layer (`str`, *optional*, defaults to `group_norm`):
         | 
| 347 | 
            +
                        The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
         | 
| 348 | 
            +
                """
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                def __init__(
         | 
| 351 | 
            +
                    self,
         | 
| 352 | 
            +
                    dims,
         | 
| 353 | 
            +
                    in_channels: int = 3,
         | 
| 354 | 
            +
                    out_channels: int = 3,
         | 
| 355 | 
            +
                    block_out_channels: Tuple[int, ...] = (64,),
         | 
| 356 | 
            +
                    layers_per_block: int = 2,
         | 
| 357 | 
            +
                    norm_num_groups: int = 32,
         | 
| 358 | 
            +
                    patch_size: int = 1,
         | 
| 359 | 
            +
                    norm_layer: str = "group_norm",
         | 
| 360 | 
            +
                    patch_size_t: Optional[int] = None,
         | 
| 361 | 
            +
                    add_channel_padding: Optional[bool] = False,
         | 
| 362 | 
            +
                ):
         | 
| 363 | 
            +
                    super().__init__()
         | 
| 364 | 
            +
                    self.patch_size = patch_size
         | 
| 365 | 
            +
                    self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
         | 
| 366 | 
            +
                    self.add_channel_padding = add_channel_padding
         | 
| 367 | 
            +
                    self.layers_per_block = layers_per_block
         | 
| 368 | 
            +
                    if add_channel_padding:
         | 
| 369 | 
            +
                        out_channels = out_channels * self.patch_size**3
         | 
| 370 | 
            +
                    else:
         | 
| 371 | 
            +
                        out_channels = out_channels * self.patch_size_t * self.patch_size**2
         | 
| 372 | 
            +
                    self.out_channels = out_channels
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    self.conv_in = make_conv_nd(
         | 
| 375 | 
            +
                        dims,
         | 
| 376 | 
            +
                        in_channels,
         | 
| 377 | 
            +
                        block_out_channels[-1],
         | 
| 378 | 
            +
                        kernel_size=3,
         | 
| 379 | 
            +
                        stride=1,
         | 
| 380 | 
            +
                        padding=1,
         | 
| 381 | 
            +
                    )
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    self.mid_block = None
         | 
| 384 | 
            +
                    self.up_blocks = nn.ModuleList([])
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                    self.mid_block = UNetMidBlock3D(
         | 
| 387 | 
            +
                        dims=dims,
         | 
| 388 | 
            +
                        in_channels=block_out_channels[-1],
         | 
| 389 | 
            +
                        num_layers=self.layers_per_block,
         | 
| 390 | 
            +
                        resnet_eps=1e-6,
         | 
| 391 | 
            +
                        resnet_groups=norm_num_groups,
         | 
| 392 | 
            +
                        norm_layer=norm_layer,
         | 
| 393 | 
            +
                    )
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    reversed_block_out_channels = list(reversed(block_out_channels))
         | 
| 396 | 
            +
                    output_channel = reversed_block_out_channels[0]
         | 
| 397 | 
            +
                    for i in range(len(reversed_block_out_channels)):
         | 
| 398 | 
            +
                        prev_output_channel = output_channel
         | 
| 399 | 
            +
                        output_channel = reversed_block_out_channels[i]
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                        is_final_block = i == len(block_out_channels) - 1
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                        up_block = UpDecoderBlock3D(
         | 
| 404 | 
            +
                            dims=dims,
         | 
| 405 | 
            +
                            num_layers=self.layers_per_block + 1,
         | 
| 406 | 
            +
                            in_channels=prev_output_channel,
         | 
| 407 | 
            +
                            out_channels=output_channel,
         | 
| 408 | 
            +
                            add_upsample=not is_final_block and 2 ** (len(block_out_channels) - i - 1) > patch_size,
         | 
| 409 | 
            +
                            resnet_eps=1e-6,
         | 
| 410 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 411 | 
            +
                            norm_layer=norm_layer,
         | 
| 412 | 
            +
                        )
         | 
| 413 | 
            +
                        self.up_blocks.append(up_block)
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                    if norm_layer == "group_norm":
         | 
| 416 | 
            +
                        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
         | 
| 417 | 
            +
                    elif norm_layer == "pixel_norm":
         | 
| 418 | 
            +
                        self.conv_norm_out = PixelNorm()
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                    self.conv_act = nn.SiLU()
         | 
| 421 | 
            +
                    self.conv_out = make_conv_nd(dims, block_out_channels[0], out_channels, 3, padding=1)
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                    self.gradient_checkpointing = False
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
         | 
| 426 | 
            +
                    r"""The forward method of the `Decoder` class."""
         | 
| 427 | 
            +
                    assert target_shape is not None, "target_shape must be provided"
         | 
| 428 | 
            +
                    upsample_in_time = sample.shape[2] < target_shape[2]
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    sample = self.conv_in(sample)
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                    checkpoint_fn = (
         | 
| 435 | 
            +
                        partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
         | 
| 436 | 
            +
                        if self.gradient_checkpointing and self.training
         | 
| 437 | 
            +
                        else lambda x: x
         | 
| 438 | 
            +
                    )
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    sample = checkpoint_fn(self.mid_block)(sample)
         | 
| 441 | 
            +
                    sample = sample.to(upscale_dtype)
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                    for up_block in self.up_blocks:
         | 
| 444 | 
            +
                        sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time)
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                    # post-process
         | 
| 447 | 
            +
                    sample = self.conv_norm_out(sample)
         | 
| 448 | 
            +
                    sample = self.conv_act(sample)
         | 
| 449 | 
            +
                    sample = self.conv_out(sample)
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    # un-patchify
         | 
| 452 | 
            +
                    patch_size_t = self.patch_size_t if upsample_in_time else 1
         | 
| 453 | 
            +
                    sample = unpatchify(
         | 
| 454 | 
            +
                        sample,
         | 
| 455 | 
            +
                        patch_size_hw=self.patch_size,
         | 
| 456 | 
            +
                        patch_size_t=patch_size_t,
         | 
| 457 | 
            +
                        add_channel_padding=self.add_channel_padding,
         | 
| 458 | 
            +
                    )
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                    return sample
         | 
| 461 | 
            +
             | 
| 462 | 
            +
             | 
| 463 | 
            +
            class DownEncoderBlock3D(nn.Module):
         | 
| 464 | 
            +
                def __init__(
         | 
| 465 | 
            +
                    self,
         | 
| 466 | 
            +
                    dims: Union[int, Tuple[int, int]],
         | 
| 467 | 
            +
                    in_channels: int,
         | 
| 468 | 
            +
                    out_channels: int,
         | 
| 469 | 
            +
                    dropout: float = 0.0,
         | 
| 470 | 
            +
                    num_layers: int = 1,
         | 
| 471 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 472 | 
            +
                    resnet_groups: int = 32,
         | 
| 473 | 
            +
                    add_downsample: bool = True,
         | 
| 474 | 
            +
                    downsample_padding: int = 1,
         | 
| 475 | 
            +
                    norm_layer: str = "group_norm",
         | 
| 476 | 
            +
                ):
         | 
| 477 | 
            +
                    super().__init__()
         | 
| 478 | 
            +
                    res_blocks = []
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                    for i in range(num_layers):
         | 
| 481 | 
            +
                        in_channels = in_channels if i == 0 else out_channels
         | 
| 482 | 
            +
                        res_blocks.append(
         | 
| 483 | 
            +
                            ResnetBlock3D(
         | 
| 484 | 
            +
                                dims=dims,
         | 
| 485 | 
            +
                                in_channels=in_channels,
         | 
| 486 | 
            +
                                out_channels=out_channels,
         | 
| 487 | 
            +
                                eps=resnet_eps,
         | 
| 488 | 
            +
                                groups=resnet_groups,
         | 
| 489 | 
            +
                                dropout=dropout,
         | 
| 490 | 
            +
                                norm_layer=norm_layer,
         | 
| 491 | 
            +
                            )
         | 
| 492 | 
            +
                        )
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                    self.res_blocks = nn.ModuleList(res_blocks)
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                    if add_downsample:
         | 
| 497 | 
            +
                        self.downsample = Downsample3D(dims, out_channels, out_channels=out_channels, padding=downsample_padding)
         | 
| 498 | 
            +
                    else:
         | 
| 499 | 
            +
                        self.downsample = Identity()
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                def forward(self, hidden_states: torch.FloatTensor, downsample_in_time) -> torch.FloatTensor:
         | 
| 502 | 
            +
                    for resnet in self.res_blocks:
         | 
| 503 | 
            +
                        hidden_states = resnet(hidden_states)
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                    hidden_states = self.downsample(hidden_states, downsample_in_time=downsample_in_time)
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                    return hidden_states
         | 
| 508 | 
            +
             | 
| 509 | 
            +
             | 
| 510 | 
            +
            class UNetMidBlock3D(nn.Module):
         | 
| 511 | 
            +
                """
         | 
| 512 | 
            +
                A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                Args:
         | 
| 515 | 
            +
                    in_channels (`int`): The number of input channels.
         | 
| 516 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
         | 
| 517 | 
            +
                    num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
         | 
| 518 | 
            +
                    resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
         | 
| 519 | 
            +
                    resnet_groups (`int`, *optional*, defaults to 32):
         | 
| 520 | 
            +
                        The number of groups to use in the group normalization layers of the resnet blocks.
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                Returns:
         | 
| 523 | 
            +
                    `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
         | 
| 524 | 
            +
                    in_channels, height, width)`.
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                """
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                def __init__(
         | 
| 529 | 
            +
                    self,
         | 
| 530 | 
            +
                    dims: Union[int, Tuple[int, int]],
         | 
| 531 | 
            +
                    in_channels: int,
         | 
| 532 | 
            +
                    dropout: float = 0.0,
         | 
| 533 | 
            +
                    num_layers: int = 1,
         | 
| 534 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 535 | 
            +
                    resnet_groups: int = 32,
         | 
| 536 | 
            +
                    norm_layer: str = "group_norm",
         | 
| 537 | 
            +
                ):
         | 
| 538 | 
            +
                    super().__init__()
         | 
| 539 | 
            +
                    resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                    self.res_blocks = nn.ModuleList(
         | 
| 542 | 
            +
                        [
         | 
| 543 | 
            +
                            ResnetBlock3D(
         | 
| 544 | 
            +
                                dims=dims,
         | 
| 545 | 
            +
                                in_channels=in_channels,
         | 
| 546 | 
            +
                                out_channels=in_channels,
         | 
| 547 | 
            +
                                eps=resnet_eps,
         | 
| 548 | 
            +
                                groups=resnet_groups,
         | 
| 549 | 
            +
                                dropout=dropout,
         | 
| 550 | 
            +
                                norm_layer=norm_layer,
         | 
| 551 | 
            +
                            )
         | 
| 552 | 
            +
                            for _ in range(num_layers)
         | 
| 553 | 
            +
                        ]
         | 
| 554 | 
            +
                    )
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
         | 
| 557 | 
            +
                    for resnet in self.res_blocks:
         | 
| 558 | 
            +
                        hidden_states = resnet(hidden_states)
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                    return hidden_states
         | 
| 561 | 
            +
             | 
| 562 | 
            +
             | 
| 563 | 
            +
            class UpDecoderBlock3D(nn.Module):
         | 
| 564 | 
            +
                def __init__(
         | 
| 565 | 
            +
                    self,
         | 
| 566 | 
            +
                    dims: Union[int, Tuple[int, int]],
         | 
| 567 | 
            +
                    in_channels: int,
         | 
| 568 | 
            +
                    out_channels: int,
         | 
| 569 | 
            +
                    resolution_idx: Optional[int] = None,
         | 
| 570 | 
            +
                    dropout: float = 0.0,
         | 
| 571 | 
            +
                    num_layers: int = 1,
         | 
| 572 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 573 | 
            +
                    resnet_groups: int = 32,
         | 
| 574 | 
            +
                    add_upsample: bool = True,
         | 
| 575 | 
            +
                    norm_layer: str = "group_norm",
         | 
| 576 | 
            +
                ):
         | 
| 577 | 
            +
                    super().__init__()
         | 
| 578 | 
            +
                    res_blocks = []
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                    for i in range(num_layers):
         | 
| 581 | 
            +
                        input_channels = in_channels if i == 0 else out_channels
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                        res_blocks.append(
         | 
| 584 | 
            +
                            ResnetBlock3D(
         | 
| 585 | 
            +
                                dims=dims,
         | 
| 586 | 
            +
                                in_channels=input_channels,
         | 
| 587 | 
            +
                                out_channels=out_channels,
         | 
| 588 | 
            +
                                eps=resnet_eps,
         | 
| 589 | 
            +
                                groups=resnet_groups,
         | 
| 590 | 
            +
                                dropout=dropout,
         | 
| 591 | 
            +
                                norm_layer=norm_layer,
         | 
| 592 | 
            +
                            )
         | 
| 593 | 
            +
                        )
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                    self.res_blocks = nn.ModuleList(res_blocks)
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                    if add_upsample:
         | 
| 598 | 
            +
                        self.upsample = Upsample3D(dims=dims, channels=out_channels, out_channels=out_channels)
         | 
| 599 | 
            +
                    else:
         | 
| 600 | 
            +
                        self.upsample = Identity()
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                    self.resolution_idx = resolution_idx
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                def forward(self, hidden_states: torch.FloatTensor, upsample_in_time=True) -> torch.FloatTensor:
         | 
| 605 | 
            +
                    for resnet in self.res_blocks:
         | 
| 606 | 
            +
                        hidden_states = resnet(hidden_states)
         | 
| 607 | 
            +
             | 
| 608 | 
            +
                    hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time)
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                    return hidden_states
         | 
| 611 | 
            +
             | 
| 612 | 
            +
             | 
| 613 | 
            +
            class ResnetBlock3D(nn.Module):
         | 
| 614 | 
            +
                r"""
         | 
| 615 | 
            +
                A Resnet block.
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                Parameters:
         | 
| 618 | 
            +
                    in_channels (`int`): The number of channels in the input.
         | 
| 619 | 
            +
                    out_channels (`int`, *optional*, default to be `None`):
         | 
| 620 | 
            +
                        The number of output channels for the first conv layer. If None, same as `in_channels`.
         | 
| 621 | 
            +
                    dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
         | 
| 622 | 
            +
                    groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
         | 
| 623 | 
            +
                    eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
         | 
| 624 | 
            +
                """
         | 
| 625 | 
            +
             | 
| 626 | 
            +
                def __init__(
         | 
| 627 | 
            +
                    self,
         | 
| 628 | 
            +
                    dims: Union[int, Tuple[int, int]],
         | 
| 629 | 
            +
                    in_channels: int,
         | 
| 630 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 631 | 
            +
                    conv_shortcut: bool = False,
         | 
| 632 | 
            +
                    dropout: float = 0.0,
         | 
| 633 | 
            +
                    groups: int = 32,
         | 
| 634 | 
            +
                    eps: float = 1e-6,
         | 
| 635 | 
            +
                    norm_layer: str = "group_norm",
         | 
| 636 | 
            +
                ):
         | 
| 637 | 
            +
                    super().__init__()
         | 
| 638 | 
            +
                    self.in_channels = in_channels
         | 
| 639 | 
            +
                    out_channels = in_channels if out_channels is None else out_channels
         | 
| 640 | 
            +
                    self.out_channels = out_channels
         | 
| 641 | 
            +
                    self.use_conv_shortcut = conv_shortcut
         | 
| 642 | 
            +
             | 
| 643 | 
            +
                    if norm_layer == "group_norm":
         | 
| 644 | 
            +
                        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
         | 
| 645 | 
            +
                    elif norm_layer == "pixel_norm":
         | 
| 646 | 
            +
                        self.norm1 = PixelNorm()
         | 
| 647 | 
            +
             | 
| 648 | 
            +
                    self.non_linearity = nn.SiLU()
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                    self.conv1 = make_conv_nd(dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1)
         | 
| 651 | 
            +
             | 
| 652 | 
            +
                    if norm_layer == "group_norm":
         | 
| 653 | 
            +
                        self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
         | 
| 654 | 
            +
                    elif norm_layer == "pixel_norm":
         | 
| 655 | 
            +
                        self.norm2 = PixelNorm()
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                    self.dropout = torch.nn.Dropout(dropout)
         | 
| 658 | 
            +
             | 
| 659 | 
            +
                    self.conv2 = make_conv_nd(dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1)
         | 
| 660 | 
            +
             | 
| 661 | 
            +
                    self.conv_shortcut = (
         | 
| 662 | 
            +
                        make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels)
         | 
| 663 | 
            +
                        if in_channels != out_channels
         | 
| 664 | 
            +
                        else nn.Identity()
         | 
| 665 | 
            +
                    )
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                def forward(
         | 
| 668 | 
            +
                    self,
         | 
| 669 | 
            +
                    input_tensor: torch.FloatTensor,
         | 
| 670 | 
            +
                ) -> torch.FloatTensor:
         | 
| 671 | 
            +
                    hidden_states = input_tensor
         | 
| 672 | 
            +
             | 
| 673 | 
            +
                    hidden_states = self.norm1(hidden_states)
         | 
| 674 | 
            +
             | 
| 675 | 
            +
                    hidden_states = self.non_linearity(hidden_states)
         | 
| 676 | 
            +
             | 
| 677 | 
            +
                    hidden_states = self.conv1(hidden_states)
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                    hidden_states = self.norm2(hidden_states)
         | 
| 680 | 
            +
             | 
| 681 | 
            +
                    hidden_states = self.non_linearity(hidden_states)
         | 
| 682 | 
            +
             | 
| 683 | 
            +
                    hidden_states = self.dropout(hidden_states)
         | 
| 684 | 
            +
             | 
| 685 | 
            +
                    hidden_states = self.conv2(hidden_states)
         | 
| 686 | 
            +
             | 
| 687 | 
            +
                    input_tensor = self.conv_shortcut(input_tensor)
         | 
| 688 | 
            +
             | 
| 689 | 
            +
                    output_tensor = input_tensor + hidden_states
         | 
| 690 | 
            +
             | 
| 691 | 
            +
                    return output_tensor
         | 
| 692 | 
            +
             | 
| 693 | 
            +
             | 
| 694 | 
            +
            class Downsample3D(nn.Module):
         | 
| 695 | 
            +
                def __init__(self, dims, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1):
         | 
| 696 | 
            +
                    super().__init__()
         | 
| 697 | 
            +
                    stride: int = 2
         | 
| 698 | 
            +
                    self.padding = padding
         | 
| 699 | 
            +
                    self.in_channels = in_channels
         | 
| 700 | 
            +
                    self.dims = dims
         | 
| 701 | 
            +
                    self.conv = make_conv_nd(
         | 
| 702 | 
            +
                        dims=dims,
         | 
| 703 | 
            +
                        in_channels=in_channels,
         | 
| 704 | 
            +
                        out_channels=out_channels,
         | 
| 705 | 
            +
                        kernel_size=kernel_size,
         | 
| 706 | 
            +
                        stride=stride,
         | 
| 707 | 
            +
                        padding=padding,
         | 
| 708 | 
            +
                    )
         | 
| 709 | 
            +
             | 
| 710 | 
            +
                def forward(self, x, downsample_in_time=True):
         | 
| 711 | 
            +
                    conv = self.conv
         | 
| 712 | 
            +
                    if self.padding == 0:
         | 
| 713 | 
            +
                        if self.dims == 2:
         | 
| 714 | 
            +
                            padding = (0, 1, 0, 1)
         | 
| 715 | 
            +
                        else:
         | 
| 716 | 
            +
                            padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)
         | 
| 717 | 
            +
             | 
| 718 | 
            +
                        x = functional.pad(x, padding, mode="constant", value=0)
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                        if self.dims == (2, 1) and not downsample_in_time:
         | 
| 721 | 
            +
                            return conv(x, skip_time_conv=True)
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                    return conv(x)
         | 
| 724 | 
            +
             | 
| 725 | 
            +
             | 
| 726 | 
            +
            class Upsample3D(nn.Module):
         | 
| 727 | 
            +
                """
         | 
| 728 | 
            +
                An upsampling layer for 3D tensors of shape (B, C, D, H, W).
         | 
| 729 | 
            +
             | 
| 730 | 
            +
                :param channels: channels in the inputs and outputs.
         | 
| 731 | 
            +
                """
         | 
| 732 | 
            +
             | 
| 733 | 
            +
                def __init__(self, dims, channels, out_channels=None):
         | 
| 734 | 
            +
                    super().__init__()
         | 
| 735 | 
            +
                    self.dims = dims
         | 
| 736 | 
            +
                    self.channels = channels
         | 
| 737 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 738 | 
            +
                    self.conv = make_conv_nd(dims, channels, out_channels, kernel_size=3, padding=1, bias=True)
         | 
| 739 | 
            +
             | 
| 740 | 
            +
                def forward(self, x, upsample_in_time):
         | 
| 741 | 
            +
                    if self.dims == 2:
         | 
| 742 | 
            +
                        x = functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest")
         | 
| 743 | 
            +
                    else:
         | 
| 744 | 
            +
                        time_scale_factor = 2 if upsample_in_time else 1
         | 
| 745 | 
            +
                        # print("before:", x.shape)
         | 
| 746 | 
            +
                        b, c, d, h, w = x.shape
         | 
| 747 | 
            +
                        x = rearrange(x, "b c d h w -> (b d) c h w")
         | 
| 748 | 
            +
                        # height and width interpolate
         | 
| 749 | 
            +
                        x = functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest")
         | 
| 750 | 
            +
                        _, _, h, w = x.shape
         | 
| 751 | 
            +
             | 
| 752 | 
            +
                        if not upsample_in_time and self.dims == (2, 1):
         | 
| 753 | 
            +
                            x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w)
         | 
| 754 | 
            +
                            return self.conv(x, skip_time_conv=True)
         | 
| 755 | 
            +
             | 
| 756 | 
            +
                        # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension
         | 
| 757 | 
            +
                        x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b)
         | 
| 758 | 
            +
             | 
| 759 | 
            +
                        # (b h w) c 1 d
         | 
| 760 | 
            +
                        new_d = x.shape[-1] * time_scale_factor
         | 
| 761 | 
            +
                        x = functional.interpolate(x, (1, new_d), mode="nearest")
         | 
| 762 | 
            +
                        # (b h w) c 1 new_d
         | 
| 763 | 
            +
                        x = rearrange(x, "(b h w) c 1 new_d  -> b c new_d h w", b=b, h=h, w=w, new_d=new_d)
         | 
| 764 | 
            +
                        # b c d h w
         | 
| 765 | 
            +
             | 
| 766 | 
            +
                        # x = functional.interpolate(
         | 
| 767 | 
            +
                        #     x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
         | 
| 768 | 
            +
                        # )
         | 
| 769 | 
            +
                        # print("after:", x.shape)
         | 
| 770 | 
            +
             | 
| 771 | 
            +
                    return self.conv(x)
         | 
| 772 | 
            +
             | 
| 773 | 
            +
             | 
| 774 | 
            +
            def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
         | 
| 775 | 
            +
                if patch_size_hw == 1 and patch_size_t == 1:
         | 
| 776 | 
            +
                    return x
         | 
| 777 | 
            +
                if x.dim() == 4:
         | 
| 778 | 
            +
                    x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw)
         | 
| 779 | 
            +
                elif x.dim() == 5:
         | 
| 780 | 
            +
                    x = rearrange(x, "b c (f p) (h q) (w r) -> b (c p r q) f h w", p=patch_size_t, q=patch_size_hw, r=patch_size_hw)
         | 
| 781 | 
            +
                else:
         | 
| 782 | 
            +
                    raise ValueError(f"Invalid input shape: {x.shape}")
         | 
| 783 | 
            +
             | 
| 784 | 
            +
                if (x.dim() == 5) and (patch_size_hw > patch_size_t) and (patch_size_t > 1 or add_channel_padding):
         | 
| 785 | 
            +
                    channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
         | 
| 786 | 
            +
                    padding_zeros = torch.zeros(
         | 
| 787 | 
            +
                        x.shape[0],
         | 
| 788 | 
            +
                        channels_to_pad,
         | 
| 789 | 
            +
                        x.shape[2],
         | 
| 790 | 
            +
                        x.shape[3],
         | 
| 791 | 
            +
                        x.shape[4],
         | 
| 792 | 
            +
                        device=x.device,
         | 
| 793 | 
            +
                        dtype=x.dtype,
         | 
| 794 | 
            +
                    )
         | 
| 795 | 
            +
                    x = torch.cat([padding_zeros, x], dim=1)
         | 
| 796 | 
            +
             | 
| 797 | 
            +
                return x
         | 
| 798 | 
            +
             | 
| 799 | 
            +
             | 
| 800 | 
            +
            def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
         | 
| 801 | 
            +
                if patch_size_hw == 1 and patch_size_t == 1:
         | 
| 802 | 
            +
                    return x
         | 
| 803 | 
            +
             | 
| 804 | 
            +
                if (x.dim() == 5) and (patch_size_hw > patch_size_t) and (patch_size_t > 1 or add_channel_padding):
         | 
| 805 | 
            +
                    channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
         | 
| 806 | 
            +
                    x = x[:, :channels_to_keep, :, :, :]
         | 
| 807 | 
            +
             | 
| 808 | 
            +
                if x.dim() == 4:
         | 
| 809 | 
            +
                    x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw)
         | 
| 810 | 
            +
                elif x.dim() == 5:
         | 
| 811 | 
            +
                    x = rearrange(x, "b (c p r q) f h w -> b c (f p) (h q) (w r)", p=patch_size_t, q=patch_size_hw, r=patch_size_hw)
         | 
| 812 | 
            +
             | 
| 813 | 
            +
                return x
         | 
| 814 | 
            +
             | 
| 815 | 
            +
             | 
| 816 | 
            +
            def create_video_autoencoder_config(
         | 
| 817 | 
            +
                latent_channels: int = 4,
         | 
| 818 | 
            +
            ):
         | 
| 819 | 
            +
                config = {
         | 
| 820 | 
            +
                    "_class_name": "VideoAutoencoder",
         | 
| 821 | 
            +
                    "dims": (2, 1),  # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
         | 
| 822 | 
            +
                    "in_channels": 3,  # Number of input color channels (e.g., RGB)
         | 
| 823 | 
            +
                    "out_channels": 3,  # Number of output color channels
         | 
| 824 | 
            +
                    "latent_channels": latent_channels,  # Number of channels in the latent space representation
         | 
| 825 | 
            +
                    "block_out_channels": [128, 256, 512, 512],  # Number of output channels of each encoder / decoder inner block
         | 
| 826 | 
            +
                    "patch_size": 1,
         | 
| 827 | 
            +
                }
         | 
| 828 | 
            +
             | 
| 829 | 
            +
                return config
         | 
| 830 | 
            +
             | 
| 831 | 
            +
             | 
| 832 | 
            +
            def create_video_autoencoder_pathify4x4x4_config(
         | 
| 833 | 
            +
                latent_channels: int = 4,
         | 
| 834 | 
            +
            ):
         | 
| 835 | 
            +
                config = {
         | 
| 836 | 
            +
                    "_class_name": "VideoAutoencoder",
         | 
| 837 | 
            +
                    "dims": (2, 1),  # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
         | 
| 838 | 
            +
                    "in_channels": 3,  # Number of input color channels (e.g., RGB)
         | 
| 839 | 
            +
                    "out_channels": 3,  # Number of output color channels
         | 
| 840 | 
            +
                    "latent_channels": latent_channels,  # Number of channels in the latent space representation
         | 
| 841 | 
            +
                    "block_out_channels": [512] * 4,  # Number of output channels of each encoder / decoder inner block
         | 
| 842 | 
            +
                    "patch_size": 4,
         | 
| 843 | 
            +
                    "latent_log_var": "uniform",
         | 
| 844 | 
            +
                }
         | 
| 845 | 
            +
             | 
| 846 | 
            +
                return config
         | 
| 847 | 
            +
             | 
| 848 | 
            +
             | 
| 849 | 
            +
            def create_video_autoencoder_pathify4x4_config(
         | 
| 850 | 
            +
                latent_channels: int = 4,
         | 
| 851 | 
            +
            ):
         | 
| 852 | 
            +
                config = {
         | 
| 853 | 
            +
                    "_class_name": "VideoAutoencoder",
         | 
| 854 | 
            +
                    "dims": 2,  # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
         | 
| 855 | 
            +
                    "in_channels": 3,  # Number of input color channels (e.g., RGB)
         | 
| 856 | 
            +
                    "out_channels": 3,  # Number of output color channels
         | 
| 857 | 
            +
                    "latent_channels": latent_channels,  # Number of channels in the latent space representation
         | 
| 858 | 
            +
                    "block_out_channels": [512] * 4,  # Number of output channels of each encoder / decoder inner block
         | 
| 859 | 
            +
                    "patch_size": 4,
         | 
| 860 | 
            +
                    "norm_layer": "pixel_norm",
         | 
| 861 | 
            +
                }
         | 
| 862 | 
            +
             | 
| 863 | 
            +
                return config
         | 
| 864 | 
            +
             | 
| 865 | 
            +
             | 
| 866 | 
            +
            def test_vae_patchify_unpatchify():
         | 
| 867 | 
            +
                import torch
         | 
| 868 | 
            +
             | 
| 869 | 
            +
                x = torch.randn(2, 3, 8, 64, 64)
         | 
| 870 | 
            +
                x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
         | 
| 871 | 
            +
                x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
         | 
| 872 | 
            +
                assert torch.allclose(x, x_unpatched)
         | 
| 873 | 
            +
             | 
| 874 | 
            +
             | 
| 875 | 
            +
            def demo_video_autoencoder_forward_backward():
         | 
| 876 | 
            +
                # Configuration for the VideoAutoencoder
         | 
| 877 | 
            +
                config = create_video_autoencoder_pathify4x4x4_config()
         | 
| 878 | 
            +
             | 
| 879 | 
            +
                # Instantiate the VideoAutoencoder with the specified configuration
         | 
| 880 | 
            +
                video_autoencoder = VideoAutoencoder.from_config(config)
         | 
| 881 | 
            +
             | 
| 882 | 
            +
                print(video_autoencoder)
         | 
| 883 | 
            +
             | 
| 884 | 
            +
                # Print the total number of parameters in the video autoencoder
         | 
| 885 | 
            +
                total_params = sum(p.numel() for p in video_autoencoder.parameters())
         | 
| 886 | 
            +
                print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
         | 
| 887 | 
            +
             | 
| 888 | 
            +
                # Create a mock input tensor simulating a batch of videos
         | 
| 889 | 
            +
                # Shape: (batch_size, channels, depth, height, width)
         | 
| 890 | 
            +
                # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
         | 
| 891 | 
            +
                input_videos = torch.randn(2, 3, 8, 64, 64)
         | 
| 892 | 
            +
             | 
| 893 | 
            +
                # Forward pass: encode and decode the input videos
         | 
| 894 | 
            +
                latent = video_autoencoder.encode(input_videos).latent_dist.mode()
         | 
| 895 | 
            +
                print(f"input shape={input_videos.shape}")
         | 
| 896 | 
            +
                print(f"latent shape={latent.shape}")
         | 
| 897 | 
            +
                reconstructed_videos = video_autoencoder.decode(latent, target_shape=input_videos.shape).sample
         | 
| 898 | 
            +
             | 
| 899 | 
            +
                print(f"reconstructed shape={reconstructed_videos.shape}")
         | 
| 900 | 
            +
             | 
| 901 | 
            +
                # Calculate the loss (e.g., mean squared error)
         | 
| 902 | 
            +
                loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
         | 
| 903 | 
            +
             | 
| 904 | 
            +
                # Perform backward pass
         | 
| 905 | 
            +
                loss.backward()
         | 
| 906 | 
            +
             | 
| 907 | 
            +
                print(f"Demo completed with loss: {loss.item()}")
         | 
| 908 | 
            +
             | 
| 909 | 
            +
             | 
| 910 | 
            +
            # Ensure to call the demo function to execute the forward and backward pass
         | 
| 911 | 
            +
            if __name__ == "__main__":
         | 
| 912 | 
            +
                demo_video_autoencoder_forward_backward()
         | 
    	
        xora/models/transformers/embeddings.py
    ADDED
    
    | @@ -0,0 +1,125 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from einops import rearrange
         | 
| 7 | 
            +
            from torch import nn
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def get_timestep_embedding(
         | 
| 11 | 
            +
                timesteps: torch.Tensor,
         | 
| 12 | 
            +
                embedding_dim: int,
         | 
| 13 | 
            +
                flip_sin_to_cos: bool = False,
         | 
| 14 | 
            +
                downscale_freq_shift: float = 1,
         | 
| 15 | 
            +
                scale: float = 1,
         | 
| 16 | 
            +
                max_period: int = 10000,
         | 
| 17 | 
            +
            ):
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                :param timesteps: a 1-D Tensor of N indices, one per batch element.
         | 
| 22 | 
            +
                                  These may be fractional.
         | 
| 23 | 
            +
                :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
         | 
| 24 | 
            +
                embeddings. :return: an [N x dim] Tensor of positional embeddings.
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                half_dim = embedding_dim // 2
         | 
| 29 | 
            +
                exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
         | 
| 30 | 
            +
                exponent = exponent / (half_dim - downscale_freq_shift)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                emb = torch.exp(exponent)
         | 
| 33 | 
            +
                emb = timesteps[:, None].float() * emb[None, :]
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                # scale embeddings
         | 
| 36 | 
            +
                emb = scale * emb
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                # concat sine and cosine embeddings
         | 
| 39 | 
            +
                emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                # flip sine and cosine embeddings
         | 
| 42 | 
            +
                if flip_sin_to_cos:
         | 
| 43 | 
            +
                    emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                # zero pad
         | 
| 46 | 
            +
                if embedding_dim % 2 == 1:
         | 
| 47 | 
            +
                    emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
         | 
| 48 | 
            +
                return emb
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
         | 
| 54 | 
            +
                [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
         | 
| 55 | 
            +
                """
         | 
| 56 | 
            +
                grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w)
         | 
| 57 | 
            +
                grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w)
         | 
| 58 | 
            +
                grid = grid.reshape([3, 1, w, h, f])
         | 
| 59 | 
            +
                pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
         | 
| 60 | 
            +
                pos_embed = pos_embed.transpose(1, 0, 2, 3)
         | 
| 61 | 
            +
                return rearrange(pos_embed, "h w f c -> (f h w) c")
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
         | 
| 65 | 
            +
                if embed_dim % 3 != 0:
         | 
| 66 | 
            +
                    raise ValueError("embed_dim must be divisible by 3")
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                # use half of dimensions to encode grid_h
         | 
| 69 | 
            +
                emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0])  # (H*W*T, D/3)
         | 
| 70 | 
            +
                emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1])  # (H*W*T, D/3)
         | 
| 71 | 
            +
                emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2])  # (H*W*T, D/3)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1)  # (H*W*T, D)
         | 
| 74 | 
            +
                return emb
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
         | 
| 78 | 
            +
                """
         | 
| 79 | 
            +
                embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
         | 
| 80 | 
            +
                """
         | 
| 81 | 
            +
                if embed_dim % 2 != 0:
         | 
| 82 | 
            +
                    raise ValueError("embed_dim must be divisible by 2")
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                omega = np.arange(embed_dim // 2, dtype=np.float64)
         | 
| 85 | 
            +
                omega /= embed_dim / 2.0
         | 
| 86 | 
            +
                omega = 1.0 / 10000**omega  # (D/2,)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                pos_shape = pos.shape
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                pos = pos.reshape(-1)
         | 
| 91 | 
            +
                out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
         | 
| 92 | 
            +
                out = out.reshape([*pos_shape, -1])[0]
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                emb_sin = np.sin(out)  # (M, D/2)
         | 
| 95 | 
            +
                emb_cos = np.cos(out)  # (M, D/2)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                emb = np.concatenate([emb_sin, emb_cos], axis=-1)  # (M, D)
         | 
| 98 | 
            +
                return emb
         | 
| 99 | 
            +
             | 
| 100 | 
            +
             | 
| 101 | 
            +
            class SinusoidalPositionalEmbedding(nn.Module):
         | 
| 102 | 
            +
                """Apply positional information to a sequence of embeddings.
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
         | 
| 105 | 
            +
                them
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                Args:
         | 
| 108 | 
            +
                    embed_dim: (int): Dimension of the positional embedding.
         | 
| 109 | 
            +
                    max_seq_length: Maximum sequence length to apply positional embeddings
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                """
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                def __init__(self, embed_dim: int, max_seq_length: int = 32):
         | 
| 114 | 
            +
                    super().__init__()
         | 
| 115 | 
            +
                    position = torch.arange(max_seq_length).unsqueeze(1)
         | 
| 116 | 
            +
                    div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
         | 
| 117 | 
            +
                    pe = torch.zeros(1, max_seq_length, embed_dim)
         | 
| 118 | 
            +
                    pe[0, :, 0::2] = torch.sin(position * div_term)
         | 
| 119 | 
            +
                    pe[0, :, 1::2] = torch.cos(position * div_term)
         | 
| 120 | 
            +
                    self.register_buffer("pe", pe)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def forward(self, x):
         | 
| 123 | 
            +
                    _, seq_length, _ = x.shape
         | 
| 124 | 
            +
                    x = x + self.pe[:, :seq_length]
         | 
| 125 | 
            +
                    return x
         | 
    	
        xora/models/transformers/transformer3d.py
    CHANGED
    
    | @@ -1,7 +1,7 @@ | |
| 1 | 
             
            # Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
         | 
| 2 | 
             
            import math
         | 
| 3 | 
             
            from dataclasses import dataclass
         | 
| 4 | 
            -
            from typing import Any, Dict, List, Optional
         | 
| 5 |  | 
| 6 | 
             
            import torch
         | 
| 7 | 
             
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| @@ -9,10 +9,13 @@ from diffusers.models.embeddings import PixArtAlphaTextProjection | |
| 9 | 
             
            from diffusers.models.modeling_utils import ModelMixin
         | 
| 10 | 
             
            from diffusers.models.normalization import AdaLayerNormSingle
         | 
| 11 | 
             
            from diffusers.utils import BaseOutput, is_torch_version
         | 
|  | |
| 12 | 
             
            from torch import nn
         | 
| 13 |  | 
| 14 | 
             
            from xora.models.transformers.attention import BasicTransformerBlock
         | 
|  | |
| 15 |  | 
|  | |
| 16 |  | 
| 17 | 
             
            @dataclass
         | 
| 18 | 
             
            class Transformer3DModelOutput(BaseOutput):
         | 
| @@ -143,6 +146,61 @@ class Transformer3DModel(ModelMixin, ConfigMixin): | |
| 143 |  | 
| 144 | 
             
                    self.gradient_checkpointing = False
         | 
| 145 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 146 | 
             
                def _set_gradient_checkpointing(self, module, value=False):
         | 
| 147 | 
             
                    if hasattr(module, "gradient_checkpointing"):
         | 
| 148 | 
             
                        module.gradient_checkpointing = value
         | 
| @@ -287,10 +345,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin): | |
| 287 | 
             
                    if self.timestep_scale_multiplier:
         | 
| 288 | 
             
                        timestep = self.timestep_scale_multiplier * timestep
         | 
| 289 |  | 
| 290 | 
            -
                    if self.positional_embedding_type == " | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 291 | 
             
                        freqs_cis = self.precompute_freqs_cis(indices_grid)
         | 
| 292 | 
            -
                    else:
         | 
| 293 | 
            -
                        raise NotImplementedError("Only rope pos embed supported.")
         | 
| 294 |  | 
| 295 | 
             
                    batch_size = hidden_states.shape[0]
         | 
| 296 | 
             
                    timestep, embedded_timestep = self.adaln_single(
         | 
| @@ -358,3 +420,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin): | |
| 358 |  | 
| 359 | 
             
                    return Transformer3DModelOutput(sample=hidden_states)
         | 
| 360 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            # Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
         | 
| 2 | 
             
            import math
         | 
| 3 | 
             
            from dataclasses import dataclass
         | 
| 4 | 
            +
            from typing import Any, Dict, List, Optional, Literal
         | 
| 5 |  | 
| 6 | 
             
            import torch
         | 
| 7 | 
             
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
|  | |
| 9 | 
             
            from diffusers.models.modeling_utils import ModelMixin
         | 
| 10 | 
             
            from diffusers.models.normalization import AdaLayerNormSingle
         | 
| 11 | 
             
            from diffusers.utils import BaseOutput, is_torch_version
         | 
| 12 | 
            +
            from diffusers.utils import logging
         | 
| 13 | 
             
            from torch import nn
         | 
| 14 |  | 
| 15 | 
             
            from xora.models.transformers.attention import BasicTransformerBlock
         | 
| 16 | 
            +
            from xora.models.transformers.embeddings import get_3d_sincos_pos_embed
         | 
| 17 |  | 
| 18 | 
            +
            logger = logging.get_logger(__name__)
         | 
| 19 |  | 
| 20 | 
             
            @dataclass
         | 
| 21 | 
             
            class Transformer3DModelOutput(BaseOutput):
         | 
|  | |
| 146 |  | 
| 147 | 
             
                    self.gradient_checkpointing = False
         | 
| 148 |  | 
| 149 | 
            +
                def set_use_tpu_flash_attention(self):
         | 
| 150 | 
            +
                    r"""
         | 
| 151 | 
            +
                    Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
         | 
| 152 | 
            +
                    attention kernel.
         | 
| 153 | 
            +
                    """
         | 
| 154 | 
            +
                    logger.info(" ENABLE TPU FLASH ATTENTION -> TRUE")
         | 
| 155 | 
            +
                    # if using TPU -> configure components to use TPU flash attention
         | 
| 156 | 
            +
                    if dist_util.acceleration_type() == dist_util.AccelerationType.TPU:
         | 
| 157 | 
            +
                        self.use_tpu_flash_attention = True
         | 
| 158 | 
            +
                        # push config down to the attention modules
         | 
| 159 | 
            +
                        for block in self.transformer_blocks:
         | 
| 160 | 
            +
                            block.set_use_tpu_flash_attention()
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def initialize(self, embedding_std: float, mode: Literal["xora", "pixart"]):
         | 
| 163 | 
            +
                    def _basic_init(module):
         | 
| 164 | 
            +
                        if isinstance(module, nn.Linear):
         | 
| 165 | 
            +
                            torch.nn.init.xavier_uniform_(module.weight)
         | 
| 166 | 
            +
                            if module.bias is not None:
         | 
| 167 | 
            +
                                nn.init.constant_(module.bias, 0)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    self.apply(_basic_init)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    # Initialize timestep embedding MLP:
         | 
| 172 | 
            +
                    nn.init.normal_(self.adaln_single.emb.timestep_embedder.linear_1.weight, std=embedding_std)
         | 
| 173 | 
            +
                    nn.init.normal_(self.adaln_single.emb.timestep_embedder.linear_2.weight, std=embedding_std)
         | 
| 174 | 
            +
                    nn.init.normal_(self.adaln_single.linear.weight, std=embedding_std)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    if hasattr(self.adaln_single.emb, "resolution_embedder"):
         | 
| 177 | 
            +
                        nn.init.normal_(self.adaln_single.emb.resolution_embedder.linear_1.weight, std=embedding_std)
         | 
| 178 | 
            +
                        nn.init.normal_(self.adaln_single.emb.resolution_embedder.linear_2.weight, std=embedding_std)
         | 
| 179 | 
            +
                    if hasattr(self.adaln_single.emb, "aspect_ratio_embedder"):
         | 
| 180 | 
            +
                        nn.init.normal_(self.adaln_single.emb.aspect_ratio_embedder.linear_1.weight, std=embedding_std)
         | 
| 181 | 
            +
                        nn.init.normal_(self.adaln_single.emb.aspect_ratio_embedder.linear_2.weight, std=embedding_std)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    # Initialize caption embedding MLP:
         | 
| 184 | 
            +
                    nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
         | 
| 185 | 
            +
                    nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    # Zero-out adaLN modulation layers in PixArt blocks:
         | 
| 188 | 
            +
                    for block in self.transformer_blocks:
         | 
| 189 | 
            +
                        if mode == "xora":
         | 
| 190 | 
            +
                            nn.init.constant_(block.attn1.to_out[0].weight, 0)
         | 
| 191 | 
            +
                            nn.init.constant_(block.attn1.to_out[0].bias, 0)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                        nn.init.constant_(block.attn2.to_out[0].weight, 0)
         | 
| 194 | 
            +
                        nn.init.constant_(block.attn2.to_out[0].bias, 0)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                        if mode == "xora":
         | 
| 197 | 
            +
                            nn.init.constant_(block.ff.net[2].weight, 0)
         | 
| 198 | 
            +
                            nn.init.constant_(block.ff.net[2].bias, 0)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    # Zero-out output layers:
         | 
| 201 | 
            +
                    nn.init.constant_(self.proj_out.weight, 0)
         | 
| 202 | 
            +
                    nn.init.constant_(self.proj_out.bias, 0)
         | 
| 203 | 
            +
             | 
| 204 | 
             
                def _set_gradient_checkpointing(self, module, value=False):
         | 
| 205 | 
             
                    if hasattr(module, "gradient_checkpointing"):
         | 
| 206 | 
             
                        module.gradient_checkpointing = value
         | 
|  | |
| 345 | 
             
                    if self.timestep_scale_multiplier:
         | 
| 346 | 
             
                        timestep = self.timestep_scale_multiplier * timestep
         | 
| 347 |  | 
| 348 | 
            +
                    if self.positional_embedding_type == "absolute":
         | 
| 349 | 
            +
                        pos_embed_3d = self.get_absolute_pos_embed(indices_grid).to(hidden_states.device)
         | 
| 350 | 
            +
                        if self.project_to_2d_pos:
         | 
| 351 | 
            +
                            pos_embed = self.to_2d_proj(pos_embed_3d)
         | 
| 352 | 
            +
                        hidden_states = (hidden_states + pos_embed).to(hidden_states.dtype)
         | 
| 353 | 
            +
                        freqs_cis = None
         | 
| 354 | 
            +
                    elif self.positional_embedding_type == "rope":
         | 
| 355 | 
             
                        freqs_cis = self.precompute_freqs_cis(indices_grid)
         | 
|  | |
|  | |
| 356 |  | 
| 357 | 
             
                    batch_size = hidden_states.shape[0]
         | 
| 358 | 
             
                    timestep, embedded_timestep = self.adaln_single(
         | 
|  | |
| 420 |  | 
| 421 | 
             
                    return Transformer3DModelOutput(sample=hidden_states)
         | 
| 422 |  | 
| 423 | 
            +
                def get_absolute_pos_embed(self, grid):
         | 
| 424 | 
            +
                    grid_np = grid[0].cpu().numpy()
         | 
| 425 | 
            +
                    embed_dim_3d = math.ceil((self.inner_dim / 2) * 3) if self.project_to_2d_pos else self.inner_dim
         | 
| 426 | 
            +
                    pos_embed = get_3d_sincos_pos_embed(  # (f h w)
         | 
| 427 | 
            +
                        embed_dim_3d,
         | 
| 428 | 
            +
                        grid_np,
         | 
| 429 | 
            +
                        h=int(max(grid_np[1]) + 1),
         | 
| 430 | 
            +
                        w=int(max(grid_np[2]) + 1),
         | 
| 431 | 
            +
                        f=int(max(grid_np[0] + 1)),
         | 
| 432 | 
            +
                    )
         | 
| 433 | 
            +
                    return torch.from_numpy(pos_embed).float().unsqueeze(0)
         | 
    	
        xora/pipelines/pipeline_video_pixart_alpha.py
    CHANGED
    
    | @@ -32,16 +32,106 @@ from xora.models.transformers.symmetric_patchifier import Patchifier | |
| 32 | 
             
            from xora.models.autoencoders.vae_encode import get_vae_size_scale_factor, vae_decode, vae_encode
         | 
| 33 | 
             
            from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
         | 
| 34 | 
             
            from xora.schedulers.rf import TimestepShifter
         | 
|  | |
| 35 |  | 
| 36 | 
             
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 37 |  | 
| 38 | 
            -
             | 
| 39 | 
             
            if is_bs4_available():
         | 
| 40 | 
             
                from bs4 import BeautifulSoup
         | 
| 41 |  | 
| 42 | 
             
            if is_ftfy_available():
         | 
| 43 | 
             
                import ftfy
         | 
| 44 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 45 | 
             
            def retrieve_timesteps(
         | 
| 46 | 
             
                scheduler,
         | 
| 47 | 
             
                num_inference_steps: Optional[int] = None,
         | 
| @@ -520,14 +610,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline): | |
| 520 |  | 
| 521 | 
             
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
         | 
| 522 | 
             
                def prepare_latents(
         | 
| 523 | 
            -
                    self,
         | 
| 524 | 
            -
                    batch_size,
         | 
| 525 | 
            -
                    num_latent_channels,
         | 
| 526 | 
            -
                    num_patches,
         | 
| 527 | 
            -
                    dtype,
         | 
| 528 | 
            -
                    device,
         | 
| 529 | 
            -
                    generator,
         | 
| 530 | 
            -
                    latents=None,
         | 
| 531 | 
             
                ):
         | 
| 532 | 
             
                    shape = (
         | 
| 533 | 
             
                        batch_size,
         | 
| @@ -543,6 +626,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline): | |
| 543 |  | 
| 544 | 
             
                    if latents is None:
         | 
| 545 | 
             
                        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
         | 
|  | |
|  | |
|  | |
| 546 | 
             
                    else:
         | 
| 547 | 
             
                        latents = latents.to(device)
         | 
| 548 |  | 
| @@ -582,8 +668,8 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline): | |
| 582 |  | 
| 583 | 
             
                    return samples
         | 
| 584 |  | 
| 585 | 
            -
             | 
| 586 | 
             
                @torch.no_grad()
         | 
|  | |
| 587 | 
             
                def __call__(
         | 
| 588 | 
             
                    self,
         | 
| 589 | 
             
                    height: int,
         | 
| @@ -607,6 +693,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline): | |
| 607 | 
             
                    return_dict: bool = True,
         | 
| 608 | 
             
                    callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
         | 
| 609 | 
             
                    clean_caption: bool = True,
         | 
|  | |
| 610 | 
             
                    **kwargs,
         | 
| 611 | 
             
                ) -> Union[ImagePipelineOutput, Tuple]:
         | 
| 612 | 
             
                    """
         | 
| @@ -736,8 +823,15 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline): | |
| 736 | 
             
                        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
         | 
| 737 | 
             
                        prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
         | 
| 738 |  | 
| 739 | 
            -
                    #  | 
| 740 | 
             
                    self.video_scale_factor = self.video_scale_factor if is_video else 1
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 741 | 
             
                    latent_height = height // self.vae_scale_factor
         | 
| 742 | 
             
                    latent_width = width // self.vae_scale_factor
         | 
| 743 | 
             
                    latent_num_frames = num_frames // self.video_scale_factor
         | 
| @@ -752,7 +846,12 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline): | |
| 752 | 
             
                        dtype=prompt_embeds.dtype,
         | 
| 753 | 
             
                        device=device,
         | 
| 754 | 
             
                        generator=generator,
         | 
|  | |
|  | |
| 755 | 
             
                    )
         | 
|  | |
|  | |
|  | |
| 756 |  | 
| 757 | 
             
                    # 5. Prepare timesteps
         | 
| 758 | 
             
                    retrieve_timesteps_kwargs = {}
         | 
| @@ -790,7 +889,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline): | |
| 790 | 
             
                            elif len(current_timestep.shape) == 0:
         | 
| 791 | 
             
                                current_timestep = current_timestep[None].to(latent_model_input.device)
         | 
| 792 | 
             
                            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         | 
| 793 | 
            -
                            current_timestep = current_timestep.expand(latent_model_input.shape[0])
         | 
| 794 | 
             
                            scale_grid = (
         | 
| 795 | 
             
                                (1 / latent_frame_rates, self.vae_scale_factor, self.vae_scale_factor)
         | 
| 796 | 
             
                                if self.transformer.use_rope
         | 
| @@ -805,6 +904,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline): | |
| 805 | 
             
                                device=latents.device,
         | 
| 806 | 
             
                            )
         | 
| 807 |  | 
|  | |
|  | |
|  | |
| 808 | 
             
                            # predict noise model_output
         | 
| 809 | 
             
                            noise_pred = self.transformer(
         | 
| 810 | 
             
                                latent_model_input.to(self.transformer.dtype),
         | 
| @@ -819,13 +921,20 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline): | |
| 819 | 
             
                            if do_classifier_free_guidance:
         | 
| 820 | 
             
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 821 | 
             
                                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
|  | |
| 822 |  | 
| 823 | 
             
                            # learned sigma
         | 
| 824 | 
             
                            if self.transformer.config.out_channels // 2 == self.transformer.config.in_channels:
         | 
| 825 | 
             
                                noise_pred = noise_pred.chunk(2, dim=1)[0]
         | 
| 826 |  | 
| 827 | 
             
                            # compute previous image: x_t -> x_t-1
         | 
| 828 | 
            -
                            latents = self.scheduler.step( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 829 |  | 
| 830 | 
             
                            # call the callback, if provided
         | 
| 831 | 
             
                            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         | 
| @@ -857,3 +966,62 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline): | |
| 857 | 
             
                        return (image,)
         | 
| 858 |  | 
| 859 | 
             
                    return ImagePipelineOutput(images=image)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 32 | 
             
            from xora.models.autoencoders.vae_encode import get_vae_size_scale_factor, vae_decode, vae_encode
         | 
| 33 | 
             
            from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
         | 
| 34 | 
             
            from xora.schedulers.rf import TimestepShifter
         | 
| 35 | 
            +
            from xora.utils.conditioning_method import ConditioningMethod
         | 
| 36 |  | 
| 37 | 
             
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 38 |  | 
|  | |
| 39 | 
             
            if is_bs4_available():
         | 
| 40 | 
             
                from bs4 import BeautifulSoup
         | 
| 41 |  | 
| 42 | 
             
            if is_ftfy_available():
         | 
| 43 | 
             
                import ftfy
         | 
| 44 |  | 
| 45 | 
            +
            EXAMPLE_DOC_STRING = """
         | 
| 46 | 
            +
                Examples:
         | 
| 47 | 
            +
                    ```py
         | 
| 48 | 
            +
                    >>> import torch
         | 
| 49 | 
            +
                    >>> from diffusers import PixArtAlphaPipeline
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
         | 
| 52 | 
            +
                    >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
         | 
| 53 | 
            +
                    >>> # Enable memory optimizations.
         | 
| 54 | 
            +
                    >>> pipe.enable_model_cpu_offload()
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    >>> prompt = "A small cactus with a happy face in the Sahara desert."
         | 
| 57 | 
            +
                    >>> image = pipe(prompt).images[0]
         | 
| 58 | 
            +
                    ```
         | 
| 59 | 
            +
            """
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            ASPECT_RATIO_1024_BIN = {
         | 
| 62 | 
            +
                "0.25": [512.0, 2048.0],
         | 
| 63 | 
            +
                "0.28": [512.0, 1856.0],
         | 
| 64 | 
            +
                "0.32": [576.0, 1792.0],
         | 
| 65 | 
            +
                "0.33": [576.0, 1728.0],
         | 
| 66 | 
            +
                "0.35": [576.0, 1664.0],
         | 
| 67 | 
            +
                "0.4": [640.0, 1600.0],
         | 
| 68 | 
            +
                "0.42": [640.0, 1536.0],
         | 
| 69 | 
            +
                "0.48": [704.0, 1472.0],
         | 
| 70 | 
            +
                "0.5": [704.0, 1408.0],
         | 
| 71 | 
            +
                "0.52": [704.0, 1344.0],
         | 
| 72 | 
            +
                "0.57": [768.0, 1344.0],
         | 
| 73 | 
            +
                "0.6": [768.0, 1280.0],
         | 
| 74 | 
            +
                "0.68": [832.0, 1216.0],
         | 
| 75 | 
            +
                "0.72": [832.0, 1152.0],
         | 
| 76 | 
            +
                "0.78": [896.0, 1152.0],
         | 
| 77 | 
            +
                "0.82": [896.0, 1088.0],
         | 
| 78 | 
            +
                "0.88": [960.0, 1088.0],
         | 
| 79 | 
            +
                "0.94": [960.0, 1024.0],
         | 
| 80 | 
            +
                "1.0": [1024.0, 1024.0],
         | 
| 81 | 
            +
                "1.07": [1024.0, 960.0],
         | 
| 82 | 
            +
                "1.13": [1088.0, 960.0],
         | 
| 83 | 
            +
                "1.21": [1088.0, 896.0],
         | 
| 84 | 
            +
                "1.29": [1152.0, 896.0],
         | 
| 85 | 
            +
                "1.38": [1152.0, 832.0],
         | 
| 86 | 
            +
                "1.46": [1216.0, 832.0],
         | 
| 87 | 
            +
                "1.67": [1280.0, 768.0],
         | 
| 88 | 
            +
                "1.75": [1344.0, 768.0],
         | 
| 89 | 
            +
                "2.0": [1408.0, 704.0],
         | 
| 90 | 
            +
                "2.09": [1472.0, 704.0],
         | 
| 91 | 
            +
                "2.4": [1536.0, 640.0],
         | 
| 92 | 
            +
                "2.5": [1600.0, 640.0],
         | 
| 93 | 
            +
                "3.0": [1728.0, 576.0],
         | 
| 94 | 
            +
                "4.0": [2048.0, 512.0],
         | 
| 95 | 
            +
            }
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            ASPECT_RATIO_512_BIN = {
         | 
| 98 | 
            +
                "0.25": [256.0, 1024.0],
         | 
| 99 | 
            +
                "0.28": [256.0, 928.0],
         | 
| 100 | 
            +
                "0.32": [288.0, 896.0],
         | 
| 101 | 
            +
                "0.33": [288.0, 864.0],
         | 
| 102 | 
            +
                "0.35": [288.0, 832.0],
         | 
| 103 | 
            +
                "0.4": [320.0, 800.0],
         | 
| 104 | 
            +
                "0.42": [320.0, 768.0],
         | 
| 105 | 
            +
                "0.48": [352.0, 736.0],
         | 
| 106 | 
            +
                "0.5": [352.0, 704.0],
         | 
| 107 | 
            +
                "0.52": [352.0, 672.0],
         | 
| 108 | 
            +
                "0.57": [384.0, 672.0],
         | 
| 109 | 
            +
                "0.6": [384.0, 640.0],
         | 
| 110 | 
            +
                "0.68": [416.0, 608.0],
         | 
| 111 | 
            +
                "0.72": [416.0, 576.0],
         | 
| 112 | 
            +
                "0.78": [448.0, 576.0],
         | 
| 113 | 
            +
                "0.82": [448.0, 544.0],
         | 
| 114 | 
            +
                "0.88": [480.0, 544.0],
         | 
| 115 | 
            +
                "0.94": [480.0, 512.0],
         | 
| 116 | 
            +
                "1.0": [512.0, 512.0],
         | 
| 117 | 
            +
                "1.07": [512.0, 480.0],
         | 
| 118 | 
            +
                "1.13": [544.0, 480.0],
         | 
| 119 | 
            +
                "1.21": [544.0, 448.0],
         | 
| 120 | 
            +
                "1.29": [576.0, 448.0],
         | 
| 121 | 
            +
                "1.38": [576.0, 416.0],
         | 
| 122 | 
            +
                "1.46": [608.0, 416.0],
         | 
| 123 | 
            +
                "1.67": [640.0, 384.0],
         | 
| 124 | 
            +
                "1.75": [672.0, 384.0],
         | 
| 125 | 
            +
                "2.0": [704.0, 352.0],
         | 
| 126 | 
            +
                "2.09": [736.0, 352.0],
         | 
| 127 | 
            +
                "2.4": [768.0, 320.0],
         | 
| 128 | 
            +
                "2.5": [800.0, 320.0],
         | 
| 129 | 
            +
                "3.0": [864.0, 288.0],
         | 
| 130 | 
            +
                "4.0": [1024.0, 256.0],
         | 
| 131 | 
            +
            }
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
         | 
| 135 | 
             
            def retrieve_timesteps(
         | 
| 136 | 
             
                scheduler,
         | 
| 137 | 
             
                num_inference_steps: Optional[int] = None,
         | 
|  | |
| 610 |  | 
| 611 | 
             
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
         | 
| 612 | 
             
                def prepare_latents(
         | 
| 613 | 
            +
                    self, batch_size, num_latent_channels, num_patches, dtype, device, generator, latents=None, latents_mask=None
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 614 | 
             
                ):
         | 
| 615 | 
             
                    shape = (
         | 
| 616 | 
             
                        batch_size,
         | 
|  | |
| 626 |  | 
| 627 | 
             
                    if latents is None:
         | 
| 628 | 
             
                        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
         | 
| 629 | 
            +
                    elif latents_mask is not None:
         | 
| 630 | 
            +
                        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
         | 
| 631 | 
            +
                        latents = latents * latents_mask[..., None] + noise * (1 - latents_mask[..., None])
         | 
| 632 | 
             
                    else:
         | 
| 633 | 
             
                        latents = latents.to(device)
         | 
| 634 |  | 
|  | |
| 668 |  | 
| 669 | 
             
                    return samples
         | 
| 670 |  | 
|  | |
| 671 | 
             
                @torch.no_grad()
         | 
| 672 | 
            +
                @replace_example_docstring(EXAMPLE_DOC_STRING)
         | 
| 673 | 
             
                def __call__(
         | 
| 674 | 
             
                    self,
         | 
| 675 | 
             
                    height: int,
         | 
|  | |
| 693 | 
             
                    return_dict: bool = True,
         | 
| 694 | 
             
                    callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
         | 
| 695 | 
             
                    clean_caption: bool = True,
         | 
| 696 | 
            +
                    media_items: Optional[torch.FloatTensor] = None,
         | 
| 697 | 
             
                    **kwargs,
         | 
| 698 | 
             
                ) -> Union[ImagePipelineOutput, Tuple]:
         | 
| 699 | 
             
                    """
         | 
|  | |
| 823 | 
             
                        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
         | 
| 824 | 
             
                        prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
         | 
| 825 |  | 
| 826 | 
            +
                    # 3b. Encode and prepare conditioning data
         | 
| 827 | 
             
                    self.video_scale_factor = self.video_scale_factor if is_video else 1
         | 
| 828 | 
            +
                    conditioning_method = kwargs.get("conditioning_method", None)
         | 
| 829 | 
            +
                    vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", False)
         | 
| 830 | 
            +
                    init_latents, conditioning_mask = self.prepare_conditioning(
         | 
| 831 | 
            +
                        media_items, num_frames, height, width, conditioning_method, vae_per_channel_normalize
         | 
| 832 | 
            +
                    )
         | 
| 833 | 
            +
             | 
| 834 | 
            +
                    # 4. Prepare latents.
         | 
| 835 | 
             
                    latent_height = height // self.vae_scale_factor
         | 
| 836 | 
             
                    latent_width = width // self.vae_scale_factor
         | 
| 837 | 
             
                    latent_num_frames = num_frames // self.video_scale_factor
         | 
|  | |
| 846 | 
             
                        dtype=prompt_embeds.dtype,
         | 
| 847 | 
             
                        device=device,
         | 
| 848 | 
             
                        generator=generator,
         | 
| 849 | 
            +
                        latents=init_latents,
         | 
| 850 | 
            +
                        latents_mask=conditioning_mask,
         | 
| 851 | 
             
                    )
         | 
| 852 | 
            +
                    if conditioning_mask is not None and is_video:
         | 
| 853 | 
            +
                        assert num_images_per_prompt == 1
         | 
| 854 | 
            +
                        conditioning_mask = torch.cat([conditioning_mask] * 2) if do_classifier_free_guidance else conditioning_mask
         | 
| 855 |  | 
| 856 | 
             
                    # 5. Prepare timesteps
         | 
| 857 | 
             
                    retrieve_timesteps_kwargs = {}
         | 
|  | |
| 889 | 
             
                            elif len(current_timestep.shape) == 0:
         | 
| 890 | 
             
                                current_timestep = current_timestep[None].to(latent_model_input.device)
         | 
| 891 | 
             
                            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         | 
| 892 | 
            +
                            current_timestep = current_timestep.expand(latent_model_input.shape[0]).unsqueeze(-1)
         | 
| 893 | 
             
                            scale_grid = (
         | 
| 894 | 
             
                                (1 / latent_frame_rates, self.vae_scale_factor, self.vae_scale_factor)
         | 
| 895 | 
             
                                if self.transformer.use_rope
         | 
|  | |
| 904 | 
             
                                device=latents.device,
         | 
| 905 | 
             
                            )
         | 
| 906 |  | 
| 907 | 
            +
                            if conditioning_mask is not None:
         | 
| 908 | 
            +
                                current_timestep = current_timestep * (1 - conditioning_mask)
         | 
| 909 | 
            +
             | 
| 910 | 
             
                            # predict noise model_output
         | 
| 911 | 
             
                            noise_pred = self.transformer(
         | 
| 912 | 
             
                                latent_model_input.to(self.transformer.dtype),
         | 
|  | |
| 921 | 
             
                            if do_classifier_free_guidance:
         | 
| 922 | 
             
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 923 | 
             
                                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
| 924 | 
            +
                                current_timestep, _ = current_timestep.chunk(2)
         | 
| 925 |  | 
| 926 | 
             
                            # learned sigma
         | 
| 927 | 
             
                            if self.transformer.config.out_channels // 2 == self.transformer.config.in_channels:
         | 
| 928 | 
             
                                noise_pred = noise_pred.chunk(2, dim=1)[0]
         | 
| 929 |  | 
| 930 | 
             
                            # compute previous image: x_t -> x_t-1
         | 
| 931 | 
            +
                            latents = self.scheduler.step(
         | 
| 932 | 
            +
                                noise_pred,
         | 
| 933 | 
            +
                                t if current_timestep is None else current_timestep,
         | 
| 934 | 
            +
                                latents,
         | 
| 935 | 
            +
                                **extra_step_kwargs,
         | 
| 936 | 
            +
                                return_dict=False,
         | 
| 937 | 
            +
                            )[0]
         | 
| 938 |  | 
| 939 | 
             
                            # call the callback, if provided
         | 
| 940 | 
             
                            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         | 
|  | |
| 966 | 
             
                        return (image,)
         | 
| 967 |  | 
| 968 | 
             
                    return ImagePipelineOutput(images=image)
         | 
| 969 | 
            +
             | 
| 970 | 
            +
                def prepare_conditioning(
         | 
| 971 | 
            +
                    self,
         | 
| 972 | 
            +
                    media_items: torch.Tensor,
         | 
| 973 | 
            +
                    num_frames: int,
         | 
| 974 | 
            +
                    height: int,
         | 
| 975 | 
            +
                    width: int,
         | 
| 976 | 
            +
                    method: ConditioningMethod = ConditioningMethod.UNCONDITIONAL,
         | 
| 977 | 
            +
                    vae_per_channel_normalize: bool = False,
         | 
| 978 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 979 | 
            +
                    """
         | 
| 980 | 
            +
                    Prepare the conditioning data for the video generation. If an input media item is provided, encode it
         | 
| 981 | 
            +
                    and set the conditioning_mask to indicate which tokens to condition on. Input media item should have
         | 
| 982 | 
            +
                    the same height and width as the generated video.
         | 
| 983 | 
            +
             | 
| 984 | 
            +
                    Args:
         | 
| 985 | 
            +
                        media_items (torch.Tensor): media items to condition on (images or videos)
         | 
| 986 | 
            +
                        num_frames (int): number of frames to generate
         | 
| 987 | 
            +
                        height (int): height of the generated video
         | 
| 988 | 
            +
                        width (int): width of the generated video
         | 
| 989 | 
            +
                        method (ConditioningMethod, optional): conditioning method to use. Defaults to ConditioningMethod.UNCONDITIONAL.
         | 
| 990 | 
            +
                        vae_per_channel_normalize (bool, optional): whether to normalize the input to the VAE per channel. Defaults to False.
         | 
| 991 | 
            +
             | 
| 992 | 
            +
                    Returns:
         | 
| 993 | 
            +
                        Tuple[torch.Tensor, torch.Tensor]: the conditioning latents and the conditioning mask
         | 
| 994 | 
            +
                    """
         | 
| 995 | 
            +
                    if media_items is None or method == ConditioningMethod.UNCONDITIONAL:
         | 
| 996 | 
            +
                        return None, None
         | 
| 997 | 
            +
             | 
| 998 | 
            +
                    assert media_items.ndim == 5
         | 
| 999 | 
            +
                    assert height == media_items.shape[-2] and width == media_items.shape[-1]
         | 
| 1000 | 
            +
             | 
| 1001 | 
            +
                    # Encode the input video and repeat to the required number of frame-tokens
         | 
| 1002 | 
            +
                    init_latents = vae_encode(
         | 
| 1003 | 
            +
                        media_items.to(dtype=self.vae.dtype, device=self.vae.device),
         | 
| 1004 | 
            +
                        self.vae,
         | 
| 1005 | 
            +
                        vae_per_channel_normalize=vae_per_channel_normalize,
         | 
| 1006 | 
            +
                    ).float()
         | 
| 1007 | 
            +
             | 
| 1008 | 
            +
                    init_len, target_len = init_latents.shape[2], num_frames // self.video_scale_factor
         | 
| 1009 | 
            +
                    if isinstance(self.vae, CausalVideoAutoencoder):
         | 
| 1010 | 
            +
                        target_len += 1
         | 
| 1011 | 
            +
                    init_latents = init_latents[:, :, :target_len]
         | 
| 1012 | 
            +
                    if target_len > init_len:
         | 
| 1013 | 
            +
                        repeat_factor = (target_len + init_len - 1) // init_len  # Ceiling division
         | 
| 1014 | 
            +
                        init_latents = init_latents.repeat(1, 1, repeat_factor, 1, 1)[:, :, :target_len]
         | 
| 1015 | 
            +
             | 
| 1016 | 
            +
                    # Prepare the conditioning mask (1.0 = condition on this token)
         | 
| 1017 | 
            +
                    b, n, f, h, w = init_latents.shape
         | 
| 1018 | 
            +
                    conditioning_mask = torch.zeros([b, 1, f, h, w], device=init_latents.device)
         | 
| 1019 | 
            +
                    if method in [ConditioningMethod.FIRST_FRAME, ConditioningMethod.FIRST_AND_LAST_FRAME]:
         | 
| 1020 | 
            +
                        conditioning_mask[:, :, 0] = 1.0
         | 
| 1021 | 
            +
                    if method in [ConditioningMethod.LAST_FRAME, ConditioningMethod.FIRST_AND_LAST_FRAME]:
         | 
| 1022 | 
            +
                        conditioning_mask[:, :, -1] = 1.0
         | 
| 1023 | 
            +
             | 
| 1024 | 
            +
                    # Patchify the init latents and the mask
         | 
| 1025 | 
            +
                    conditioning_mask = self.patchifier.patchify(conditioning_mask).squeeze(-1)
         | 
| 1026 | 
            +
                    init_latents = self.patchifier.patchify(latents=init_latents)
         | 
| 1027 | 
            +
                    return init_latents, conditioning_mask
         | 
    	
        xora/schedulers/rf.py
    CHANGED
    
    | @@ -9,7 +9,7 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin | |
| 9 | 
             
            from diffusers.utils import BaseOutput
         | 
| 10 | 
             
            from torch import Tensor
         | 
| 11 |  | 
| 12 | 
            -
            from  | 
| 13 |  | 
| 14 |  | 
| 15 | 
             
            def simple_diffusion_resolution_dependent_timestep_shift(
         | 
| @@ -199,8 +199,17 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter): | |
| 199 | 
             
                            "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
         | 
| 200 | 
             
                        )
         | 
| 201 |  | 
| 202 | 
            -
                     | 
| 203 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 204 |  | 
| 205 | 
             
                    prev_sample = sample - dt * model_output
         | 
| 206 |  | 
| @@ -219,4 +228,4 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter): | |
| 219 | 
             
                    sigmas = append_dims(sigmas, original_samples.ndim)
         | 
| 220 | 
             
                    alphas = 1 - sigmas
         | 
| 221 | 
             
                    noisy_samples = alphas * original_samples + sigmas * noise
         | 
| 222 | 
            -
                    return noisy_samples
         | 
|  | |
| 9 | 
             
            from diffusers.utils import BaseOutput
         | 
| 10 | 
             
            from torch import Tensor
         | 
| 11 |  | 
| 12 | 
            +
            from txt2img.common.torch_utils import append_dims
         | 
| 13 |  | 
| 14 |  | 
| 15 | 
             
            def simple_diffusion_resolution_dependent_timestep_shift(
         | 
|  | |
| 199 | 
             
                            "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
         | 
| 200 | 
             
                        )
         | 
| 201 |  | 
| 202 | 
            +
                    if timestep.ndim == 0:
         | 
| 203 | 
            +
                        # Global timestep
         | 
| 204 | 
            +
                        current_index = (self.timesteps - timestep).abs().argmin()
         | 
| 205 | 
            +
                        dt = self.delta_timesteps.gather(0, current_index.unsqueeze(0))
         | 
| 206 | 
            +
                    else:
         | 
| 207 | 
            +
                        # Timestep per token
         | 
| 208 | 
            +
                        assert timestep.ndim == 2
         | 
| 209 | 
            +
                        current_index = (self.timesteps[:, None, None] - timestep[None]).abs().argmin(dim=0)
         | 
| 210 | 
            +
                        dt = self.delta_timesteps[current_index]
         | 
| 211 | 
            +
                        # Special treatment for zero timestep tokens - set dt to 0 so prev_sample = sample
         | 
| 212 | 
            +
                        dt = torch.where(timestep == 0.0, torch.zeros_like(dt), dt)[..., None]
         | 
| 213 |  | 
| 214 | 
             
                    prev_sample = sample - dt * model_output
         | 
| 215 |  | 
|  | |
| 228 | 
             
                    sigmas = append_dims(sigmas, original_samples.ndim)
         | 
| 229 | 
             
                    alphas = 1 - sigmas
         | 
| 230 | 
             
                    noisy_samples = alphas * original_samples + sigmas * noise
         | 
| 231 | 
            +
                    return noisy_samples
         | 
    	
        xora/utils/conditioning_method.py
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from enum import Enum
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            class ConditioningMethod(Enum):
         | 
| 4 | 
            +
                UNCONDITIONAL = "unconditional"
         | 
| 5 | 
            +
                FIRST_FRAME = "first_frame"
         | 
| 6 | 
            +
                LAST_FRAME = "last_frame"
         | 
| 7 | 
            +
                FIRST_AND_LAST_FRAME = "first_and_last_frame"
         | 
    	
        xora/utils/dist_util.py
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from enum import Enum
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            class AccelerationType(Enum):
         | 
| 4 | 
            +
                CPU = "cpu"
         | 
| 5 | 
            +
                GPU = "gpu"
         | 
| 6 | 
            +
                TPU = "tpu"
         | 
| 7 | 
            +
                MPS = "mps"
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            def execute_graph() -> None:
         | 
| 10 | 
            +
                if _acceleration_type == AccelerationType.TPU:
         | 
| 11 | 
            +
                    xm.mark_step()
         | 
