minho commited on
Commit
b54e897
·
verified ·
1 Parent(s): a924fd3

Upload tiny-random orion model

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": false,
3
+ "_name_or_path": "OrionStarAI/Orion-14B-Base",
4
+ "architectures": [
5
+ "OrionForCausalLM"
6
+ ],
7
+ "attention_bias": false,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_orion.OrionConfig",
10
+ "AutoModelForCausalLM": "modeling_orion.OrionForCausalLM"
11
+ },
12
+ "bos_token_id": 1,
13
+ "eos_token_id": 2,
14
+ "hidden_act": "silu",
15
+ "hidden_size": 512,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 1536,
18
+ "max_position_embeddings": 4096,
19
+ "max_sequence_length": 4096,
20
+ "model_type": "orion",
21
+ "num_attention_heads": 4,
22
+ "num_hidden_layers": 2,
23
+ "num_key_value_heads": 4,
24
+ "pad_token_id": 0,
25
+ "pretraining_tp": 1,
26
+ "rms_norm_eps": 1e-05,
27
+ "rope_scaling": null,
28
+ "rope_theta": 10000.0,
29
+ "tie_word_embeddings": false,
30
+ "torch_dtype": "float32",
31
+ "transformers_version": "4.44.0",
32
+ "use_cache": true,
33
+ "vocab_size": 84608
34
+ }
configuration_orion.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, OrionStar Inc. All rights reserved.
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+ class OrionConfig(PretrainedConfig):
6
+ model_type = "orion"
7
+ keys_to_ignore_at_inference = ["past_key_values"]
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=84608,
12
+ hidden_size=4096,
13
+ intermediate_size=15360,
14
+ num_hidden_layers=40,
15
+ num_attention_heads=40,
16
+ num_key_value_heads=40,
17
+ hidden_act="silu",
18
+ max_position_embeddings=4096,
19
+ initializer_range=0.02,
20
+ rms_norm_eps=1e-5,
21
+ use_cache=True,
22
+ pad_token_id=None,
23
+ bos_token_id=1,
24
+ eos_token_id=2,
25
+ pretraining_tp=1,
26
+ tie_word_embeddings=False,
27
+ rope_theta=10000.0,
28
+ rope_scaling=None,
29
+ attention_bias=False,
30
+ **kwargs,
31
+ ):
32
+ self.vocab_size = vocab_size
33
+ self.max_position_embeddings = max_position_embeddings
34
+ self.hidden_size = hidden_size
35
+ self.intermediate_size = intermediate_size
36
+ self.num_hidden_layers = num_hidden_layers
37
+ self.num_attention_heads = num_attention_heads
38
+
39
+ # for backward compatibility
40
+ if num_key_value_heads is None:
41
+ num_key_value_heads = num_attention_heads
42
+
43
+ self.num_key_value_heads = num_key_value_heads
44
+ self.hidden_act = hidden_act
45
+ self.initializer_range = initializer_range
46
+ self.rms_norm_eps = rms_norm_eps
47
+ self.pretraining_tp = pretraining_tp
48
+ self.use_cache = use_cache
49
+ self.rope_theta = rope_theta
50
+ self.rope_scaling = rope_scaling
51
+ self._rope_scaling_validation()
52
+ self.attention_bias = attention_bias
53
+
54
+ super().__init__(
55
+ pad_token_id=pad_token_id,
56
+ bos_token_id=bos_token_id,
57
+ eos_token_id=eos_token_id,
58
+ tie_word_embeddings=tie_word_embeddings,
59
+ **kwargs,
60
+ )
61
+
62
+ def _rope_scaling_validation(self):
63
+ """
64
+ Validate the `rope_scaling` configuration.
65
+ """
66
+ if self.rope_scaling is None:
67
+ return
68
+
69
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
70
+ raise ValueError(
71
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
72
+ f"got {self.rope_scaling}"
73
+ )
74
+ rope_scaling_type = self.rope_scaling.get("type", None)
75
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
76
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
77
+ raise ValueError(
78
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
79
+ )
80
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
81
+ raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
82
+
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.44.0"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ffaa75e0042bbdb8e4a2a602f3e369b7401bed7ea42b1b452126dd4c404b2fd
3
+ size 373840656
modeling_orion.py ADDED
@@ -0,0 +1,1096 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 OrionStar Inc. team. All rights reserved.
2
+ # Copied and adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
3
+
4
+ from transformers import AutoConfig, AutoModel
5
+
6
+ from .configuration_orion import OrionConfig
7
+
8
+ import numbers
9
+ import importlib
10
+ import math
11
+ from typing import List, Optional, Tuple, Union
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch.nn.parameter import Parameter
16
+ import torch.utils.checkpoint
17
+ from torch import nn
18
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
19
+ from torch.nn import init
20
+
21
+ from transformers.activations import ACT2FN
22
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
23
+ from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
25
+ from transformers.utils import (
26
+ add_start_docstrings,
27
+ add_start_docstrings_to_model_forward,
28
+ is_flash_attn_2_available,
29
+ replace_return_docstrings,
30
+ )
31
+ import logging
32
+ if is_flash_attn_2_available():
33
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
34
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ _CONFIG_FOR_DOC = "OrionConfig"
39
+
40
+ def _get_unpad_data(padding_mask):
41
+ seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
42
+ indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
43
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
44
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
45
+ return (
46
+ indices,
47
+ cu_seqlens,
48
+ max_seqlen_in_batch,
49
+ )
50
+
51
+
52
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
53
+ def _make_causal_mask(
54
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
55
+ ):
56
+ """
57
+ Make causal mask used for bi-directional self-attention.
58
+ """
59
+ bsz, tgt_len = input_ids_shape
60
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
61
+ mask_cond = torch.arange(mask.size(-1), device=device)
62
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
63
+ mask = mask.to(dtype)
64
+
65
+ if past_key_values_length > 0:
66
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
67
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
68
+
69
+
70
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
71
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
72
+ """
73
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
74
+ """
75
+ bsz, src_len = mask.size()
76
+ tgt_len = tgt_len if tgt_len is not None else src_len
77
+
78
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
79
+
80
+ inverted_mask = 1.0 - expanded_mask
81
+
82
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
83
+
84
+ class OrionRotaryEmbedding(nn.Module):
85
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
86
+ super().__init__()
87
+
88
+ self.dim = dim
89
+ self.max_position_embeddings = max_position_embeddings
90
+ self.base = base
91
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
92
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
93
+
94
+ # Build here to make `torch.jit.trace` work.
95
+ self._set_cos_sin_cache(
96
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
97
+ )
98
+
99
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
100
+ self.max_seq_len_cached = seq_len
101
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
102
+
103
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
104
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
105
+ emb = torch.cat((freqs, freqs), dim=-1)
106
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
107
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
108
+
109
+ def forward(self, x, seq_len=None):
110
+ # x: [bs, num_attention_heads, seq_len, head_size]
111
+ if seq_len > self.max_seq_len_cached:
112
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
113
+
114
+ return (
115
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
116
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
117
+ )
118
+
119
+
120
+ class OrionLinearScalingRotaryEmbedding(OrionRotaryEmbedding):
121
+ """OrionRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
122
+
123
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
124
+ self.scaling_factor = scaling_factor
125
+ super().__init__(dim, max_position_embeddings, base, device)
126
+
127
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
128
+ self.max_seq_len_cached = seq_len
129
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
130
+ t = t / self.scaling_factor
131
+
132
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
133
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
134
+ emb = torch.cat((freqs, freqs), dim=-1)
135
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
136
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
137
+
138
+
139
+ class OrionDynamicNTKScalingRotaryEmbedding(OrionRotaryEmbedding):
140
+ """OrionRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
141
+
142
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
143
+ self.scaling_factor = scaling_factor
144
+ super().__init__(dim, max_position_embeddings, base, device)
145
+
146
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
147
+ self.max_seq_len_cached = seq_len
148
+
149
+ if seq_len > self.max_position_embeddings:
150
+ base = self.base * (
151
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
152
+ ) ** (self.dim / (self.dim - 2))
153
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
154
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
155
+
156
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
157
+
158
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
159
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
160
+ emb = torch.cat((freqs, freqs), dim=-1)
161
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
162
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
163
+
164
+
165
+ def rotate_half(x):
166
+ """Rotates half the hidden dims of the input."""
167
+ x1 = x[..., : x.shape[-1] // 2]
168
+ x2 = x[..., x.shape[-1] // 2 :]
169
+ return torch.cat((-x2, x1), dim=-1)
170
+
171
+
172
+ # Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb
173
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
174
+ cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
175
+ sin = sin[position_ids].unsqueeze(1)
176
+ q_embed = (q * cos) + (rotate_half(q) * sin)
177
+ k_embed = (k * cos) + (rotate_half(k) * sin)
178
+ return q_embed, k_embed
179
+
180
+
181
+ class OrionMLP(nn.Module):
182
+ def __init__(self, config):
183
+ super().__init__()
184
+ self.config = config
185
+ self.hidden_size = config.hidden_size
186
+ self.intermediate_size = config.intermediate_size
187
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
188
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
189
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
190
+ self.act_fn = ACT2FN[config.hidden_act]
191
+
192
+ def forward(self, x):
193
+ if self.config.pretraining_tp > 1:
194
+ slice = self.intermediate_size // self.config.pretraining_tp
195
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
196
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
197
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
198
+
199
+ gate_proj = torch.cat(
200
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
201
+ )
202
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
203
+
204
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
205
+ down_proj = [
206
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
207
+ ]
208
+ down_proj = sum(down_proj)
209
+ else:
210
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
211
+
212
+ return down_proj
213
+
214
+
215
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
216
+ """
217
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
218
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
219
+ """
220
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
221
+ if n_rep == 1:
222
+ return hidden_states
223
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
224
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
225
+
226
+
227
+ class OrionAttention(nn.Module):
228
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
229
+
230
+ def __init__(self, config: OrionConfig):
231
+ super().__init__()
232
+ self.config = config
233
+ self.hidden_size = config.hidden_size
234
+ self.num_heads = config.num_attention_heads
235
+ self.head_dim = self.hidden_size // self.num_heads
236
+ self.num_key_value_heads = config.num_key_value_heads
237
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
238
+ self.max_position_embeddings = config.max_position_embeddings
239
+ self.rope_theta = config.rope_theta
240
+
241
+ if (self.head_dim * self.num_heads) != self.hidden_size:
242
+ raise ValueError(
243
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
244
+ f" and `num_heads`: {self.num_heads})."
245
+ )
246
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
247
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
248
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
249
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
250
+ self._init_rope()
251
+
252
+ def _init_rope(self):
253
+ if self.config.rope_scaling is None:
254
+ self.rotary_emb = OrionRotaryEmbedding(
255
+ self.head_dim,
256
+ max_position_embeddings=self.max_position_embeddings,
257
+ base=self.rope_theta,
258
+ )
259
+ else:
260
+ scaling_type = self.config.rope_scaling["type"]
261
+ scaling_factor = self.config.rope_scaling["factor"]
262
+ if scaling_type == "linear":
263
+ self.rotary_emb = OrionLinearScalingRotaryEmbedding(
264
+ self.head_dim,
265
+ max_position_embeddings=self.max_position_embeddings,
266
+ scaling_factor=scaling_factor,
267
+ base=self.rope_theta,
268
+ )
269
+ elif scaling_type == "dynamic":
270
+ self.rotary_emb = OrionDynamicNTKScalingRotaryEmbedding(
271
+ self.head_dim,
272
+ max_position_embeddings=self.max_position_embeddings,
273
+ scaling_factor=scaling_factor,
274
+ base=self.rope_theta,
275
+ )
276
+ else:
277
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
278
+
279
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
280
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
281
+
282
+ def forward(
283
+ self,
284
+ hidden_states: torch.Tensor,
285
+ attention_mask: Optional[torch.Tensor] = None,
286
+ position_ids: Optional[torch.LongTensor] = None,
287
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
288
+ output_attentions: bool = False,
289
+ use_cache: bool = False,
290
+ padding_mask: Optional[torch.LongTensor] = None,
291
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
292
+ bsz, q_len, _ = hidden_states.size()
293
+
294
+ if self.config.pretraining_tp > 1:
295
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
296
+ query_slices = self.q_proj.weight.split(
297
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
298
+ )
299
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
300
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
301
+
302
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
303
+ query_states = torch.cat(query_states, dim=-1)
304
+
305
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
306
+ key_states = torch.cat(key_states, dim=-1)
307
+
308
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
309
+ value_states = torch.cat(value_states, dim=-1)
310
+
311
+ else:
312
+ query_states = self.q_proj(hidden_states)
313
+ key_states = self.k_proj(hidden_states)
314
+ value_states = self.v_proj(hidden_states)
315
+
316
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
317
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
318
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
319
+
320
+ kv_seq_len = key_states.shape[-2]
321
+ if past_key_value is not None:
322
+ kv_seq_len += past_key_value[0].shape[-2]
323
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
324
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
325
+
326
+ if past_key_value is not None:
327
+ # reuse k, v, self_attention
328
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
329
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
330
+
331
+ past_key_value = (key_states, value_states) if use_cache else None
332
+
333
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
334
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
335
+
336
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
337
+
338
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
339
+ raise ValueError(
340
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
341
+ f" {attn_weights.size()}"
342
+ )
343
+
344
+ if attention_mask is not None:
345
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
346
+ raise ValueError(
347
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
348
+ )
349
+ attn_weights = attn_weights + attention_mask
350
+
351
+ # upcast attention to fp32
352
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
353
+ attn_output = torch.matmul(attn_weights, value_states)
354
+
355
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
356
+ raise ValueError(
357
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
358
+ f" {attn_output.size()}"
359
+ )
360
+
361
+ attn_output = attn_output.transpose(1, 2).contiguous()
362
+
363
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
364
+
365
+ if self.config.pretraining_tp > 1:
366
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
367
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
368
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
369
+ else:
370
+ attn_output = self.o_proj(attn_output)
371
+
372
+ if not output_attentions:
373
+ attn_weights = None
374
+
375
+ return attn_output, attn_weights, past_key_value
376
+
377
+
378
+ class OrionFlashAttention2(OrionAttention):
379
+ """
380
+ Orion flash attention module. This module inherits from `OrionAttention` as the weights of the module stays
381
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
382
+ flash attention and deal with padding tokens in case the input contains any of them.
383
+ """
384
+
385
+ def forward(
386
+ self,
387
+ hidden_states: torch.Tensor,
388
+ attention_mask: Optional[torch.Tensor] = None,
389
+ position_ids: Optional[torch.LongTensor] = None,
390
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
391
+ output_attentions: bool = False,
392
+ use_cache: bool = False,
393
+ padding_mask: Optional[torch.LongTensor] = None,
394
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
395
+ # OrionFlashAttention2 attention does not support output_attentions
396
+ output_attentions = False
397
+
398
+ bsz, q_len, _ = hidden_states.size()
399
+
400
+ query_states = self.q_proj(hidden_states)
401
+ key_states = self.k_proj(hidden_states)
402
+ value_states = self.v_proj(hidden_states)
403
+
404
+ # Flash attention requires the input to have the shape
405
+ # batch_size x seq_length x head_dime x hidden_dim
406
+ # therefore we just need to keep the original shape
407
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
408
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
409
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
410
+
411
+ kv_seq_len = key_states.shape[-2]
412
+ if past_key_value is not None:
413
+ kv_seq_len += past_key_value[0].shape[-2]
414
+
415
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
416
+
417
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
418
+
419
+ if past_key_value is not None:
420
+ # reuse k, v, self_attention
421
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
422
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
423
+
424
+ past_key_value = (key_states, value_states) if use_cache else None
425
+
426
+ query_states = query_states.transpose(1, 2)
427
+ key_states = key_states.transpose(1, 2)
428
+ value_states = value_states.transpose(1, 2)
429
+
430
+ # TODO: llama does not have dropout in the config??
431
+ # It is recommended to use dropout with FA according to the docs
432
+ # when training.
433
+ dropout_rate = 0.0 # if not self.training else self.attn_dropout
434
+
435
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
436
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
437
+ # cast them back in float16 just to be sure everything works as expected.
438
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
439
+ # in fp32. (LlamaRMSNorm handles it correctly)
440
+ input_dtype = query_states.dtype
441
+ if input_dtype == torch.float32:
442
+ logger.warning_once(
443
+ "The input hidden states seems to be silently casted in float32, this might be related to"
444
+ " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
445
+ " float16."
446
+ )
447
+
448
+ query_states = query_states.to(torch.float16)
449
+ key_states = key_states.to(torch.float16)
450
+ value_states = value_states.to(torch.float16)
451
+
452
+ attn_output = self._flash_attention_forward(
453
+ query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
454
+ )
455
+
456
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
457
+ attn_output = self.o_proj(attn_output)
458
+
459
+ if not output_attentions:
460
+ attn_weights = None
461
+
462
+ return attn_output, attn_weights, past_key_value
463
+
464
+ def _flash_attention_forward(
465
+ self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None
466
+ ):
467
+ """
468
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
469
+ first unpad the input, then computes the attention scores and pad the final attention scores.
470
+
471
+ Args:
472
+ query_states (`torch.Tensor`):
473
+ Input query states to be passed to Flash Attention API
474
+ key_states (`torch.Tensor`):
475
+ Input key states to be passed to Flash Attention API
476
+ value_states (`torch.Tensor`):
477
+ Input value states to be passed to Flash Attention API
478
+ padding_mask (`torch.Tensor`):
479
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
480
+ position of padding tokens and 1 for the position of non-padding tokens.
481
+ dropout (`int`, *optional*):
482
+ Attention dropout
483
+ softmax_scale (`float`, *optional*):
484
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
485
+ """
486
+ # Contains at least one padding token in the sequence
487
+ if padding_mask is not None:
488
+ batch_size = query_states.shape[0]
489
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
490
+ query_states, key_states, value_states, padding_mask, query_length
491
+ )
492
+
493
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
494
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
495
+
496
+ attn_output_unpad = flash_attn_varlen_func(
497
+ query_states,
498
+ key_states,
499
+ value_states,
500
+ cu_seqlens_q=cu_seqlens_q,
501
+ cu_seqlens_k=cu_seqlens_k,
502
+ max_seqlen_q=max_seqlen_in_batch_q,
503
+ max_seqlen_k=max_seqlen_in_batch_k,
504
+ dropout_p=dropout,
505
+ softmax_scale=softmax_scale,
506
+ causal=True,
507
+ )
508
+
509
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
510
+ else:
511
+ attn_output = flash_attn_func(
512
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
513
+ )
514
+
515
+ return attn_output
516
+
517
+ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
518
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
519
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
520
+
521
+ key_layer = index_first_axis(
522
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
523
+ )
524
+ value_layer = index_first_axis(
525
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
526
+ )
527
+ if query_length == kv_seq_len:
528
+ query_layer = index_first_axis(
529
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
530
+ )
531
+ cu_seqlens_q = cu_seqlens_k
532
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
533
+ indices_q = indices_k
534
+ elif query_length == 1:
535
+ max_seqlen_in_batch_q = 1
536
+ cu_seqlens_q = torch.arange(
537
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
538
+ ) # There is a memcpy here, that is very bad.
539
+ indices_q = cu_seqlens_q[:-1]
540
+ query_layer = query_layer.squeeze(1)
541
+ else:
542
+ # The -q_len: slice assumes left padding.
543
+ padding_mask = padding_mask[:, -query_length:]
544
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)
545
+
546
+ return (
547
+ query_layer,
548
+ key_layer,
549
+ value_layer,
550
+ indices_q,
551
+ (cu_seqlens_q, cu_seqlens_k),
552
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
553
+ )
554
+
555
+
556
+ class OrionDecoderLayer(nn.Module):
557
+ def __init__(self, config: OrionConfig):
558
+ super().__init__()
559
+ self.hidden_size = config.hidden_size
560
+ self.self_attn = (
561
+ OrionAttention(config=config)
562
+ if not getattr(config, "_flash_attn_2_enabled", False)
563
+ else OrionFlashAttention2(config=config)
564
+ )
565
+ self.mlp = OrionMLP(config)
566
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
567
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
568
+
569
+ def forward(
570
+ self,
571
+ hidden_states: torch.Tensor,
572
+ attention_mask: Optional[torch.Tensor] = None,
573
+ position_ids: Optional[torch.LongTensor] = None,
574
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
575
+ output_attentions: Optional[bool] = False,
576
+ use_cache: Optional[bool] = False,
577
+ padding_mask: Optional[torch.LongTensor] = None,
578
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
579
+ """
580
+ Args:
581
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
582
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
583
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
584
+ output_attentions (`bool`, *optional*):
585
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
586
+ returned tensors for more detail.
587
+ use_cache (`bool`, *optional*):
588
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
589
+ (see `past_key_values`).
590
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
591
+ """
592
+
593
+ residual = hidden_states
594
+
595
+ hidden_states = self.input_layernorm(hidden_states)
596
+
597
+ # Self Attention
598
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
599
+ hidden_states=hidden_states,
600
+ attention_mask=attention_mask,
601
+ position_ids=position_ids,
602
+ past_key_value=past_key_value,
603
+ output_attentions=output_attentions,
604
+ use_cache=use_cache,
605
+ padding_mask=padding_mask,
606
+ )
607
+ hidden_states = residual + hidden_states
608
+
609
+ # Fully Connected
610
+ residual = hidden_states
611
+ hidden_states = self.post_attention_layernorm(hidden_states)
612
+ hidden_states = self.mlp(hidden_states)
613
+ hidden_states = residual + hidden_states
614
+
615
+ outputs = (hidden_states,)
616
+
617
+ if output_attentions:
618
+ outputs += (self_attn_weights,)
619
+
620
+ if use_cache:
621
+ outputs += (present_key_value,)
622
+
623
+ return outputs
624
+
625
+ class OrionPreTrainedModel(PreTrainedModel):
626
+ config_class = OrionConfig
627
+ base_model_prefix = "model"
628
+ supports_gradient_checkpointing = True
629
+ _no_split_modules = ["OrionDecoderLayer"]
630
+ _skip_keys_device_placement = "past_key_values"
631
+ _supports_flash_attn_2 = True
632
+
633
+ def _init_weights(self, module):
634
+ std = self.config.initializer_range
635
+ if isinstance(module, nn.Linear):
636
+ module.weight.data.normal_(mean=0.0, std=std)
637
+ if module.bias is not None:
638
+ module.bias.data.zero_()
639
+ elif isinstance(module, nn.Embedding):
640
+ module.weight.data.normal_(mean=0.0, std=std)
641
+ if module.padding_idx is not None:
642
+ module.weight.data[module.padding_idx].zero_()
643
+
644
+ def _set_gradient_checkpointing(self, module, value=False):
645
+ if isinstance(module, OrionModel):
646
+ module.gradient_checkpointing = value
647
+
648
+ class OrionModel(OrionPreTrainedModel):
649
+ """
650
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OrionDecoderLayer`]
651
+
652
+ Args:
653
+ config: OrionConfig
654
+ """
655
+
656
+ def __init__(self, config: OrionConfig):
657
+ super().__init__(config)
658
+ self.padding_idx = config.pad_token_id
659
+ self.vocab_size = config.vocab_size
660
+
661
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
662
+ self.layers = nn.ModuleList([OrionDecoderLayer(config) for _ in range(config.num_hidden_layers)])
663
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
664
+
665
+ self.gradient_checkpointing = False
666
+ # Initialize weights and apply final processing
667
+ self.post_init()
668
+
669
+ def get_input_embeddings(self):
670
+ return self.embed_tokens
671
+
672
+ def set_input_embeddings(self, value):
673
+ self.embed_tokens = value
674
+
675
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
676
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
677
+ # create causal mask
678
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
679
+ combined_attention_mask = None
680
+ if input_shape[-1] > 1:
681
+ combined_attention_mask = _make_causal_mask(
682
+ input_shape,
683
+ inputs_embeds.dtype,
684
+ device=inputs_embeds.device,
685
+ past_key_values_length=past_key_values_length,
686
+ )
687
+
688
+ if attention_mask is not None:
689
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
690
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
691
+ inputs_embeds.device
692
+ )
693
+ combined_attention_mask = (
694
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
695
+ )
696
+
697
+ return combined_attention_mask
698
+
699
+ def forward(
700
+ self,
701
+ input_ids: torch.LongTensor = None,
702
+ attention_mask: Optional[torch.Tensor] = None,
703
+ position_ids: Optional[torch.LongTensor] = None,
704
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
705
+ inputs_embeds: Optional[torch.FloatTensor] = None,
706
+ use_cache: Optional[bool] = None,
707
+ output_attentions: Optional[bool] = None,
708
+ output_hidden_states: Optional[bool] = None,
709
+ return_dict: Optional[bool] = None,
710
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
711
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
712
+ output_hidden_states = (
713
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
714
+ )
715
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
716
+
717
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
718
+
719
+ # retrieve input_ids and inputs_embeds
720
+ if input_ids is not None and inputs_embeds is not None:
721
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
722
+ elif input_ids is not None:
723
+ batch_size, seq_length = input_ids.shape
724
+ elif inputs_embeds is not None:
725
+ batch_size, seq_length, _ = inputs_embeds.shape
726
+ else:
727
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
728
+
729
+ seq_length_with_past = seq_length
730
+ past_key_values_length = 0
731
+
732
+ if past_key_values is not None:
733
+ past_key_values_length = past_key_values[0][0].shape[2]
734
+ seq_length_with_past = seq_length_with_past + past_key_values_length
735
+
736
+ if position_ids is None:
737
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
738
+ position_ids = torch.arange(
739
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
740
+ )
741
+ position_ids = position_ids.unsqueeze(0)
742
+
743
+ if inputs_embeds is None:
744
+ inputs_embeds = self.embed_tokens(input_ids)
745
+ # embed positions
746
+ if attention_mask is None:
747
+ attention_mask = torch.ones(
748
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
749
+ )
750
+ padding_mask = None
751
+ else:
752
+ if 0 in attention_mask:
753
+ padding_mask = attention_mask
754
+ else:
755
+ padding_mask = None
756
+
757
+ attention_mask = self._prepare_decoder_attention_mask(
758
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
759
+ )
760
+
761
+ hidden_states = inputs_embeds
762
+
763
+ if self.gradient_checkpointing and self.training:
764
+ if use_cache:
765
+ logger.warning_once(
766
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
767
+ )
768
+ use_cache = False
769
+
770
+ # decoder layers
771
+ all_hidden_states = () if output_hidden_states else None
772
+ all_self_attns = () if output_attentions else None
773
+ next_decoder_cache = () if use_cache else None
774
+
775
+ for idx, decoder_layer in enumerate(self.layers):
776
+ if output_hidden_states:
777
+ all_hidden_states += (hidden_states,)
778
+
779
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
780
+
781
+ if self.gradient_checkpointing and self.training:
782
+
783
+ def create_custom_forward(module):
784
+ def custom_forward(*inputs):
785
+ # None for past_key_value
786
+ return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)
787
+
788
+ return custom_forward
789
+
790
+ layer_outputs = torch.utils.checkpoint.checkpoint(
791
+ create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
792
+ )
793
+ else:
794
+ layer_outputs = decoder_layer(
795
+ hidden_states,
796
+ attention_mask=attention_mask,
797
+ position_ids=position_ids,
798
+ past_key_value=past_key_value,
799
+ output_attentions=output_attentions,
800
+ use_cache=use_cache,
801
+ padding_mask=padding_mask,
802
+ )
803
+
804
+ hidden_states = layer_outputs[0]
805
+
806
+ if use_cache:
807
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
808
+
809
+ if output_attentions:
810
+ all_self_attns += (layer_outputs[1],)
811
+
812
+ hidden_states = self.norm(hidden_states)
813
+
814
+ # add hidden states from the last decoder layer
815
+ if output_hidden_states:
816
+ all_hidden_states += (hidden_states,)
817
+
818
+ next_cache = next_decoder_cache if use_cache else None
819
+ if not return_dict:
820
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
821
+ return BaseModelOutputWithPast(
822
+ last_hidden_state=hidden_states,
823
+ past_key_values=next_cache,
824
+ hidden_states=all_hidden_states,
825
+ attentions=all_self_attns,
826
+ )
827
+
828
+
829
+ class OrionForCausalLM(OrionPreTrainedModel):
830
+ model_type = "orion"
831
+ _tied_weights_keys = ["lm_head.weight"]
832
+
833
+ def __init__(self, config):
834
+ super().__init__(config)
835
+ self.model = OrionModel(config)
836
+ self.vocab_size = config.vocab_size
837
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
838
+
839
+ # Initialize weights and apply final processing
840
+ self.post_init()
841
+
842
+ def get_input_embeddings(self):
843
+ return self.model.embed_tokens
844
+
845
+ def set_input_embeddings(self, value):
846
+ self.model.embed_tokens = value
847
+
848
+ def get_output_embeddings(self):
849
+ return self.lm_head
850
+
851
+ def set_output_embeddings(self, new_embeddings):
852
+ self.lm_head = new_embeddings
853
+
854
+ def set_decoder(self, decoder):
855
+ self.model = decoder
856
+
857
+ def get_decoder(self):
858
+ return self.model
859
+
860
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
861
+ def forward(
862
+ self,
863
+ input_ids: torch.LongTensor = None,
864
+ attention_mask: Optional[torch.Tensor] = None,
865
+ position_ids: Optional[torch.LongTensor] = None,
866
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
867
+ inputs_embeds: Optional[torch.FloatTensor] = None,
868
+ labels: Optional[torch.LongTensor] = None,
869
+ use_cache: Optional[bool] = None,
870
+ output_attentions: Optional[bool] = None,
871
+ output_hidden_states: Optional[bool] = None,
872
+ return_dict: Optional[bool] = None,
873
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
874
+ r"""
875
+ Args:
876
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
877
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
878
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
879
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
880
+
881
+ Returns:
882
+
883
+ Example:
884
+
885
+ ```python
886
+ >>> from transformers import AutoTokenizer, OrionForCausalLM
887
+
888
+ >>> model = OrionForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
889
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
890
+
891
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
892
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
893
+
894
+ >>> # Generate
895
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
896
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
897
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
898
+ ```"""
899
+
900
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
901
+ output_hidden_states = (
902
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
903
+ )
904
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
905
+
906
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
907
+ outputs = self.model(
908
+ input_ids=input_ids,
909
+ attention_mask=attention_mask,
910
+ position_ids=position_ids,
911
+ past_key_values=past_key_values,
912
+ inputs_embeds=inputs_embeds,
913
+ use_cache=use_cache,
914
+ output_attentions=output_attentions,
915
+ output_hidden_states=output_hidden_states,
916
+ return_dict=return_dict,
917
+ )
918
+
919
+ hidden_states = outputs[0]
920
+ if self.config.pretraining_tp > 1:
921
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
922
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
923
+ logits = torch.cat(logits, dim=-1)
924
+ else:
925
+ logits = self.lm_head(hidden_states)
926
+ logits = logits.float()
927
+
928
+ loss = None
929
+ if labels is not None:
930
+ # Shift so that tokens < n predict n
931
+ shift_logits = logits[..., :-1, :].contiguous()
932
+ shift_labels = labels[..., 1:].contiguous()
933
+ # Flatten the tokens
934
+ loss_fct = CrossEntropyLoss()
935
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
936
+ shift_labels = shift_labels.view(-1)
937
+ # Enable model parallelism
938
+ shift_labels = shift_labels.to(shift_logits.device)
939
+ loss = loss_fct(shift_logits, shift_labels)
940
+
941
+ if not return_dict:
942
+ output = (logits,) + outputs[1:]
943
+ return (loss,) + output if loss is not None else output
944
+
945
+ return CausalLMOutputWithPast(
946
+ loss=loss,
947
+ logits=logits,
948
+ past_key_values=outputs.past_key_values,
949
+ hidden_states=outputs.hidden_states,
950
+ attentions=outputs.attentions,
951
+ )
952
+
953
+ def prepare_inputs_for_generation(
954
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
955
+ ):
956
+ if past_key_values:
957
+ input_ids = input_ids[:, -1:]
958
+
959
+ position_ids = kwargs.get("position_ids", None)
960
+ if attention_mask is not None and position_ids is None:
961
+ # create position_ids on the fly for batch generation
962
+ position_ids = attention_mask.long().cumsum(-1) - 1
963
+ position_ids.masked_fill_(attention_mask == 0, 1)
964
+ if past_key_values:
965
+ position_ids = position_ids[:, -1].unsqueeze(-1)
966
+
967
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
968
+ if inputs_embeds is not None and past_key_values is None:
969
+ model_inputs = {"inputs_embeds": inputs_embeds}
970
+ else:
971
+ model_inputs = {"input_ids": input_ids}
972
+
973
+ model_inputs.update(
974
+ {
975
+ "position_ids": position_ids,
976
+ "past_key_values": past_key_values,
977
+ "use_cache": kwargs.get("use_cache"),
978
+ "attention_mask": attention_mask,
979
+ }
980
+ )
981
+ return model_inputs
982
+
983
+ @staticmethod
984
+ def _reorder_cache(past_key_values, beam_idx):
985
+ reordered_past = ()
986
+ for layer_past in past_key_values:
987
+ reordered_past += (
988
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
989
+ )
990
+ return reordered_past
991
+
992
+ class OrionForSequenceClassification(OrionPreTrainedModel):
993
+ def __init__(self, config):
994
+ super().__init__(config)
995
+ self.num_labels = config.num_labels
996
+ self.model = OrionModel(config)
997
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
998
+
999
+ # Initialize weights and apply final processing
1000
+ self.post_init()
1001
+
1002
+ def get_input_embeddings(self):
1003
+ return self.model.embed_tokens
1004
+
1005
+ def set_input_embeddings(self, value):
1006
+ self.model.embed_tokens = value
1007
+
1008
+ def forward(
1009
+ self,
1010
+ input_ids: torch.LongTensor = None,
1011
+ attention_mask: Optional[torch.Tensor] = None,
1012
+ position_ids: Optional[torch.LongTensor] = None,
1013
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1014
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1015
+ labels: Optional[torch.LongTensor] = None,
1016
+ use_cache: Optional[bool] = None,
1017
+ output_attentions: Optional[bool] = None,
1018
+ output_hidden_states: Optional[bool] = None,
1019
+ return_dict: Optional[bool] = None,
1020
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1021
+ r"""
1022
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1023
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1024
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1025
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1026
+ """
1027
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1028
+
1029
+ transformer_outputs = self.model(
1030
+ input_ids,
1031
+ attention_mask=attention_mask,
1032
+ position_ids=position_ids,
1033
+ past_key_values=past_key_values,
1034
+ inputs_embeds=inputs_embeds,
1035
+ use_cache=use_cache,
1036
+ output_attentions=output_attentions,
1037
+ output_hidden_states=output_hidden_states,
1038
+ return_dict=return_dict,
1039
+ )
1040
+ hidden_states = transformer_outputs[0]
1041
+ logits = self.score(hidden_states)
1042
+
1043
+ if input_ids is not None:
1044
+ batch_size = input_ids.shape[0]
1045
+ else:
1046
+ batch_size = inputs_embeds.shape[0]
1047
+
1048
+ if self.config.pad_token_id is None and batch_size != 1:
1049
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1050
+ if self.config.pad_token_id is None:
1051
+ sequence_lengths = -1
1052
+ else:
1053
+ if input_ids is not None:
1054
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
1055
+ logits.device
1056
+ )
1057
+ else:
1058
+ sequence_lengths = -1
1059
+
1060
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1061
+
1062
+ loss = None
1063
+ if labels is not None:
1064
+ labels = labels.to(logits.device)
1065
+ if self.config.problem_type is None:
1066
+ if self.num_labels == 1:
1067
+ self.config.problem_type = "regression"
1068
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1069
+ self.config.problem_type = "single_label_classification"
1070
+ else:
1071
+ self.config.problem_type = "multi_label_classification"
1072
+
1073
+ if self.config.problem_type == "regression":
1074
+ loss_fct = MSELoss()
1075
+ if self.num_labels == 1:
1076
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1077
+ else:
1078
+ loss = loss_fct(pooled_logits, labels)
1079
+ elif self.config.problem_type == "single_label_classification":
1080
+ loss_fct = CrossEntropyLoss()
1081
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1082
+ elif self.config.problem_type == "multi_label_classification":
1083
+ loss_fct = BCEWithLogitsLoss()
1084
+ loss = loss_fct(pooled_logits, labels)
1085
+ if not return_dict:
1086
+ output = (pooled_logits,) + transformer_outputs[1:]
1087
+ return ((loss,) + output) if loss is not None else output
1088
+
1089
+ return SequenceClassifierOutputWithPast(
1090
+ loss=loss,
1091
+ logits=pooled_logits,
1092
+ past_key_values=transformer_outputs.past_key_values,
1093
+ hidden_states=transformer_outputs.hidden_states,
1094
+ attentions=transformer_outputs.attentions,
1095
+ )
1096
+