Spaces:
Sleeping
Sleeping
File size: 5,259 Bytes
b3fb4dd 1379e6f 49ebc1f b3fb4dd 49ebc1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
import torch.nn as nn
from .Modules.conformer import ConformerEncoder, ConformerDecoder
from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
from .kan.fasterkan import FasterKAN
from kan import KAN
class ConvBlock(nn.Module):
def __init__(self, args, num_layer) -> None:
super().__init__()
if args.activation == 'silu':
self.activation = nn.SiLU()
else:
self.activation = nn.ReLU()
in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
out_channels = args.encoder_dims[num_layer] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
self.layers = nn.Sequential(
nn.Conv1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=args.kernel_size,
stride=1, padding='same', bias=False),
nn.BatchNorm1d(num_features=out_channels),
self.activation,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
class CNNEncoder(nn.Module):
def __init__(self, args) -> None:
super().__init__()
print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output)
if args.activation == 'silu':
self.activation = nn.SiLU()
else:
self.activation = nn.ReLU()
self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels,
kernel_size=3, out_channels = args.encoder_dims[0], stride=1, padding = 'same', bias = False),
nn.BatchNorm1d(args.encoder_dims[0]),
self.activation,
)
self.layers = nn.ModuleList([ConvBlock(args, i+1)
for i in range(args.num_layers)])
self.pool = nn.MaxPool1d(2)
self.output_dim = args.encoder_dims[-1]
self.min_seq_len = 2
self.avg_output = args.avg_output
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x.shape)==2:
x = x.unsqueeze(1)
if len(x.shape)==3 and x.shape[-1]==1:
x = x.permute(0,2,1)
x = self.embedding(x)
for m in self.layers:
x = m(x)
if x.shape[-1] > self.min_seq_len:
x = self.pool(x)
if self.avg_output:
x = x.mean(dim=-1)
return x
class MultiEncoder(nn.Module):
def __init__(self, args, conformer_args):
super().__init__()
self.backbone = CNNEncoder(args)
self.backbone.avg_output = False
self.head_size = conformer_args.encoder_dim // conformer_args.num_heads
self.rotary_ndims = int(self.head_size * 0.5)
self.pe = RotaryEmbedding(self.rotary_ndims)
self.encoder = ConformerEncoder(conformer_args)
self.output_dim = conformer_args.encoder_dim
self.avg_output = args.avg_output
def forward(self, x):
# Store backbone output in a separate tensor
backbone_out = self.backbone(x)
# Create x_enc from backbone_out
if len(backbone_out.shape) == 2:
x_enc = backbone_out.unsqueeze(1).clone()
else:
x_enc = backbone_out.permute(0,2,1).clone()
RoPE = self.pe(x_enc, x_enc.shape[1])
x_enc = self.encoder(x_enc, RoPE)
if len(x_enc.shape) == 3:
if self.avg_output:
x_enc = x_enc.sum(dim=1)
else:
x_enc = x_enc.permute(0,2,1)
# Return x_enc and the original backbone output
return x_enc, backbone_out
class DualEncoder(nn.Module):
def __init__(self, args_x, args_f, conformer_args) -> None:
super().__init__()
self.encoder_x = CNNEncoder(args_x)
self.encoder_f = MultiEncoder(args_f, conformer_args)
total_output_dim = args_x.encoder_dims[-1] + args_f.encoder_dims[-1]
self.regressor = nn.Sequential(
nn.Linear(total_output_dim, total_output_dim//2),
nn.BatchNorm1d(total_output_dim//2),
nn.SiLU(),
nn.Linear(total_output_dim//2, 1)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1 = self.encoder_x(x)
x2, _ = self.encoder_f(x)
logits = torch.cat([x1, x2], dim=-1)
return self.regressor(logits).squeeze()
class CNNKan(nn.Module):
def __init__(self, args, conformer_args, kan_args):
super().__init__()
self.backbone = CNNEncoder(args)
# self.kan = KAN(width=kan_args['layers_hidden'])
self.kan = FasterKAN(**kan_args)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.backbone(x)
x = x.mean(dim=1)
return self.kan(x)
class KanEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.kan_x = FasterKAN(**args)
self.kan_f = FasterKAN(**args)
self.kan_out = FasterKAN(layers_hidden=[args['layers_hidden'][-1]*2, 8,8,1])
def forward(self, x: torch.Tensor, f: torch.Tensor) -> torch.Tensor:
x = self.kan_x(x)
f = self.kan_f(f)
out = torch.cat([x, f], dim=-1)
return self.kan_out(out)
|