Update modular_Sewy2.py
Browse files- modular_Sewy2.py +22 -1
modular_Sewy2.py
CHANGED
@@ -2,7 +2,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|
2 |
from transformers.utils import logging
|
3 |
|
4 |
""" PyTorch Sewy model."""
|
5 |
-
"""Used
|
6 |
import math
|
7 |
import warnings
|
8 |
from typing import List, Optional, Tuple, Union
|
@@ -214,6 +214,8 @@ class SewyV2Config(PretrainedConfig):
|
|
214 |
unit_norm_eps = 1e-6,
|
215 |
resformer_lambda = 2.0,
|
216 |
neutreno_lambda=0.4,
|
|
|
|
|
217 |
**kwargs,
|
218 |
):
|
219 |
self.vocab_size = vocab_size
|
@@ -260,6 +262,8 @@ class SewyV2Config(PretrainedConfig):
|
|
260 |
self.unit_norm_eps = unit_norm_eps
|
261 |
self.resformer_lambda = resformer_lambda
|
262 |
self.neutreno_lambda = neutreno_lambda
|
|
|
|
|
263 |
super().__init__(
|
264 |
pad_token_id=pad_token_id,
|
265 |
bos_token_id=bos_token_id,
|
@@ -907,6 +911,9 @@ class SewyV2Attention(nn.Module):
|
|
907 |
|
908 |
self.neutreno_lambda = nn.Parameter(torch.tensor(float(config.neutreno_lambda)))
|
909 |
|
|
|
|
|
|
|
910 |
def _get_unit_norm(self, x,eps=1e-6):
|
911 |
"""
|
912 |
Normalize a tensor to unit norm
|
@@ -1080,6 +1087,10 @@ class SewyV2Attention(nn.Module):
|
|
1080 |
)
|
1081 |
attn_weights = attn_weights + attention_mask
|
1082 |
|
|
|
|
|
|
|
|
|
1083 |
# upcast attention to fp32
|
1084 |
attn_weights = nn.functional.softmax(
|
1085 |
attn_weights, dim=-1, dtype=torch.float32
|
@@ -1279,6 +1290,7 @@ class SewyV2FlashAttention2(SewyV2Attention):
|
|
1279 |
q_len,
|
1280 |
dropout=dropout_rate,
|
1281 |
softmax_scale=self.softmax_scale,
|
|
|
1282 |
)
|
1283 |
if self.q_head_dim != self.v_head_dim:
|
1284 |
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
@@ -1865,6 +1877,10 @@ class SewyV2ForCausalLM(SewyV2PreTrainedModel):
|
|
1865 |
self.s_z = nn.Parameter(torch.ones(self.vocab_size) * (1/config.hidden_size ** 0.5))
|
1866 |
self.s_z_init = 1
|
1867 |
self.s_z_scale = 1/config.hidden_size ** 0.5
|
|
|
|
|
|
|
|
|
1868 |
# Initialize weights and apply final processing
|
1869 |
self.post_init()
|
1870 |
|
@@ -2017,6 +2033,11 @@ class SewyV2ForCausalLM(SewyV2PreTrainedModel):
|
|
2017 |
logits = self.lm_head(hidden_states)
|
2018 |
logits = logits.float()
|
2019 |
|
|
|
|
|
|
|
|
|
|
|
2020 |
## nGPT
|
2021 |
|
2022 |
s_z = self.s_z * (self.s_z_init/self.s_z_scale)
|
|
|
2 |
from transformers.utils import logging
|
3 |
|
4 |
""" PyTorch Sewy model."""
|
5 |
+
"""Used deepseekv3 as starting point"""
|
6 |
import math
|
7 |
import warnings
|
8 |
from typing import List, Optional, Tuple, Union
|
|
|
214 |
unit_norm_eps = 1e-6,
|
215 |
resformer_lambda = 2.0,
|
216 |
neutreno_lambda=0.4,
|
217 |
+
final_logit_softcapping=30.0,
|
218 |
+
attn_logit_softcapping=50.0,
|
219 |
**kwargs,
|
220 |
):
|
221 |
self.vocab_size = vocab_size
|
|
|
262 |
self.unit_norm_eps = unit_norm_eps
|
263 |
self.resformer_lambda = resformer_lambda
|
264 |
self.neutreno_lambda = neutreno_lambda
|
265 |
+
self.final_logit_softcapping = final_logit_softcapping
|
266 |
+
self.attn_logit_softcapping = attn_logit_softcapping
|
267 |
super().__init__(
|
268 |
pad_token_id=pad_token_id,
|
269 |
bos_token_id=bos_token_id,
|
|
|
911 |
|
912 |
self.neutreno_lambda = nn.Parameter(torch.tensor(float(config.neutreno_lambda)))
|
913 |
|
914 |
+
self.attn_logit_softcapping = self.config.attn_logit_softcapping
|
915 |
+
|
916 |
+
|
917 |
def _get_unit_norm(self, x,eps=1e-6):
|
918 |
"""
|
919 |
Normalize a tensor to unit norm
|
|
|
1087 |
)
|
1088 |
attn_weights = attn_weights + attention_mask
|
1089 |
|
1090 |
+
## tanh softcapping
|
1091 |
+
|
1092 |
+
attn_weights = self.attn_logit_softcapping * torch.tanh(attn_weights/self.attn_logit_softcapping)
|
1093 |
+
|
1094 |
# upcast attention to fp32
|
1095 |
attn_weights = nn.functional.softmax(
|
1096 |
attn_weights, dim=-1, dtype=torch.float32
|
|
|
1290 |
q_len,
|
1291 |
dropout=dropout_rate,
|
1292 |
softmax_scale=self.softmax_scale,
|
1293 |
+
softcap=self.attn_logit_softcapping,
|
1294 |
)
|
1295 |
if self.q_head_dim != self.v_head_dim:
|
1296 |
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
|
|
1877 |
self.s_z = nn.Parameter(torch.ones(self.vocab_size) * (1/config.hidden_size ** 0.5))
|
1878 |
self.s_z_init = 1
|
1879 |
self.s_z_scale = 1/config.hidden_size ** 0.5
|
1880 |
+
|
1881 |
+
# tanh softcapping
|
1882 |
+
|
1883 |
+
self.tanh_softcapping = config.final_logit_softcapping
|
1884 |
# Initialize weights and apply final processing
|
1885 |
self.post_init()
|
1886 |
|
|
|
2033 |
logits = self.lm_head(hidden_states)
|
2034 |
logits = logits.float()
|
2035 |
|
2036 |
+
## tanh softcapping
|
2037 |
+
|
2038 |
+
logits = self.tanh_softcapping * torch.tanh(logits/self.tanh_softcapping)
|
2039 |
+
|
2040 |
+
|
2041 |
## nGPT
|
2042 |
|
2043 |
s_z = self.s_z * (self.s_z_init/self.s_z_scale)
|