File size: 5,259 Bytes
b3fb4dd
 
1379e6f
 
49ebc1f
 
b3fb4dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49ebc1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import torch.nn as nn
from .Modules.conformer import ConformerEncoder, ConformerDecoder
from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
from .kan.fasterkan import FasterKAN
from kan import KAN

class ConvBlock(nn.Module):
  def __init__(self, args, num_layer) -> None:
    super().__init__()
    if args.activation == 'silu':
        self.activation = nn.SiLU()
    else:
        self.activation = nn.ReLU()
    in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
    out_channels = args.encoder_dims[num_layer] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
    self.layers = nn.Sequential(
        nn.Conv1d(in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=args.kernel_size,
                stride=1, padding='same', bias=False),
        nn.BatchNorm1d(num_features=out_channels),
        self.activation,
    )

  def forward(self, x: torch.Tensor) -> torch.Tensor:  
    return self.layers(x)

class CNNEncoder(nn.Module):
    def __init__(self, args) -> None:
        super().__init__()
        print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output)
        if args.activation == 'silu':
            self.activation = nn.SiLU()
        else:
            self.activation = nn.ReLU()
        self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels,
                kernel_size=3, out_channels = args.encoder_dims[0], stride=1, padding = 'same', bias = False),
                        nn.BatchNorm1d(args.encoder_dims[0]),
                        self.activation,
        )
        
        self.layers = nn.ModuleList([ConvBlock(args, i+1)
        for i in range(args.num_layers)])
        self.pool = nn.MaxPool1d(2)
        self.output_dim = args.encoder_dims[-1]
        self.min_seq_len = 2 
        self.avg_output = args.avg_output
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if len(x.shape)==2:
            x = x.unsqueeze(1)
        if len(x.shape)==3 and x.shape[-1]==1:
            x = x.permute(0,2,1)
        x = self.embedding(x)
        for m in self.layers:
            x = m(x)
            if x.shape[-1] > self.min_seq_len:
                x = self.pool(x)
        if self.avg_output:
            x = x.mean(dim=-1)
        return x


class MultiEncoder(nn.Module):
    def __init__(self, args, conformer_args):
        super().__init__()
        self.backbone = CNNEncoder(args)
        self.backbone.avg_output = False
        self.head_size = conformer_args.encoder_dim // conformer_args.num_heads
        self.rotary_ndims = int(self.head_size * 0.5)
        self.pe = RotaryEmbedding(self.rotary_ndims)
        self.encoder = ConformerEncoder(conformer_args)
        self.output_dim = conformer_args.encoder_dim
        self.avg_output = args.avg_output
        
    def forward(self, x):
        # Store backbone output in a separate tensor
        backbone_out = self.backbone(x)
        
        # Create x_enc from backbone_out
        if len(backbone_out.shape) == 2:
            x_enc = backbone_out.unsqueeze(1).clone()
        else:
            x_enc = backbone_out.permute(0,2,1).clone()
            
        RoPE = self.pe(x_enc, x_enc.shape[1])
        x_enc = self.encoder(x_enc, RoPE)
        
        if len(x_enc.shape) == 3:
            if self.avg_output:
                x_enc = x_enc.sum(dim=1)
            else:
                x_enc = x_enc.permute(0,2,1)
                
        # Return x_enc and the original backbone output
        return x_enc, backbone_out

class DualEncoder(nn.Module):
    def __init__(self, args_x, args_f, conformer_args) -> None:
        super().__init__()
        self.encoder_x = CNNEncoder(args_x)
        self.encoder_f = MultiEncoder(args_f, conformer_args)
        total_output_dim = args_x.encoder_dims[-1] + args_f.encoder_dims[-1]
        self.regressor = nn.Sequential(
            nn.Linear(total_output_dim, total_output_dim//2),
            nn.BatchNorm1d(total_output_dim//2),
            nn.SiLU(),
            nn.Linear(total_output_dim//2, 1)
        )
            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.encoder_x(x)
        x2, _ = self.encoder_f(x)
        logits = torch.cat([x1, x2], dim=-1)
        return self.regressor(logits).squeeze()

class CNNKan(nn.Module):
    def __init__(self, args, conformer_args, kan_args):
        super().__init__()
        self.backbone = CNNEncoder(args)
        # self.kan = KAN(width=kan_args['layers_hidden'])
        self.kan = FasterKAN(**kan_args)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        x = x.mean(dim=1)
        return self.kan(x)

class KanEncoder(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.kan_x = FasterKAN(**args)
        self.kan_f = FasterKAN(**args)
        self.kan_out = FasterKAN(layers_hidden=[args['layers_hidden'][-1]*2, 8,8,1])

    def forward(self, x: torch.Tensor, f: torch.Tensor) -> torch.Tensor:
        x = self.kan_x(x)
        f = self.kan_f(f)
        out = torch.cat([x, f], dim=-1)
        return self.kan_out(out)