import torch import torch.nn as nn import torch.nn.functional as F # Changelog since original version: # xATGLU instead of top linear in transformer block # Added a learned residual scale to all blocks and all residuals. This allowed bfloat16 training to stabilize, prior it was just exploding. # This architecture was my attempt at the following Simple Diffusion paper with some modifications: # https://arxiv.org/pdf/2410.19324v1 # Very similar to GeGLU or SwiGLU, there's a learned gate FN, uses arctan as the activation fn. class xATGLU(nn.Module): def __init__(self, input_dim, output_dim, bias=True): super().__init__() # GATE path | VALUE path self.proj = nn.Linear(input_dim, output_dim * 2, bias=bias) nn.init.kaiming_normal_(self.proj.weight, nonlinearity='linear') self.alpha = nn.Parameter(torch.zeros(1)) self.half_pi = torch.pi / 2 self.inv_pi = 1 / torch.pi def forward(self, x): projected = self.proj(x) gate_path, value_path = projected.chunk(2, dim=-1) # Apply arctan gating with expanded range via learned alpha -- https://arxiv.org/pdf/2405.20768 gate = (torch.arctan(gate_path) + self.half_pi) * self.inv_pi expanded_gate = gate * (1 + 2 * self.alpha) - self.alpha return expanded_gate * value_path # g(x) × y class ResBlock(nn.Module): def __init__(self, channels): super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.norm1 = nn.GroupNorm(32, channels) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.norm2 = nn.GroupNorm(32, channels) self.learned_residual_scale = nn.Parameter(torch.ones(1) * 0.1) def forward(self, x): h = self.conv1(F.silu(self.norm1(x))) h = self.conv2(F.silu(self.norm2(h))) return x + h * self.learned_residual_scale class TransformerBlock(nn.Module): def __init__(self, channels, num_heads=8): super().__init__() self.norm1 = nn.LayerNorm(channels) self.norm2 = nn.LayerNorm(channels) # Params recommended by TPA paper, seem to work fine. self.attn = nn.MultiheadAttention(channels, num_heads) self.mlp = nn.Sequential( xATGLU(channels, 2 * channels, bias=False), nn.Linear(2 * channels, channels, bias=False) # Candidate for a bias ) self.learned_residual_scale_attn = nn.Parameter(torch.ones(1) * 0.1) self.learned_residual_scale_mlp = nn.Parameter(torch.ones(1) * 0.1) def forward(self, x): # Input shape B C H W b, c, h, w = x.shape x = x.reshape(b, h * w, c) # [B, H*W, C] # Pre-norm architecture, this was really helpful for network stability when using bf16 identity = x x = self.norm1(x) h_attn, _ = self.attn(x, x, x) x = identity + h_attn * self.learned_residual_scale_attn identity = x x = self.norm2(x) h_mlp = self.mlp(x) x = identity + h_mlp * self.learned_residual_scale_mlp # Reshape back to B C H W x = x.permute(1, 2, 0).reshape(b, c, h, w) return x class LevelBlock(nn.Module): def __init__(self, channels, num_blocks, block_type='res'): super().__init__() self.blocks = nn.ModuleList() for _ in range(num_blocks): if block_type == 'transformer': self.blocks.append(TransformerBlock(channels)) else: self.blocks.append(ResBlock(channels)) def forward(self, x): for block in self.blocks: x = block(x) return x class AsymmetricResidualUDiT(nn.Module): def __init__(self, in_channels=3, # Input color channels base_channels=128, # Initial feature size, dramatically increases parameter size of network. patch_size=2, # Smaller patches dramatically increases flops and compute expenses. Recommend >=4 unless you have real compute. num_levels=3, # Feature downsample, essentially the unet depth -- so we down/upsample three times. Dramatically increases parameters as you increase. encoder_blocks=3, # Can be different number of blocks VS decoder_blocks decoder_blocks=7, # Can be different number of blocks VS encoder_blocks encoder_transformer_thresh=2, #When to start using transformer blocks instead of res blocks in the encoder. (>=) decoder_transformer_thresh=4, #When to stop using transformer blocks instead of res blocks in the decoder. (<=) mid_blocks=16, # Number of middle transformer blocks. Relatively cheap as this is at the bottom of the unet feature bottleneck. ): super().__init__() self.learned_middle_residual_scale = nn.Parameter(torch.ones(1) * 0.1) # Initial projection from image space self.patch_embed = nn.Conv2d(in_channels, base_channels, kernel_size=patch_size, stride=patch_size) self.encoders = nn.ModuleList() curr_channels = base_channels for level in range(num_levels): use_transformer = level >= encoder_transformer_thresh # Use transformers for latter levels # Encoder blocks -- N = encoder_blocks self.encoders.append( LevelBlock(curr_channels, encoder_blocks, use_transformer) ) # Each successive decoder halves the size of the feature space for each step, except for the last level. if level < num_levels - 1: self.encoders.append( nn.Conv2d(curr_channels, curr_channels * 2, 1) ) curr_channels *= 2 # Middle transformer blocks -- N = mid_blocks self.middle = nn.ModuleList([ TransformerBlock(curr_channels) for _ in range(mid_blocks) ]) # Create decoder levels self.decoders = nn.ModuleList() for level in range(num_levels): use_transformer = level <= decoder_transformer_thresh # Use transformers for early levels (inverse of encoder) # Decoder blocks -- N = decoder_blocks self.decoders.append( LevelBlock(curr_channels, decoder_blocks, use_transformer) ) # Each successive decoder halves the size of the feature space for each step, except for the last level. if level < num_levels - 1: self.decoders.append( nn.Conv2d(curr_channels, curr_channels // 2, 1) ) curr_channels //= 2 # Final projection back to image space self.final_proj = nn.ConvTranspose2d(base_channels, in_channels, kernel_size=patch_size, stride=patch_size) def downsample(self, x): return F.avg_pool2d(x, kernel_size=2) def upsample(self, x): return F.interpolate(x, scale_factor=2, mode='nearest') def forward(self, x, t=None): # x shape B C H W # This patchifies our input, for example given an input shape like: # From 2, 3, 256, 256 x = self.patch_embed(x) # Our shape is now more channels and with smaller W and H # To 2, 128, 64, 64 # *Per resolution e.g. per num_level resolution block more or less # f(x) = fu( U(fm(D(h)) - D(h)) + h ) where h = fd(x) # # Where # 1. h = fd(x) : Encoder path processes input # 2. D(h) : Downsample the encoded features # 3. fm(D(h)) : Middle transformer blocks process downsampled features # 4. fm(D(h))-D(h): Subtract original downsampled features (residual connection) # 5. U(...) : Upsample the processed features # 6. ... + h : Add back original encoder features (skip connection) # 7. fu(...) : Decoder path processes the combined features residuals = [] curr_res = x # Encoder path (computing h = fd(x)) h = x for i, blocks in enumerate(self.encoders): if isinstance(blocks, LevelBlock): h = blocks(h) else: # Save residual before downsampling residuals.append(curr_res) # Downsample and update current residual h = self.downsample(blocks(h)) curr_res = h # Middle blocks (fm) x = h for block in self.middle: x = block(x) # Subtract the residual at this level (D(h)) x = x - curr_res * self.learned_middle_residual_scale # Decoder path (fu) for i, blocks in enumerate(self.decoders): if isinstance(blocks, LevelBlock): x = blocks(x) else: # Channel reduction x = blocks(x) # Upsample x = self.upsample(x) # Add residual from encoder at this level, LIFO, last residual added is the first we want, since it's this u-shape. curr_res = residuals.pop() x = x + curr_res * self.learned_middle_residual_scale # Final projection x = self.final_proj(x) return x