Model save
Browse files- README.md +39 -39
- model.safetensors +1 -1
- modeling_duo_predict_gpt2.py +23 -15
README.md
CHANGED
@@ -17,9 +17,9 @@ should probably proofread and complete it, then remove this comment. -->
|
|
17 |
|
18 |
This model is a fine-tuned version of [](https://huggingface.co/) on an unknown dataset.
|
19 |
It achieves the following results on the evaluation set:
|
20 |
-
- Loss:
|
21 |
- Accuracy: 0.0073
|
22 |
-
- Perplexity:
|
23 |
- Bleu: 1.0
|
24 |
|
25 |
## Model description
|
@@ -50,43 +50,43 @@ The following hyperparameters were used during training:
|
|
50 |
|
51 |
### Training results
|
52 |
|
53 |
-
| Training Loss | Epoch | Step | Validation Loss | Accuracy | Perplexity | Bleu
|
54 |
-
|
55 |
-
| 7.
|
56 |
-
|
|
57 |
-
|
|
58 |
-
|
|
59 |
-
|
|
60 |
-
|
|
61 |
-
|
|
62 |
-
|
|
63 |
-
|
|
64 |
-
|
|
65 |
-
|
|
66 |
-
|
|
67 |
-
|
|
68 |
-
|
|
69 |
-
|
|
70 |
-
|
|
71 |
-
|
|
72 |
-
|
|
73 |
-
|
|
74 |
-
|
|
75 |
-
|
|
76 |
-
|
|
77 |
-
|
|
78 |
-
|
|
79 |
-
|
|
80 |
-
|
|
81 |
-
|
|
82 |
-
|
|
83 |
-
|
|
84 |
-
|
|
85 |
-
|
|
86 |
-
|
|
87 |
-
|
|
88 |
-
|
|
89 |
-
|
|
90 |
|
91 |
|
92 |
### Framework versions
|
|
|
17 |
|
18 |
This model is a fine-tuned version of [](https://huggingface.co/) on an unknown dataset.
|
19 |
It achieves the following results on the evaluation set:
|
20 |
+
- Loss: 2.2546
|
21 |
- Accuracy: 0.0073
|
22 |
+
- Perplexity: 9.5311
|
23 |
- Bleu: 1.0
|
24 |
|
25 |
## Model description
|
|
|
50 |
|
51 |
### Training results
|
52 |
|
53 |
+
| Training Loss | Epoch | Step | Validation Loss | Accuracy | Perplexity | Bleu |
|
54 |
+
|:-------------:|:------:|:-----:|:---------------:|:--------:|:----------:|:----:|
|
55 |
+
| 7.6654 | 0.1403 | 500 | 3.7315 | 0.0073 | 41.7396 | 1.0 |
|
56 |
+
| 7.0276 | 0.2807 | 1000 | 3.4735 | 0.0073 | 32.2490 | 1.0 |
|
57 |
+
| 6.4629 | 0.4210 | 1500 | 3.1863 | 0.0073 | 24.1987 | 1.0 |
|
58 |
+
| 5.9671 | 0.5613 | 2000 | 2.9542 | 0.0073 | 19.1873 | 1.0 |
|
59 |
+
| 5.6969 | 0.7017 | 2500 | 2.8233 | 0.0073 | 16.8331 | 1.0 |
|
60 |
+
| 5.5077 | 0.8420 | 3000 | 2.7351 | 0.0073 | 15.4112 | 1.0 |
|
61 |
+
| 5.3536 | 0.9823 | 3500 | 2.6607 | 0.0073 | 14.3059 | 1.0 |
|
62 |
+
| 5.2099 | 1.1226 | 4000 | 2.6000 | 0.0073 | 13.4641 | 1.0 |
|
63 |
+
| 5.1158 | 1.2630 | 4500 | 2.5493 | 0.0073 | 12.7980 | 1.0 |
|
64 |
+
| 5.0453 | 1.4033 | 5000 | 2.5125 | 0.0073 | 12.3362 | 1.0 |
|
65 |
+
| 4.955 | 1.5436 | 5500 | 2.4806 | 0.0073 | 11.9489 | 1.0 |
|
66 |
+
| 4.9157 | 1.6840 | 6000 | 2.4537 | 0.0073 | 11.6310 | 1.0 |
|
67 |
+
| 4.8756 | 1.8243 | 6500 | 2.4300 | 0.0073 | 11.3584 | 1.0 |
|
68 |
+
| 4.844 | 1.9646 | 7000 | 2.4100 | 0.0073 | 11.1342 | 1.0 |
|
69 |
+
| 4.7136 | 2.1050 | 7500 | 2.3948 | 0.0073 | 10.9657 | 1.0 |
|
70 |
+
| 4.6911 | 2.2453 | 8000 | 2.3805 | 0.0073 | 10.8105 | 1.0 |
|
71 |
+
| 4.6741 | 2.3856 | 8500 | 2.3668 | 0.0073 | 10.6637 | 1.0 |
|
72 |
+
| 4.6485 | 2.5260 | 9000 | 2.3538 | 0.0073 | 10.5257 | 1.0 |
|
73 |
+
| 4.623 | 2.6663 | 9500 | 2.3416 | 0.0073 | 10.3976 | 1.0 |
|
74 |
+
| 4.6016 | 2.8066 | 10000 | 2.3303 | 0.0073 | 10.2806 | 1.0 |
|
75 |
+
| 4.5823 | 2.9470 | 10500 | 2.3202 | 0.0073 | 10.1776 | 1.0 |
|
76 |
+
| 4.4802 | 3.0873 | 11000 | 2.3143 | 0.0073 | 10.1182 | 1.0 |
|
77 |
+
| 4.4671 | 3.2276 | 11500 | 2.3073 | 0.0073 | 10.0469 | 1.0 |
|
78 |
+
| 4.4557 | 3.3679 | 12000 | 2.3006 | 0.0073 | 9.9800 | 1.0 |
|
79 |
+
| 4.4437 | 3.5083 | 12500 | 2.2928 | 0.0073 | 9.9023 | 1.0 |
|
80 |
+
| 4.4402 | 3.6486 | 13000 | 2.2862 | 0.0073 | 9.8375 | 1.0 |
|
81 |
+
| 4.4482 | 3.7889 | 13500 | 2.2800 | 0.0073 | 9.7763 | 1.0 |
|
82 |
+
| 4.4279 | 3.9293 | 14000 | 2.2752 | 0.0073 | 9.7303 | 1.0 |
|
83 |
+
| 4.3188 | 4.0696 | 14500 | 2.2730 | 0.0073 | 9.7087 | 1.0 |
|
84 |
+
| 4.3193 | 4.2099 | 15000 | 2.2691 | 0.0073 | 9.6704 | 1.0 |
|
85 |
+
| 4.3158 | 4.3503 | 15500 | 2.2652 | 0.0073 | 9.6329 | 1.0 |
|
86 |
+
| 4.3196 | 4.4906 | 16000 | 2.2619 | 0.0073 | 9.6012 | 1.0 |
|
87 |
+
| 4.2946 | 4.6309 | 16500 | 2.2589 | 0.0073 | 9.5722 | 1.0 |
|
88 |
+
| 4.3078 | 4.7713 | 17000 | 2.2564 | 0.0073 | 9.5487 | 1.0 |
|
89 |
+
| 4.2974 | 4.9116 | 17500 | 2.2546 | 0.0073 | 9.5311 | 1.0 |
|
90 |
|
91 |
|
92 |
### Framework versions
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1417229824
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9839d5f4cd171dea6e15ba583be6e1cce05dea4e8444d950b8ccd0ca73da4483
|
3 |
size 1417229824
|
modeling_duo_predict_gpt2.py
CHANGED
@@ -77,7 +77,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
|
77 |
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
78 |
if is_causal:
|
79 |
assert attn_mask is None
|
80 |
-
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
81 |
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
82 |
attn_bias.to(query.dtype)
|
83 |
|
@@ -129,10 +129,10 @@ def sdpa_attention_forward(
|
|
129 |
query,
|
130 |
key,
|
131 |
value,
|
132 |
-
attn_mask=create_attention_mask_matrix(query.shape[-2]).to(query.device),
|
133 |
dropout_p=dropout,
|
134 |
scale=scaling,
|
135 |
-
is_causal=
|
136 |
)
|
137 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
138 |
|
@@ -582,9 +582,12 @@ class DuoPredictGPT2Model(DuoPredictGPT2PretrainedModel):
|
|
582 |
inputs_embeds = self.wte(input_ids)
|
583 |
position_embeds = self.wpe(position_ids)
|
584 |
###TODO: correctly initialized
|
585 |
-
|
586 |
-
|
587 |
-
|
|
|
|
|
|
|
588 |
|
589 |
# Attention mask.
|
590 |
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
@@ -836,17 +839,21 @@ class DuoPredictGPT2LMHeadModel(DuoPredictGPT2PretrainedModel, GenerationMixin):
|
|
836 |
lm_logits = self.lm_head(hidden_states)
|
837 |
|
838 |
loss = None
|
|
|
839 |
if labels is not None:
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
|
|
|
|
|
|
844 |
loss = self.loss_function(
|
845 |
-
lm_logits,
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
|
851 |
if not return_dict:
|
852 |
output = (lm_logits,) + transformer_outputs[1:]
|
@@ -897,5 +904,6 @@ if __name__ == "__main__":
|
|
897 |
model = DuoPredictGPT2LMHeadModel(cg)
|
898 |
from src.utils.model_utlis import print_trainable_parameters
|
899 |
print_trainable_parameters(model)
|
|
|
900 |
model(torch.randint(0, 10000, (1, 100)))
|
901 |
print()
|
|
|
77 |
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
78 |
if is_causal:
|
79 |
assert attn_mask is None
|
80 |
+
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
81 |
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
82 |
attn_bias.to(query.dtype)
|
83 |
|
|
|
129 |
query,
|
130 |
key,
|
131 |
value,
|
132 |
+
attn_mask=create_attention_mask_matrix(query.shape[-2]).to(query.device) if query.shape[1]>module.config.max_position_embeddings else None,
|
133 |
dropout_p=dropout,
|
134 |
scale=scaling,
|
135 |
+
is_causal=False if query.shape[1]>module.config.max_position_embeddings else True,
|
136 |
)
|
137 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
138 |
|
|
|
582 |
inputs_embeds = self.wte(input_ids)
|
583 |
position_embeds = self.wpe(position_ids)
|
584 |
###TODO: correctly initialized
|
585 |
+
if inputs_embeds.shape[1] != position_embeds.shape[1]:
|
586 |
+
hidden_states = torch.empty((batch_size, input_shape[-1], self.embed_dim), device=device)
|
587 |
+
hidden_states[:, ::2] = inputs_embeds[:, ::2] + position_embeds.to(inputs_embeds.device)
|
588 |
+
hidden_states[:, 1::2] = inputs_embeds[:, 1::2] + position_embeds[:, :self.config.max_position_embeddings-1].to(inputs_embeds.device)
|
589 |
+
else:
|
590 |
+
hidden_states = inputs_embeds + position_embeds
|
591 |
|
592 |
# Attention mask.
|
593 |
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
|
|
839 |
lm_logits = self.lm_head(hidden_states)
|
840 |
|
841 |
loss = None
|
842 |
+
bs, seq = lm_logits.shape[:2]
|
843 |
if labels is not None:
|
844 |
+
if seq>labels.shape[1]:
|
845 |
+
# Flatten the tokens
|
846 |
+
total_labels = torch.full((bs, seq-1), -100, dtype=input_ids.dtype, device=input_ids.device)
|
847 |
+
total_labels[:, :-1:2] = labels[:, 1: ]
|
848 |
+
total_labels[:, 1::2] = labels[:, :-1]
|
849 |
+
else:
|
850 |
+
total_labels = labels[:, 1:]
|
851 |
loss = self.loss_function(
|
852 |
+
lm_logits[:, :-1],
|
853 |
+
total_labels,
|
854 |
+
vocab_size=self.config.vocab_size,
|
855 |
+
**kwargs,
|
856 |
+
)
|
857 |
|
858 |
if not return_dict:
|
859 |
output = (lm_logits,) + transformer_outputs[1:]
|
|
|
904 |
model = DuoPredictGPT2LMHeadModel(cg)
|
905 |
from src.utils.model_utlis import print_trainable_parameters
|
906 |
print_trainable_parameters(model)
|
907 |
+
model.eval()
|
908 |
model(torch.randint(0, 10000, (1, 100)))
|
909 |
print()
|