Update modeling_latent_recurrent_depth.py
Browse files
modeling_latent_recurrent_depth.py
CHANGED
@@ -4,7 +4,146 @@ import torch.nn.functional as F
|
|
4 |
from typing import Optional, Tuple
|
5 |
import math
|
6 |
from transformers import PretrainedConfig, PreTrainedModel
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
# Configuration for the Latent Recurrent Depth Model
|
10 |
class LatentRecurrentDepthConfig(PretrainedConfig):
|
|
|
4 |
from typing import Optional, Tuple
|
5 |
import math
|
6 |
from transformers import PretrainedConfig, PreTrainedModel
|
7 |
+
|
8 |
+
class MultiHeadAttention(nn.Module):
|
9 |
+
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
|
10 |
+
super().__init__()
|
11 |
+
assert d_model % num_heads == 0
|
12 |
+
|
13 |
+
self.d_model = d_model
|
14 |
+
self.num_heads = num_heads
|
15 |
+
self.head_dim = d_model // num_heads
|
16 |
+
|
17 |
+
self.q_proj = nn.Linear(d_model, d_model)
|
18 |
+
self.k_proj = nn.Linear(d_model, d_model)
|
19 |
+
self.v_proj = nn.Linear(d_model, d_model)
|
20 |
+
self.o_proj = nn.Linear(d_model, d_model)
|
21 |
+
|
22 |
+
self.dropout = nn.Dropout(dropout)
|
23 |
+
|
24 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
25 |
+
batch_size, seq_len, d_model = x.shape
|
26 |
+
|
27 |
+
# Project and reshape for multi-head attention
|
28 |
+
q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
29 |
+
k = self.k_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
30 |
+
v = self.v_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
31 |
+
|
32 |
+
# Transpose for attention computation
|
33 |
+
q = q.transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
|
34 |
+
k = k.transpose(1, 2)
|
35 |
+
v = v.transpose(1, 2)
|
36 |
+
|
37 |
+
# Compute attention scores
|
38 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
39 |
+
|
40 |
+
if mask is not None:
|
41 |
+
scores = scores.masked_fill(mask == 0, float('-inf'))
|
42 |
+
|
43 |
+
attn_weights = F.softmax(scores, dim=-1)
|
44 |
+
attn_weights = self.dropout(attn_weights)
|
45 |
+
|
46 |
+
# Apply attention to values
|
47 |
+
out = torch.matmul(attn_weights, v) # (batch_size, num_heads, seq_len, head_dim)
|
48 |
+
out = out.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
|
49 |
+
out = out.reshape(batch_size, seq_len, d_model)
|
50 |
+
|
51 |
+
return self.o_proj(out)
|
52 |
+
|
53 |
+
class PreludeBlock(nn.Module):
|
54 |
+
def __init__(self, vocab_size: int, d_model: int, num_heads: int, dropout: float = 0.1):
|
55 |
+
super().__init__()
|
56 |
+
self.token_embedding = nn.Embedding(vocab_size, d_model)
|
57 |
+
self.pos_encoding = nn.Parameter(torch.zeros(1, 1024, d_model)) # Max sequence length of 1024
|
58 |
+
self.attention = MultiHeadAttention(d_model, num_heads, dropout)
|
59 |
+
self.norm1 = nn.LayerNorm(d_model)
|
60 |
+
self.norm2 = nn.LayerNorm(d_model)
|
61 |
+
self.feed_forward = nn.Sequential(
|
62 |
+
nn.Linear(d_model, 4 * d_model),
|
63 |
+
nn.GELU(),
|
64 |
+
nn.Linear(4 * d_model, d_model),
|
65 |
+
nn.Dropout(dropout)
|
66 |
+
)
|
67 |
+
|
68 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
69 |
+
seq_len = x.size(1)
|
70 |
+
# Embed tokens and add positional encoding
|
71 |
+
x = self.token_embedding(x) + self.pos_encoding[:, :seq_len, :]
|
72 |
+
|
73 |
+
# Self-attention block
|
74 |
+
attended = self.attention(self.norm1(x), mask)
|
75 |
+
x = x + attended
|
76 |
+
|
77 |
+
# Feed-forward block
|
78 |
+
x = x + self.feed_forward(self.norm2(x))
|
79 |
+
return x
|
80 |
+
|
81 |
+
class RecurrentBlock(nn.Module):
|
82 |
+
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
|
83 |
+
super().__init__()
|
84 |
+
self.attention = MultiHeadAttention(d_model, num_heads, dropout)
|
85 |
+
self.norm1 = nn.LayerNorm(d_model)
|
86 |
+
self.norm2 = nn.LayerNorm(d_model)
|
87 |
+
self.feed_forward = nn.Sequential(
|
88 |
+
nn.Linear(d_model, 4 * d_model),
|
89 |
+
nn.GELU(),
|
90 |
+
nn.Linear(4 * d_model, d_model),
|
91 |
+
nn.Dropout(dropout)
|
92 |
+
)
|
93 |
+
|
94 |
+
# Recurrent state projection
|
95 |
+
self.state_proj = nn.Linear(d_model, d_model)
|
96 |
+
|
97 |
+
def forward(self, x: torch.Tensor, recurrent_state: torch.Tensor,
|
98 |
+
mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
99 |
+
# Update recurrent state
|
100 |
+
recurrent_state = self.state_proj(recurrent_state)
|
101 |
+
|
102 |
+
# Combine input with recurrent state
|
103 |
+
x = x + recurrent_state
|
104 |
+
|
105 |
+
# Self-attention block
|
106 |
+
attended = self.attention(self.norm1(x), mask)
|
107 |
+
x = x + attended
|
108 |
+
|
109 |
+
# Feed-forward block
|
110 |
+
x = x + self.feed_forward(self.norm2(x))
|
111 |
+
|
112 |
+
return x, x # Return both output and new recurrent state
|
113 |
+
|
114 |
+
class CodaBlock(nn.Module):
|
115 |
+
def __init__(self, d_model: int, vocab_size: int):
|
116 |
+
super().__init__()
|
117 |
+
self.norm = nn.LayerNorm(d_model)
|
118 |
+
self.output_proj = nn.Linear(d_model, vocab_size)
|
119 |
+
|
120 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
121 |
+
x = self.norm(x)
|
122 |
+
return self.output_proj(x)
|
123 |
+
|
124 |
+
class LatentRecurrentDepthLM(nn.Module):
|
125 |
+
def __init__(self, vocab_size: int, d_model: int, num_heads: int, dropout: float = 0.1):
|
126 |
+
super().__init__()
|
127 |
+
self.prelude = PreludeBlock(vocab_size, d_model, num_heads, dropout)
|
128 |
+
self.recurrent = RecurrentBlock(d_model, num_heads, dropout)
|
129 |
+
self.coda = CodaBlock(d_model, vocab_size)
|
130 |
+
|
131 |
+
def forward(self, x: torch.Tensor, num_iterations: int,
|
132 |
+
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
133 |
+
# Initial embedding and processing
|
134 |
+
hidden = self.prelude(x, mask)
|
135 |
+
|
136 |
+
# Initialize recurrent state
|
137 |
+
recurrent_state = torch.zeros_like(hidden)
|
138 |
+
|
139 |
+
# Apply recurrent block multiple times
|
140 |
+
for _ in range(num_iterations):
|
141 |
+
hidden, recurrent_state = self.recurrent(hidden, recurrent_state, mask)
|
142 |
+
|
143 |
+
# Final output projection
|
144 |
+
return self.coda(hidden)
|
145 |
+
|
146 |
+
|
147 |
|
148 |
# Configuration for the Latent Recurrent Depth Model
|
149 |
class LatentRecurrentDepthConfig(PretrainedConfig):
|