Upload model
Browse files- README.md +79 -3
- config.json +34 -0
- configuration_generanno.py +190 -0
- model.safetensors +3 -0
- modeling_generanno.py +1071 -0
- special_tokens_map.json +12 -0
- tokenizer.model +3 -0
- tokenizer_config.json +42 -0
README.md
CHANGED
@@ -1,3 +1,79 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
pipeline_tag: fill-mask
|
4 |
+
tags:
|
5 |
+
- biology
|
6 |
+
- genomics
|
7 |
+
- long-context
|
8 |
+
library_name: transformers
|
9 |
+
---
|
10 |
+
# GENERanno-eukaryote-0.5b-base model
|
11 |
+
|
12 |
+
## Abouts
|
13 |
+
In this repository, we present GENERanno, a genomic foundation model featuring a context length of 8k base pairs and 500M parameters, trained on an expansive dataset comprising 386 billion base pairs of eukaryotic DNA. Our evaluations demonstrate that the GENERator consistently achieves state-of-the-art performance across a wide spectrum of benchmarks, including [Genomic Benchmarks](https://huggingface.co/datasets/katielink/genomic-benchmarks/tree/main), [NT tasks](https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_downstream_tasks_revised), and our newly proposed [Gener tasks](https://huggingface.co/GenerTeam).
|
14 |
+
|
15 |
+
Beyond benchmark performance, the GENERanno model is meticulously designed with its specialization in gene annotation. The model efficiently and accurately identifies gene locations, predicts gene function, and annotates gene structure, highlighting its potential to revolutionize genomic research by significantly enhancing the precision and efficiency of gene annotation processes.
|
16 |
+
|
17 |
+
Please note that the GENERanno is currently in the developmental phase. We are actively refining the model and will release more technical details soon. Stay tuned for updates!
|
18 |
+
|
19 |
+
## How to use
|
20 |
+
### Simple example: embedding
|
21 |
+
|
22 |
+
```python
|
23 |
+
|
24 |
+
import torch
|
25 |
+
from transformers import AutoTokenizer, AutoModel
|
26 |
+
|
27 |
+
# Load the tokenizer and model using the pretrained model name
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained("GenerTeam/GENERanno-eukaryote-0.5b-base")
|
29 |
+
model = AutoModel.from_pretrained("GenerTeam/GENERanno-eukaryote-0.5b-base", trust_remote_code=True)
|
30 |
+
|
31 |
+
# Get model configuration and maximum sequence length
|
32 |
+
config = model.config
|
33 |
+
max_length = config.max_position_embeddings
|
34 |
+
|
35 |
+
# Define input sequences
|
36 |
+
sequences = [
|
37 |
+
"ATGAGGTGGCAAGAAATGGGCTAC",
|
38 |
+
"GAATTCCATGAGGCTATAGAATAATCTAAGAGAAAT"
|
39 |
+
]
|
40 |
+
|
41 |
+
# Tokenize the sequences
|
42 |
+
# The add_special_tokens=True adds special tokens
|
43 |
+
tokenizer.padding_side = "right"
|
44 |
+
inputs = tokenizer(
|
45 |
+
sequences,
|
46 |
+
add_special_tokens=True,
|
47 |
+
return_tensors="pt",
|
48 |
+
padding=True,
|
49 |
+
truncation=True,
|
50 |
+
max_length=max_length
|
51 |
+
)
|
52 |
+
|
53 |
+
# Perform a forward pass through the model to obtain the outputs, including hidden states
|
54 |
+
with torch.inference_mode():
|
55 |
+
outputs = model(**inputs, output_hidden_states=True)
|
56 |
+
|
57 |
+
# Retrieve the hidden states from the last layer
|
58 |
+
# hidden_states shape: (batch_size, sequence_length, hidden_size)
|
59 |
+
hidden_states = outputs.hidden_states[-1]
|
60 |
+
|
61 |
+
# Option 1: Use the first token (BOS) as the sentence embedding
|
62 |
+
cls_embeddings = hidden_states[:, 0, :]
|
63 |
+
|
64 |
+
# Option 2: Use mean pooling over the token embeddings
|
65 |
+
# Use the attention mask to take care of the padded tokens
|
66 |
+
attention_mask = inputs["attention_mask"] # Shape: (batch_size, sequence_length)
|
67 |
+
# Expand the attention mask dimensions so that it matches the hidden_states dimensions
|
68 |
+
expanded_mask = attention_mask.unsqueeze(-1).expand(hidden_states.size()).to(torch.float32)
|
69 |
+
# Sum the token embeddings, taking the mask into account
|
70 |
+
sum_embeddings = torch.sum(hidden_states * expanded_mask, dim=1)
|
71 |
+
# Compute the average by dividing with the sum of the attention mask
|
72 |
+
mean_embeddings = sum_embeddings / expanded_mask.sum(dim=1)
|
73 |
+
|
74 |
+
print("BOS Embeddings:", cls_embeddings)
|
75 |
+
print("Mean Embeddings:", mean_embeddings)
|
76 |
+
```
|
77 |
+
|
78 |
+
## Citation
|
79 |
+
TBD
|
config.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"GenerannoForMaskedLM"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_generanno.GenerannoConfig",
|
7 |
+
"AutoModel": "modeling_generanno.GenerannoModel",
|
8 |
+
"AutoModelForMaskedLM": "modeling_generanno.GenerannoForMaskedLM"
|
9 |
+
},
|
10 |
+
"attention_bias": false,
|
11 |
+
"attention_dropout": 0.0,
|
12 |
+
"bos_token_id": 1,
|
13 |
+
"eos_token_id": 2,
|
14 |
+
"hidden_act": "silu",
|
15 |
+
"hidden_size": 1280,
|
16 |
+
"initializer_range": 0.02,
|
17 |
+
"intermediate_size": 3520,
|
18 |
+
"mask_token_id": 4,
|
19 |
+
"max_position_embeddings": 8192,
|
20 |
+
"mlp_bias": false,
|
21 |
+
"model_type": "generanno",
|
22 |
+
"num_attention_heads": 16,
|
23 |
+
"num_hidden_layers": 28,
|
24 |
+
"num_key_value_heads": 4,
|
25 |
+
"pad_token_id": 3,
|
26 |
+
"pretraining_tp": 1,
|
27 |
+
"rms_norm_eps": 1e-05,
|
28 |
+
"rope_scaling": null,
|
29 |
+
"rope_theta": 500000.0,
|
30 |
+
"tie_word_embeddings": false,
|
31 |
+
"torch_dtype": "float32",
|
32 |
+
"transformers_version": "4.44.0",
|
33 |
+
"vocab_size": 64
|
34 |
+
}
|
configuration_generanno.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""LLaMA model configuration"""
|
21 |
+
|
22 |
+
from transformers.configuration_utils import PretrainedConfig
|
23 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
24 |
+
|
25 |
+
|
26 |
+
class GenerannoConfig(PretrainedConfig):
|
27 |
+
r"""
|
28 |
+
This is the configuration class to store the configuration of a [`GenerannoModel`]. It is used to instantiate an LLaMA
|
29 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
30 |
+
defaults will yield a similar configuration to that of the LLaMA-7B.
|
31 |
+
|
32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
33 |
+
documentation from [`PretrainedConfig`] for more information.
|
34 |
+
|
35 |
+
|
36 |
+
Args:
|
37 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
38 |
+
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
|
39 |
+
`inputs_ids` passed when calling [`GenerannoModel`]
|
40 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
41 |
+
Dimension of the hidden representations.
|
42 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
43 |
+
Dimension of the MLP representations.
|
44 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
45 |
+
Number of hidden layers in the Transformer decoder.
|
46 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
47 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
48 |
+
num_key_value_heads (`int`, *optional*):
|
49 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
50 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
51 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
52 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
53 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
54 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
55 |
+
`num_attention_heads`.
|
56 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
57 |
+
The non-linear activation function (function or string) in the decoder.
|
58 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
59 |
+
The maximum sequence length that this model might ever be used with. Generanno 1 supports up to 2048 tokens,
|
60 |
+
Generanno 2 up to 4096, CodeGeneranno up to 16384.
|
61 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
62 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
63 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
64 |
+
The epsilon used by the rms normalization layers.
|
65 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
66 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
67 |
+
relevant if `config.is_decoder=True`.
|
68 |
+
pad_token_id (`int`, *optional*):
|
69 |
+
Padding token id.
|
70 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
71 |
+
Beginning of stream token id.
|
72 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
73 |
+
End of stream token id.
|
74 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
75 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
76 |
+
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
|
77 |
+
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
|
78 |
+
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
|
79 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
80 |
+
Whether to tie weight embeddings
|
81 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
82 |
+
The base period of the RoPE embeddings.
|
83 |
+
rope_scaling (`Dict`, *optional*):
|
84 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
85 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
86 |
+
accordingly.
|
87 |
+
Expected contents:
|
88 |
+
`rope_type` (`str`):
|
89 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
90 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
91 |
+
`factor` (`float`, *optional*):
|
92 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
93 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
94 |
+
original maximum pre-trained length.
|
95 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
96 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
97 |
+
pretraining.
|
98 |
+
`attention_factor` (`float`, *optional*):
|
99 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
100 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
101 |
+
`factor` field to infer the suggested value.
|
102 |
+
`beta_fast` (`float`, *optional*):
|
103 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
104 |
+
ramp function. If unspecified, it defaults to 32.
|
105 |
+
`beta_slow` (`float`, *optional*):
|
106 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
107 |
+
ramp function. If unspecified, it defaults to 1.
|
108 |
+
`short_factor` (`List[float]`, *optional*):
|
109 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
110 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
111 |
+
size divided by the number of attention heads divided by 2
|
112 |
+
`long_factor` (`List[float]`, *optional*):
|
113 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
114 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
115 |
+
size divided by the number of attention heads divided by 2
|
116 |
+
`low_freq_factor` (`float`, *optional*):
|
117 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
118 |
+
`high_freq_factor` (`float`, *optional*):
|
119 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
120 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
121 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
122 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
123 |
+
The dropout ratio for the attention probabilities.
|
124 |
+
mlp_bias (`bool`, *optional*, defaults to `False`):
|
125 |
+
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
126 |
+
"""
|
127 |
+
|
128 |
+
model_type = "generanno"
|
129 |
+
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
vocab_size=32000,
|
133 |
+
hidden_size=4096,
|
134 |
+
intermediate_size=11008,
|
135 |
+
num_hidden_layers=32,
|
136 |
+
num_attention_heads=32,
|
137 |
+
num_key_value_heads=None,
|
138 |
+
hidden_act="silu",
|
139 |
+
max_position_embeddings=2048,
|
140 |
+
initializer_range=0.02,
|
141 |
+
rms_norm_eps=1e-6,
|
142 |
+
pad_token_id=None,
|
143 |
+
bos_token_id=1,
|
144 |
+
eos_token_id=2,
|
145 |
+
mask_token_id=4,
|
146 |
+
pretraining_tp=1,
|
147 |
+
tie_word_embeddings=False,
|
148 |
+
rope_theta=10000.0,
|
149 |
+
rope_scaling=None,
|
150 |
+
attention_bias=False,
|
151 |
+
attention_dropout=0.0,
|
152 |
+
mlp_bias=False,
|
153 |
+
**kwargs,
|
154 |
+
):
|
155 |
+
self.vocab_size = vocab_size
|
156 |
+
self.max_position_embeddings = max_position_embeddings
|
157 |
+
self.hidden_size = hidden_size
|
158 |
+
self.intermediate_size = intermediate_size
|
159 |
+
self.num_hidden_layers = num_hidden_layers
|
160 |
+
self.num_attention_heads = num_attention_heads
|
161 |
+
|
162 |
+
# for backward compatibility
|
163 |
+
if num_key_value_heads is None:
|
164 |
+
num_key_value_heads = num_attention_heads
|
165 |
+
|
166 |
+
self.num_key_value_heads = num_key_value_heads
|
167 |
+
self.hidden_act = hidden_act
|
168 |
+
self.initializer_range = initializer_range
|
169 |
+
self.rms_norm_eps = rms_norm_eps
|
170 |
+
self.pretraining_tp = pretraining_tp
|
171 |
+
self.rope_theta = rope_theta
|
172 |
+
self.rope_scaling = rope_scaling
|
173 |
+
self.attention_bias = attention_bias
|
174 |
+
self.attention_dropout = attention_dropout
|
175 |
+
self.mlp_bias = mlp_bias
|
176 |
+
|
177 |
+
# Validate the correctness of rotary position embeddings parameters
|
178 |
+
# BC: if there is a 'type' field, move it to 'rope_type'.
|
179 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
180 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
181 |
+
rope_config_validation(self)
|
182 |
+
|
183 |
+
super().__init__(
|
184 |
+
pad_token_id=pad_token_id,
|
185 |
+
bos_token_id=bos_token_id,
|
186 |
+
eos_token_id=eos_token_id,
|
187 |
+
mask_token_id=mask_token_id,
|
188 |
+
tie_word_embeddings=tie_word_embeddings,
|
189 |
+
**kwargs,
|
190 |
+
)
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:da8baaca6ba858f43c0f2911301597f1a99f755f70773a1be4d815c3492e6660
|
3 |
+
size 1973609768
|
modeling_generanno.py
ADDED
@@ -0,0 +1,1071 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
import math
|
21 |
+
from typing import Optional, Tuple, Union, Any
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
import torch.utils.checkpoint
|
26 |
+
from torch import nn, Tensor
|
27 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
28 |
+
from transformers.activations import ACT2FN
|
29 |
+
from transformers.modeling_attn_mask_utils import (
|
30 |
+
_prepare_4d_attention_mask_for_sdpa,
|
31 |
+
_prepare_4d_attention_mask,
|
32 |
+
)
|
33 |
+
from transformers.modeling_outputs import (
|
34 |
+
TokenClassifierOutput,
|
35 |
+
BaseModelOutput,
|
36 |
+
MaskedLMOutput,
|
37 |
+
SequenceClassifierOutput,
|
38 |
+
)
|
39 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
40 |
+
from transformers.modeling_utils import PreTrainedModel
|
41 |
+
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
42 |
+
from transformers.utils import (
|
43 |
+
logging,
|
44 |
+
)
|
45 |
+
|
46 |
+
from .configuration_generanno import GenerannoConfig
|
47 |
+
|
48 |
+
logger = logging.get_logger(__name__)
|
49 |
+
|
50 |
+
_CONFIG_FOR_DOC = "GenerannoConfig"
|
51 |
+
|
52 |
+
|
53 |
+
class GenerannoRMSNorm(nn.Module):
|
54 |
+
def __init__(self, hidden_size, eps=1e-6):
|
55 |
+
"""
|
56 |
+
GenerannoRMSNorm is equivalent to T5LayerNorm
|
57 |
+
"""
|
58 |
+
super().__init__()
|
59 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
60 |
+
self.variance_epsilon = eps
|
61 |
+
|
62 |
+
def forward(self, hidden_states):
|
63 |
+
input_dtype = hidden_states.dtype
|
64 |
+
hidden_states = hidden_states.to(torch.float32)
|
65 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
66 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
67 |
+
return self.weight * hidden_states.to(input_dtype)
|
68 |
+
|
69 |
+
def extra_repr(self):
|
70 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
71 |
+
|
72 |
+
|
73 |
+
ALL_LAYERNORM_LAYERS.append(GenerannoRMSNorm)
|
74 |
+
|
75 |
+
|
76 |
+
class GenerannoRotaryEmbedding(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
dim=None,
|
80 |
+
max_position_embeddings=2048,
|
81 |
+
base=10000,
|
82 |
+
device=None,
|
83 |
+
scaling_factor=1.0,
|
84 |
+
rope_type="default",
|
85 |
+
config: Optional[GenerannoConfig] = None,
|
86 |
+
):
|
87 |
+
super().__init__()
|
88 |
+
# TODO (joao): remove the `if` below, only used for BC
|
89 |
+
self.rope_kwargs = {}
|
90 |
+
if config is None:
|
91 |
+
logger.warning_once(
|
92 |
+
"`GenerannoRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
93 |
+
"`config` argument. All other arguments will be removed in v4.45"
|
94 |
+
)
|
95 |
+
self.rope_kwargs = {
|
96 |
+
"rope_type": rope_type,
|
97 |
+
"factor": scaling_factor,
|
98 |
+
"dim": dim,
|
99 |
+
"base": base,
|
100 |
+
"max_position_embeddings": max_position_embeddings,
|
101 |
+
}
|
102 |
+
self.rope_type = rope_type
|
103 |
+
self.max_seq_len_cached = max_position_embeddings
|
104 |
+
self.original_max_seq_len = max_position_embeddings
|
105 |
+
else:
|
106 |
+
# BC: "rope_type" was originally "type"
|
107 |
+
if config.rope_scaling is not None:
|
108 |
+
self.rope_type = config.rope_scaling.get(
|
109 |
+
"rope_type", config.rope_scaling.get("type")
|
110 |
+
)
|
111 |
+
else:
|
112 |
+
self.rope_type = "default"
|
113 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
114 |
+
self.original_max_seq_len = config.max_position_embeddings
|
115 |
+
|
116 |
+
self.config = config
|
117 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
118 |
+
|
119 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
120 |
+
self.config, device, **self.rope_kwargs
|
121 |
+
)
|
122 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
123 |
+
self.original_inv_freq = self.inv_freq
|
124 |
+
|
125 |
+
def _dynamic_frequency_update(self, position_ids, device):
|
126 |
+
"""
|
127 |
+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
128 |
+
1 - growing beyond the cached sequence length (allow scaling)
|
129 |
+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
130 |
+
"""
|
131 |
+
seq_len = torch.max(position_ids) + 1
|
132 |
+
if seq_len > self.max_seq_len_cached: # growth
|
133 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
134 |
+
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
135 |
+
)
|
136 |
+
self.register_buffer(
|
137 |
+
"inv_freq", inv_freq, persistent=False
|
138 |
+
) # TODO joao: may break with compilation
|
139 |
+
self.max_seq_len_cached = seq_len
|
140 |
+
|
141 |
+
if (
|
142 |
+
seq_len < self.original_max_seq_len
|
143 |
+
and self.max_seq_len_cached > self.original_max_seq_len
|
144 |
+
): # reset
|
145 |
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
146 |
+
self.max_seq_len_cached = self.original_max_seq_len
|
147 |
+
|
148 |
+
@torch.no_grad()
|
149 |
+
def forward(self, x, position_ids):
|
150 |
+
if "dynamic" in self.rope_type:
|
151 |
+
self._dynamic_frequency_update(position_ids, device=x.device)
|
152 |
+
|
153 |
+
# Core RoPE block
|
154 |
+
inv_freq_expanded = (
|
155 |
+
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
156 |
+
)
|
157 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
158 |
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
159 |
+
device_type = x.device.type
|
160 |
+
device_type = (
|
161 |
+
device_type
|
162 |
+
if isinstance(device_type, str) and device_type != "mps"
|
163 |
+
else "cpu"
|
164 |
+
)
|
165 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
166 |
+
freqs = (
|
167 |
+
inv_freq_expanded.float() @ position_ids_expanded.float()
|
168 |
+
).transpose(1, 2)
|
169 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
170 |
+
cos = emb.cos()
|
171 |
+
sin = emb.sin()
|
172 |
+
|
173 |
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
174 |
+
cos = cos * self.attention_scaling
|
175 |
+
sin = sin * self.attention_scaling
|
176 |
+
|
177 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
178 |
+
|
179 |
+
|
180 |
+
class GenerannoLinearScalingRotaryEmbedding(GenerannoRotaryEmbedding):
|
181 |
+
"""GenerannoRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
182 |
+
|
183 |
+
def __init__(self, *args, **kwargs):
|
184 |
+
logger.warning_once(
|
185 |
+
"`GenerannoLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
|
186 |
+
"`GenerannoRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
187 |
+
)
|
188 |
+
kwargs["rope_type"] = "linear"
|
189 |
+
super().__init__(*args, **kwargs)
|
190 |
+
|
191 |
+
|
192 |
+
class GenerannoDynamicNTKScalingRotaryEmbedding(GenerannoRotaryEmbedding):
|
193 |
+
"""GenerannoRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
194 |
+
|
195 |
+
def __init__(self, *args, **kwargs):
|
196 |
+
logger.warning_once(
|
197 |
+
"`GenerannoDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
|
198 |
+
"`GenerannoRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
199 |
+
"__init__)."
|
200 |
+
)
|
201 |
+
kwargs["rope_type"] = "dynamic"
|
202 |
+
super().__init__(*args, **kwargs)
|
203 |
+
|
204 |
+
|
205 |
+
def rotate_half(x):
|
206 |
+
"""Rotates half the hidden dims of the input."""
|
207 |
+
x1 = x[..., : x.shape[-1] // 2]
|
208 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
209 |
+
return torch.cat((-x2, x1), dim=-1)
|
210 |
+
|
211 |
+
|
212 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
213 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
q (`torch.Tensor`): The query tensor.
|
217 |
+
k (`torch.Tensor`): The key tensor.
|
218 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
219 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
220 |
+
position_ids (`torch.Tensor`, *optional*):
|
221 |
+
Deprecated and unused.
|
222 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
223 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
224 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
225 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
226 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
227 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
228 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
229 |
+
Returns:
|
230 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
231 |
+
"""
|
232 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
233 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
234 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
235 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
236 |
+
return q_embed, k_embed
|
237 |
+
|
238 |
+
|
239 |
+
class GenerannoMLP(nn.Module):
|
240 |
+
def __init__(self, config):
|
241 |
+
super().__init__()
|
242 |
+
self.config = config
|
243 |
+
self.hidden_size = config.hidden_size
|
244 |
+
self.intermediate_size = config.intermediate_size
|
245 |
+
self.gate_proj = nn.Linear(
|
246 |
+
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
|
247 |
+
)
|
248 |
+
self.up_proj = nn.Linear(
|
249 |
+
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
|
250 |
+
)
|
251 |
+
self.down_proj = nn.Linear(
|
252 |
+
self.intermediate_size, self.hidden_size, bias=config.mlp_bias
|
253 |
+
)
|
254 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
255 |
+
|
256 |
+
def forward(self, x):
|
257 |
+
if self.config.pretraining_tp > 1:
|
258 |
+
slice = self.intermediate_size // self.config.pretraining_tp
|
259 |
+
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
260 |
+
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
261 |
+
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
262 |
+
|
263 |
+
gate_proj = torch.cat(
|
264 |
+
[
|
265 |
+
F.linear(x, gate_proj_slices[i])
|
266 |
+
for i in range(self.config.pretraining_tp)
|
267 |
+
],
|
268 |
+
dim=-1,
|
269 |
+
)
|
270 |
+
up_proj = torch.cat(
|
271 |
+
[
|
272 |
+
F.linear(x, up_proj_slices[i])
|
273 |
+
for i in range(self.config.pretraining_tp)
|
274 |
+
],
|
275 |
+
dim=-1,
|
276 |
+
)
|
277 |
+
|
278 |
+
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
279 |
+
down_proj = [
|
280 |
+
F.linear(intermediate_states[i], down_proj_slices[i])
|
281 |
+
for i in range(self.config.pretraining_tp)
|
282 |
+
]
|
283 |
+
down_proj = sum(down_proj)
|
284 |
+
else:
|
285 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
286 |
+
|
287 |
+
return down_proj
|
288 |
+
|
289 |
+
|
290 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
291 |
+
"""
|
292 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
293 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
294 |
+
"""
|
295 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
296 |
+
if n_rep == 1:
|
297 |
+
return hidden_states
|
298 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
299 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
300 |
+
)
|
301 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
302 |
+
|
303 |
+
|
304 |
+
class GenerannoAttention(nn.Module):
|
305 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
306 |
+
|
307 |
+
def __init__(self, config: GenerannoConfig, layer_idx: Optional[int] = None):
|
308 |
+
super().__init__()
|
309 |
+
self.config = config
|
310 |
+
self.layer_idx = layer_idx
|
311 |
+
if layer_idx is None:
|
312 |
+
logger.warning_once(
|
313 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
314 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
315 |
+
"when creating this class."
|
316 |
+
)
|
317 |
+
|
318 |
+
self.attention_dropout = config.attention_dropout
|
319 |
+
self.hidden_size = config.hidden_size
|
320 |
+
self.num_heads = config.num_attention_heads
|
321 |
+
self.head_dim = self.hidden_size // self.num_heads
|
322 |
+
self.num_key_value_heads = config.num_key_value_heads
|
323 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
324 |
+
self.max_position_embeddings = config.max_position_embeddings
|
325 |
+
self.rope_theta = config.rope_theta
|
326 |
+
|
327 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
328 |
+
raise ValueError(
|
329 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
330 |
+
f" and `num_heads`: {self.num_heads})."
|
331 |
+
)
|
332 |
+
|
333 |
+
self.q_proj = nn.Linear(
|
334 |
+
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
|
335 |
+
)
|
336 |
+
self.k_proj = nn.Linear(
|
337 |
+
self.hidden_size,
|
338 |
+
self.num_key_value_heads * self.head_dim,
|
339 |
+
bias=config.attention_bias,
|
340 |
+
)
|
341 |
+
self.v_proj = nn.Linear(
|
342 |
+
self.hidden_size,
|
343 |
+
self.num_key_value_heads * self.head_dim,
|
344 |
+
bias=config.attention_bias,
|
345 |
+
)
|
346 |
+
self.o_proj = nn.Linear(
|
347 |
+
self.hidden_size, self.hidden_size, bias=config.attention_bias
|
348 |
+
)
|
349 |
+
|
350 |
+
# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the encoder layers)
|
351 |
+
self.rotary_emb = GenerannoRotaryEmbedding(config=self.config)
|
352 |
+
|
353 |
+
def forward(
|
354 |
+
self,
|
355 |
+
hidden_states: torch.Tensor,
|
356 |
+
attention_mask: Optional[torch.Tensor] = None,
|
357 |
+
position_ids: Optional[torch.LongTensor] = None,
|
358 |
+
output_attentions: bool = False,
|
359 |
+
position_embeddings: Optional[
|
360 |
+
Tuple[torch.Tensor, torch.Tensor]
|
361 |
+
] = None, # will become mandatory in v4.45
|
362 |
+
**kwargs,
|
363 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
364 |
+
bsz, q_len, _ = hidden_states.size()
|
365 |
+
|
366 |
+
if self.config.pretraining_tp > 1:
|
367 |
+
key_value_slicing = (
|
368 |
+
self.num_key_value_heads * self.head_dim
|
369 |
+
) // self.config.pretraining_tp
|
370 |
+
query_slices = self.q_proj.weight.split(
|
371 |
+
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
372 |
+
)
|
373 |
+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
374 |
+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
375 |
+
|
376 |
+
query_states = [
|
377 |
+
F.linear(hidden_states, query_slices[i])
|
378 |
+
for i in range(self.config.pretraining_tp)
|
379 |
+
]
|
380 |
+
query_states = torch.cat(query_states, dim=-1)
|
381 |
+
|
382 |
+
key_states = [
|
383 |
+
F.linear(hidden_states, key_slices[i])
|
384 |
+
for i in range(self.config.pretraining_tp)
|
385 |
+
]
|
386 |
+
key_states = torch.cat(key_states, dim=-1)
|
387 |
+
|
388 |
+
value_states = [
|
389 |
+
F.linear(hidden_states, value_slices[i])
|
390 |
+
for i in range(self.config.pretraining_tp)
|
391 |
+
]
|
392 |
+
value_states = torch.cat(value_states, dim=-1)
|
393 |
+
|
394 |
+
else:
|
395 |
+
query_states = self.q_proj(hidden_states)
|
396 |
+
key_states = self.k_proj(hidden_states)
|
397 |
+
value_states = self.v_proj(hidden_states)
|
398 |
+
|
399 |
+
query_states = query_states.view(
|
400 |
+
bsz, q_len, self.num_heads, self.head_dim
|
401 |
+
).transpose(1, 2)
|
402 |
+
key_states = key_states.view(
|
403 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
404 |
+
).transpose(1, 2)
|
405 |
+
value_states = value_states.view(
|
406 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
407 |
+
).transpose(1, 2)
|
408 |
+
|
409 |
+
if position_embeddings is None:
|
410 |
+
logger.warning_once(
|
411 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
412 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
413 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
414 |
+
"removed and `position_embeddings` will be mandatory."
|
415 |
+
)
|
416 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
417 |
+
else:
|
418 |
+
cos, sin = position_embeddings
|
419 |
+
query_states, key_states = apply_rotary_pos_emb(
|
420 |
+
query_states, key_states, cos, sin
|
421 |
+
)
|
422 |
+
|
423 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
424 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
425 |
+
|
426 |
+
attn_weights = torch.matmul(
|
427 |
+
query_states, key_states.transpose(2, 3)
|
428 |
+
) / math.sqrt(self.head_dim)
|
429 |
+
|
430 |
+
if attention_mask is not None:
|
431 |
+
attn_weights = attn_weights + attention_mask
|
432 |
+
|
433 |
+
# upcast attention to fp32
|
434 |
+
attn_weights = nn.functional.softmax(
|
435 |
+
attn_weights, dim=-1, dtype=torch.float32
|
436 |
+
).to(query_states.dtype)
|
437 |
+
attn_weights = nn.functional.dropout(
|
438 |
+
attn_weights, p=self.attention_dropout, training=self.training
|
439 |
+
)
|
440 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
441 |
+
|
442 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
443 |
+
raise ValueError(
|
444 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
445 |
+
f" {attn_output.size()}"
|
446 |
+
)
|
447 |
+
|
448 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
449 |
+
|
450 |
+
attn_output = attn_output.reshape(bsz, q_len, -1)
|
451 |
+
|
452 |
+
if self.config.pretraining_tp > 1:
|
453 |
+
attn_output = attn_output.split(
|
454 |
+
self.hidden_size // self.config.pretraining_tp, dim=2
|
455 |
+
)
|
456 |
+
o_proj_slices = self.o_proj.weight.split(
|
457 |
+
self.hidden_size // self.config.pretraining_tp, dim=1
|
458 |
+
)
|
459 |
+
attn_output = sum(
|
460 |
+
[
|
461 |
+
F.linear(attn_output[i], o_proj_slices[i])
|
462 |
+
for i in range(self.config.pretraining_tp)
|
463 |
+
]
|
464 |
+
)
|
465 |
+
else:
|
466 |
+
attn_output = self.o_proj(attn_output)
|
467 |
+
|
468 |
+
if not output_attentions:
|
469 |
+
attn_weights = None
|
470 |
+
|
471 |
+
return attn_output, attn_weights
|
472 |
+
|
473 |
+
|
474 |
+
class GenerannoSdpaAttention(GenerannoAttention):
|
475 |
+
"""
|
476 |
+
Generanno attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
477 |
+
`GenerannoAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
478 |
+
SDPA API.
|
479 |
+
"""
|
480 |
+
|
481 |
+
# Adapted from GenerannoAttention.forward
|
482 |
+
def forward(
|
483 |
+
self,
|
484 |
+
hidden_states: torch.Tensor,
|
485 |
+
attention_mask: Optional[torch.Tensor] = None,
|
486 |
+
position_ids: Optional[torch.LongTensor] = None,
|
487 |
+
output_attentions: bool = False,
|
488 |
+
position_embeddings: Optional[
|
489 |
+
Tuple[torch.Tensor, torch.Tensor]
|
490 |
+
] = None, # will become mandatory in v4.45
|
491 |
+
**kwargs,
|
492 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
493 |
+
if output_attentions:
|
494 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
495 |
+
logger.warning_once(
|
496 |
+
"GenerannoModel is using GenerannoSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
497 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
498 |
+
)
|
499 |
+
return super().forward(
|
500 |
+
hidden_states=hidden_states,
|
501 |
+
attention_mask=attention_mask,
|
502 |
+
position_ids=position_ids,
|
503 |
+
output_attentions=output_attentions,
|
504 |
+
position_embeddings=position_embeddings,
|
505 |
+
)
|
506 |
+
|
507 |
+
bsz, q_len, _ = hidden_states.size()
|
508 |
+
|
509 |
+
query_states = self.q_proj(hidden_states)
|
510 |
+
key_states = self.k_proj(hidden_states)
|
511 |
+
value_states = self.v_proj(hidden_states)
|
512 |
+
|
513 |
+
query_states = query_states.view(
|
514 |
+
bsz, q_len, self.num_heads, self.head_dim
|
515 |
+
).transpose(1, 2)
|
516 |
+
key_states = key_states.view(
|
517 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
518 |
+
).transpose(1, 2)
|
519 |
+
value_states = value_states.view(
|
520 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
521 |
+
).transpose(1, 2)
|
522 |
+
|
523 |
+
if position_embeddings is None:
|
524 |
+
logger.warning_once(
|
525 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
526 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
527 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
528 |
+
"removed and `position_embeddings` will be mandatory."
|
529 |
+
)
|
530 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
531 |
+
else:
|
532 |
+
cos, sin = position_embeddings
|
533 |
+
query_states, key_states = apply_rotary_pos_emb(
|
534 |
+
query_states, key_states, cos, sin
|
535 |
+
)
|
536 |
+
|
537 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
538 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
539 |
+
|
540 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
541 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
542 |
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
543 |
+
query_states = query_states.contiguous()
|
544 |
+
key_states = key_states.contiguous()
|
545 |
+
value_states = value_states.contiguous()
|
546 |
+
|
547 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
548 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
549 |
+
|
550 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
551 |
+
query_states,
|
552 |
+
key_states,
|
553 |
+
value_states,
|
554 |
+
attn_mask=attention_mask,
|
555 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
556 |
+
is_causal=False,
|
557 |
+
)
|
558 |
+
|
559 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
560 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
561 |
+
|
562 |
+
attn_output = self.o_proj(attn_output)
|
563 |
+
|
564 |
+
return attn_output, None
|
565 |
+
|
566 |
+
|
567 |
+
GENERANNO_ATTENTION_CLASSES = {
|
568 |
+
"eager": GenerannoAttention,
|
569 |
+
"sdpa": GenerannoSdpaAttention,
|
570 |
+
}
|
571 |
+
|
572 |
+
|
573 |
+
class GenerannoEncoderLayer(nn.Module):
|
574 |
+
def __init__(self, config: GenerannoConfig, layer_idx: int):
|
575 |
+
super().__init__()
|
576 |
+
self.hidden_size = config.hidden_size
|
577 |
+
|
578 |
+
self.self_attn = GENERANNO_ATTENTION_CLASSES[config._attn_implementation](
|
579 |
+
config=config, layer_idx=layer_idx
|
580 |
+
)
|
581 |
+
|
582 |
+
self.mlp = GenerannoMLP(config)
|
583 |
+
self.input_layernorm = GenerannoRMSNorm(
|
584 |
+
config.hidden_size, eps=config.rms_norm_eps
|
585 |
+
)
|
586 |
+
self.post_attention_layernorm = GenerannoRMSNorm(
|
587 |
+
config.hidden_size, eps=config.rms_norm_eps
|
588 |
+
)
|
589 |
+
|
590 |
+
def forward(
|
591 |
+
self,
|
592 |
+
hidden_states: torch.Tensor,
|
593 |
+
attention_mask: Optional[torch.Tensor] = None,
|
594 |
+
position_ids: Optional[torch.LongTensor] = None,
|
595 |
+
output_attentions: Optional[bool] = False,
|
596 |
+
position_embeddings: Optional[
|
597 |
+
Tuple[torch.Tensor, torch.Tensor]
|
598 |
+
] = None, # will become mandatory in v4.45
|
599 |
+
**kwargs,
|
600 |
+
) -> tuple[Tensor | Any]:
|
601 |
+
"""
|
602 |
+
Args:
|
603 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
604 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
605 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
606 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
607 |
+
output_attentions (`bool`, *optional*):
|
608 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
609 |
+
returned tensors for more detail.
|
610 |
+
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
611 |
+
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
612 |
+
with `head_dim` being the embedding dimension of each attention head.
|
613 |
+
kwargs (`dict`, *optional*):
|
614 |
+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
615 |
+
into the model
|
616 |
+
"""
|
617 |
+
residual = hidden_states
|
618 |
+
|
619 |
+
hidden_states = self.input_layernorm(hidden_states)
|
620 |
+
|
621 |
+
# Self Attention
|
622 |
+
hidden_states, self_attn_weights = self.self_attn(
|
623 |
+
hidden_states=hidden_states,
|
624 |
+
attention_mask=attention_mask,
|
625 |
+
position_ids=position_ids,
|
626 |
+
output_attentions=output_attentions,
|
627 |
+
position_embeddings=position_embeddings,
|
628 |
+
**kwargs,
|
629 |
+
)
|
630 |
+
hidden_states = residual + hidden_states
|
631 |
+
|
632 |
+
# Fully Connected
|
633 |
+
residual = hidden_states
|
634 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
635 |
+
hidden_states = self.mlp(hidden_states)
|
636 |
+
hidden_states = residual + hidden_states
|
637 |
+
|
638 |
+
outputs = (hidden_states,)
|
639 |
+
|
640 |
+
if output_attentions:
|
641 |
+
outputs += (self_attn_weights,)
|
642 |
+
|
643 |
+
return outputs
|
644 |
+
|
645 |
+
|
646 |
+
class GenerannoPreTrainedModel(PreTrainedModel):
|
647 |
+
config_class = GenerannoConfig
|
648 |
+
base_model_prefix = "model"
|
649 |
+
supports_gradient_checkpointing = True
|
650 |
+
_no_split_modules = ["GenerannoEncoderLayer"]
|
651 |
+
_supports_flash_attn_2 = False # TODO
|
652 |
+
_supports_sdpa = True
|
653 |
+
|
654 |
+
def _init_weights(self, module):
|
655 |
+
std = self.config.initializer_range
|
656 |
+
if isinstance(module, nn.Linear):
|
657 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
658 |
+
if module.bias is not None:
|
659 |
+
module.bias.data.zero_()
|
660 |
+
elif isinstance(module, nn.Embedding):
|
661 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
662 |
+
if module.padding_idx is not None:
|
663 |
+
module.weight.data[module.padding_idx].zero_()
|
664 |
+
|
665 |
+
|
666 |
+
class GenerannoModel(GenerannoPreTrainedModel):
|
667 |
+
"""
|
668 |
+
Transformer encoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GenerannoEncoderLayer`]
|
669 |
+
|
670 |
+
Args:
|
671 |
+
config: GenerannoConfig
|
672 |
+
"""
|
673 |
+
|
674 |
+
def __init__(self, config: GenerannoConfig):
|
675 |
+
super().__init__(config)
|
676 |
+
self.padding_idx = config.pad_token_id
|
677 |
+
self.vocab_size = config.vocab_size
|
678 |
+
|
679 |
+
self.embed_tokens = nn.Embedding(
|
680 |
+
config.vocab_size, config.hidden_size, self.padding_idx
|
681 |
+
)
|
682 |
+
self.layers = nn.ModuleList(
|
683 |
+
[
|
684 |
+
GenerannoEncoderLayer(config, layer_idx)
|
685 |
+
for layer_idx in range(config.num_hidden_layers)
|
686 |
+
]
|
687 |
+
)
|
688 |
+
self.norm = GenerannoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
689 |
+
self.rotary_emb = GenerannoRotaryEmbedding(config=config)
|
690 |
+
self.gradient_checkpointing = False
|
691 |
+
|
692 |
+
# Initialize weights and apply final processing
|
693 |
+
self.post_init()
|
694 |
+
|
695 |
+
def get_input_embeddings(self):
|
696 |
+
return self.embed_tokens
|
697 |
+
|
698 |
+
def set_input_embeddings(self, value):
|
699 |
+
self.embed_tokens = value
|
700 |
+
|
701 |
+
def forward(
|
702 |
+
self,
|
703 |
+
input_ids: torch.LongTensor = None,
|
704 |
+
attention_mask: Optional[torch.Tensor] = None,
|
705 |
+
position_ids: Optional[torch.LongTensor] = None,
|
706 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
707 |
+
output_attentions: Optional[bool] = None,
|
708 |
+
output_hidden_states: Optional[bool] = None,
|
709 |
+
return_dict: Optional[bool] = None,
|
710 |
+
) -> tuple[tuple, ...] | BaseModelOutput:
|
711 |
+
output_attentions = (
|
712 |
+
output_attentions
|
713 |
+
if output_attentions is not None
|
714 |
+
else self.config.output_attentions
|
715 |
+
)
|
716 |
+
output_hidden_states = (
|
717 |
+
output_hidden_states
|
718 |
+
if output_hidden_states is not None
|
719 |
+
else self.config.output_hidden_states
|
720 |
+
)
|
721 |
+
return_dict = (
|
722 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
723 |
+
)
|
724 |
+
|
725 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
726 |
+
raise ValueError(
|
727 |
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
728 |
+
)
|
729 |
+
|
730 |
+
if inputs_embeds is None:
|
731 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
732 |
+
|
733 |
+
if position_ids is None:
|
734 |
+
position_ids = torch.arange(
|
735 |
+
0, inputs_embeds.shape[1], device=inputs_embeds.device
|
736 |
+
).unsqueeze(0)
|
737 |
+
|
738 |
+
if attention_mask is None:
|
739 |
+
attention_mask = torch.ones(
|
740 |
+
(inputs_embeds.shape[0], inputs_embeds.shape[1]),
|
741 |
+
device=inputs_embeds.device,
|
742 |
+
)
|
743 |
+
|
744 |
+
attention_mask_converter = (
|
745 |
+
_prepare_4d_attention_mask_for_sdpa
|
746 |
+
if self.config._attn_implementation == "sdpa"
|
747 |
+
else _prepare_4d_attention_mask
|
748 |
+
)
|
749 |
+
|
750 |
+
attention_mask = attention_mask_converter(
|
751 |
+
attention_mask, inputs_embeds.dtype, tgt_len=inputs_embeds.shape[1]
|
752 |
+
)
|
753 |
+
|
754 |
+
hidden_states = inputs_embeds
|
755 |
+
|
756 |
+
# create position embeddings to be shared across the encoder layers
|
757 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
758 |
+
|
759 |
+
# encoder layers
|
760 |
+
all_hidden_states = () if output_hidden_states else None
|
761 |
+
all_self_attns = () if output_attentions else None
|
762 |
+
|
763 |
+
for encoder_layer in self.layers:
|
764 |
+
if output_hidden_states:
|
765 |
+
all_hidden_states += (hidden_states,)
|
766 |
+
|
767 |
+
if self.gradient_checkpointing and self.training:
|
768 |
+
layer_outputs = self._gradient_checkpointing_func(
|
769 |
+
encoder_layer.__call__,
|
770 |
+
hidden_states,
|
771 |
+
attention_mask,
|
772 |
+
position_ids,
|
773 |
+
output_attentions,
|
774 |
+
position_embeddings,
|
775 |
+
)
|
776 |
+
else:
|
777 |
+
layer_outputs = encoder_layer(
|
778 |
+
hidden_states,
|
779 |
+
attention_mask=attention_mask,
|
780 |
+
position_ids=position_ids,
|
781 |
+
output_attentions=output_attentions,
|
782 |
+
position_embeddings=position_embeddings,
|
783 |
+
)
|
784 |
+
|
785 |
+
hidden_states = layer_outputs[0]
|
786 |
+
|
787 |
+
if output_attentions:
|
788 |
+
all_self_attns += (layer_outputs[1],)
|
789 |
+
|
790 |
+
hidden_states = self.norm(hidden_states)
|
791 |
+
|
792 |
+
# add hidden states from the last encoder layer
|
793 |
+
if output_hidden_states:
|
794 |
+
all_hidden_states += (hidden_states,)
|
795 |
+
|
796 |
+
if not return_dict:
|
797 |
+
return tuple(
|
798 |
+
v
|
799 |
+
for v in [hidden_states, all_hidden_states, all_self_attns]
|
800 |
+
if v is not None
|
801 |
+
)
|
802 |
+
return BaseModelOutput(
|
803 |
+
last_hidden_state=hidden_states,
|
804 |
+
hidden_states=all_hidden_states,
|
805 |
+
attentions=all_self_attns,
|
806 |
+
)
|
807 |
+
|
808 |
+
|
809 |
+
class GenerannoForMaskedLM(GenerannoPreTrainedModel):
|
810 |
+
_tied_weights_keys = ["lm_head.weight"]
|
811 |
+
|
812 |
+
def __init__(self, config):
|
813 |
+
super().__init__(config)
|
814 |
+
|
815 |
+
self.model = GenerannoModel(config)
|
816 |
+
self.vocab_size = config.vocab_size
|
817 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
818 |
+
|
819 |
+
self.init_weights()
|
820 |
+
|
821 |
+
def get_input_embeddings(self):
|
822 |
+
return self.model.embed_tokens
|
823 |
+
|
824 |
+
def set_input_embeddings(self, value):
|
825 |
+
self.model.embed_tokens = value
|
826 |
+
|
827 |
+
def get_output_embeddings(self):
|
828 |
+
return self.lm_head
|
829 |
+
|
830 |
+
def set_output_embeddings(self, new_embeddings):
|
831 |
+
self.lm_head = new_embeddings
|
832 |
+
|
833 |
+
def set_encoder(self, encoder):
|
834 |
+
self.model = encoder
|
835 |
+
|
836 |
+
def get_encoder(self):
|
837 |
+
return self.model
|
838 |
+
|
839 |
+
def forward(
|
840 |
+
self,
|
841 |
+
input_ids: torch.LongTensor = None,
|
842 |
+
attention_mask: Optional[torch.Tensor] = None,
|
843 |
+
position_ids: Optional[torch.LongTensor] = None,
|
844 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
845 |
+
labels: Optional[torch.LongTensor] = None,
|
846 |
+
output_attentions: Optional[bool] = None,
|
847 |
+
output_hidden_states: Optional[bool] = None,
|
848 |
+
return_dict: Optional[bool] = None,
|
849 |
+
) -> Union[Tuple, MaskedLMOutput]:
|
850 |
+
r"""
|
851 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
852 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
853 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
854 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
855 |
+
kwargs (`Dict[str, any]`, *optional*, defaults to `{}`):
|
856 |
+
Used to hide legacy arguments that have been deprecated.
|
857 |
+
"""
|
858 |
+
return_dict = (
|
859 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
860 |
+
)
|
861 |
+
|
862 |
+
outputs = self.model(
|
863 |
+
input_ids,
|
864 |
+
attention_mask=attention_mask,
|
865 |
+
position_ids=position_ids,
|
866 |
+
inputs_embeds=inputs_embeds,
|
867 |
+
output_attentions=output_attentions,
|
868 |
+
output_hidden_states=output_hidden_states,
|
869 |
+
return_dict=return_dict,
|
870 |
+
)
|
871 |
+
hidden_states = outputs[0]
|
872 |
+
if self.config.pretraining_tp > 1:
|
873 |
+
lm_head_slices = self.lm_head.weight.split(
|
874 |
+
self.vocab_size // self.config.pretraining_tp, dim=0
|
875 |
+
)
|
876 |
+
logits = [
|
877 |
+
F.linear(hidden_states, lm_head_slices[i])
|
878 |
+
for i in range(self.config.pretraining_tp)
|
879 |
+
]
|
880 |
+
logits = torch.cat(logits, dim=-1)
|
881 |
+
else:
|
882 |
+
logits = self.lm_head(hidden_states)
|
883 |
+
|
884 |
+
masked_lm_loss = None
|
885 |
+
if labels is not None:
|
886 |
+
loss_fct = CrossEntropyLoss()
|
887 |
+
|
888 |
+
labels = labels.to(logits.device)
|
889 |
+
masked_lm_loss = loss_fct(
|
890 |
+
logits.view(-1, self.config.vocab_size).float(), labels.view(-1)
|
891 |
+
)
|
892 |
+
|
893 |
+
if not return_dict:
|
894 |
+
output = (logits,) + outputs[2:]
|
895 |
+
return (
|
896 |
+
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
897 |
+
)
|
898 |
+
|
899 |
+
return MaskedLMOutput(
|
900 |
+
loss=masked_lm_loss,
|
901 |
+
logits=logits,
|
902 |
+
hidden_states=outputs.hidden_states,
|
903 |
+
attentions=outputs.attentions,
|
904 |
+
)
|
905 |
+
|
906 |
+
|
907 |
+
class GenerannoForTokenClassification(GenerannoPreTrainedModel):
|
908 |
+
def __init__(self, config):
|
909 |
+
super().__init__(config)
|
910 |
+
self.num_labels = config.num_labels
|
911 |
+
|
912 |
+
self.model = GenerannoModel(config)
|
913 |
+
self.feature_layer = getattr(config, "feature_layer", -1)
|
914 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
915 |
+
if getattr(config, "use_mlp_classifier", False):
|
916 |
+
self.score = nn.Sequential(
|
917 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
918 |
+
nn.GELU(),
|
919 |
+
nn.Dropout(0.1),
|
920 |
+
nn.Linear(config.hidden_size, self.num_labels, bias=False),
|
921 |
+
)
|
922 |
+
|
923 |
+
self.label_weights = (
|
924 |
+
torch.tensor(config.label_weights)
|
925 |
+
if hasattr(config, "label_weights")
|
926 |
+
else None
|
927 |
+
)
|
928 |
+
|
929 |
+
self.init_weights()
|
930 |
+
|
931 |
+
def forward(
|
932 |
+
self,
|
933 |
+
input_ids: Optional[torch.LongTensor] = None,
|
934 |
+
attention_mask: Optional[torch.Tensor] = None,
|
935 |
+
position_ids: Optional[torch.LongTensor] = None,
|
936 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
937 |
+
labels: Optional[torch.LongTensor] = None,
|
938 |
+
output_attentions: Optional[bool] = None,
|
939 |
+
output_hidden_states: Optional[bool] = None,
|
940 |
+
return_dict: Optional[bool] = None,
|
941 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
942 |
+
r"""
|
943 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
944 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
945 |
+
"""
|
946 |
+
return_dict = (
|
947 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
948 |
+
)
|
949 |
+
|
950 |
+
output_hidden_states = True
|
951 |
+
outputs = self.model(
|
952 |
+
input_ids,
|
953 |
+
attention_mask=attention_mask,
|
954 |
+
position_ids=position_ids,
|
955 |
+
inputs_embeds=inputs_embeds,
|
956 |
+
output_attentions=output_attentions,
|
957 |
+
output_hidden_states=output_hidden_states,
|
958 |
+
return_dict=return_dict,
|
959 |
+
)
|
960 |
+
|
961 |
+
hidden_states = outputs["hidden_states"][
|
962 |
+
self.feature_layer if hasattr(self, "feature_layer") else -1
|
963 |
+
]
|
964 |
+
logits = self.score(hidden_states)
|
965 |
+
|
966 |
+
loss = None
|
967 |
+
if labels is not None:
|
968 |
+
if self.label_weights is not None:
|
969 |
+
self.label_weights = self.label_weights.to(
|
970 |
+
device=logits.device, dtype=logits.dtype
|
971 |
+
)
|
972 |
+
loss_fct = CrossEntropyLoss(weight=self.label_weights)
|
973 |
+
else:
|
974 |
+
loss_fct = CrossEntropyLoss()
|
975 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
976 |
+
|
977 |
+
if not return_dict:
|
978 |
+
output = (logits,)
|
979 |
+
return ((loss,) + output) if loss is not None else output
|
980 |
+
|
981 |
+
return TokenClassifierOutput(loss=loss, logits=logits)
|
982 |
+
|
983 |
+
|
984 |
+
class GenerannoForSequenceClassification(GenerannoPreTrainedModel):
|
985 |
+
def __init__(self, config):
|
986 |
+
super().__init__(config)
|
987 |
+
self.num_labels = config.num_labels
|
988 |
+
self.config = config
|
989 |
+
|
990 |
+
self.model = GenerannoModel(config)
|
991 |
+
self.feature_layer = getattr(config, "feature_layer", -1)
|
992 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
993 |
+
if getattr(config, "use_mlp_classifier", False):
|
994 |
+
self.score = nn.Sequential(
|
995 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
996 |
+
nn.GELU(),
|
997 |
+
nn.Dropout(0.1),
|
998 |
+
nn.Linear(config.hidden_size, self.num_labels, bias=False),
|
999 |
+
)
|
1000 |
+
|
1001 |
+
# Initialize weights and apply final processing
|
1002 |
+
self.post_init()
|
1003 |
+
|
1004 |
+
def forward(
|
1005 |
+
self,
|
1006 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1007 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1008 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1009 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1010 |
+
labels: Optional[torch.LongTensor] = None,
|
1011 |
+
output_attentions: Optional[bool] = None,
|
1012 |
+
output_hidden_states: Optional[bool] = None,
|
1013 |
+
return_dict: Optional[bool] = None,
|
1014 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
1015 |
+
r"""
|
1016 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1017 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1018 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1019 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1020 |
+
"""
|
1021 |
+
return_dict = (
|
1022 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1023 |
+
)
|
1024 |
+
|
1025 |
+
output_hidden_states = True
|
1026 |
+
outputs = self.model(
|
1027 |
+
input_ids,
|
1028 |
+
attention_mask=attention_mask,
|
1029 |
+
position_ids=position_ids,
|
1030 |
+
inputs_embeds=inputs_embeds,
|
1031 |
+
output_attentions=output_attentions,
|
1032 |
+
output_hidden_states=output_hidden_states,
|
1033 |
+
return_dict=return_dict,
|
1034 |
+
)
|
1035 |
+
hidden_states = outputs["hidden_states"][
|
1036 |
+
self.feature_layer if hasattr(self, "feature_layer") else -1
|
1037 |
+
]
|
1038 |
+
pooled_hidden_states = hidden_states[:, 0]
|
1039 |
+
logits = self.score(pooled_hidden_states)
|
1040 |
+
|
1041 |
+
loss = None
|
1042 |
+
if labels is not None:
|
1043 |
+
labels = labels.to(logits.device)
|
1044 |
+
|
1045 |
+
if self.config.problem_type is None:
|
1046 |
+
if self.num_labels == 1:
|
1047 |
+
self.config.problem_type = "regression"
|
1048 |
+
elif self.num_labels > 1 and (
|
1049 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
1050 |
+
):
|
1051 |
+
self.config.problem_type = "single_label_classification"
|
1052 |
+
else:
|
1053 |
+
self.config.problem_type = "multi_label_classification"
|
1054 |
+
|
1055 |
+
if self.config.problem_type == "regression":
|
1056 |
+
loss_fct = MSELoss()
|
1057 |
+
if self.num_labels == 1:
|
1058 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
1059 |
+
else:
|
1060 |
+
loss = loss_fct(logits, labels)
|
1061 |
+
elif self.config.problem_type == "single_label_classification":
|
1062 |
+
loss_fct = CrossEntropyLoss()
|
1063 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1064 |
+
elif self.config.problem_type == "multi_label_classification":
|
1065 |
+
loss_fct = BCEWithLogitsLoss()
|
1066 |
+
loss = loss_fct(logits, labels)
|
1067 |
+
if not return_dict:
|
1068 |
+
output = (logits,)
|
1069 |
+
return ((loss,) + output) if loss is not None else output
|
1070 |
+
|
1071 |
+
return SequenceClassifierOutput(loss=loss, logits=logits)
|
special_tokens_map.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<s>",
|
3 |
+
"eos_token": "</s>",
|
4 |
+
"pad_token": "<pad>",
|
5 |
+
"unk_token": {
|
6 |
+
"content": "<oov>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false
|
11 |
+
}
|
12 |
+
}
|
tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6c04f653f2a0e6f1f3c0e744aaa6109fde1a5b862d7f837185d29470b2aba563
|
3 |
+
size 238575
|
tokenizer_config.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"add_prefix_space": false,
|
5 |
+
"added_tokens_decoder": {
|
6 |
+
"0": {
|
7 |
+
"content": "<oov>",
|
8 |
+
"lstrip": false,
|
9 |
+
"normalized": false,
|
10 |
+
"rstrip": false,
|
11 |
+
"single_word": false,
|
12 |
+
"special": true
|
13 |
+
},
|
14 |
+
"1": {
|
15 |
+
"content": "<s>",
|
16 |
+
"lstrip": false,
|
17 |
+
"normalized": false,
|
18 |
+
"rstrip": false,
|
19 |
+
"single_word": false,
|
20 |
+
"special": true
|
21 |
+
},
|
22 |
+
"2": {
|
23 |
+
"content": "</s>",
|
24 |
+
"lstrip": false,
|
25 |
+
"normalized": false,
|
26 |
+
"rstrip": false,
|
27 |
+
"single_word": false,
|
28 |
+
"special": true
|
29 |
+
}
|
30 |
+
},
|
31 |
+
"bos_token": "<s>",
|
32 |
+
"clean_up_tokenization_spaces": false,
|
33 |
+
"eos_token": "</s>",
|
34 |
+
"legacy": true,
|
35 |
+
"model_max_length": 1000000000000000019884624838656,
|
36 |
+
"pad_token": "<pad>",
|
37 |
+
"sp_model_kwargs": {},
|
38 |
+
"spaces_between_special_tokens": false,
|
39 |
+
"tokenizer_class": "LlamaTokenizer",
|
40 |
+
"unk_token": "<oov>",
|
41 |
+
"use_default_system_prompt": false
|
42 |
+
}
|