Update modeling_duo_predict_gpt2.py
Browse files
modeling_duo_predict_gpt2.py
CHANGED
@@ -129,10 +129,10 @@ def sdpa_attention_forward(
|
|
129 |
query,
|
130 |
key,
|
131 |
value,
|
132 |
-
attn_mask=create_attention_mask_matrix(query.shape[-2]).to(query.device),
|
133 |
dropout_p=dropout,
|
134 |
scale=scaling,
|
135 |
-
is_causal=
|
136 |
)
|
137 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
138 |
|
@@ -582,9 +582,12 @@ class DuoPredictGPT2Model(DuoPredictGPT2PretrainedModel):
|
|
582 |
inputs_embeds = self.wte(input_ids)
|
583 |
position_embeds = self.wpe(position_ids)
|
584 |
###TODO: correctly initialized
|
585 |
-
|
586 |
-
|
587 |
-
|
|
|
|
|
|
|
588 |
|
589 |
# Attention mask.
|
590 |
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
@@ -897,5 +900,6 @@ if __name__ == "__main__":
|
|
897 |
model = DuoPredictGPT2LMHeadModel(cg)
|
898 |
from src.utils.model_utlis import print_trainable_parameters
|
899 |
print_trainable_parameters(model)
|
|
|
900 |
model(torch.randint(0, 10000, (1, 100)))
|
901 |
print()
|
|
|
129 |
query,
|
130 |
key,
|
131 |
value,
|
132 |
+
attn_mask=create_attention_mask_matrix(query.shape[-2]).to(query.device) if module.training else None,
|
133 |
dropout_p=dropout,
|
134 |
scale=scaling,
|
135 |
+
is_causal=False if module.training else True,
|
136 |
)
|
137 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
138 |
|
|
|
582 |
inputs_embeds = self.wte(input_ids)
|
583 |
position_embeds = self.wpe(position_ids)
|
584 |
###TODO: correctly initialized
|
585 |
+
if inputs_embeds.shape[1] != position_embeds.shape[1]:
|
586 |
+
hidden_states = torch.empty((batch_size, input_shape[-1], self.embed_dim), device=device)
|
587 |
+
hidden_states[:, ::2] = inputs_embeds[:, ::2] + position_embeds.to(inputs_embeds.device)
|
588 |
+
hidden_states[:, 1::2] = inputs_embeds[:, 1::2] + position_embeds[:, :self.config.max_position_embeddings-1].to(inputs_embeds.device)
|
589 |
+
else:
|
590 |
+
hidden_states = inputs_embeds + position_embeds
|
591 |
|
592 |
# Attention mask.
|
593 |
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
|
|
900 |
model = DuoPredictGPT2LMHeadModel(cg)
|
901 |
from src.utils.model_utlis import print_trainable_parameters
|
902 |
print_trainable_parameters(model)
|
903 |
+
model.eval()
|
904 |
model(torch.randint(0, 10000, (1, 100)))
|
905 |
print()
|