Spaces:
Paused
Paused
File size: 472 Bytes
2d9b22b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
import torch.nn as nn
class BiGRU(nn.Module):
def __init__(self, input_features: int, hidden_features: int, num_layers: int):
super().__init__()
self.gru = nn.GRU(
input_features,
hidden_features,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
)
self.gru.flatten_parameters()
def forward(self, x: torch.Tensor):
return self.gru(x)[0]
|