codewithdark commited on
Commit
2a4e4bc
·
verified ·
1 Parent(s): 0b515c5

Update Model/latent_Recurrent.py

Browse files
Files changed (1) hide show
  1. Model/latent_Recurrent.py +21 -21
Model/latent_Recurrent.py CHANGED
@@ -1,22 +1,22 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from typing import Optional, Tuple
5
- from Model.prelude_Block import PreludeBlock
6
- from Model.recurrent_Block import RecurrentBlock
7
- from Model.codaBlock import CodaBlock
8
-
9
- # Full Latent Recurrent Depth Model
10
- class LatentRecurrentDepthLM(nn.Module):
11
- def __init__(self, vocab_size: int, d_model: int, num_heads: int, dropout: float = 0.1):
12
- super().__init__()
13
- self.prelude = PreludeBlock(vocab_size, d_model, num_heads, dropout)
14
- self.recurrent = RecurrentBlock(d_model, num_heads, dropout)
15
- self.coda = CodaBlock(d_model, vocab_size)
16
-
17
- def forward(self, x: torch.Tensor, num_iterations: int, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
18
- hidden = self.prelude(x, mask)
19
- recurrent_state = torch.zeros_like(hidden)
20
- for _ in range(num_iterations):
21
- hidden, recurrent_state = self.recurrent(hidden, recurrent_state, mask)
22
  return self.coda(hidden)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Tuple
5
+ from prelude_Block import PreludeBlock
6
+ from recurrent_Block import RecurrentBlock
7
+ from codaBlock import CodaBlock
8
+
9
+ # Full Latent Recurrent Depth Model
10
+ class LatentRecurrentDepthLM(nn.Module):
11
+ def __init__(self, vocab_size: int, d_model: int, num_heads: int, dropout: float = 0.1):
12
+ super().__init__()
13
+ self.prelude = PreludeBlock(vocab_size, d_model, num_heads, dropout)
14
+ self.recurrent = RecurrentBlock(d_model, num_heads, dropout)
15
+ self.coda = CodaBlock(d_model, vocab_size)
16
+
17
+ def forward(self, x: torch.Tensor, num_iterations: int, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
18
+ hidden = self.prelude(x, mask)
19
+ recurrent_state = torch.zeros_like(hidden)
20
+ for _ in range(num_iterations):
21
+ hidden, recurrent_state = self.recurrent(hidden, recurrent_state, mask)
22
  return self.coda(hidden)