|
import torch |
|
import torch.nn as nn |
|
class TxnAnomalyGRU(nn.Module): |
|
def __init__(self, input_dim=32, hidden_dim=128, num_layers=2, dropout=0.3): |
|
super().__init__() |
|
self.gru = nn.GRU(input_dim, hidden_dim, num_layers=num_layers, bidirectional=True, batch_first=True, dropout=dropout) |
|
self.batchnorm = nn.BatchNorm1d(hidden_dim * 2) |
|
self.dropout = nn.Dropout(dropout) |
|
self.fc1 = nn.Linear(hidden_dim * 2, 64) |
|
self.relu = nn.ReLU() |
|
self.out = nn.Linear(64, 1) |
|
self.sigmoid = nn.Sigmoid() |
|
def forward(self, x): |
|
h0 = torch.zeros(self.gru.num_layers * 2, x.size(0), self.gru.hidden_size).to(x.device) |
|
out, _ = self.gru(x, h0) |
|
out = out[:, -1, :] |
|
out = self.batchnorm(out) |
|
out = self.dropout(out) |
|
out = self.relu(self.fc1(out)) |
|
out = self.sigmoid(self.out(out)) |
|
return out |
|
|