shivanandmn commited on
Commit
53f13f1
·
verified ·
1 Parent(s): 092b6c5

Update modeling_duo_predict_gpt2.py

Browse files
Files changed (1) hide show
  1. modeling_duo_predict_gpt2.py +9 -5
modeling_duo_predict_gpt2.py CHANGED
@@ -129,10 +129,10 @@ def sdpa_attention_forward(
129
  query,
130
  key,
131
  value,
132
- attn_mask=create_attention_mask_matrix(query.shape[-2]).to(query.device),
133
  dropout_p=dropout,
134
  scale=scaling,
135
- is_causal=is_causal,
136
  )
137
  attn_output = attn_output.transpose(1, 2).contiguous()
138
 
@@ -582,9 +582,12 @@ class DuoPredictGPT2Model(DuoPredictGPT2PretrainedModel):
582
  inputs_embeds = self.wte(input_ids)
583
  position_embeds = self.wpe(position_ids)
584
  ###TODO: correctly initialized
585
- hidden_states = torch.empty((batch_size, input_shape[-1], self.embed_dim), device=device)
586
- hidden_states[:, ::2] = inputs_embeds[:, ::2] + position_embeds.to(inputs_embeds.device)
587
- hidden_states[:, 1::2] = inputs_embeds[:, 1::2] + position_embeds[:, :self.config.max_position_embeddings-1].to(inputs_embeds.device)
 
 
 
588
 
589
  # Attention mask.
590
  _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
@@ -897,5 +900,6 @@ if __name__ == "__main__":
897
  model = DuoPredictGPT2LMHeadModel(cg)
898
  from src.utils.model_utlis import print_trainable_parameters
899
  print_trainable_parameters(model)
 
900
  model(torch.randint(0, 10000, (1, 100)))
901
  print()
 
129
  query,
130
  key,
131
  value,
132
+ attn_mask=create_attention_mask_matrix(query.shape[-2]).to(query.device) if module.training else None,
133
  dropout_p=dropout,
134
  scale=scaling,
135
+ is_causal=False if module.training else True,
136
  )
137
  attn_output = attn_output.transpose(1, 2).contiguous()
138
 
 
582
  inputs_embeds = self.wte(input_ids)
583
  position_embeds = self.wpe(position_ids)
584
  ###TODO: correctly initialized
585
+ if inputs_embeds.shape[1] != position_embeds.shape[1]:
586
+ hidden_states = torch.empty((batch_size, input_shape[-1], self.embed_dim), device=device)
587
+ hidden_states[:, ::2] = inputs_embeds[:, ::2] + position_embeds.to(inputs_embeds.device)
588
+ hidden_states[:, 1::2] = inputs_embeds[:, 1::2] + position_embeds[:, :self.config.max_position_embeddings-1].to(inputs_embeds.device)
589
+ else:
590
+ hidden_states = inputs_embeds + position_embeds
591
 
592
  # Attention mask.
593
  _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
 
900
  model = DuoPredictGPT2LMHeadModel(cg)
901
  from src.utils.model_utlis import print_trainable_parameters
902
  print_trainable_parameters(model)
903
+ model.eval()
904
  model(torch.randint(0, 10000, (1, 100)))
905
  print()