codewithdark commited on
Commit
9cde2c9
·
verified ·
1 Parent(s): ff99207

Update Model/recurrent_Block.py

Browse files
Files changed (1) hide show
  1. Model/recurrent_Block.py +24 -24
Model/recurrent_Block.py CHANGED
@@ -1,25 +1,25 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from typing import Optional, Tuple
5
- from Model.multi_head_Attention import MultiHeadAttention
6
-
7
- # Recurrent Block (Processing Over Time)
8
- class RecurrentBlock(nn.Module):
9
- def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
10
- super().__init__()
11
- self.attention = MultiHeadAttention(d_model, num_heads, dropout)
12
- self.norm1, self.norm2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)
13
- self.feed_forward = nn.Sequential(
14
- nn.Linear(d_model, 4 * d_model),
15
- nn.GELU(),
16
- nn.Linear(4 * d_model, d_model),
17
- nn.Dropout(dropout)
18
- )
19
- self.state_proj = nn.Linear(d_model, d_model)
20
-
21
- def forward(self, x: torch.Tensor, recurrent_state: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
22
- recurrent_state = self.state_proj(recurrent_state)
23
- x = x + recurrent_state
24
- attended = self.attention(self.norm1(x), mask)
25
  return x + attended + self.feed_forward(self.norm2(x)), x
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Tuple
5
+ from multi_head_Attention import MultiHeadAttention
6
+
7
+ # Recurrent Block (Processing Over Time)
8
+ class RecurrentBlock(nn.Module):
9
+ def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
10
+ super().__init__()
11
+ self.attention = MultiHeadAttention(d_model, num_heads, dropout)
12
+ self.norm1, self.norm2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)
13
+ self.feed_forward = nn.Sequential(
14
+ nn.Linear(d_model, 4 * d_model),
15
+ nn.GELU(),
16
+ nn.Linear(4 * d_model, d_model),
17
+ nn.Dropout(dropout)
18
+ )
19
+ self.state_proj = nn.Linear(d_model, d_model)
20
+
21
+ def forward(self, x: torch.Tensor, recurrent_state: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
22
+ recurrent_state = self.state_proj(recurrent_state)
23
+ x = x + recurrent_state
24
+ attended = self.attention(self.norm1(x), mask)
25
  return x + attended + self.feed_forward(self.norm2(x)), x