Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from .Modules.conformer import ConformerEncoder, ConformerDecoder | |
from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding | |
from .kan.fasterkan import FasterKAN | |
class Sine(nn.Module): | |
def __init__(self, w0=1.0): | |
super().__init__() | |
self.w0 = w0 | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return torch.sin(self.w0 * x) | |
class MLPEncoder(nn.Module): | |
def __init__(self, args): | |
""" | |
Initialize an MLP with hidden layers, BatchNorm, and Dropout. | |
Args: | |
input_dim (int): Dimension of the input features. | |
hidden_dims (list of int): List of dimensions for hidden layers. | |
output_dim (int): Dimension of the output. | |
dropout (float): Dropout probability (default: 0.0). | |
""" | |
super(MLPEncoder, self).__init__() | |
layers = [] | |
prev_dim = args.input_dim | |
# Add hidden layers | |
for hidden_dim in args.hidden_dims: | |
layers.append(nn.Linear(prev_dim, hidden_dim)) | |
layers.append(nn.BatchNorm1d(hidden_dim)) | |
layers.append(nn.SiLU()) | |
if args.dropout > 0.0: | |
layers.append(nn.Dropout(args.dropout)) | |
prev_dim = hidden_dim | |
self.model = nn.Sequential(*layers) | |
self.output_dim = hidden_dim | |
def forward(self, x): | |
# if x.dim() == 2: | |
# x = x.unsqueeze(-1) | |
x = self.model(x) | |
# x = x.mean(-1) | |
return x | |
class ConvBlock(nn.Module): | |
def __init__(self, args, num_layer) -> None: | |
super().__init__() | |
if args.activation == 'silu': | |
self.activation = nn.SiLU() | |
elif args.activation == 'sine': | |
self.activation = Sine(w0=args.sine_w0) | |
else: | |
self.activation = nn.ReLU() | |
in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1] | |
out_channels = args.encoder_dims[num_layer] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1] | |
self.layers = nn.Sequential( | |
nn.Conv1d(in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=args.kernel_size, | |
stride=1, padding='same', bias=False), | |
nn.BatchNorm1d(num_features=out_channels), | |
self.activation, | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.layers(x) | |
class CNNEncoder(nn.Module): | |
def __init__(self, args) -> None: | |
super().__init__() | |
print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output) | |
if args.activation == 'silu': | |
self.activation = nn.SiLU() | |
elif args.activation == 'sine': | |
self.activation = Sine(w0=args.sine_w0) | |
else: | |
self.activation = nn.ReLU() | |
self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels, | |
kernel_size=3, out_channels = args.encoder_dims[0], stride=1, padding = 'same', bias = False), | |
nn.BatchNorm1d(args.encoder_dims[0]), | |
self.activation, | |
) | |
self.layers = nn.ModuleList([ConvBlock(args, i+1) | |
for i in range(args.num_layers)]) | |
self.pool = nn.MaxPool1d(2) | |
self.output_dim = args.encoder_dims[-1] | |
self.min_seq_len = 2 | |
self.avg_output = args.avg_output | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if len(x.shape)==2: | |
x = x.unsqueeze(1) | |
if len(x.shape)==3 and x.shape[-1]==1: | |
x = x.permute(0,2,1) | |
x = self.embedding(x) | |
for m in self.layers: | |
x = m(x) | |
if x.shape[-1] > self.min_seq_len: | |
x = self.pool(x) | |
if self.avg_output: | |
x = x.mean(dim=-1) | |
return x | |
class MultiEncoder(nn.Module): | |
def __init__(self, args, conformer_args): | |
super().__init__() | |
self.backbone = CNNEncoder(args) | |
self.backbone.avg_output = False | |
self.head_size = conformer_args.encoder_dim // conformer_args.num_heads | |
self.rotary_ndims = int(self.head_size * 0.5) | |
self.pe = RotaryEmbedding(self.rotary_ndims) | |
self.encoder = ConformerEncoder(conformer_args) | |
self.output_dim = conformer_args.encoder_dim | |
self.avg_output = args.avg_output | |
def forward(self, x): | |
# Store backbone output in a separate tensor | |
backbone_out = self.backbone(x) | |
# Create x_enc from backbone_out | |
if len(backbone_out.shape) == 2: | |
x_enc = backbone_out.unsqueeze(1).clone() | |
else: | |
x_enc = backbone_out.permute(0,2,1).clone() | |
RoPE = self.pe(x_enc, x_enc.shape[1]) | |
x_enc = self.encoder(x_enc, RoPE) | |
if len(x_enc.shape) == 3: | |
if self.avg_output: | |
x_enc = x_enc.sum(dim=1) | |
else: | |
x_enc = x_enc.permute(0,2,1) | |
# Return x_enc and the original backbone output | |
return x_enc, backbone_out | |
class DualEncoder(nn.Module): | |
def __init__(self, args_x, args_f, conformer_args) -> None: | |
super().__init__() | |
self.encoder_x = CNNEncoder(args_x) | |
self.encoder_f = MultiEncoder(args_f, conformer_args) | |
total_output_dim = args_x.encoder_dims[-1] + args_f.encoder_dims[-1] | |
self.regressor = nn.Sequential( | |
nn.Linear(total_output_dim, total_output_dim//2), | |
nn.BatchNorm1d(total_output_dim//2), | |
nn.SiLU(), | |
nn.Linear(total_output_dim//2, 1) | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x1 = self.encoder_x(x) | |
x2, _ = self.encoder_f(x) | |
logits = torch.cat([x1, x2], dim=-1) | |
return self.regressor(logits).squeeze() | |
class CNNKan(nn.Module): | |
def __init__(self, args, conformer_args, kan_args): | |
super().__init__() | |
self.backbone = CNNEncoder(args) | |
# self.kan = KAN(width=kan_args['layers_hidden']) | |
self.kan = FasterKAN(**kan_args) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.backbone(x) | |
x = x.mean(dim=1) | |
return self.kan(x) | |
class CNNKanFeaturesEncoder(nn.Module): | |
def __init__(self, args, mlp_args, kan_args): | |
super().__init__() | |
self.backbone = CNNEncoder(args) | |
self.mlp = MLPEncoder(mlp_args) | |
kan_args['layers_hidden'][0] += self.mlp.output_dim | |
self.kan = FasterKAN(**kan_args) | |
def forward(self, x: torch.Tensor, f: torch.Tensor) -> torch.Tensor: | |
x = self.backbone(x) | |
x = x.mean(dim=1) | |
f = self.mlp(f) | |
x_f = torch.cat([x, f], dim=-1) | |
return self.kan(x_f) | |
class KanEncoder(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.kan_x = FasterKAN(**args) | |
self.kan_f = FasterKAN(**args) | |
self.kan_out = FasterKAN(layers_hidden=[args['layers_hidden'][-1]*2, 8,8,1]) | |
def forward(self, x: torch.Tensor, f: torch.Tensor) -> torch.Tensor: | |
x = self.kan_x(x) | |
f = self.kan_f(f) | |
out = torch.cat([x, f], dim=-1) | |
return self.kan_out(out) | |
class MultiGraph(nn.Module): | |
def __init__(self, graph_net, args): | |
super().__init__() | |
self.graph_net = graph_net | |
self.cnn = CNNEncoder(args) | |
total_output_dim = args.encoder_dims[-1] | |
self.projection = nn.Sequential( | |
nn.Linear(total_output_dim, total_output_dim // 2), | |
nn.BatchNorm1d(total_output_dim // 2), | |
nn.SiLU(), | |
nn.Linear(total_output_dim // 2, 1) | |
) | |
def forward(self, g: torch.Tensor, x:torch.Tensor) -> torch.Tensor: | |
# g_out = self.graph_net(g) | |
x_out = self.cnn(x) | |
# g_out = g_out.expand(x.shape[0], -1) | |
# features = torch.cat([g_out, x_out], dim=-1) | |
return self.projection(x_out) | |
class ImplicitEncoder(nn.Module): | |
def __init__(self, transform_net, encoder_net): | |
super().__init__() | |
self.transform_net = transform_net | |
self.encoder_net = encoder_net | |
def get_weights_and_bises(self): | |
state_dict = self.transform_net.state_dict() | |
weights = tuple( | |
[v.permute(1, 0).unsqueeze(-1).unsqueeze(0) for w, v in state_dict.items() if "weight" in w] | |
) | |
biases = tuple([v.unsqueeze(-1).unsqueeze(0) for w, v in state_dict.items() if "bias" in w]) | |
return weights, biases | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
transformed_x = self.transform_net(x.permute(0, 2, 1)).permute(0, 2, 1) | |
inputs = self.get_weights_and_bises() | |
outputs = self.encoder_net(inputs, transformed_x) | |
return outputs | |