Update Model/latent_Recurrent.py
Browse files- Model/latent_Recurrent.py +21 -21
Model/latent_Recurrent.py
CHANGED
@@ -1,22 +1,22 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from typing import Optional, Tuple
|
5 |
-
from
|
6 |
-
from
|
7 |
-
from
|
8 |
-
|
9 |
-
# Full Latent Recurrent Depth Model
|
10 |
-
class LatentRecurrentDepthLM(nn.Module):
|
11 |
-
def __init__(self, vocab_size: int, d_model: int, num_heads: int, dropout: float = 0.1):
|
12 |
-
super().__init__()
|
13 |
-
self.prelude = PreludeBlock(vocab_size, d_model, num_heads, dropout)
|
14 |
-
self.recurrent = RecurrentBlock(d_model, num_heads, dropout)
|
15 |
-
self.coda = CodaBlock(d_model, vocab_size)
|
16 |
-
|
17 |
-
def forward(self, x: torch.Tensor, num_iterations: int, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
18 |
-
hidden = self.prelude(x, mask)
|
19 |
-
recurrent_state = torch.zeros_like(hidden)
|
20 |
-
for _ in range(num_iterations):
|
21 |
-
hidden, recurrent_state = self.recurrent(hidden, recurrent_state, mask)
|
22 |
return self.coda(hidden)
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from typing import Optional, Tuple
|
5 |
+
from prelude_Block import PreludeBlock
|
6 |
+
from recurrent_Block import RecurrentBlock
|
7 |
+
from codaBlock import CodaBlock
|
8 |
+
|
9 |
+
# Full Latent Recurrent Depth Model
|
10 |
+
class LatentRecurrentDepthLM(nn.Module):
|
11 |
+
def __init__(self, vocab_size: int, d_model: int, num_heads: int, dropout: float = 0.1):
|
12 |
+
super().__init__()
|
13 |
+
self.prelude = PreludeBlock(vocab_size, d_model, num_heads, dropout)
|
14 |
+
self.recurrent = RecurrentBlock(d_model, num_heads, dropout)
|
15 |
+
self.coda = CodaBlock(d_model, vocab_size)
|
16 |
+
|
17 |
+
def forward(self, x: torch.Tensor, num_iterations: int, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
18 |
+
hidden = self.prelude(x, mask)
|
19 |
+
recurrent_state = torch.zeros_like(hidden)
|
20 |
+
for _ in range(num_iterations):
|
21 |
+
hidden, recurrent_state = self.recurrent(hidden, recurrent_state, mask)
|
22 |
return self.coda(hidden)
|