Aarushhh commited on
Commit
99d8169
·
verified ·
1 Parent(s): 11253f5

Update modular_Sewy2.py

Browse files
Files changed (1) hide show
  1. modular_Sewy2.py +22 -1
modular_Sewy2.py CHANGED
@@ -2,7 +2,7 @@ from transformers.configuration_utils import PretrainedConfig
2
  from transformers.utils import logging
3
 
4
  """ PyTorch Sewy model."""
5
- """Used deepseek-V3 as a starting point."""
6
  import math
7
  import warnings
8
  from typing import List, Optional, Tuple, Union
@@ -214,6 +214,8 @@ class SewyV2Config(PretrainedConfig):
214
  unit_norm_eps = 1e-6,
215
  resformer_lambda = 2.0,
216
  neutreno_lambda=0.4,
 
 
217
  **kwargs,
218
  ):
219
  self.vocab_size = vocab_size
@@ -260,6 +262,8 @@ class SewyV2Config(PretrainedConfig):
260
  self.unit_norm_eps = unit_norm_eps
261
  self.resformer_lambda = resformer_lambda
262
  self.neutreno_lambda = neutreno_lambda
 
 
263
  super().__init__(
264
  pad_token_id=pad_token_id,
265
  bos_token_id=bos_token_id,
@@ -907,6 +911,9 @@ class SewyV2Attention(nn.Module):
907
 
908
  self.neutreno_lambda = nn.Parameter(torch.tensor(float(config.neutreno_lambda)))
909
 
 
 
 
910
  def _get_unit_norm(self, x,eps=1e-6):
911
  """
912
  Normalize a tensor to unit norm
@@ -1080,6 +1087,10 @@ class SewyV2Attention(nn.Module):
1080
  )
1081
  attn_weights = attn_weights + attention_mask
1082
 
 
 
 
 
1083
  # upcast attention to fp32
1084
  attn_weights = nn.functional.softmax(
1085
  attn_weights, dim=-1, dtype=torch.float32
@@ -1279,6 +1290,7 @@ class SewyV2FlashAttention2(SewyV2Attention):
1279
  q_len,
1280
  dropout=dropout_rate,
1281
  softmax_scale=self.softmax_scale,
 
1282
  )
1283
  if self.q_head_dim != self.v_head_dim:
1284
  attn_output = attn_output[:, :, :, : self.v_head_dim]
@@ -1865,6 +1877,10 @@ class SewyV2ForCausalLM(SewyV2PreTrainedModel):
1865
  self.s_z = nn.Parameter(torch.ones(self.vocab_size) * (1/config.hidden_size ** 0.5))
1866
  self.s_z_init = 1
1867
  self.s_z_scale = 1/config.hidden_size ** 0.5
 
 
 
 
1868
  # Initialize weights and apply final processing
1869
  self.post_init()
1870
 
@@ -2017,6 +2033,11 @@ class SewyV2ForCausalLM(SewyV2PreTrainedModel):
2017
  logits = self.lm_head(hidden_states)
2018
  logits = logits.float()
2019
 
 
 
 
 
 
2020
  ## nGPT
2021
 
2022
  s_z = self.s_z * (self.s_z_init/self.s_z_scale)
 
2
  from transformers.utils import logging
3
 
4
  """ PyTorch Sewy model."""
5
+ """Used deepseekv3 as starting point"""
6
  import math
7
  import warnings
8
  from typing import List, Optional, Tuple, Union
 
214
  unit_norm_eps = 1e-6,
215
  resformer_lambda = 2.0,
216
  neutreno_lambda=0.4,
217
+ final_logit_softcapping=30.0,
218
+ attn_logit_softcapping=50.0,
219
  **kwargs,
220
  ):
221
  self.vocab_size = vocab_size
 
262
  self.unit_norm_eps = unit_norm_eps
263
  self.resformer_lambda = resformer_lambda
264
  self.neutreno_lambda = neutreno_lambda
265
+ self.final_logit_softcapping = final_logit_softcapping
266
+ self.attn_logit_softcapping = attn_logit_softcapping
267
  super().__init__(
268
  pad_token_id=pad_token_id,
269
  bos_token_id=bos_token_id,
 
911
 
912
  self.neutreno_lambda = nn.Parameter(torch.tensor(float(config.neutreno_lambda)))
913
 
914
+ self.attn_logit_softcapping = self.config.attn_logit_softcapping
915
+
916
+
917
  def _get_unit_norm(self, x,eps=1e-6):
918
  """
919
  Normalize a tensor to unit norm
 
1087
  )
1088
  attn_weights = attn_weights + attention_mask
1089
 
1090
+ ## tanh softcapping
1091
+
1092
+ attn_weights = self.attn_logit_softcapping * torch.tanh(attn_weights/self.attn_logit_softcapping)
1093
+
1094
  # upcast attention to fp32
1095
  attn_weights = nn.functional.softmax(
1096
  attn_weights, dim=-1, dtype=torch.float32
 
1290
  q_len,
1291
  dropout=dropout_rate,
1292
  softmax_scale=self.softmax_scale,
1293
+ softcap=self.attn_logit_softcapping,
1294
  )
1295
  if self.q_head_dim != self.v_head_dim:
1296
  attn_output = attn_output[:, :, :, : self.v_head_dim]
 
1877
  self.s_z = nn.Parameter(torch.ones(self.vocab_size) * (1/config.hidden_size ** 0.5))
1878
  self.s_z_init = 1
1879
  self.s_z_scale = 1/config.hidden_size ** 0.5
1880
+
1881
+ # tanh softcapping
1882
+
1883
+ self.tanh_softcapping = config.final_logit_softcapping
1884
  # Initialize weights and apply final processing
1885
  self.post_init()
1886
 
 
2033
  logits = self.lm_head(hidden_states)
2034
  logits = logits.float()
2035
 
2036
+ ## tanh softcapping
2037
+
2038
+ logits = self.tanh_softcapping * torch.tanh(logits/self.tanh_softcapping)
2039
+
2040
+
2041
  ## nGPT
2042
 
2043
  s_z = self.s_z * (self.s_z_init/self.s_z_scale)