Spaces:
Running
on
L40S
Running
on
L40S
from typing import Optional | |
import torch | |
from torch import nn | |
class MultiHeadAttention(nn.Module): | |
def __init__( | |
self, | |
direction_input_dim: int, | |
conditioning_input_dim: int, | |
latent_dim: int, | |
num_heads: int, | |
): | |
""" | |
Multi-Head Attention module. | |
Args: | |
direction_input_dim (int): The input dimension of the directional input. | |
conditioning_input_dim (int): The input dimension of the conditioning input. | |
latent_dim (int): The latent dimension of the module. | |
num_heads (int): The number of heads to use in the attention mechanism. | |
""" | |
super().__init__() | |
assert latent_dim % num_heads == 0, "latent_dim must be divisible by num_heads" | |
self.num_heads = num_heads | |
self.head_dim = latent_dim // num_heads | |
self.scale = self.head_dim**-0.5 | |
self.query = nn.Linear(direction_input_dim, latent_dim) | |
self.key = nn.Linear(conditioning_input_dim, latent_dim) | |
self.value = nn.Linear(conditioning_input_dim, latent_dim) | |
self.fc_out = nn.Linear(latent_dim, latent_dim) | |
def forward( | |
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor | |
) -> torch.Tensor: | |
""" | |
Forward pass of the Multi-Head Attention module. | |
Args: | |
query (torch.Tensor): The directional input tensor. | |
key (torch.Tensor): The conditioning input tensor for the keys. | |
value (torch.Tensor): The conditioning input tensor for the values. | |
Returns: | |
torch.Tensor: The output tensor of the Multi-Head Attention module. | |
""" | |
batch_size = query.size(0) | |
Q = ( | |
self.query(query) | |
.view(batch_size, -1, self.num_heads, self.head_dim) | |
.transpose(1, 2) | |
) | |
K = ( | |
self.key(key) | |
.view(batch_size, -1, self.num_heads, self.head_dim) | |
.transpose(1, 2) | |
) | |
V = ( | |
self.value(value) | |
.view(batch_size, -1, self.num_heads, self.head_dim) | |
.transpose(1, 2) | |
) | |
attention = ( | |
torch.einsum("bnqk,bnkh->bnqh", [Q, K.transpose(-2, -1)]) * self.scale | |
) | |
attention = torch.softmax(attention, dim=-1) | |
out = torch.einsum("bnqh,bnhv->bnqv", [attention, V]) | |
out = ( | |
out.transpose(1, 2) | |
.contiguous() | |
.view(batch_size, -1, self.num_heads * self.head_dim) | |
) | |
out = self.fc_out(out).squeeze(1) | |
return out | |
class AttentionLayer(nn.Module): | |
def __init__( | |
self, | |
direction_input_dim: int, | |
conditioning_input_dim: int, | |
latent_dim: int, | |
num_heads: int, | |
): | |
""" | |
Attention Layer module. | |
Args: | |
direction_input_dim (int): The input dimension of the directional input. | |
conditioning_input_dim (int): The input dimension of the conditioning input. | |
latent_dim (int): The latent dimension of the module. | |
num_heads (int): The number of heads to use in the attention mechanism. | |
""" | |
super().__init__() | |
self.mha = MultiHeadAttention( | |
direction_input_dim, conditioning_input_dim, latent_dim, num_heads | |
) | |
self.norm1 = nn.LayerNorm(latent_dim) | |
self.norm2 = nn.LayerNorm(latent_dim) | |
self.fc = nn.Sequential( | |
nn.Linear(latent_dim, latent_dim), | |
nn.ReLU(), | |
nn.Linear(latent_dim, latent_dim), | |
) | |
def forward( | |
self, directional_input: torch.Tensor, conditioning_input: torch.Tensor | |
) -> torch.Tensor: | |
""" | |
Forward pass of the Attention Layer module. | |
Args: | |
directional_input (torch.Tensor): The directional input tensor. | |
conditioning_input (torch.Tensor): The conditioning input tensor. | |
Returns: | |
torch.Tensor: The output tensor of the Attention Layer module. | |
""" | |
attn_output = self.mha( | |
directional_input, conditioning_input, conditioning_input | |
) | |
out1 = self.norm1(attn_output + directional_input) | |
fc_output = self.fc(out1) | |
out2 = self.norm2(fc_output + out1) | |
return out2 | |
class Decoder(nn.Module): | |
def __init__( | |
self, | |
in_dim: int, | |
conditioning_input_dim: int, | |
hidden_features: int, | |
num_heads: int, | |
num_layers: int, | |
out_activation: Optional[nn.Module], | |
): | |
""" | |
Decoder module. | |
Args: | |
in_dim (int): The input dimension of the module. | |
conditioning_input_dim (int): The input dimension of the conditioning input. | |
hidden_features (int): The number of hidden features in the module. | |
num_heads (int): The number of heads to use in the attention mechanism. | |
num_layers (int): The number of layers in the module. | |
out_activation (nn.Module): The activation function to use on the output tensor. | |
""" | |
super().__init__() | |
self.residual_projection = nn.Linear( | |
in_dim, hidden_features | |
) # projection for residual connection | |
self.layers = nn.ModuleList( | |
[ | |
AttentionLayer( | |
hidden_features, conditioning_input_dim, hidden_features, num_heads | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.fc = nn.Linear(hidden_features, 3) # 3 for RGB | |
self.out_activation = out_activation | |
def forward( | |
self, x: torch.Tensor, conditioning_input: torch.Tensor | |
) -> torch.Tensor: | |
""" | |
Forward pass of the Decoder module. | |
Args: | |
x (torch.Tensor): The input tensor. | |
conditioning_input (torch.Tensor): The conditioning input tensor. | |
Returns: | |
torch.Tensor: The output tensor of the Decoder module. | |
""" | |
x = self.residual_projection(x) | |
for layer in self.layers: | |
x = layer(x, conditioning_input) | |
x = self.fc(x) | |
if self.out_activation is not None: | |
x = self.out_activation(x) | |
return x | |