Spaces:
Running
on
Zero
Running
on
Zero
from abc import abstractmethod | |
from typing import Optional, Any, Dict | |
import torch | |
from modules.NeuralNetwork import transformer | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from modules.Attention import Attention | |
from modules.cond import cast | |
from modules.sample import sampling_util | |
oai_ops = cast.disable_weight_init | |
class TimestepBlock1(nn.Module): | |
"""#### Abstract class representing a timestep block.""" | |
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
"""#### Forward pass for the timestep block. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `emb` (torch.Tensor): The embedding tensor. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
pass | |
def forward_timestep_embed1( | |
ts: nn.ModuleList, | |
x: torch.Tensor, | |
emb: torch.Tensor, | |
context: Optional[torch.Tensor] = None, | |
transformer_options: Optional[Dict[str, Any]] = {}, | |
output_shape: Optional[torch.Size] = None, | |
time_context: Optional[torch.Tensor] = None, | |
num_video_frames: Optional[int] = None, | |
image_only_indicator: Optional[bool] = None, | |
) -> torch.Tensor: | |
"""#### Forward pass for timestep embedding. | |
#### Args: | |
- `ts` (nn.ModuleList): The list of timestep blocks. | |
- `x` (torch.Tensor): The input tensor. | |
- `emb` (torch.Tensor): The embedding tensor. | |
- `context` (torch.Tensor, optional): The context tensor. Defaults to None. | |
- `transformer_options` (dict, optional): The transformer options. Defaults to {}. | |
- `output_shape` (torch.Size, optional): The output shape. Defaults to None. | |
- `time_context` (torch.Tensor, optional): The time context tensor. Defaults to None. | |
- `num_video_frames` (int, optional): The number of video frames. Defaults to None. | |
- `image_only_indicator` (bool, optional): The image only indicator. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
for layer in ts: | |
if isinstance(layer, TimestepBlock1): | |
x = layer(x, emb) | |
elif isinstance(layer, transformer.SpatialTransformer): | |
x = layer(x, context, transformer_options) | |
if "transformer_index" in transformer_options: | |
transformer_options["transformer_index"] += 1 | |
elif isinstance(layer, Upsample1): | |
x = layer(x, output_shape=output_shape) | |
else: | |
x = layer(x) | |
return x | |
class Upsample1(nn.Module): | |
"""#### Class representing an upsample layer.""" | |
def __init__( | |
self, | |
channels: int, | |
use_conv: bool, | |
dims: int = 2, | |
out_channels: Optional[int] = None, | |
padding: int = 1, | |
dtype: Optional[torch.dtype] = None, | |
device: Optional[torch.device] = None, | |
operations: Any = oai_ops, | |
): | |
"""#### Initialize the upsample layer. | |
#### Args: | |
- `channels` (int): The number of input channels. | |
- `use_conv` (bool): Whether to use convolution. | |
- `dims` (int, optional): The number of dimensions. Defaults to 2. | |
- `out_channels` (int, optional): The number of output channels. Defaults to None. | |
- `padding` (int, optional): The padding size. Defaults to 1. | |
- `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
- `device` (torch.device, optional): The device. Defaults to None. | |
- `operations` (any, optional): The operations. Defaults to oai_ops. | |
""" | |
super().__init__() | |
self.channels = channels | |
self.out_channels = out_channels or channels | |
self.use_conv = use_conv | |
self.dims = dims | |
if use_conv: | |
self.conv = operations.conv_nd( | |
dims, | |
self.channels, | |
self.out_channels, | |
3, | |
padding=padding, | |
dtype=dtype, | |
device=device, | |
) | |
def forward( | |
self, x: torch.Tensor, output_shape: Optional[torch.Size] = None | |
) -> torch.Tensor: | |
"""#### Forward pass for the upsample layer. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `output_shape` (torch.Size, optional): The output shape. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
assert x.shape[1] == self.channels | |
shape = [x.shape[2] * 2, x.shape[3] * 2] | |
if output_shape is not None: | |
shape[0] = output_shape[2] | |
shape[1] = output_shape[3] | |
x = F.interpolate(x, size=shape, mode="nearest") | |
if self.use_conv: | |
x = self.conv(x) | |
return x | |
class Downsample1(nn.Module): | |
"""#### Class representing a downsample layer.""" | |
def __init__( | |
self, | |
channels: int, | |
use_conv: bool, | |
dims: int = 2, | |
out_channels: Optional[int] = None, | |
padding: int = 1, | |
dtype: Optional[torch.dtype] = None, | |
device: Optional[torch.device] = None, | |
operations: Any = oai_ops, | |
): | |
"""#### Initialize the downsample layer. | |
#### Args: | |
- `channels` (int): The number of input channels. | |
- `use_conv` (bool): Whether to use convolution. | |
- `dims` (int, optional): The number of dimensions. Defaults to 2. | |
- `out_channels` (int, optional): The number of output channels. Defaults to None. | |
- `padding` (int, optional): The padding size. Defaults to 1. | |
- `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
- `device` (torch.device, optional): The device. Defaults to None. | |
- `operations` (any, optional): The operations. Defaults to oai_ops. | |
""" | |
super().__init__() | |
self.channels = channels | |
self.out_channels = out_channels or channels | |
self.use_conv = use_conv | |
self.dims = dims | |
stride = 2 if dims != 3 else (1, 2, 2) | |
self.op = operations.conv_nd( | |
dims, | |
self.channels, | |
self.out_channels, | |
3, | |
stride=stride, | |
padding=padding, | |
dtype=dtype, | |
device=device, | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""#### Forward pass for the downsample layer. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
assert x.shape[1] == self.channels | |
return self.op(x) | |
class ResBlock1(TimestepBlock1): | |
"""#### Class representing a residual block layer.""" | |
def __init__( | |
self, | |
channels: int, | |
emb_channels: int, | |
dropout: float, | |
out_channels: Optional[int] = None, | |
use_conv: bool = False, | |
use_scale_shift_norm: bool = False, | |
dims: int = 2, | |
use_checkpoint: bool = False, | |
up: bool = False, | |
down: bool = False, | |
kernel_size: int = 3, | |
exchange_temb_dims: bool = False, | |
skip_t_emb: bool = False, | |
dtype: Optional[torch.dtype] = None, | |
device: Optional[torch.device] = None, | |
operations: Any = oai_ops, | |
): | |
"""#### Initialize the residual block layer. | |
#### Args: | |
- `channels` (int): The number of input channels. | |
- `emb_channels` (int): The number of embedding channels. | |
- `dropout` (float): The dropout rate. | |
- `out_channels` (int, optional): The number of output channels. Defaults to None. | |
- `use_conv` (bool, optional): Whether to use convolution. Defaults to False. | |
- `use_scale_shift_norm` (bool, optional): Whether to use scale shift normalization. Defaults to False. | |
- `dims` (int, optional): The number of dimensions. Defaults to 2. | |
- `use_checkpoint` (bool, optional): Whether to use checkpointing. Defaults to False. | |
- `up` (bool, optional): Whether to use upsampling. Defaults to False. | |
- `down` (bool, optional): Whether to use downsampling. Defaults to False. | |
- `kernel_size` (int, optional): The kernel size. Defaults to 3. | |
- `exchange_temb_dims` (bool, optional): Whether to exchange embedding dimensions. Defaults to False. | |
- `skip_t_emb` (bool, optional): Whether to skip embedding. Defaults to False. | |
- `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
- `device` (torch.device, optional): The device. Defaults to None. | |
- `operations` (any, optional): The operations. Defaults to oai_ops. | |
""" | |
super().__init__() | |
self.channels = channels | |
self.emb_channels = emb_channels | |
self.dropout = dropout | |
self.out_channels = out_channels or channels | |
self.use_conv = use_conv | |
self.use_checkpoint = use_checkpoint | |
self.use_scale_shift_norm = use_scale_shift_norm | |
self.exchange_temb_dims = exchange_temb_dims | |
padding = kernel_size // 2 | |
self.in_layers = nn.Sequential( | |
operations.GroupNorm(32, channels, dtype=dtype, device=device), | |
nn.SiLU(), | |
operations.conv_nd( | |
dims, | |
channels, | |
self.out_channels, | |
kernel_size, | |
padding=padding, | |
dtype=dtype, | |
device=device, | |
), | |
) | |
self.updown = up or down | |
self.h_upd = self.x_upd = nn.Identity() | |
self.skip_t_emb = skip_t_emb | |
self.emb_layers = nn.Sequential( | |
nn.SiLU(), | |
operations.Linear( | |
emb_channels, | |
(2 * self.out_channels if use_scale_shift_norm else self.out_channels), | |
dtype=dtype, | |
device=device, | |
), | |
) | |
self.out_layers = nn.Sequential( | |
operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device), | |
nn.SiLU(), | |
nn.Dropout(p=dropout), | |
operations.conv_nd( | |
dims, | |
self.out_channels, | |
self.out_channels, | |
kernel_size, | |
padding=padding, | |
dtype=dtype, | |
device=device, | |
), | |
) | |
if self.out_channels == channels: | |
self.skip_connection = nn.Identity() | |
else: | |
self.skip_connection = operations.conv_nd( | |
dims, channels, self.out_channels, 1, dtype=dtype, device=device | |
) | |
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
"""#### Forward pass for the residual block layer. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `emb` (torch.Tensor): The embedding tensor. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
return sampling_util.checkpoint( | |
self._forward, (x, emb), self.parameters(), self.use_checkpoint | |
) | |
def _forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
"""#### Internal forward pass for the residual block layer. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `emb` (torch.Tensor): The embedding tensor. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
h = self.in_layers(x) | |
emb_out = None | |
if not self.skip_t_emb: | |
emb_out = self.emb_layers(emb).type(h.dtype) | |
while len(emb_out.shape) < len(h.shape): | |
emb_out = emb_out[..., None] | |
if emb_out is not None: | |
h = h + emb_out | |
h = self.out_layers(h) | |
return self.skip_connection(x) + h | |
ops = cast.disable_weight_init | |
class ResnetBlock(nn.Module): | |
"""#### Class representing a ResNet block layer.""" | |
def __init__( | |
self, | |
*, | |
in_channels: int, | |
out_channels: Optional[int] = None, | |
conv_shortcut: bool = False, | |
dropout: float, | |
temb_channels: int = 512, | |
): | |
"""#### Initialize the ResNet block layer. | |
#### Args: | |
- `in_channels` (int): The number of input channels. | |
- `out_channels` (int, optional): The number of output channels. Defaults to None. | |
- `conv_shortcut` (bool, optional): Whether to use convolution shortcut. Defaults to False. | |
- `dropout` (float): The dropout rate. | |
- `temb_channels` (int, optional): The number of embedding channels. Defaults to 512. | |
""" | |
super().__init__() | |
self.in_channels = in_channels | |
out_channels = in_channels if out_channels is None else out_channels | |
self.out_channels = out_channels | |
self.use_conv_shortcut = conv_shortcut | |
self.swish = torch.nn.SiLU(inplace=True) | |
self.norm1 = Attention.Normalize(in_channels) | |
self.conv1 = ops.Conv2d( | |
in_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
) | |
self.norm2 = Attention.Normalize(out_channels) | |
self.dropout = torch.nn.Dropout(dropout, inplace=True) | |
self.conv2 = ops.Conv2d( | |
out_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
) | |
if self.in_channels != self.out_channels: | |
self.nin_shortcut = ops.Conv2d( | |
in_channels, out_channels, kernel_size=1, stride=1, padding=0 | |
) | |
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: | |
"""#### Forward pass for the ResNet block layer. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `temb` (torch.Tensor): The embedding tensor. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
h = x | |
h = self.norm1(h) | |
h = self.swish(h) | |
h = self.conv1(h) | |
h = self.norm2(h) | |
h = self.swish(h) | |
h = self.dropout(h) | |
h = self.conv2(h) | |
if self.in_channels != self.out_channels: | |
x = self.nin_shortcut(x) | |
return x + h | |