mitmul commited on
Commit
03703df
·
verified ·
1 Parent(s): cf57853

Add files using upload-large-folder tool

Browse files
README.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ - ja
6
+ pipeline_tag: text-generation
7
+ library_name: transformers
8
+ base_model: pfnet/plamo-2-1b
9
+ tags:
10
+ - mlx
11
+ ---
12
+
13
+ # mlx-community/plamo-2-1b
14
+
15
+ The Model [mlx-community/plamo-2-1b](https://huggingface.co/mlx-community/plamo-2-1b) was
16
+ converted to MLX format from [pfnet/plamo-2-1b](https://huggingface.co/pfnet/plamo-2-1b)
17
+ using mlx-lm version **0.21.0**.
18
+
19
+ ## Use with mlx
20
+
21
+ ```bash
22
+ pip install mlx-lm
23
+ ```
24
+
25
+ ```python
26
+ from mlx_lm import load, generate
27
+
28
+ model, tokenizer = load("mlx-community/plamo-2-1b")
29
+
30
+ prompt = "hello"
31
+
32
+ if tokenizer.chat_template is not None:
33
+ messages = [{"role": "user", "content": prompt}]
34
+ prompt = tokenizer.apply_chat_template(
35
+ messages, add_generation_prompt=True
36
+ )
37
+
38
+ response = generate(model, tokenizer, prompt=prompt, verbose=True)
39
+ ```
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "PlamoForCausalLM"
4
+ ],
5
+ "attention_window_size": 2048,
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_plamo.PlamoConfig",
8
+ "AutoModelForCausalLM": "modeling_plamo.PlamoForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "capacity_factor": 1.0,
12
+ "eos_token_id": 2,
13
+ "eval_attention_n_bit": null,
14
+ "eval_mlp_n_bit": null,
15
+ "expert_dropout": 0.0,
16
+ "fp8_accum_dtype": "bfloat16",
17
+ "group_size": 1024,
18
+ "hidden_size": 2048,
19
+ "hidden_size_per_head": 128,
20
+ "image_feature_size": null,
21
+ "image_proj_type": "linear",
22
+ "image_token_id": null,
23
+ "intermediate_size": 8192,
24
+ "k_expert": null,
25
+ "linear_type": "fp8",
26
+ "mamba_chunk_size": 256,
27
+ "mamba_d_conv": 4,
28
+ "mamba_d_state": 64,
29
+ "mamba_enabled": true,
30
+ "mamba_num_heads": 32,
31
+ "mamba_step": 2,
32
+ "max_position_embeddings": 10485760,
33
+ "model_type": "plamo2",
34
+ "n_expert": null,
35
+ "num_attention_heads": 16,
36
+ "num_hidden_layers": 16,
37
+ "num_key_value_heads": 1,
38
+ "rms_norm_eps": 1e-06,
39
+ "shared_intermediate_size": null,
40
+ "sliding_window": 2048,
41
+ "sparse_intermediate_size": null,
42
+ "sparse_step": null,
43
+ "tokenizer_class": "PlamoTokenizer",
44
+ "torch_dtype": "float32",
45
+ "transformers_version": "4.44.2",
46
+ "use_cache": true,
47
+ "use_predefined_initial_state": false,
48
+ "vocab_size": 100000
49
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3a5eacef896d4ebe5ce590df7fbfc8dc2c4f3dc4886e2ae01e7a609dd7bd827
3
+ size 2582909060
model.safetensors.index.json ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 2582883840
4
+ },
5
+ "weight_map": {
6
+ "model.embed_tokens.weight": "model.safetensors",
7
+ "model.layers.layers.0.mixer.A_log": "model.safetensors",
8
+ "model.layers.layers.0.mixer.B_norm_weight": "model.safetensors",
9
+ "model.layers.layers.0.mixer.C_norm_weight": "model.safetensors",
10
+ "model.layers.layers.0.mixer.D": "model.safetensors",
11
+ "model.layers.layers.0.mixer.bcdt_proj.weight": "model.safetensors",
12
+ "model.layers.layers.0.mixer.conv1d.weight": "model.safetensors",
13
+ "model.layers.layers.0.mixer.dt_bias": "model.safetensors",
14
+ "model.layers.layers.0.mixer.dt_norm_weight": "model.safetensors",
15
+ "model.layers.layers.0.mixer.dt_proj.weight": "model.safetensors",
16
+ "model.layers.layers.0.mixer.in_proj.weight": "model.safetensors",
17
+ "model.layers.layers.0.mixer.out_proj.weight": "model.safetensors",
18
+ "model.layers.layers.0.mlp.down_proj.weight": "model.safetensors",
19
+ "model.layers.layers.0.mlp.gate_up_proj.weight": "model.safetensors",
20
+ "model.layers.layers.0.post_mixer_norm.weight": "model.safetensors",
21
+ "model.layers.layers.0.post_mlp_norm.weight": "model.safetensors",
22
+ "model.layers.layers.0.pre_mixer_norm.weight": "model.safetensors",
23
+ "model.layers.layers.0.pre_mlp_norm.weight": "model.safetensors",
24
+ "model.layers.layers.1.mixer.k_weight": "model.safetensors",
25
+ "model.layers.layers.1.mixer.o_proj.weight": "model.safetensors",
26
+ "model.layers.layers.1.mixer.q_weight": "model.safetensors",
27
+ "model.layers.layers.1.mixer.qkv_proj.weight": "model.safetensors",
28
+ "model.layers.layers.1.mlp.down_proj.weight": "model.safetensors",
29
+ "model.layers.layers.1.mlp.gate_up_proj.weight": "model.safetensors",
30
+ "model.layers.layers.1.post_mixer_norm.weight": "model.safetensors",
31
+ "model.layers.layers.1.post_mlp_norm.weight": "model.safetensors",
32
+ "model.layers.layers.1.pre_mixer_norm.weight": "model.safetensors",
33
+ "model.layers.layers.1.pre_mlp_norm.weight": "model.safetensors",
34
+ "model.layers.layers.10.mixer.A_log": "model.safetensors",
35
+ "model.layers.layers.10.mixer.B_norm_weight": "model.safetensors",
36
+ "model.layers.layers.10.mixer.C_norm_weight": "model.safetensors",
37
+ "model.layers.layers.10.mixer.D": "model.safetensors",
38
+ "model.layers.layers.10.mixer.bcdt_proj.weight": "model.safetensors",
39
+ "model.layers.layers.10.mixer.conv1d.weight": "model.safetensors",
40
+ "model.layers.layers.10.mixer.dt_bias": "model.safetensors",
41
+ "model.layers.layers.10.mixer.dt_norm_weight": "model.safetensors",
42
+ "model.layers.layers.10.mixer.dt_proj.weight": "model.safetensors",
43
+ "model.layers.layers.10.mixer.in_proj.weight": "model.safetensors",
44
+ "model.layers.layers.10.mixer.out_proj.weight": "model.safetensors",
45
+ "model.layers.layers.10.mlp.down_proj.weight": "model.safetensors",
46
+ "model.layers.layers.10.mlp.gate_up_proj.weight": "model.safetensors",
47
+ "model.layers.layers.10.post_mixer_norm.weight": "model.safetensors",
48
+ "model.layers.layers.10.post_mlp_norm.weight": "model.safetensors",
49
+ "model.layers.layers.10.pre_mixer_norm.weight": "model.safetensors",
50
+ "model.layers.layers.10.pre_mlp_norm.weight": "model.safetensors",
51
+ "model.layers.layers.11.mixer.k_weight": "model.safetensors",
52
+ "model.layers.layers.11.mixer.o_proj.weight": "model.safetensors",
53
+ "model.layers.layers.11.mixer.q_weight": "model.safetensors",
54
+ "model.layers.layers.11.mixer.qkv_proj.weight": "model.safetensors",
55
+ "model.layers.layers.11.mlp.down_proj.weight": "model.safetensors",
56
+ "model.layers.layers.11.mlp.gate_up_proj.weight": "model.safetensors",
57
+ "model.layers.layers.11.post_mixer_norm.weight": "model.safetensors",
58
+ "model.layers.layers.11.post_mlp_norm.weight": "model.safetensors",
59
+ "model.layers.layers.11.pre_mixer_norm.weight": "model.safetensors",
60
+ "model.layers.layers.11.pre_mlp_norm.weight": "model.safetensors",
61
+ "model.layers.layers.12.mixer.A_log": "model.safetensors",
62
+ "model.layers.layers.12.mixer.B_norm_weight": "model.safetensors",
63
+ "model.layers.layers.12.mixer.C_norm_weight": "model.safetensors",
64
+ "model.layers.layers.12.mixer.D": "model.safetensors",
65
+ "model.layers.layers.12.mixer.bcdt_proj.weight": "model.safetensors",
66
+ "model.layers.layers.12.mixer.conv1d.weight": "model.safetensors",
67
+ "model.layers.layers.12.mixer.dt_bias": "model.safetensors",
68
+ "model.layers.layers.12.mixer.dt_norm_weight": "model.safetensors",
69
+ "model.layers.layers.12.mixer.dt_proj.weight": "model.safetensors",
70
+ "model.layers.layers.12.mixer.in_proj.weight": "model.safetensors",
71
+ "model.layers.layers.12.mixer.out_proj.weight": "model.safetensors",
72
+ "model.layers.layers.12.mlp.down_proj.weight": "model.safetensors",
73
+ "model.layers.layers.12.mlp.gate_up_proj.weight": "model.safetensors",
74
+ "model.layers.layers.12.post_mixer_norm.weight": "model.safetensors",
75
+ "model.layers.layers.12.post_mlp_norm.weight": "model.safetensors",
76
+ "model.layers.layers.12.pre_mixer_norm.weight": "model.safetensors",
77
+ "model.layers.layers.12.pre_mlp_norm.weight": "model.safetensors",
78
+ "model.layers.layers.13.mixer.k_weight": "model.safetensors",
79
+ "model.layers.layers.13.mixer.o_proj.weight": "model.safetensors",
80
+ "model.layers.layers.13.mixer.q_weight": "model.safetensors",
81
+ "model.layers.layers.13.mixer.qkv_proj.weight": "model.safetensors",
82
+ "model.layers.layers.13.mlp.down_proj.weight": "model.safetensors",
83
+ "model.layers.layers.13.mlp.gate_up_proj.weight": "model.safetensors",
84
+ "model.layers.layers.13.post_mixer_norm.weight": "model.safetensors",
85
+ "model.layers.layers.13.post_mlp_norm.weight": "model.safetensors",
86
+ "model.layers.layers.13.pre_mixer_norm.weight": "model.safetensors",
87
+ "model.layers.layers.13.pre_mlp_norm.weight": "model.safetensors",
88
+ "model.layers.layers.14.mixer.A_log": "model.safetensors",
89
+ "model.layers.layers.14.mixer.B_norm_weight": "model.safetensors",
90
+ "model.layers.layers.14.mixer.C_norm_weight": "model.safetensors",
91
+ "model.layers.layers.14.mixer.D": "model.safetensors",
92
+ "model.layers.layers.14.mixer.bcdt_proj.weight": "model.safetensors",
93
+ "model.layers.layers.14.mixer.conv1d.weight": "model.safetensors",
94
+ "model.layers.layers.14.mixer.dt_bias": "model.safetensors",
95
+ "model.layers.layers.14.mixer.dt_norm_weight": "model.safetensors",
96
+ "model.layers.layers.14.mixer.dt_proj.weight": "model.safetensors",
97
+ "model.layers.layers.14.mixer.in_proj.weight": "model.safetensors",
98
+ "model.layers.layers.14.mixer.out_proj.weight": "model.safetensors",
99
+ "model.layers.layers.14.mlp.down_proj.weight": "model.safetensors",
100
+ "model.layers.layers.14.mlp.gate_up_proj.weight": "model.safetensors",
101
+ "model.layers.layers.14.post_mixer_norm.weight": "model.safetensors",
102
+ "model.layers.layers.14.post_mlp_norm.weight": "model.safetensors",
103
+ "model.layers.layers.14.pre_mixer_norm.weight": "model.safetensors",
104
+ "model.layers.layers.14.pre_mlp_norm.weight": "model.safetensors",
105
+ "model.layers.layers.15.mixer.k_weight": "model.safetensors",
106
+ "model.layers.layers.15.mixer.o_proj.weight": "model.safetensors",
107
+ "model.layers.layers.15.mixer.q_weight": "model.safetensors",
108
+ "model.layers.layers.15.mixer.qkv_proj.weight": "model.safetensors",
109
+ "model.layers.layers.15.mlp.down_proj.weight": "model.safetensors",
110
+ "model.layers.layers.15.mlp.gate_up_proj.weight": "model.safetensors",
111
+ "model.layers.layers.15.post_mixer_norm.weight": "model.safetensors",
112
+ "model.layers.layers.15.post_mlp_norm.weight": "model.safetensors",
113
+ "model.layers.layers.15.pre_mixer_norm.weight": "model.safetensors",
114
+ "model.layers.layers.15.pre_mlp_norm.weight": "model.safetensors",
115
+ "model.layers.layers.2.mixer.A_log": "model.safetensors",
116
+ "model.layers.layers.2.mixer.B_norm_weight": "model.safetensors",
117
+ "model.layers.layers.2.mixer.C_norm_weight": "model.safetensors",
118
+ "model.layers.layers.2.mixer.D": "model.safetensors",
119
+ "model.layers.layers.2.mixer.bcdt_proj.weight": "model.safetensors",
120
+ "model.layers.layers.2.mixer.conv1d.weight": "model.safetensors",
121
+ "model.layers.layers.2.mixer.dt_bias": "model.safetensors",
122
+ "model.layers.layers.2.mixer.dt_norm_weight": "model.safetensors",
123
+ "model.layers.layers.2.mixer.dt_proj.weight": "model.safetensors",
124
+ "model.layers.layers.2.mixer.in_proj.weight": "model.safetensors",
125
+ "model.layers.layers.2.mixer.out_proj.weight": "model.safetensors",
126
+ "model.layers.layers.2.mlp.down_proj.weight": "model.safetensors",
127
+ "model.layers.layers.2.mlp.gate_up_proj.weight": "model.safetensors",
128
+ "model.layers.layers.2.post_mixer_norm.weight": "model.safetensors",
129
+ "model.layers.layers.2.post_mlp_norm.weight": "model.safetensors",
130
+ "model.layers.layers.2.pre_mixer_norm.weight": "model.safetensors",
131
+ "model.layers.layers.2.pre_mlp_norm.weight": "model.safetensors",
132
+ "model.layers.layers.3.mixer.k_weight": "model.safetensors",
133
+ "model.layers.layers.3.mixer.o_proj.weight": "model.safetensors",
134
+ "model.layers.layers.3.mixer.q_weight": "model.safetensors",
135
+ "model.layers.layers.3.mixer.qkv_proj.weight": "model.safetensors",
136
+ "model.layers.layers.3.mlp.down_proj.weight": "model.safetensors",
137
+ "model.layers.layers.3.mlp.gate_up_proj.weight": "model.safetensors",
138
+ "model.layers.layers.3.post_mixer_norm.weight": "model.safetensors",
139
+ "model.layers.layers.3.post_mlp_norm.weight": "model.safetensors",
140
+ "model.layers.layers.3.pre_mixer_norm.weight": "model.safetensors",
141
+ "model.layers.layers.3.pre_mlp_norm.weight": "model.safetensors",
142
+ "model.layers.layers.4.mixer.A_log": "model.safetensors",
143
+ "model.layers.layers.4.mixer.B_norm_weight": "model.safetensors",
144
+ "model.layers.layers.4.mixer.C_norm_weight": "model.safetensors",
145
+ "model.layers.layers.4.mixer.D": "model.safetensors",
146
+ "model.layers.layers.4.mixer.bcdt_proj.weight": "model.safetensors",
147
+ "model.layers.layers.4.mixer.conv1d.weight": "model.safetensors",
148
+ "model.layers.layers.4.mixer.dt_bias": "model.safetensors",
149
+ "model.layers.layers.4.mixer.dt_norm_weight": "model.safetensors",
150
+ "model.layers.layers.4.mixer.dt_proj.weight": "model.safetensors",
151
+ "model.layers.layers.4.mixer.in_proj.weight": "model.safetensors",
152
+ "model.layers.layers.4.mixer.out_proj.weight": "model.safetensors",
153
+ "model.layers.layers.4.mlp.down_proj.weight": "model.safetensors",
154
+ "model.layers.layers.4.mlp.gate_up_proj.weight": "model.safetensors",
155
+ "model.layers.layers.4.post_mixer_norm.weight": "model.safetensors",
156
+ "model.layers.layers.4.post_mlp_norm.weight": "model.safetensors",
157
+ "model.layers.layers.4.pre_mixer_norm.weight": "model.safetensors",
158
+ "model.layers.layers.4.pre_mlp_norm.weight": "model.safetensors",
159
+ "model.layers.layers.5.mixer.k_weight": "model.safetensors",
160
+ "model.layers.layers.5.mixer.o_proj.weight": "model.safetensors",
161
+ "model.layers.layers.5.mixer.q_weight": "model.safetensors",
162
+ "model.layers.layers.5.mixer.qkv_proj.weight": "model.safetensors",
163
+ "model.layers.layers.5.mlp.down_proj.weight": "model.safetensors",
164
+ "model.layers.layers.5.mlp.gate_up_proj.weight": "model.safetensors",
165
+ "model.layers.layers.5.post_mixer_norm.weight": "model.safetensors",
166
+ "model.layers.layers.5.post_mlp_norm.weight": "model.safetensors",
167
+ "model.layers.layers.5.pre_mixer_norm.weight": "model.safetensors",
168
+ "model.layers.layers.5.pre_mlp_norm.weight": "model.safetensors",
169
+ "model.layers.layers.6.mixer.A_log": "model.safetensors",
170
+ "model.layers.layers.6.mixer.B_norm_weight": "model.safetensors",
171
+ "model.layers.layers.6.mixer.C_norm_weight": "model.safetensors",
172
+ "model.layers.layers.6.mixer.D": "model.safetensors",
173
+ "model.layers.layers.6.mixer.bcdt_proj.weight": "model.safetensors",
174
+ "model.layers.layers.6.mixer.conv1d.weight": "model.safetensors",
175
+ "model.layers.layers.6.mixer.dt_bias": "model.safetensors",
176
+ "model.layers.layers.6.mixer.dt_norm_weight": "model.safetensors",
177
+ "model.layers.layers.6.mixer.dt_proj.weight": "model.safetensors",
178
+ "model.layers.layers.6.mixer.in_proj.weight": "model.safetensors",
179
+ "model.layers.layers.6.mixer.out_proj.weight": "model.safetensors",
180
+ "model.layers.layers.6.mlp.down_proj.weight": "model.safetensors",
181
+ "model.layers.layers.6.mlp.gate_up_proj.weight": "model.safetensors",
182
+ "model.layers.layers.6.post_mixer_norm.weight": "model.safetensors",
183
+ "model.layers.layers.6.post_mlp_norm.weight": "model.safetensors",
184
+ "model.layers.layers.6.pre_mixer_norm.weight": "model.safetensors",
185
+ "model.layers.layers.6.pre_mlp_norm.weight": "model.safetensors",
186
+ "model.layers.layers.7.mixer.k_weight": "model.safetensors",
187
+ "model.layers.layers.7.mixer.o_proj.weight": "model.safetensors",
188
+ "model.layers.layers.7.mixer.q_weight": "model.safetensors",
189
+ "model.layers.layers.7.mixer.qkv_proj.weight": "model.safetensors",
190
+ "model.layers.layers.7.mlp.down_proj.weight": "model.safetensors",
191
+ "model.layers.layers.7.mlp.gate_up_proj.weight": "model.safetensors",
192
+ "model.layers.layers.7.post_mixer_norm.weight": "model.safetensors",
193
+ "model.layers.layers.7.post_mlp_norm.weight": "model.safetensors",
194
+ "model.layers.layers.7.pre_mixer_norm.weight": "model.safetensors",
195
+ "model.layers.layers.7.pre_mlp_norm.weight": "model.safetensors",
196
+ "model.layers.layers.8.mixer.A_log": "model.safetensors",
197
+ "model.layers.layers.8.mixer.B_norm_weight": "model.safetensors",
198
+ "model.layers.layers.8.mixer.C_norm_weight": "model.safetensors",
199
+ "model.layers.layers.8.mixer.D": "model.safetensors",
200
+ "model.layers.layers.8.mixer.bcdt_proj.weight": "model.safetensors",
201
+ "model.layers.layers.8.mixer.conv1d.weight": "model.safetensors",
202
+ "model.layers.layers.8.mixer.dt_bias": "model.safetensors",
203
+ "model.layers.layers.8.mixer.dt_norm_weight": "model.safetensors",
204
+ "model.layers.layers.8.mixer.dt_proj.weight": "model.safetensors",
205
+ "model.layers.layers.8.mixer.in_proj.weight": "model.safetensors",
206
+ "model.layers.layers.8.mixer.out_proj.weight": "model.safetensors",
207
+ "model.layers.layers.8.mlp.down_proj.weight": "model.safetensors",
208
+ "model.layers.layers.8.mlp.gate_up_proj.weight": "model.safetensors",
209
+ "model.layers.layers.8.post_mixer_norm.weight": "model.safetensors",
210
+ "model.layers.layers.8.post_mlp_norm.weight": "model.safetensors",
211
+ "model.layers.layers.8.pre_mixer_norm.weight": "model.safetensors",
212
+ "model.layers.layers.8.pre_mlp_norm.weight": "model.safetensors",
213
+ "model.layers.layers.9.mixer.k_weight": "model.safetensors",
214
+ "model.layers.layers.9.mixer.o_proj.weight": "model.safetensors",
215
+ "model.layers.layers.9.mixer.q_weight": "model.safetensors",
216
+ "model.layers.layers.9.mixer.qkv_proj.weight": "model.safetensors",
217
+ "model.layers.layers.9.mlp.down_proj.weight": "model.safetensors",
218
+ "model.layers.layers.9.mlp.gate_up_proj.weight": "model.safetensors",
219
+ "model.layers.layers.9.post_mixer_norm.weight": "model.safetensors",
220
+ "model.layers.layers.9.post_mlp_norm.weight": "model.safetensors",
221
+ "model.layers.layers.9.pre_mixer_norm.weight": "model.safetensors",
222
+ "model.layers.layers.9.pre_mlp_norm.weight": "model.safetensors",
223
+ "model.norm.weight": "model.safetensors"
224
+ }
225
+ }
modeling_plamo.py ADDED
@@ -0,0 +1,1801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import math
3
+ from collections import OrderedDict
4
+ from dataclasses import dataclass
5
+ from typing import Any, Literal, NamedTuple, Optional, Union
6
+
7
+ import mlx.core as mx
8
+ import mlx.nn as nn
9
+
10
+ from .base import BaseModelArgs, create_attention_mask
11
+
12
+
13
+ def _is_first_token(mask: mx.array) -> mx.array:
14
+ assert mask.dtype == mx.bool_ # type: ignore
15
+ B, Nh, q_len, kv_len = mask.shape
16
+ mask = mask[:, :, :, -q_len:]
17
+ cont = q_len != kv_len
18
+ v = False if cont else True
19
+ out = mx.logical_not(mx.diagonal(mask, offset=-1, axis1=-2, axis2=-1).astype(mx.bool_)) # type: ignore
20
+ out = mx.concatenate([mx.full(shape=(B, Nh, 1), dtype=mx.bool_, vals=v), out], axis=-1) # type: ignore
21
+ return out
22
+
23
+
24
+ def _swiglu(h: mx.array) -> mx.array:
25
+ size = h.shape[-1]
26
+ chunks = 2
27
+ _current_idx = 0
28
+ split_sizes = []
29
+ for i in range(chunks - 1):
30
+ _current_idx += size // chunks + (1 if i < size % chunks else 0)
31
+ split_sizes.append(_current_idx)
32
+ hs = mx.split(h, split_sizes, axis=-1)
33
+ return nn.silu(hs[0]) * hs[1]
34
+
35
+
36
+ class RotaryEmbedding(nn.Module):
37
+ def __init__(self, dim: int, max_position_embeddings: int = 2048, base: int = 10000) -> None:
38
+ super().__init__()
39
+
40
+ self.dim = dim
41
+ self.max_position_embeddings = max_position_embeddings
42
+ self.base = base
43
+ inv_freq = 1.0 / (self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim))
44
+ self._inv_freq = inv_freq
45
+
46
+ # Build here to make `torch.jit.trace` work.
47
+ self._set_cos_sin_cache(seq_len=max_position_embeddings, dtype=mx.float32)
48
+
49
+ def _set_cos_sin_cache(self, seq_len: int, dtype: Any) -> None:
50
+ self.max_seq_len_cached = seq_len
51
+ t = mx.arange(self.max_seq_len_cached, dtype=self._inv_freq.dtype) # type: ignore
52
+
53
+ freqs = mx.einsum("i,j->ij", t, self._inv_freq)
54
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
55
+ emb = mx.concatenate([freqs, freqs], axis=-1)
56
+ self._cos_cached = emb.cos()[None, None, :, :].astype(mx.float32)
57
+ self._sin_cached = emb.sin()[None, None, :, :].astype(mx.float32)
58
+
59
+ def __call__(self, x: mx.array, seq_len: int) -> tuple[mx.array, mx.array]:
60
+ # x: [bs, num_attention_heads, seq_len, head_size]
61
+ if seq_len > self.max_seq_len_cached:
62
+ self._set_cos_sin_cache(seq_len=seq_len, dtype=x.dtype)
63
+
64
+ return (
65
+ self._cos_cached[:, :, :seq_len, ...].astype(x.dtype), # type: ignore
66
+ self._sin_cached[:, :, :seq_len, ...].astype(x.dtype), # type: ignore
67
+ )
68
+
69
+
70
+ def _rotate_half(x: mx.array) -> mx.array:
71
+ """Rotates half the hidden dims of the input."""
72
+ x1 = x[..., : x.shape[-1] // 2]
73
+ x2 = x[..., x.shape[-1] // 2 :]
74
+ return mx.concatenate([-x2, x1], axis=-1)
75
+
76
+
77
+ def _rotary_pos_emb(x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array) -> mx.array:
78
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
79
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
80
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
81
+ cos = mx.expand_dims(cos[position_ids], 1) # [bs, 1, seq_len, dim]
82
+ sin = mx.expand_dims(sin[position_ids], 1) # [bs, 1, seq_len, dim]
83
+ x_embed = (x * cos) + (_rotate_half(x) * sin)
84
+ return x_embed
85
+
86
+
87
+ class LinearType(str, enum.Enum):
88
+ Normal = "normal"
89
+ Fp8 = "fp8"
90
+ Fp8Retain = "fp8-retain"
91
+
92
+
93
+ @dataclass
94
+ class ModelArgs(BaseModelArgs): # type: ignore
95
+ model_type: str = "plamo2"
96
+
97
+ def __init__(
98
+ self,
99
+ hidden_size: int = 4096,
100
+ num_hidden_layers: int = 32,
101
+ rms_norm_eps: float = 1e-6,
102
+ tie_word_embeddings: bool = True,
103
+ # Attention
104
+ num_attention_heads: int = 32,
105
+ num_key_value_heads: int = 4,
106
+ hidden_size_per_head: int = 128,
107
+ max_position_embeddings: int = 2048,
108
+ attention_window_size: int = 2048,
109
+ full_attention_idx: list[int] | None = None,
110
+ # Mamba
111
+ mamba_d_state: int = 64,
112
+ mamba_d_conv: int = 4,
113
+ mamba_num_heads: int = 64,
114
+ mamba_step: int = 2,
115
+ mamba_chunk_size: int = 256,
116
+ mamba_enabled: bool = True,
117
+ # MLP
118
+ intermediate_size: int = 13312,
119
+ # Tokenizer
120
+ vocab_size: int = 32000,
121
+ tokenizer_class: str = "PlamoTokenizer",
122
+ pad_token_id: Optional[int] = None,
123
+ bos_token_id: int = 1,
124
+ eos_token_id: int = 2,
125
+ # Multimodal
126
+ image_token_id: Optional[int] = None,
127
+ image_feature_size: Optional[int] = None,
128
+ image_proj_type: Literal["linear", "mlp"] = "linear",
129
+ # FP8
130
+ linear_type: LinearType = LinearType.Normal,
131
+ fp8_accum_dtype: Optional[str] = None,
132
+ # Evaluation
133
+ eval_attention_n_bit: Optional[int] = None,
134
+ eval_mlp_n_bit: Optional[int] = None,
135
+ use_cache: bool = True,
136
+ **kwargs: Any,
137
+ ) -> None:
138
+ # max_position_embeddings is often used to determine the max length during inference,
139
+ # but samba should have extrapolation abilities
140
+ self.max_position_embeddings = max(10 * 1024 * 1024, max_position_embeddings)
141
+ self.hidden_size = hidden_size
142
+ self.rms_norm_eps = rms_norm_eps
143
+
144
+ self.num_hidden_layers = num_hidden_layers
145
+ self.num_attention_heads = num_attention_heads
146
+ self.hidden_size_per_head = hidden_size_per_head
147
+ self.num_key_value_heads = num_key_value_heads
148
+ self.attention_window_size = attention_window_size
149
+ self.full_attention_idx = full_attention_idx if full_attention_idx is not None else []
150
+
151
+ self.mamba_d_state = mamba_d_state
152
+ self.mamba_d_conv = mamba_d_conv
153
+ self.mamba_num_heads = mamba_num_heads
154
+ self.mamba_step = mamba_step
155
+ self.mamba_chunk_size = mamba_chunk_size
156
+ self.mamba_enabled = mamba_enabled
157
+
158
+ self.intermediate_size = intermediate_size
159
+
160
+ self.vocab_size = vocab_size
161
+
162
+ self.image_token_id = image_token_id
163
+ self.image_feature_size = image_feature_size
164
+ self.image_proj_type = image_proj_type
165
+
166
+ self.linear_type = linear_type
167
+ self.fp8_accum_dtype = fp8_accum_dtype
168
+
169
+ self.eval_attention_n_bit = eval_attention_n_bit
170
+ self.eval_mlp_n_bit = eval_mlp_n_bit
171
+ self.use_cache = use_cache
172
+
173
+ # fields for vLLM
174
+ self.sliding_window = attention_window_size
175
+
176
+ self.tokenizer_class = tokenizer_class
177
+ self.pad_token_id = pad_token_id
178
+ self.bos_token_id = bos_token_id
179
+ self.eos_token_id = eos_token_id
180
+ self.tie_word_embeddings = tie_word_embeddings
181
+
182
+ # From PretrainedConfig of transformers
183
+ self.use_return_dict = kwargs.pop("use_return_dict", True)
184
+ self.output_attentions = kwargs.pop("output_attentions", False)
185
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
186
+
187
+
188
+ class PlamoAttentionCache(nn.Module):
189
+ def __init__(self, key: mx.array, value: mx.array) -> None:
190
+ super().__init__()
191
+ B, nh, L, c = key.shape
192
+ assert len(value.shape) == 4
193
+ assert value.shape[0] == B
194
+ assert value.shape[2] == L
195
+ self.key = key
196
+ self.value = value
197
+
198
+
199
+ class PlamoMambaCache(nn.Module):
200
+ def __init__(self, conv_state: mx.array, ssm_state: mx.array) -> None:
201
+ super().__init__()
202
+ # conv_state: [B, C, d_conv]
203
+ # ssm_state: [B, nhead, nchanel_per_head, d_state]
204
+ assert len(conv_state.shape) == 3
205
+ assert len(ssm_state.shape) == 4
206
+ assert conv_state.shape[0] == ssm_state.shape[0]
207
+ self.conv_state = conv_state
208
+ self.ssm_state = ssm_state
209
+
210
+
211
+ PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache
212
+
213
+
214
+ class PlamoCache(nn.Module):
215
+ """
216
+ stores states of the model for fast decoding.
217
+ `transformers` uses `transformers.Cache` for this purpose, but the interface and variable names are
218
+ deeply dependent on Transformers architecture (e.g., `key_states`) and it is difficult to use
219
+ other architectures (e.g., Mamba).
220
+ This class provides a similar interface to `transformers.Cache`, but is designed to also handle
221
+ the state of Mamba properly.
222
+ """
223
+
224
+ def __init__(self, config: ModelArgs) -> None:
225
+ super().__init__()
226
+ self.config = config
227
+ self.cache: list[Optional[PlamoLayerCache]] = [None for _ in range(config.num_hidden_layers)]
228
+
229
+ def append_kv(self, key: mx.array, value: mx.array, layer_idx: int) -> tuple[mx.array, mx.array]:
230
+ c = self.cache[layer_idx]
231
+ if c is None:
232
+ return key, value
233
+ assert isinstance(c, PlamoAttentionCache)
234
+
235
+ def _validate(cache: mx.array, new_tensor: mx.array) -> None:
236
+ assert len(cache.shape) == 4
237
+ assert len(new_tensor.shape) == 4
238
+ assert cache.shape[0] == new_tensor.shape[0]
239
+ assert cache.shape[1] == new_tensor.shape[1]
240
+ assert cache.shape[3] == new_tensor.shape[3]
241
+
242
+ _validate(c.key, key)
243
+ _validate(c.value, value)
244
+ assert key.shape[2] == value.shape[2]
245
+ return mx.concatenate([c.key, key], axis=2), mx.concatenate([c.value, value], axis=2)
246
+
247
+ def update_attention(self, key_states: mx.array, value_states: mx.array, layer_idx: int) -> PlamoAttentionCache:
248
+ full_attn = layer_idx in self.config.full_attention_idx
249
+ window_size = self.config.attention_window_size
250
+
251
+ if self.cache[layer_idx] is None:
252
+ if full_attn:
253
+ self.cache[layer_idx] = PlamoAttentionCache(key_states, value_states)
254
+ else:
255
+ self.cache[layer_idx] = PlamoAttentionCache(
256
+ key_states[:, :, -window_size:, :],
257
+ value_states[:, :, -window_size:, :],
258
+ )
259
+ else:
260
+ c = self.cache[layer_idx]
261
+ assert isinstance(c, PlamoAttentionCache)
262
+ k, v = self.append_kv(key_states, value_states, layer_idx)
263
+ if full_attn:
264
+ c.key = k
265
+ c.value = v
266
+ else:
267
+ c.key = k[:, :, -window_size:, :]
268
+ c.value = v[:, :, -window_size:, :]
269
+ self.cache[layer_idx] = c
270
+ return self.cache[layer_idx] # type: ignore
271
+
272
+ def update_mamba(self, conv_state: mx.array, ssm_state: mx.array, layer_idx: int) -> PlamoMambaCache:
273
+ if self.cache[layer_idx] is None:
274
+ self.cache[layer_idx] = PlamoMambaCache(conv_state, ssm_state)
275
+ else:
276
+ c = self.cache[layer_idx]
277
+ assert isinstance(c, PlamoMambaCache)
278
+ assert c.conv_state.shape == conv_state.shape
279
+ assert c.ssm_state.shape == ssm_state.shape
280
+ c.conv_state = conv_state
281
+ c.ssm_state = ssm_state
282
+ return self.cache[layer_idx] # type: ignore
283
+
284
+ def __getitem__(self, layer_idx: int) -> PlamoLayerCache | None:
285
+ assert layer_idx < len(self.cache)
286
+ layer_cache = self.cache[layer_idx]
287
+ return layer_cache # type: ignore
288
+
289
+ @property
290
+ def state(self):
291
+ return self.cache
292
+
293
+ @state.setter
294
+ def state(self, v):
295
+ self.cache = v
296
+
297
+ def __len__(self) -> int:
298
+ return len(self.cache)
299
+
300
+ def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
301
+ if layer_idx is not None:
302
+ c = self.cache[layer_idx]
303
+ assert isinstance(c, PlamoAttentionCache)
304
+ return c.key.shape[2] # type: ignore
305
+
306
+ sequence_length: int = 0
307
+ for layer_cache in self.cache:
308
+ if isinstance(layer_cache, PlamoAttentionCache):
309
+ sequence_length = (
310
+ max(layer_cache.key.shape[2], sequence_length)
311
+ if sequence_length is not None
312
+ else layer_cache.key.shape[2]
313
+ )
314
+ return sequence_length
315
+
316
+ def get_max_length(self) -> int | None:
317
+ return None
318
+
319
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
320
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
321
+ # Cache without size limit -> all cache is usable
322
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
323
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
324
+ max_length = self.get_max_length()
325
+ previous_seq_length = self.get_seq_length(layer_idx)
326
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
327
+ return max_length - new_seq_length
328
+ return previous_seq_length
329
+
330
+ def reorder_cache(self, beam_idx: mx.array) -> None:
331
+ def _mamba(cache: PlamoMambaCache) -> PlamoMambaCache:
332
+ return PlamoMambaCache(
333
+ conv_state=mx.take(cache.conv_state, beam_idx, axis=0),
334
+ ssm_state=mx.take(cache.ssm_state, beam_idx, axis=0),
335
+ )
336
+
337
+ def _attention(cache: PlamoAttentionCache) -> PlamoAttentionCache:
338
+ return PlamoAttentionCache(
339
+ key=mx.take(cache.key, beam_idx, axis=0),
340
+ value=mx.take(cache.value, beam_idx, axis=0),
341
+ )
342
+
343
+ for i in range(len(self.cache)):
344
+ if self.cache[i] is None:
345
+ continue
346
+ layer_cache = self.cache[i]
347
+ if isinstance(layer_cache, PlamoMambaCache):
348
+ self.cache[i] = _mamba(layer_cache)
349
+ else:
350
+ assert isinstance(layer_cache, PlamoAttentionCache)
351
+ self.cache[i] = _attention(layer_cache)
352
+
353
+ @property
354
+ def seen_tokens(self) -> int | None:
355
+ return None
356
+
357
+
358
+ class DecoderInput(NamedTuple):
359
+ hidden_states: mx.array
360
+ attention_mask: Optional[mx.array] = None
361
+ past_states: Optional[PlamoCache] = None
362
+ output_hidden_states: Optional[bool] = False
363
+ output_attentions: Optional[bool] = False
364
+ gradient_checkpointing: bool = False
365
+ input_ids: Optional[mx.array] = None
366
+
367
+
368
+ class DecoderOutput(NamedTuple):
369
+ hidden_states: mx.array
370
+ all_hidden_states: Optional[tuple[mx.array, ...]]
371
+ all_self_attns: Optional[tuple[mx.array, ...]]
372
+
373
+
374
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
375
+ def _make_causal_mask(input_ids_shape: tuple[int, int], dtype: mx.Dtype, past_key_values_length: int = 0) -> mx.array:
376
+ """
377
+ Make causal mask used for bi-directional self-attention.
378
+ """
379
+ bsz, tgt_len = input_ids_shape
380
+ mask = mx.full((tgt_len, tgt_len), float("-inf"))
381
+ mask_cond = mx.arange(mask.shape[-1])
382
+ mask = mx.where(mask_cond < (mask_cond + 1).reshape((mask.shape[-1], 1)), 0, mask)
383
+ mask = mask.astype(dtype)
384
+
385
+ if past_key_values_length > 0:
386
+ mask = mx.concatenate([mx.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1)
387
+ return mx.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
388
+
389
+
390
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
391
+ def _expand_mask(mask: mx.array, dtype: mx.Dtype, tgt_len: Optional[int] = None) -> mx.array:
392
+ """
393
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
394
+ """
395
+ bsz, src_len = mask.shape
396
+ tgt_len = tgt_len if tgt_len is not None else src_len
397
+
398
+ expanded_mask = mx.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)).astype(dtype)
399
+
400
+ inverted_mask = 1.0 - expanded_mask
401
+
402
+ return mx.where(inverted_mask.astype(mx.bool_), float("-inf"), inverted_mask) # type: ignore
403
+
404
+
405
+ def _rms_norm(hidden_states: mx.array, weight: Optional[mx.array], eps: float, offset: float = 1.0) -> mx.array:
406
+ input_dtype = hidden_states.dtype
407
+ hidden_states = hidden_states.astype(mx.float32)
408
+ variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
409
+ hidden_states = hidden_states * mx.rsqrt(variance + eps)
410
+ hidden_states = hidden_states.astype(input_dtype)
411
+ if weight is not None:
412
+ hidden_states = (offset + weight) * hidden_states
413
+ return hidden_states
414
+
415
+
416
+ class RMSNorm(nn.Module):
417
+ def __init__(
418
+ self,
419
+ hidden_size: int,
420
+ eps: float = 1e-6,
421
+ offset: float = 1.0,
422
+ ) -> None:
423
+ super().__init__()
424
+ self.weight = mx.zeros(hidden_size)
425
+ self.variance_epsilon = eps
426
+ self.offset = offset
427
+
428
+ def __call__(self, hidden_states: mx.array) -> mx.array:
429
+ return _rms_norm(hidden_states, self.weight, self.variance_epsilon, offset=self.offset)
430
+
431
+
432
+ def get_initial_dt_bias(num_heads: int) -> mx.array:
433
+ dt_min = 0.001
434
+ dt_max = 0.1
435
+ dt = mx.exp(mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
436
+ dt = mx.clip(dt, a_min=1e-4, a_max=None)
437
+ inv_dt = dt + mx.log(-mx.expm1(-dt))
438
+ return inv_dt
439
+
440
+
441
+ def get_initial_A(num_heads: int) -> mx.array:
442
+ A = mx.arange(1, num_heads + 1, dtype=mx.float32)
443
+ return mx.log(A)
444
+
445
+
446
+ def selective_state_update_ref(
447
+ state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
448
+ ) -> tuple[mx.array, mx.array]:
449
+ """
450
+ Argument:
451
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
452
+ x: (batch, dim) or (batch, nheads, dim)
453
+ dt: (batch, dim) or (batch, nheads, dim)
454
+ A: (dim, dstate) or (nheads, dim, dstate)
455
+ B: (batch, dstate) or (batch, ngroups, dstate)
456
+ C: (batch, dstate) or (batch, ngroups, dstate)
457
+ D: (dim,) or (nheads, dim)
458
+ z: (batch, dim) or (batch, nheads, dim)
459
+ dt_bias: (dim,) or (nheads, dim)
460
+ Return:
461
+ out: (batch, dim) or (batch, nheads, dim)
462
+ """
463
+ has_heads = state.ndim > 3
464
+ if state.ndim == 3:
465
+ state = mx.expand_dims(state, 1)
466
+ if x.ndim == 2:
467
+ x = mx.expand_dims(x, 1)
468
+ if dt.ndim == 2:
469
+ dt = mx.expand_dims(dt, 1)
470
+ if A.ndim == 2:
471
+ A = mx.expand_dims(A, 0)
472
+ if B.ndim == 2:
473
+ B = mx.expand_dims(B, 1)
474
+ if C.ndim == 2:
475
+ C = mx.expand_dims(C, 1)
476
+ if D is not None and D.ndim == 1:
477
+ D = mx.expand_dims(D, 0)
478
+ if z is not None and z.ndim == 2:
479
+ z = mx.expand_dims(z, 1)
480
+ if dt_bias is not None and dt_bias.ndim == 1:
481
+ dt_bias = mx.expand_dims(dt_bias, 0)
482
+ batch, nheads, dim, dstate = state.shape
483
+ assert x.shape == (batch, nheads, dim)
484
+ assert dt.shape == x.shape
485
+ assert A.shape == (nheads, dim, dstate)
486
+ ngroups = B.shape[1]
487
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
488
+ assert B.shape == (batch, ngroups, dstate)
489
+ assert C.shape == B.shape
490
+ if D is not None:
491
+ assert D.shape == (nheads, dim)
492
+ if z is not None:
493
+ assert z.shape == x.shape
494
+ if dt_bias is not None:
495
+ assert dt_bias.shape == (nheads, dim)
496
+ dt = dt + dt_bias
497
+ dt = nn.softplus(dt) if dt_softplus else dt
498
+ dA = mx.exp(mx.expand_dims(dt, axis=-1) * A) # (batch, nheads, dim, dstate)
499
+ B = mx.reshape(
500
+ mx.tile(mx.expand_dims(B, axis=2), (1, 1, nheads // ngroups, 1)),
501
+ (batch, nheads, dstate),
502
+ ) # (batch, nheads, dstate)
503
+ C = mx.reshape(
504
+ mx.tile(mx.expand_dims(C, axis=2), (1, 1, nheads // ngroups, 1)),
505
+ (batch, nheads, dstate),
506
+ ) # (batch, nheads, dstate)
507
+ dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(B, axis=-2) # (batch, nheads, dim, dstate)
508
+ state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate
509
+ out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C)
510
+ if D is not None:
511
+ out += (x * D).astype(out.dtype)
512
+ out = (out if z is None else out * nn.silu(z)).astype(x.dtype)
513
+ if not has_heads:
514
+ out = out.squeeze(1)
515
+ return out, state
516
+
517
+
518
+ def ssd_update_state(
519
+ ssm_state: mx.array,
520
+ x: mx.array,
521
+ dt: mx.array,
522
+ A: mx.array,
523
+ B: mx.array,
524
+ C: mx.array,
525
+ D: mx.array,
526
+ z: mx.array,
527
+ dt_bias: mx.array,
528
+ dt_softplus: bool,
529
+ ) -> tuple[mx.array, mx.array]:
530
+ assert ssm_state.dtype == mx.float32
531
+ dtype = x.dtype
532
+
533
+ hidden_size_per_head = x.shape[-1]
534
+ d_state = B.shape[-1]
535
+ A = mx.broadcast_to(A[:, None, None], (A.shape[0], hidden_size_per_head, d_state)).astype(mx.float32)
536
+ dt = mx.broadcast_to(dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head))
537
+ dt_bias = mx.broadcast_to(dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head))
538
+ D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head))
539
+ out, ssm_state = selective_state_update_ref(
540
+ ssm_state,
541
+ x.astype(dtype),
542
+ dt.astype(dtype),
543
+ A.astype(mx.float32),
544
+ B.astype(dtype),
545
+ C.astype(dtype),
546
+ D.astype(mx.float32),
547
+ z.astype(dtype),
548
+ dt_bias.astype(mx.float32),
549
+ dt_softplus=dt_softplus,
550
+ )
551
+ return out[:, None], ssm_state
552
+
553
+
554
+ def _ssd_chunk_scan_combined_naive(
555
+ x: mx.array,
556
+ dt: mx.array,
557
+ A: mx.array,
558
+ B: mx.array,
559
+ C: mx.array,
560
+ D: mx.array,
561
+ z: mx.array,
562
+ dt_bias: mx.array,
563
+ dt_softplus: bool,
564
+ seq_idx: mx.array | None,
565
+ ssm_state: mx.array,
566
+ ) -> tuple[mx.array, mx.array]:
567
+ assert ssm_state.dtype == mx.float32
568
+ length = x.shape[1]
569
+ ys = []
570
+ for i in range(length):
571
+ if i != 0 and seq_idx is not None:
572
+ ssm_state = mx.where(
573
+ mx.array(seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None, None],
574
+ mx.zeros_like(ssm_state),
575
+ ssm_state,
576
+ )
577
+ y, ssm_state = ssd_update_state(
578
+ ssm_state,
579
+ x[:, i],
580
+ dt[:, i],
581
+ A,
582
+ B[:, i],
583
+ C[:, i],
584
+ D if D.ndim == 1 else D[:, i],
585
+ z=z[:, i],
586
+ dt_bias=dt_bias,
587
+ dt_softplus=dt_softplus,
588
+ )
589
+ ys.append(y)
590
+ return mx.concatenate(ys, axis=1), ssm_state
591
+
592
+
593
+ def ssd_chunk_scan_combined(
594
+ x: mx.array,
595
+ dt: mx.array,
596
+ A: mx.array,
597
+ B: mx.array,
598
+ C: mx.array,
599
+ chunk_size: int,
600
+ D: mx.array,
601
+ z: mx.array,
602
+ dt_bias: mx.array,
603
+ dt_softplus: bool,
604
+ return_final_states: bool,
605
+ seq_idx: mx.array | None,
606
+ ssm_state: mx.array | None,
607
+ ) -> tuple[mx.array, mx.array] | mx.array:
608
+ if seq_idx is not None:
609
+ assert seq_idx.dtype == mx.int32
610
+ assert ssm_state is None
611
+ assert not return_final_states
612
+ if ssm_state is not None:
613
+ assert ssm_state.dtype == mx.float32
614
+ assert seq_idx is None
615
+ """
616
+ state will be updates by following:
617
+ ```
618
+ dt = softplus(dt)
619
+ dA = exp(dt * A)
620
+ state_next = state * dA + dB * x
621
+ ```
622
+ To avoid updating state, we set dt to -inf and x to 0
623
+ because `softplus(-inf) = 0` and `exp(0) = 1`
624
+ """
625
+ if ssm_state is None:
626
+ bsize, _, num_heads, channel = x.shape
627
+ state = B.shape[-1]
628
+ ssm_state = mx.zeros((bsize, num_heads, channel, state), dtype=mx.float32)
629
+ tmp, ssm_state = _ssd_chunk_scan_combined_naive(
630
+ x,
631
+ dt,
632
+ A,
633
+ B,
634
+ C,
635
+ D,
636
+ z=z,
637
+ dt_bias=dt_bias,
638
+ dt_softplus=dt_softplus,
639
+ seq_idx=seq_idx,
640
+ ssm_state=ssm_state,
641
+ )
642
+ if return_final_states:
643
+ return tmp, ssm_state
644
+ else:
645
+ return tmp
646
+
647
+
648
+ def _causal_conv1d(
649
+ conv_state: mx.array | None, weight: mx.array, x: mx.array, seq_idx: mx.array | None
650
+ ) -> tuple[mx.array, mx.array | None]:
651
+ dtype = x.dtype
652
+ if conv_state is not None:
653
+ dtype = conv_state.dtype
654
+ assert seq_idx is None
655
+ if seq_idx is not None:
656
+ assert seq_idx.dtype == mx.int32
657
+ assert conv_state is None
658
+ weight = weight.astype(dtype)
659
+ x = x.astype(dtype)
660
+
661
+ return_final_states = conv_state is not None
662
+ if conv_state is None:
663
+ bsize = x.shape[0]
664
+ dim = weight.shape[0]
665
+ d_conv = weight.shape[-1]
666
+ conv_state = mx.zeros((bsize, dim, d_conv - 1), dtype=x.dtype)
667
+ length = x.shape[-1]
668
+ out = mx.zeros_like(x)
669
+ for i in range(length):
670
+ if i != 0 and seq_idx is not None:
671
+ conv_state = mx.where(
672
+ seq_idx[:, i - 1][:, None, None] != seq_idx[:, i][:, None, None],
673
+ mx.zeros_like(conv_state),
674
+ conv_state,
675
+ )
676
+ out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1])
677
+ x = out
678
+ if return_final_states:
679
+ return x, conv_state
680
+ else:
681
+ return x, None
682
+
683
+
684
+ def causal_conv1d_update(
685
+ x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
686
+ ) -> tuple[mx.array, mx.array]:
687
+ """
688
+ x: (batch, dim) or (batch, dim, seqlen)
689
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
690
+ weight: (dim, width)
691
+ bias: (dim,)
692
+ cache_seqlens: (batch,), dtype int32.
693
+ If not None, the conv_state is treated as a circular buffer.
694
+ The conv_state will be updated by copying x to the conv_state starting at the index
695
+ @cache_seqlens % state_len before performing the convolution.
696
+
697
+ out: (batch, dim) or (batch, dim, seqlen)
698
+ """
699
+ if activation not in [None, "silu", "swish"]:
700
+ raise NotImplementedError("activation must be None, silu, or swish")
701
+ dtype_in = x.dtype
702
+ unsqueeze = x.ndim == 2
703
+ if unsqueeze:
704
+ x = x.unsqueeze(-1)
705
+ batch, dim, seqlen = x.shape
706
+ width = weight.shape[1]
707
+ state_len = conv_state.shape[-1]
708
+ assert conv_state.shape == (batch, dim, state_len)
709
+ assert weight.shape == (dim, width)
710
+ if cache_seqlens is None:
711
+ x_new = mx.concatenate([conv_state, x], axis=-1).astype(weight.dtype) # (batch, dim, state_len + seqlen)
712
+ conv_state = x_new[:, :, -state_len:]
713
+ else:
714
+ width_idx = mx.expand_dims(mx.arange(-(width - 1), 0, dtype=mx.int64), axis=0) + mx.expand_dims(
715
+ cache_seqlens, axis=1
716
+ )
717
+ width_idx = mx.expand_dims(mx.remainder(width_idx, state_len), axis=1)
718
+ width_idx = mx.broadcast_to(width_idx, (width_idx.shape[0], dim, width_idx.shape[2]))
719
+ x_new = mx.concatenate([conv_state.gather(2, width_idx), x], axis=-1)
720
+ x_new = x_new.astype(weight.dtype)
721
+ copy_idx = mx.expand_dims(mx.arange(seqlen, dtype=mx.int64), axis=0) + mx.expand_dims(cache_seqlens, axis=1)
722
+ copy_idx = mx.expand_dims(mx.remainder(copy_idx, state_len), axis=1)
723
+ copy_idx = mx.broadcast_to(copy_idx, (copy_idx.shape[0], dim, copy_idx.shape[2]))
724
+ conv_state.scatter_(2, copy_idx, x)
725
+ assert bias is None
726
+ # x_new: (N, C, L) -> (N, L, C)
727
+ out = mx.conv1d(
728
+ x_new.transpose(0, 2, 1),
729
+ mx.expand_dims(weight, axis=2),
730
+ padding=0,
731
+ groups=dim,
732
+ ).transpose(0, 2, 1)[:, :, -seqlen:]
733
+ if unsqueeze:
734
+ out = out.squeeze(-1)
735
+ return (out if activation is None else nn.silu(out)).astype(dtype_in), conv_state
736
+
737
+
738
+ def _causal_conv1d_update(conv_state: mx.array, weight: mx.array, xBC: mx.array) -> tuple[mx.array, mx.array]:
739
+ dtype = conv_state.dtype
740
+ xBC = xBC.astype(dtype)
741
+ weight = weight.astype(dtype)
742
+
743
+ x, conv_state = causal_conv1d_update(
744
+ x=xBC,
745
+ conv_state=conv_state,
746
+ weight=weight[:, :, 0],
747
+ activation="silu",
748
+ )
749
+ return x, conv_state
750
+
751
+
752
+ # Based on: https://github.com/Dao-AILab/causal-conv1d/blob/82867a9d2e6907cc0f637ac6aff318f696838548/causal_conv1d/causal_conv1d_interface.py#L206
753
+ def causal_conv1d(x, weight, bias=None, activation=None):
754
+ """
755
+ MLX implementation of a causal depthwise 1D convolution.
756
+ Args:
757
+ x (mx.array): Input tensor of shape (batch, channels, seq_len).
758
+ weight (mx.array): Convolution filters of shape (channels, kernel_width).
759
+ Each channel has its own filter (depthwise conv).
760
+ bias (mx.array, optional): Bias for each channel of shape (channels,).
761
+ activation (str, optional): Activation to apply ("silu" or "swish" supported).
762
+ Returns:
763
+ mx.array: Output tensor of shape (batch, channels, seq_len).
764
+ """
765
+ x = mx.array(x) if not isinstance(x, mx.array) else x
766
+ weight = mx.array(weight) if not isinstance(weight, mx.array) else weight
767
+ if bias is not None:
768
+ bias = mx.array(bias) if not isinstance(bias, mx.array) else bias
769
+
770
+ batch, channels, seq_len = x.shape
771
+ _, kernel_width = weight.shape # weight shape: (channels, kernel_width)
772
+
773
+ # Reshape weight for depthwise conv: (out_channels, in_channels/groups, kernel_width)
774
+ # Here out_channels = channels, in_channels/groups = 1 (depthwise conv per channel)
775
+ w = weight.reshape((channels, 1, kernel_width))
776
+
777
+ # Pad input on the left with (kernel_width-1) zeros for causal convolution
778
+ if kernel_width > 1:
779
+ pad_shape = (batch, channels, kernel_width - 1)
780
+ pad_zeros = mx.zeros(pad_shape, dtype=x.dtype)
781
+ x_padded = mx.concatenate([pad_zeros, x], axis=2) # concat along time axis
782
+ else:
783
+ x_padded = x
784
+
785
+ # Perform depthwise convolution. Padding is already applied manually, so use padding=0 in conv1d.
786
+ y = mx.conv1d(x_padded, w, stride=1, padding=0, groups=channels)
787
+ # After convolution, y shape = (batch, channels, seq_len) because:
788
+ # input length = seq_len + kernel_width - 1, no padding in conv, so output length = seq_len.
789
+
790
+ # Add bias if provided (bias shape (channels,) broadcasts to (batch, channels, seq_len))
791
+ if bias is not None:
792
+ y = y + bias.reshape((1, channels, 1))
793
+
794
+ # Apply activation if specified
795
+ if activation in ("silu", "swish"):
796
+ # SiLU (swish) activation: y * sigmoid(y)
797
+ y = y * mx.sigmoid(y)
798
+ elif activation is not None:
799
+ raise ValueError(f"Unsupported activation: {activation}")
800
+
801
+ return y
802
+
803
+
804
+ class Mamba(nn.Module):
805
+ def __init__(self, config: ModelArgs, layer_idx: int) -> None:
806
+ super().__init__()
807
+ self.config = config
808
+ self.layer_idx = layer_idx
809
+ self.hidden_size = config.hidden_size
810
+ self.d_state = config.mamba_d_state
811
+ self.d_conv = config.mamba_d_conv
812
+ self.chunk_size = config.mamba_chunk_size
813
+ self.num_heads = config.mamba_num_heads
814
+ # TODO add mamba_hidden_size_per_head config (?)
815
+ self.hidden_size_per_head = config.hidden_size_per_head
816
+
817
+ self.intermediate_size = self.num_heads * self.hidden_size_per_head
818
+
819
+ self.in_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
820
+ self.conv1d = nn.Conv1d(
821
+ in_channels=self.intermediate_size,
822
+ out_channels=self.intermediate_size,
823
+ bias=False, # TODO the original implementation uses bias
824
+ kernel_size=self.d_conv,
825
+ groups=self.intermediate_size,
826
+ padding=0,
827
+ )
828
+ self.dt_dim = max(64, self.hidden_size // 16)
829
+ # Notes:
830
+ # Mamba2 removes this linear projection for simplicity (Figure 6 in the paper),
831
+ # but it may degrade the ability of content-length extrapolation.
832
+ self.bcdt_proj = nn.Linear(
833
+ self.intermediate_size,
834
+ self.dt_dim + 2 * self.d_state,
835
+ bias=False,
836
+ )
837
+ self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False)
838
+
839
+ self.dt_bias = get_initial_dt_bias(self.num_heads)
840
+ self.A_log = get_initial_A(self.num_heads)
841
+ self.D = mx.ones(self.num_heads, dtype=mx.float32)
842
+
843
+ # TODO norm weight before gating like Mamba2
844
+ self.dt_norm_weight = mx.ones(self.dt_dim)
845
+ self.B_norm_weight = mx.ones(self.d_state)
846
+ self.C_norm_weight = mx.ones(self.d_state)
847
+
848
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
849
+
850
+ def _no_weight_decay_param_names(self) -> set[str]:
851
+ return set(["D", "dt_bias", "A_log"])
852
+
853
+ def __call__(
854
+ self,
855
+ hidden_states: mx.array,
856
+ attention_mask: Optional[mx.array] = None,
857
+ past_states: Optional[PlamoCache] = None,
858
+ ) -> tuple[mx.array, Optional[PlamoCache]]:
859
+ bsize, length, _ = hidden_states.shape
860
+ is_update = length == 1 and past_states is not None
861
+
862
+ bool_mask: mx.array | None = None
863
+ seq_idx: mx.array | None = None
864
+ if attention_mask is not None:
865
+ if len(attention_mask.shape) == 2:
866
+ attention_mask = mx.broadcast_to(
867
+ attention_mask[None, None],
868
+ (bsize, 1, attention_mask.shape[0], attention_mask.shape[1]),
869
+ )
870
+ assert len(attention_mask.shape) == 4
871
+
872
+ if past_states is None:
873
+ # TODO: support seq_idx with cache
874
+ bool_mask_4d = mx.array(attention_mask == 0, dtype=mx.bool_) # type: ignore
875
+ is_first_token = _is_first_token(bool_mask_4d)[:, 0, :]
876
+ seq_idx = mx.cumsum(is_first_token, axis=-1) - 1
877
+ seq_idx = seq_idx.astype(mx.int32)
878
+
879
+ # `generate` function creates attention mask that contains past tokens,
880
+ # but mamba does not use them
881
+ attention_mask = attention_mask[:, 0, -length:, -length:]
882
+ bool_mask = mx.array(mx.diagonal(attention_mask, axis1=-2, axis2=-1) == 0)
883
+
884
+ conv_state: mx.array | None
885
+ ssm_state: mx.array | None
886
+ if past_states is None:
887
+ conv_state = None
888
+ ssm_state = None
889
+ elif past_states[self.layer_idx] is None:
890
+ conv_state = mx.zeros(
891
+ (bsize, self.intermediate_size, self.d_conv - 1),
892
+ dtype=hidden_states.dtype,
893
+ )
894
+ ssm_state = mx.zeros(
895
+ (bsize, self.num_heads, self.hidden_size_per_head, self.d_state),
896
+ dtype=mx.float32,
897
+ )
898
+ else:
899
+ c = past_states[self.layer_idx]
900
+ assert isinstance(c, PlamoMambaCache)
901
+ conv_state = c.conv_state
902
+ ssm_state = c.ssm_state
903
+
904
+ zx = self.in_proj(hidden_states)
905
+ zx = zx.reshape(bsize, length, self.num_heads, -1)
906
+ # z: (bsize, length, num_heads, hidden_size_per_head)
907
+ # x: (bsize, length, num_heads, hidden_size_per_head)
908
+ z, x = mx.split(
909
+ zx,
910
+ [
911
+ self.hidden_size_per_head,
912
+ ],
913
+ axis=-1,
914
+ )
915
+
916
+ # conv
917
+ x = x.reshape(bsize, length, -1).transpose(0, 2, 1) # (bsize, intermediate_size, length)
918
+ if bool_mask is not None:
919
+ x = mx.where(bool_mask[:, None, :], x, 0.0)
920
+ if is_update:
921
+ assert conv_state is not None
922
+ x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
923
+ else:
924
+ x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx)
925
+ x = x.astype(hidden_states.dtype)
926
+ x = x.transpose(0, 2, 1) # (bsize, length, intermediate_size)
927
+ x = x.reshape(bsize, length, -1)
928
+ # x: (bsize, length, num_heads, hidden_size_per_head)
929
+ # B: (bsize, length, 1, d_state)
930
+ # C: (bsize, length, 1, d_state)
931
+ # dt: (bsize, length, dt_dim)
932
+ BCdt = self.bcdt_proj(x)
933
+ x = x.reshape(bsize, length, self.num_heads, -1)
934
+ B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1)
935
+ B = B[:, :, None, :]
936
+ C = C[:, :, None, :]
937
+
938
+ A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
939
+ dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :]
940
+ B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :]
941
+ C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :]
942
+
943
+ # (bsize, length, num_heads, 1)
944
+ dt = self.dt_proj(dt)[..., None]
945
+
946
+ # TODO it may not be required
947
+ B = mx.broadcast_to(B, (B.shape[0], B.shape[1], self.num_heads, B.shape[3]))
948
+ C = mx.broadcast_to(C, (C.shape[0], C.shape[1], self.num_heads, C.shape[3]))
949
+
950
+ if bool_mask is not None:
951
+ """
952
+ state will be updates by following:
953
+ ```
954
+ dt = softplus(dt)
955
+ dA = exp(dt * A)
956
+ state_next = state * dA + dB * x
957
+ ```
958
+ To avoid updating state, we set dt to -inf and x to 0
959
+ because `softplus(-inf) = 0` and `exp(0) = 1`
960
+ """
961
+ dt = mx.where(bool_mask[:, :, None, None], dt, float("-inf"))
962
+ x = mx.where(bool_mask[:, :, None, None], x, 0.0)
963
+
964
+ # ssm
965
+ if is_update:
966
+ assert ssm_state is not None
967
+ out, ssm_state = ssd_update_state(
968
+ ssm_state,
969
+ x[:, 0],
970
+ dt[:, 0].reshape(bsize, -1),
971
+ A,
972
+ B[:, 0],
973
+ C[:, 0],
974
+ D=self.D,
975
+ z=z[:, 0],
976
+ dt_bias=self.dt_bias,
977
+ dt_softplus=True,
978
+ )
979
+ else:
980
+ tmp = ssd_chunk_scan_combined(
981
+ x,
982
+ dt.reshape(bsize, length, -1),
983
+ A,
984
+ B,
985
+ C,
986
+ self.chunk_size,
987
+ D=self.D,
988
+ z=z,
989
+ dt_bias=self.dt_bias,
990
+ dt_softplus=True,
991
+ return_final_states=past_states is not None,
992
+ seq_idx=seq_idx,
993
+ ssm_state=ssm_state,
994
+ )
995
+ if past_states is not None:
996
+ out, ssm_state = tmp
997
+ else:
998
+ assert isinstance(tmp, mx.array)
999
+ out = tmp
1000
+
1001
+ y = self.out_proj(out.reshape(bsize, length, -1))
1002
+
1003
+ if past_states is not None:
1004
+ assert ssm_state is not None
1005
+ assert conv_state is not None
1006
+ past_states.update_mamba(conv_state, ssm_state, self.layer_idx)
1007
+
1008
+ return y, past_states
1009
+
1010
+
1011
+ def swa_mask(q_len: int, kv_len: int, window_size: int) -> mx.array:
1012
+ max_len = max(q_len, kv_len)
1013
+ mask = mx.tril(
1014
+ mx.triu(mx.ones((max_len, max_len), dtype=mx.bool_), k=-window_size), # type: ignore
1015
+ k=window_size,
1016
+ )
1017
+ return mask[-q_len:, -kv_len:]
1018
+
1019
+
1020
+ class Attention(nn.Module):
1021
+ def __init__(self, config: ModelArgs, layer_idx: int) -> None:
1022
+ super().__init__()
1023
+ self.config = config
1024
+ self.layer_idx = layer_idx
1025
+ self.hidden_size = config.hidden_size
1026
+ head_dim = config.hidden_size_per_head
1027
+ self.max_position_embeddings = config.max_position_embeddings
1028
+ self.scale = head_dim**-0.5
1029
+
1030
+ self.q_num_heads = config.num_attention_heads
1031
+ self.qk_dim = self.v_dim = head_dim
1032
+ self.k_num_heads = self.v_num_heads = config.num_key_value_heads
1033
+ assert self.q_num_heads % self.k_num_heads == 0
1034
+ self.n_group = self.q_num_heads // self.k_num_heads
1035
+
1036
+ self.q_proj_dim = self.q_num_heads * self.qk_dim
1037
+ self.k_proj_dim = self.k_num_heads * self.qk_dim
1038
+ self.v_proj_dim = self.k_num_heads * self.v_dim
1039
+ self.qkv_proj = nn.Linear(
1040
+ self.hidden_size,
1041
+ self.q_proj_dim + self.k_proj_dim + self.v_proj_dim,
1042
+ bias=False,
1043
+ )
1044
+ self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False)
1045
+
1046
+ self.q_weight = mx.ones((self.q_num_heads, self.qk_dim))
1047
+ self.k_weight = mx.ones((self.k_num_heads, self.qk_dim))
1048
+
1049
+ self.rotary_emb = RotaryEmbedding(self.qk_dim, max_position_embeddings=self.config.attention_window_size)
1050
+
1051
+ def __call__(
1052
+ self,
1053
+ hidden_states: mx.array,
1054
+ attention_mask: Optional[mx.array] = None,
1055
+ past_states: Optional[PlamoCache] = None,
1056
+ output_attentions: bool = False,
1057
+ ) -> tuple[mx.array, Optional[mx.array], Optional[PlamoCache]]:
1058
+ bsz, q_len, _ = hidden_states.shape
1059
+
1060
+ qkv = self.qkv_proj(hidden_states)
1061
+ query_states, key_states, value_states = mx.split(
1062
+ qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1
1063
+ )
1064
+ query_states = query_states.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
1065
+ key_states = key_states.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
1066
+ value_states = value_states.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
1067
+
1068
+ attn_dtype = query_states.dtype
1069
+
1070
+ query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
1071
+ key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
1072
+
1073
+ if past_states is not None:
1074
+ # reuse k, v, self_attention
1075
+ key_states_new = key_states
1076
+ value_states_new = value_states
1077
+ key_states, value_states = past_states.append_kv(key_states, value_states, self.layer_idx) # type: ignore
1078
+ past_states.update_attention(key_states_new, value_states_new, self.layer_idx)
1079
+
1080
+ kv_seq_len = key_states.shape[-2]
1081
+ position_ids = mx.arange(kv_seq_len, dtype=mx.int64)[None]
1082
+ q_position_ids = position_ids[:, -query_states.shape[2] :]
1083
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1084
+ query_states = _rotary_pos_emb(query_states, cos, sin, q_position_ids)
1085
+ key_states = _rotary_pos_emb(key_states, cos, sin, position_ids)
1086
+ # [bsz, nh, t, hd]
1087
+
1088
+ # expand shared kv
1089
+ assert self.k_num_heads == self.v_num_heads
1090
+ key_states = mx.tile(key_states, (1, self.n_group, 1, 1))
1091
+ value_states = mx.tile(value_states, (1, self.n_group, 1, 1))
1092
+
1093
+ full_attn = self.layer_idx in self.config.full_attention_idx
1094
+
1095
+ query_states = query_states.astype(attn_dtype)
1096
+ key_states = key_states.astype(attn_dtype)
1097
+ value_states = value_states.astype(attn_dtype)
1098
+ if attention_mask is not None and attention_mask.dtype != bool:
1099
+ attention_mask = attention_mask.astype(attn_dtype)
1100
+ if attention_mask is None:
1101
+ if not full_attn:
1102
+ assert key_states.shape[2] <= self.config.attention_window_size + 1
1103
+ mask = create_attention_mask(hidden_states)
1104
+ attn_output = mx.fast.scaled_dot_product_attention(
1105
+ query_states,
1106
+ key_states,
1107
+ value_states,
1108
+ scale=self.scale,
1109
+ mask=mask,
1110
+ )
1111
+ else:
1112
+ if attention_mask.dtype == bool:
1113
+ attention_mask = mx.where(attention_mask, mx.array(0.0, dtype=mx.float16), float("-inf"))
1114
+ if len(attention_mask.shape) == 2:
1115
+ attention_mask = attention_mask[None, None]
1116
+ assert len(attention_mask.shape) == 4
1117
+
1118
+ if not full_attn:
1119
+ m_swa = swa_mask(
1120
+ query_states.shape[2],
1121
+ key_states.shape[2],
1122
+ self.config.attention_window_size,
1123
+ )
1124
+ # `generate` function creates attention mask that does not consider sliding window
1125
+ m_swa = m_swa[None, None]
1126
+ attention_mask = attention_mask[:, :, -query_states.shape[2] :, -key_states.shape[2] :]
1127
+ attention_mask = mx.where(m_swa, attention_mask, float("-inf"))
1128
+
1129
+ # like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers,
1130
+ # we need to attend to all tokens in masked rows for `scaled_dot_product_attention`
1131
+ bool_mask = mx.logical_not(mx.isneginf(attention_mask))
1132
+ valid_tokens = mx.sum(bool_mask, axis=-1).astype(mx.bool_) # type: ignore # (..., q_len)
1133
+ attention_mask = mx.where(valid_tokens[..., None], attention_mask, float(0.0))
1134
+ attn_output = mx.fast.scaled_dot_product_attention(
1135
+ query_states,
1136
+ key_states,
1137
+ value_states,
1138
+ scale=self.scale,
1139
+ mask=attention_mask,
1140
+ )
1141
+
1142
+ attn_output = attn_output.transpose(0, 2, 1, 3)
1143
+
1144
+ attn_output = attn_output.reshape(bsz, q_len, self.q_num_heads * self.v_dim)
1145
+ attn_output = self.o_proj(attn_output)
1146
+
1147
+ if not output_attentions:
1148
+ attn_weights = None
1149
+
1150
+ return attn_output, attn_weights, past_states
1151
+
1152
+
1153
+ class MLP(nn.Module):
1154
+ def __init__(self, config: ModelArgs) -> None:
1155
+ super().__init__()
1156
+ self.config = config
1157
+ self.hidden_size = config.hidden_size
1158
+ self.intermediate_size = config.intermediate_size
1159
+ self.gate_up_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
1160
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
1161
+
1162
+ def __call__(self, x: mx.array) -> mx.array:
1163
+ h = self.gate_up_proj(x)
1164
+ h = _swiglu(h)
1165
+ return self.down_proj(h) # type: ignore
1166
+
1167
+
1168
+ class PlamoDecoderLayer(nn.Module):
1169
+ def __init__(self, config: ModelArgs, is_mamba: bool, layer_idx: int) -> None:
1170
+ super().__init__()
1171
+ self.config = config
1172
+ self.hidden_size = config.hidden_size
1173
+ self.is_mamba = is_mamba
1174
+ self.mixer: nn.Module
1175
+ if is_mamba:
1176
+ self.mixer = Mamba(config, layer_idx)
1177
+ else:
1178
+ self.mixer = Attention(config, layer_idx)
1179
+ self.mlp = MLP(config)
1180
+ """
1181
+ Notes: The model performance was degraded when setting all offsets to 1.
1182
+ """
1183
+ self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
1184
+ self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5)
1185
+ self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
1186
+ self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5))
1187
+
1188
+ def __call__(
1189
+ self,
1190
+ hidden_states: mx.array,
1191
+ attention_mask: Optional[mx.array] = None,
1192
+ past_state: Optional[PlamoCache] = None,
1193
+ output_attentions: Optional[bool] = False,
1194
+ ) -> tuple[Any, ...]:
1195
+ # from LlamaDecoder
1196
+ residual = hidden_states
1197
+ hidden_states = self.pre_mixer_norm(hidden_states)
1198
+
1199
+ # Self Attention
1200
+ if self.is_mamba:
1201
+ hidden_states_sa, present_key_value = self.mixer(
1202
+ hidden_states=hidden_states,
1203
+ attention_mask=attention_mask,
1204
+ past_states=past_state,
1205
+ )
1206
+ self_attn_weights = None
1207
+ else:
1208
+ hidden_states_sa, self_attn_weights, present_key_value = self.mixer(
1209
+ hidden_states=hidden_states,
1210
+ attention_mask=attention_mask,
1211
+ past_states=past_state,
1212
+ output_attentions=output_attentions,
1213
+ )
1214
+
1215
+ hidden_states_sa = self.post_mixer_norm(hidden_states_sa)
1216
+ hidden_states = residual + hidden_states_sa
1217
+
1218
+ residual = hidden_states
1219
+ hidden_states = self.pre_mlp_norm(hidden_states)
1220
+
1221
+ # Fully Connected
1222
+ hidden_states_mlp = self.mlp(hidden_states)
1223
+
1224
+ # Residual
1225
+ hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp)
1226
+ hidden_states = residual + hidden_states_mlp
1227
+
1228
+ outputs: Any = (hidden_states,)
1229
+
1230
+ if output_attentions:
1231
+ outputs += (self_attn_weights,)
1232
+
1233
+ return outputs # type: ignore
1234
+
1235
+
1236
+ def is_mamba(config: ModelArgs, i: int) -> bool:
1237
+ if not config.mamba_enabled:
1238
+ return False
1239
+ assert config.mamba_step > 1
1240
+ assert i < config.num_hidden_layers
1241
+
1242
+ if config.num_hidden_layers <= (config.mamba_step // 2):
1243
+ # use attention in last layer
1244
+ return i != config.num_hidden_layers - 1
1245
+ return (i % config.mamba_step) != (config.mamba_step // 2)
1246
+
1247
+
1248
+ class PlamoDecoder(nn.Module):
1249
+ def __init__(self, config: ModelArgs) -> None:
1250
+ super().__init__()
1251
+
1252
+ self.layers = [
1253
+ PlamoDecoderLayer(config, is_mamba=is_mamba(config, i), layer_idx=i)
1254
+ for i in range(config.num_hidden_layers)
1255
+ ]
1256
+ self.gradient_checkpointing = False
1257
+
1258
+ def __call__(self, x: DecoderInput) -> DecoderOutput:
1259
+ all_hidden_states: Optional[tuple[mx.array, ...]] = () if x.output_hidden_states else None
1260
+ all_self_attns: Optional[tuple[mx.array, ...]] = () if x.output_attentions else None
1261
+ hidden_states = x.hidden_states
1262
+
1263
+ for decoder_layer in self.layers:
1264
+ if x.output_hidden_states:
1265
+ assert all_hidden_states is not None
1266
+ all_hidden_states += (hidden_states,)
1267
+
1268
+ if self.training and x.gradient_checkpointing:
1269
+ layer_outputs = self._gradient_checkpointing_func(
1270
+ decoder_layer.__call__,
1271
+ hidden_states,
1272
+ x.attention_mask,
1273
+ x.past_states,
1274
+ x.output_attentions,
1275
+ )
1276
+ else:
1277
+ layer_outputs = decoder_layer(
1278
+ hidden_states,
1279
+ attention_mask=x.attention_mask,
1280
+ past_state=x.past_states,
1281
+ output_attentions=x.output_attentions,
1282
+ )
1283
+
1284
+ hidden_states = layer_outputs[0]
1285
+
1286
+ if x.output_attentions:
1287
+ assert layer_outputs[1] is not None
1288
+ assert all_self_attns is not None
1289
+ all_self_attns += (layer_outputs[1],)
1290
+ return DecoderOutput(hidden_states, all_hidden_states, all_self_attns)
1291
+
1292
+
1293
+ class ModelOutput(OrderedDict):
1294
+ def __init__(self, *args, **kwargs):
1295
+ super().__init__(*args, **kwargs)
1296
+
1297
+ def __getitem__(self, k):
1298
+ if isinstance(k, str):
1299
+ inner_dict = dict(self.items())
1300
+ return inner_dict[k]
1301
+ else:
1302
+ return self.to_tuple()[k]
1303
+
1304
+ def to_tuple(self) -> tuple[Any]:
1305
+ """
1306
+ Convert self to a tuple containing all the attributes/keys that are not `None`.
1307
+ """
1308
+ return tuple(self[k] for k in self.keys())
1309
+
1310
+
1311
+ class BaseModelOutputWithPast(ModelOutput):
1312
+ """
1313
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
1314
+
1315
+ Args:
1316
+ last_hidden_state (:obj:`mx.array` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
1317
+ Sequence of hidden-states at the output of the last layer of the model.
1318
+
1319
+ If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape
1320
+ :obj:`(batch_size, 1, hidden_size)` is output.
1321
+ past_key_values (:obj:`list[mx.array]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
1322
+ list of :obj:`mx.array` of length :obj:`config.n_layers`, with each tensor of shape
1323
+ :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
1324
+
1325
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
1326
+ ``past_key_values`` input) to speed up sequential decoding.
1327
+ hidden_states (:obj:`tuple(mx.array)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1328
+ Tuple of :obj:`mx.array` (one for the output of the embeddings + one for the output of each layer)
1329
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1330
+
1331
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1332
+ attentions (:obj:`tuple(mx.array)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1333
+ Tuple of :obj:`mx.array` (one for each layer) of shape
1334
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1335
+
1336
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1337
+ heads.
1338
+ """
1339
+
1340
+ def __init__(self, *args, **kwargs) -> None:
1341
+ super().__init__(*args, **kwargs)
1342
+ self.last_hidden_state: mx.array = kwargs.pop("last_hidden_state")
1343
+ self.past_key_values: Optional[tuple[tuple[mx.array]]] = kwargs.pop("past_key_values", None)
1344
+ self.hidden_states: Optional[tuple[mx.array, ...]] = kwargs.pop("hidden_states", None)
1345
+ self.attentions: Optional[tuple[mx.array, ...]] = kwargs.pop("attentions", None)
1346
+
1347
+
1348
+ class CausalLMOutputWithPast(ModelOutput):
1349
+ """
1350
+ Base class for causal language model (or autoregressive) outputs.
1351
+
1352
+ Args:
1353
+ loss (`mx.array` of shape `(1,)`, *optional*, returned when `labels` is provided):
1354
+ Language modeling loss (for next-token prediction).
1355
+ logits (`mx.array` of shape `(batch_size, sequence_length, config.vocab_size)`):
1356
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1357
+ past_key_values (`tuple(tuple(mx.array))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1358
+ Tuple of `tuple(mx.array)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1359
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
1360
+
1361
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
1362
+ `past_key_values` input) to speed up sequential decoding.
1363
+ hidden_states (`tuple(mx.array)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1364
+ Tuple of `mx.array` (one for the output of the embeddings, if the model has an embedding layer, +
1365
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
1366
+
1367
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
1368
+ attentions (`tuple(mx.array)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
1369
+ Tuple of `mx.array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
1370
+ sequence_length)`.
1371
+
1372
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1373
+ heads.
1374
+ """
1375
+
1376
+ def __init__(self, *args, **kwargs) -> None:
1377
+ super().__init__(*args, **kwargs)
1378
+
1379
+ self.loss: Optional[mx.array] = kwargs.pop("loss", None)
1380
+ self.logits: mx.array | None = kwargs.pop("logits", None)
1381
+ self.past_key_values: Optional[tuple[tuple[mx.array]]] = kwargs.pop("past_key_values", None)
1382
+ self.hidden_states: Optional[tuple[mx.array, ...]] = kwargs.pop("hidden_states", None)
1383
+ self.attentions: Optional[tuple[mx.array, ...]] = kwargs.pop("attentions", None)
1384
+
1385
+
1386
+ class PlamoPreTrainedModel(nn.Module): # type: ignore
1387
+ config_class = ModelArgs
1388
+ _no_split_modules: list[str]
1389
+ base_model_prefix = "model"
1390
+ supports_gradient_checkpointing = True
1391
+ _no_split_modules = ["PlamoDecoderLayer"]
1392
+ _skip_keys_device_placement = "past_key_values"
1393
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
1394
+
1395
+ def __init__(self, config: ModelArgs):
1396
+ super().__init__()
1397
+ self.config = config
1398
+
1399
+ def _init_weights(self, module: nn.Module) -> None:
1400
+ std = 0.02
1401
+ if isinstance(module, nn.Linear):
1402
+ module.weight = mx.random.normal(loc=0.0, scale=std, shape=module.weight.shape)
1403
+ if module.bias is not None:
1404
+ module.bias = mx.zeros_like(module.bias)
1405
+ elif isinstance(module, nn.Embedding):
1406
+ module.weight = mx.random.normal(loc=0.0, scale=std, shape=module.weight.shape)
1407
+ if module.padding_idx is not None:
1408
+ module.weight[module.padding_idx] = mx.zeros_like(module.weight[module.padding_idx])
1409
+
1410
+
1411
+ class PlamoModel(PlamoPreTrainedModel):
1412
+ def __init__(self, config: ModelArgs):
1413
+ super().__init__(config)
1414
+ assert config.eval_attention_n_bit is None
1415
+ assert config.eval_mlp_n_bit is None
1416
+
1417
+ self.padding_idx = config.pad_token_id
1418
+ self.vocab_size = config.vocab_size
1419
+
1420
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
1421
+ self.layers = PlamoDecoder(config) # type: ignore
1422
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1423
+
1424
+ self.gradient_checkpointing = False
1425
+ # Initialize weights and apply final processing
1426
+ # self.post_init()
1427
+
1428
+ def get_input_embeddings(self) -> nn.Embedding:
1429
+ return self.embed_tokens
1430
+
1431
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
1432
+ self.embed_tokens = value
1433
+
1434
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
1435
+ def _prepare_decoder_attention_mask(
1436
+ self,
1437
+ attention_mask: mx.array,
1438
+ input_shape: tuple[int, int],
1439
+ inputs_embeds: Optional[mx.array],
1440
+ past_key_values_length: int,
1441
+ ) -> Optional[mx.array]:
1442
+ # create causal mask
1443
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1444
+ combined_attention_mask: Optional[mx.array] = None
1445
+ if input_shape[-1] > 1:
1446
+ assert inputs_embeds is not None
1447
+ combined_attention_mask = _make_causal_mask(
1448
+ input_shape,
1449
+ inputs_embeds.dtype,
1450
+ past_key_values_length=past_key_values_length,
1451
+ )
1452
+ input_shape = (input_shape[0], combined_attention_mask.shape[2])
1453
+
1454
+ if attention_mask is not None:
1455
+ if attention_mask.ndim == 4:
1456
+ # Custom 4D attention mask
1457
+ expanded_attn_mask = attention_mask
1458
+ else:
1459
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1460
+ assert inputs_embeds is not None
1461
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
1462
+ combined_attention_mask = (
1463
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
1464
+ )
1465
+
1466
+ return combined_attention_mask
1467
+
1468
+ def __call__(
1469
+ self,
1470
+ input_ids: Optional[mx.array] = None,
1471
+ attention_mask: Optional[mx.array] = None,
1472
+ position_ids: Optional[mx.array] = None,
1473
+ past_key_values: Optional[PlamoCache] = None,
1474
+ inputs_embeds: Optional[mx.array] = None,
1475
+ image_features: Optional[mx.array] = None,
1476
+ use_cache: Optional[bool] = None,
1477
+ output_attentions: Optional[bool] = None,
1478
+ output_hidden_states: Optional[bool] = None,
1479
+ return_dict: Optional[bool] = None,
1480
+ ) -> Union[tuple, BaseModelOutputWithPast]:
1481
+ assert input_ids is not None
1482
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1483
+ output_hidden_states = (
1484
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1485
+ )
1486
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1487
+
1488
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1489
+
1490
+ # retrieve input_ids and inputs_embeds
1491
+ if input_ids is not None and inputs_embeds is not None:
1492
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1493
+ elif input_ids is not None:
1494
+ batch_size, seq_length = input_ids.shape
1495
+ else:
1496
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1497
+
1498
+ seq_length_with_past = seq_length
1499
+ past_key_values_length = 0
1500
+
1501
+ if past_key_values is not None:
1502
+ past_key_values_length = past_key_values.get_seq_length()
1503
+ seq_length_with_past = seq_length_with_past + past_key_values_length
1504
+
1505
+ if inputs_embeds is None:
1506
+ inputs_embeds = self.embed_tokens(input_ids)
1507
+
1508
+ if image_features is not None:
1509
+ assert self.config.image_token_id is not None
1510
+ image_embeds = self.image_proj(image_features)
1511
+ assert image_embeds.shape == inputs_embeds.shape, (
1512
+ image_embeds.shape,
1513
+ inputs_embeds.shape,
1514
+ )
1515
+ mask = input_ids == self.config.image_token_id
1516
+ inputs_embeds[mask] = image_embeds[mask]
1517
+
1518
+ # embed positions
1519
+ require_attn_mask = False
1520
+ if not self.training or past_key_values is not None:
1521
+ require_attn_mask = True
1522
+ if seq_length_with_past >= self.config.attention_window_size:
1523
+ require_attn_mask = True
1524
+ if require_attn_mask and attention_mask is None:
1525
+ attention_mask = mx.ones(
1526
+ (batch_size, seq_length_with_past),
1527
+ dtype=mx.bool_, # type: ignore
1528
+ )
1529
+ if attention_mask is not None:
1530
+ attention_mask = self._prepare_decoder_attention_mask(
1531
+ attention_mask,
1532
+ (batch_size, seq_length),
1533
+ inputs_embeds,
1534
+ past_key_values_length,
1535
+ )
1536
+
1537
+ hidden_states = inputs_embeds
1538
+
1539
+ if self.gradient_checkpointing and self.training:
1540
+ if use_cache:
1541
+ use_cache = False
1542
+
1543
+ if use_cache and past_key_values is None:
1544
+ past_key_values = PlamoCache(self.config)
1545
+
1546
+ # decoder layers
1547
+ out = self.layers(
1548
+ DecoderInput(
1549
+ hidden_states,
1550
+ attention_mask,
1551
+ past_key_values,
1552
+ output_hidden_states,
1553
+ output_attentions,
1554
+ self.gradient_checkpointing,
1555
+ )
1556
+ )
1557
+
1558
+ assert isinstance(out, DecoderOutput)
1559
+ hidden_states = out.hidden_states
1560
+ all_hidden_states = out.all_hidden_states
1561
+ all_self_attns = out.all_self_attns
1562
+
1563
+ hidden_states = self.norm(hidden_states)
1564
+
1565
+ # add hidden states from the last decoder layer
1566
+ if output_hidden_states:
1567
+ assert all_hidden_states is not None
1568
+ all_hidden_states += (hidden_states,)
1569
+
1570
+ if not return_dict:
1571
+ return tuple(
1572
+ v
1573
+ for v in [
1574
+ hidden_states,
1575
+ past_key_values,
1576
+ all_hidden_states,
1577
+ all_self_attns,
1578
+ ]
1579
+ if v is not None
1580
+ )
1581
+ return BaseModelOutputWithPast(
1582
+ last_hidden_state=hidden_states,
1583
+ past_key_values=past_key_values,
1584
+ hidden_states=all_hidden_states,
1585
+ attentions=all_self_attns,
1586
+ )
1587
+
1588
+
1589
+ class Model(PlamoPreTrainedModel):
1590
+ _tied_weights_keys = ["lm_head.weight"]
1591
+
1592
+ # Without this, the model cannot be loaded into a meta device.
1593
+ # Relevant code:
1594
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L4376-L4381
1595
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L356
1596
+ # https://github.com/pytorch/pytorch/blob/v2.4.1/torch/nn/modules/module.py#L2068
1597
+ _supports_param_buffer_assignment = False
1598
+
1599
+ def __init__(self, config: ModelArgs) -> None:
1600
+ super().__init__(config)
1601
+ self.config = config
1602
+ self.model = PlamoModel(config)
1603
+
1604
+ self.vocab_size = config.vocab_size
1605
+ vocab_size = ((self.vocab_size + 15) // 16) * 16
1606
+
1607
+ if not config.tie_word_embeddings:
1608
+ self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
1609
+
1610
+ self._prefill = True
1611
+
1612
+ # Initialize weights and apply final processing
1613
+ # self.post_init()
1614
+
1615
+ def get_input_embeddings(self) -> nn.Embedding:
1616
+ return self.model.embed_tokens
1617
+
1618
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
1619
+ self.model.embed_tokens = value
1620
+
1621
+ def get_output_embeddings(self) -> nn.Module:
1622
+ return self.lm_head
1623
+
1624
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
1625
+ self.lm_head = new_embeddings
1626
+
1627
+ def set_decoder(self, decoder: PlamoModel) -> None:
1628
+ self.model = decoder
1629
+
1630
+ def get_decoder(self) -> PlamoModel:
1631
+ return self.model
1632
+
1633
+ def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]:
1634
+ for k, v in weights.items():
1635
+ if "conv1d.weight" in k and v.shape[-1] != 1:
1636
+ weights[k] = v.moveaxis(2, 1)
1637
+ return weights
1638
+
1639
+ def make_cache(self) -> PlamoCache:
1640
+ return PlamoCache(self.config)
1641
+
1642
+ def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array:
1643
+ model_inputs = self.prepare_inputs_for_generation(
1644
+ input_ids=inputs,
1645
+ past_key_values=cache,
1646
+ use_cache=self.config.use_cache,
1647
+ )
1648
+ if self._prefill:
1649
+ model_inputs["input_ids"] = inputs
1650
+ self._prefill = False
1651
+ output = self.forward(**model_inputs)
1652
+ if not isinstance(output, CausalLMOutputWithPast):
1653
+ raise ValueError(
1654
+ f"Unexpected output type for causal language model: {type(output)} != CausalLMOutputWithPast"
1655
+ )
1656
+ if output.logits is not None:
1657
+ return output.logits
1658
+ else:
1659
+ raise ValueError("The model did not return any logits.")
1660
+
1661
+ def forward(
1662
+ self,
1663
+ input_ids: Optional[mx.array] = None,
1664
+ attention_mask: Optional[mx.array] = None,
1665
+ position_ids: Optional[mx.array] = None,
1666
+ past_key_values: Optional[PlamoCache] = None,
1667
+ inputs_embeds: Optional[mx.array] = None,
1668
+ image_features: Optional[mx.array] = None,
1669
+ labels: Optional[mx.array] = None,
1670
+ use_cache: Optional[bool] = None,
1671
+ output_attentions: Optional[bool] = None,
1672
+ output_hidden_states: Optional[bool] = None,
1673
+ return_dict: Optional[bool] = None,
1674
+ ) -> Union[tuple[Any, ...], CausalLMOutputWithPast]:
1675
+ r"""
1676
+ Args:
1677
+ labels (`mx.array` of shape `(batch_size, sequence_length)`, *optional*):
1678
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1679
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1680
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1681
+ Returns:
1682
+ Example:
1683
+ ```python
1684
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1685
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1686
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1687
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
1688
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1689
+ >>> # Generate
1690
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1691
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1692
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
1693
+ ```"""
1694
+ assert input_ids is not None
1695
+
1696
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1697
+ output_hidden_states = (
1698
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1699
+ )
1700
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1701
+
1702
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1703
+ outputs = self.model(
1704
+ input_ids=input_ids,
1705
+ attention_mask=attention_mask,
1706
+ position_ids=position_ids,
1707
+ past_key_values=past_key_values,
1708
+ inputs_embeds=inputs_embeds,
1709
+ image_features=image_features,
1710
+ use_cache=use_cache,
1711
+ output_attentions=output_attentions,
1712
+ output_hidden_states=output_hidden_states,
1713
+ return_dict=return_dict,
1714
+ )
1715
+ if isinstance(outputs, tuple):
1716
+ hidden_states = outputs[0]
1717
+ elif isinstance(outputs, BaseModelOutputWithPast):
1718
+ hidden_states = outputs.last_hidden_state
1719
+
1720
+ if self.config.tie_word_embeddings:
1721
+ logits = self.model.embed_tokens.as_linear(hidden_states)
1722
+ else:
1723
+ logits = self.lm_head(hidden_states)
1724
+
1725
+ logits = logits[..., : self.vocab_size]
1726
+
1727
+ loss = None
1728
+ if labels is not None:
1729
+ # Shift so that tokens < n predict n
1730
+ shift_logits = logits[..., :-1, :]
1731
+ shift_labels = labels[..., 1:]
1732
+ # Flatten the tokens
1733
+ loss_fct = nn.losses.cross_entropy
1734
+ shift_logits = shift_logits.reshape((-1, self.config.vocab_size))
1735
+ shift_labels = shift_labels.reshape((-1,))
1736
+ # Enable model parallelism
1737
+ loss = loss_fct(shift_logits, shift_labels)
1738
+
1739
+ if not return_dict:
1740
+ output = (logits,) + outputs[1:]
1741
+ return (loss,) + output if loss is not None else output
1742
+
1743
+ if not isinstance(outputs, BaseModelOutputWithPast):
1744
+ raise ValueError(
1745
+ f"Unexpected output type for causal language model: {type(outputs)} != BaseModelOutputWithPast"
1746
+ )
1747
+ return CausalLMOutputWithPast(
1748
+ loss=loss,
1749
+ logits=logits,
1750
+ past_key_values=outputs.past_key_values,
1751
+ hidden_states=outputs.hidden_states,
1752
+ attentions=outputs.attentions,
1753
+ )
1754
+
1755
+ def prepare_inputs_for_generation(
1756
+ self,
1757
+ input_ids: mx.array,
1758
+ past_key_values: Optional[PlamoCache] = None,
1759
+ attention_mask: Optional[mx.array] = None,
1760
+ inputs_embeds: Optional[mx.array] = None,
1761
+ image_features: Optional[mx.array] = None,
1762
+ **kwargs: Any,
1763
+ ) -> dict[str, Any]:
1764
+ if past_key_values:
1765
+ input_ids = input_ids[:, -1:]
1766
+ if image_features is not None:
1767
+ image_features = image_features[:, -1:, :]
1768
+
1769
+ position_ids = kwargs.get("position_ids", None)
1770
+ if attention_mask is not None and position_ids is None:
1771
+ # create position_ids on the fly for batch generation
1772
+ position_ids = attention_mask.astype(mx.int64).cumsum(-1) - 1
1773
+ position_ids.masked_fill_(attention_mask == 0, 1)
1774
+ if past_key_values:
1775
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1776
+
1777
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1778
+ if inputs_embeds is not None and past_key_values is None:
1779
+ model_inputs: dict[str, Any] = {"inputs_embeds": inputs_embeds}
1780
+ else:
1781
+ model_inputs = {"input_ids": input_ids}
1782
+
1783
+ model_inputs.update(
1784
+ {
1785
+ "position_ids": position_ids,
1786
+ "past_key_values": past_key_values,
1787
+ "use_cache": kwargs.get("use_cache"),
1788
+ "attention_mask": attention_mask,
1789
+ "image_features": image_features,
1790
+ }
1791
+ )
1792
+ return model_inputs
1793
+
1794
+ @staticmethod
1795
+ def _reorder_cache(past_key_values: PlamoCache, beam_idx: mx.array) -> PlamoCache:
1796
+ past_key_values.reorder_cache(beam_idx)
1797
+ return past_key_values
1798
+
1799
+ @property
1800
+ def layers(self):
1801
+ return self.model.layers
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|plamo:bos|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|plamo:eos|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|plamo:pad|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|plamo:unk|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenization_plamo.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ from shutil import copyfile
5
+ from typing import Any, Optional, Tuple
6
+
7
+ import numpy as np
8
+
9
+ # NOTE: numba does not support type hints for njit: https://github.com/python/mypy/issues/16149
10
+ from numba import njit # type: ignore[attr-defined]
11
+ from numba.core import types
12
+ from numba.typed import Dict, List
13
+ from transformers.tokenization_utils import PreTrainedTokenizer
14
+ from transformers.utils import logging
15
+
16
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.jsonl"}
17
+ logger = logging.get_logger(__name__)
18
+
19
+ INVALID_SCORE = -20000000
20
+ UNKNOWN_SCORE = -10000000
21
+
22
+ TABLE_PIECE_LENGTH = 0
23
+ TABLE_TOKEN_ID = 1
24
+ TABLE_SCORE = 2
25
+ TABLE_PIECE_ID = 3
26
+
27
+ PATH_TOKEN_LENGTH = 0
28
+ PATH_TOKEN_ID = 1
29
+ PATH_NUM_TOKENS = 2
30
+
31
+
32
+ class AhoCorasick:
33
+ def __init__(self) -> None:
34
+ # List of tokens in the vocabulary.
35
+ self._tokens: list[str]
36
+
37
+ # A mapping from a byte code point to a token ID, used for byte fallback.
38
+ self._bytes: np.ndarray
39
+
40
+ # A mapping from a suffix's piece code to a suffix ID.
41
+ #
42
+ # Typically, the Aho-Corasick algorithm builds a Trie and adds suffix links between nodes
43
+ # of the Trie. In this implementation, a suffix ID corresponds to a node in the trie, and
44
+ # a piece code to an edge (in other words, a pair of a node and the next character).
45
+ #
46
+ # A piece code is a 64-bit integer:
47
+ # - The upper 32 bits store the Unicode code point of the first character.
48
+ # - The lower 32 bits store the suffix ID of the remaining suffix.
49
+ #
50
+ # A suffix ID is an integer indicating the starting position in the _table.
51
+ self._to_suffix_id: Dict[types.int64, types.int32]
52
+
53
+ # Flattened table representing the Trie structure for the Aho-Corasick algorithm.
54
+ # It stores information including scores for each piece (prefix) within each suffix.
55
+ # It is flattened for memory efficiency and performance. Suffixes are stored in
56
+ # lexicographical order of their reversed strings, which improves memory access locality
57
+ # when exploring new characters starting from the string's end. Pieces within a suffix are
58
+ # stored in the decreasing order of their lengths.
59
+ #
60
+ # Each piece (a prefix fo the suffix) contains four pieces of information:
61
+ # - TABLE_PIECE_LENGTH: Length of the piece.
62
+ # - TABLE_TOKEN_ID: Token ID (or -1 if the piece is not a valid token).
63
+ # - TABLE_SCORE: Score (or INVALID_SCORE if the piece is not a valid token).
64
+ # - TABLE_PIECE_ID: Piece ID of the suffix.
65
+ #
66
+ # Each suffix also includes a sentinel row with a length of 1, a score of UNKNOWN_SCORE,
67
+ # and a token ID of -1. Sentinel rows are identified by the score being UNKNOWN_SCORE.
68
+ self._table: np.ndarray
69
+
70
+ def build(self, vocab: list[Any]) -> None:
71
+ self._bytes = np.zeros(256, dtype=np.int32)
72
+ self._to_suffix_id = Dict.empty(key_type=types.int64, value_type=types.int32)
73
+
74
+ # Build suffix_to_score and token_to_token_id.
75
+ # The suffix_to_score dictionary maps a suffix to its score. It also includes all suffixes
76
+ # of the token for the Trie structure for the Aho-Corasick algorithm. If a suffix is not a
77
+ # valid token, its score is set to math.nan.
78
+ # The token_to_token_id dictionary maps a token to its token ID.
79
+ suffix_to_score: dict[str, float] = {}
80
+ token_to_token_id: dict[str, int] = {}
81
+ self._tokens = []
82
+ for token_id, row in enumerate(vocab):
83
+ assert isinstance(row[0], str), row
84
+ assert isinstance(row[1], (int, float)), row
85
+
86
+ token = str(row[0])
87
+ self._tokens.append(token)
88
+ token_to_token_id[token] = token_id
89
+
90
+ # Special handling for byte tokens.
91
+ if len(row) > 2 and row[2] == "BYTE":
92
+ assert len(token) == 6 and token.startswith("<0x") and token.endswith(">"), row[0]
93
+ self._bytes[int(row[0][3:5], 16)] = token_id
94
+ continue
95
+
96
+ suffix_to_score[token] = float(row[1])
97
+ # Ensure that all suffixes are included in suffix_to_score.
98
+ for i in range(1, len(token)):
99
+ suffix_to_score[token[i:]] = suffix_to_score.get(token[i:], math.nan)
100
+
101
+ # Ensure all byte tokens are set.
102
+ for i in range(256):
103
+ assert self._bytes[i] != 0, f"Byte token for <0x{i:02X}> is not set."
104
+
105
+ # List suffixes in lexicographical order of their reversed strings.
106
+ suffixes = list(suffix_to_score.keys())
107
+ suffixes.append("")
108
+ suffixes.sort(key=lambda x: x[::-1])
109
+
110
+ # Build suffix_to_id, which is a mapping from a suffix to a suffix ID, and _to_suffix_id,
111
+ # which is a mapping from a piece code to a suffix ID.
112
+ suffix_to_id: dict[str, int] = {}
113
+ num_pieces = 0
114
+ for s in suffixes:
115
+ suffix_to_id[s] = num_pieces
116
+ if s != "":
117
+ self._to_suffix_id[ord(s[0]) << 32 | suffix_to_id[s[1:]]] = np.int32(num_pieces)
118
+ num_pieces += 1 + sum(s[:i] in suffix_to_score for i in range(1, len(s) + 1))
119
+ assert suffix_to_id[""] == 0, suffix_to_id[""]
120
+
121
+ # Build _table, which is a flattened table representing the Trie structure for the Aho-Corasick.
122
+ self._table = np.zeros((num_pieces, 4), dtype=np.int32)
123
+ i = 0
124
+ for suffix in suffixes:
125
+ # Add all prefixes of the suffix to the table.
126
+ for piece_length in range(len(suffix), 0, -1):
127
+ piece = suffix[:piece_length]
128
+ score = suffix_to_score.get(piece, None)
129
+ if score is None:
130
+ continue
131
+ self._table[i, TABLE_PIECE_LENGTH] = piece_length
132
+ self._table[i, TABLE_TOKEN_ID] = token_to_token_id.get(piece, -1)
133
+ self._table[i, TABLE_SCORE] = round(score * 1e4) if math.isfinite(score) else INVALID_SCORE
134
+ self._table[i, TABLE_PIECE_ID] = suffix_to_id[piece]
135
+ i += 1
136
+
137
+ # Add a sentinel row.
138
+ self._table[i, TABLE_PIECE_LENGTH] = 1
139
+ self._table[i, TABLE_TOKEN_ID] = -1
140
+ self._table[i, TABLE_SCORE] = UNKNOWN_SCORE
141
+ i += 1
142
+ assert i == num_pieces, (i, num_pieces)
143
+
144
+ @staticmethod
145
+ @njit
146
+ def _encode(
147
+ to_suffix_id: Dict[types.int64, types.int32],
148
+ table: np.ndarray,
149
+ bytes: np.ndarray,
150
+ data: np.ndarray,
151
+ ) -> np.ndarray:
152
+ # Initialize scores array with a high value and set the score at the end to 0.
153
+ # This array keeps track of the minimum cost (best score) to encode from each position to the end.
154
+ scores = np.full((len(data) + 1,), 2**60, dtype=np.int64)
155
+ scores[-1] = 0
156
+
157
+ # Path array to store the best path information.
158
+ # The path array keeps track of token length, token ID, and number of tokens needed to encode.
159
+ path = np.zeros((len(data) + 1, 3), dtype=np.int32)
160
+
161
+ # Initialize suffix_id to 0, which represents the root of the Trie.
162
+ suffix_id = 0
163
+
164
+ # Process the input data from the end to the beginning.
165
+ for i in range(len(data) - 1, -1, -1):
166
+ c = data[i]
167
+
168
+ # Find the next suffix ID by iterating the suffix IDs of prefixes of the current suffix.
169
+ # NOTE: If no suffix ID is found, suffix_id will be set to 0.
170
+ for p in range(suffix_id, len(table)):
171
+ suffix_id = to_suffix_id.get(c << 32 | table[p, TABLE_PIECE_ID], np.int32(0))
172
+ # If a next suffix ID is found or a sentinel row is reached, break the loop.
173
+ if suffix_id > 0 or table[p, TABLE_SCORE] == UNKNOWN_SCORE:
174
+ break
175
+
176
+ # Update the best path to the current position. If multiple paths have the same score,
177
+ # this chooses the longest prefix as the best path (table is sorted in the decreasing
178
+ # order of piece length).
179
+ for p in range(suffix_id, len(table)):
180
+ score = table[p, TABLE_SCORE]
181
+ if score > INVALID_SCORE:
182
+ piece_length = table[p, TABLE_PIECE_LENGTH]
183
+ s = scores[i + piece_length] - score
184
+ if s < scores[i]:
185
+ scores[i] = s
186
+ path[i, PATH_TOKEN_LENGTH] = piece_length
187
+ path[i, PATH_TOKEN_ID] = table[p, TABLE_TOKEN_ID]
188
+ path[i, PATH_NUM_TOKENS] = path[i + piece_length, PATH_NUM_TOKENS] + 1
189
+ if score == UNKNOWN_SCORE:
190
+ # Add number of bytes to represent `c` in UTF-8 (minus 1; 1 is already
191
+ # added above).
192
+ path[i, PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000)
193
+
194
+ # If it reaches a sentinel row, break the loop.
195
+ if score == UNKNOWN_SCORE:
196
+ break
197
+
198
+ # Decode the best path from the beginning to get the token IDs.
199
+ pos = 0
200
+ token_ids = np.zeros(path[0, PATH_NUM_TOKENS], dtype=np.int32)
201
+ token_pos = 0
202
+ while pos < len(data):
203
+ if path[pos, PATH_TOKEN_ID] >= 0:
204
+ token_ids[token_pos] = path[pos, PATH_TOKEN_ID]
205
+ token_pos += 1
206
+ else:
207
+ # Fall back to byte tokens.
208
+ c = data[pos]
209
+ s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000)
210
+ # Add byte tokens representing UTF-8 bytes.
211
+ for i in range(s):
212
+ b = c if s == 1 else (0xF00 >> s) & 0xFF if i == 0 else 0x80
213
+ token_ids[token_pos] = bytes[b | ((c >> (s - i - 1) * 6) & 0x3F)]
214
+ token_pos += 1
215
+
216
+ # Ensure that pos should increase by at least 1.
217
+ assert path[pos, PATH_TOKEN_LENGTH] > 0, (pos, path[pos])
218
+ pos += path[pos, PATH_TOKEN_LENGTH]
219
+
220
+ return token_ids
221
+
222
+ def encode(self, data: str) -> np.ndarray:
223
+ """Encodes a string into a sequence of token IDs."""
224
+ return np.asarray(
225
+ self._encode(
226
+ self._to_suffix_id,
227
+ self._table,
228
+ self._bytes,
229
+ # Convert a string into a numpy array of Unicode code points.
230
+ # NOTE: This skips UTF-32 BOM.
231
+ np.frombuffer(data.encode("utf-32"), dtype=np.int32)[1:],
232
+ )
233
+ )
234
+
235
+ def encode_as_tokens(self, data: str) -> list[str]:
236
+ """Encodes a string into a sequence of tokens."""
237
+ return [self._tokens[token_id] for token_id in self.encode(data)]
238
+
239
+
240
+ class PlamoTokenizer(PreTrainedTokenizer): # type: ignore
241
+ vocab_files_names = VOCAB_FILES_NAMES
242
+ model_input_names = ["input_ids", "attention_mask"]
243
+
244
+ _save_files = [
245
+ "special_tokens_map.json",
246
+ "tokenization_plamo.py",
247
+ "tokenizer.jsonl",
248
+ "tokenizer_config.json",
249
+ ]
250
+
251
+ def __init__(
252
+ self,
253
+ vocab_file: str,
254
+ unk_token: str = "<|plamo:unk|>",
255
+ bos_token: str = "<|plamo:bos|>",
256
+ eos_token: str = "<|plamo:eos|>",
257
+ pad_token: str = "<|plamo:pad|>",
258
+ cls_token: Optional[str] = None,
259
+ sep_token: Optional[str] = None,
260
+ mask_token: Optional[str] = None,
261
+ clean_up_tokenization_spaces: bool = False,
262
+ **kwargs: Any,
263
+ ) -> None:
264
+ """Tokenizer for PLaMo.
265
+
266
+ Args:
267
+ vocab_file (str): Vocabrary file path.
268
+ unk_token (str): Unknown token.
269
+ bos_token (str): Beginning of sentence token.
270
+ eos_token (str): End of sentence token.
271
+ pad_token (str): Padding token.
272
+ cls_token (str):
273
+ Classification token, to extract a summary of an input sequence leveraging self-attention along the
274
+ full depth of the model.
275
+ sep_token (str): Separation token, to separate context and query in an input sequence.
276
+ mask_token (str): Mask token, to use when training a model with masked-language modeling.
277
+ clean_up_tokenization_spaces (bool): Whether or not to clean up the tokenization spaces.
278
+ num_threads (int):
279
+ Number of threads. This value will be ignored if one of `PLAMO_TOKENIZER_NUM_THREADS` or
280
+ `RAYON_NUM_THREADS` is set as an environment variable.
281
+ """
282
+ if "add_bos_token" not in kwargs:
283
+ kwargs["add_bos_token"] = False
284
+ if "add_eos_token" not in kwargs:
285
+ kwargs["add_eos_token"] = False
286
+ self.data: list[Any] = [json.loads(line) for line in open(vocab_file, "r", encoding="utf-8")]
287
+ self.vocab: dict[str, int] = {v[0]: i for i, v in enumerate(self.data)}
288
+ self.aho_corasick = AhoCorasick()
289
+ self.aho_corasick.build(self.data)
290
+ self.vocab_file = vocab_file
291
+ self.add_bos_token = kwargs["add_bos_token"]
292
+ self.add_eos_token = kwargs["add_eos_token"]
293
+
294
+ super().__init__(
295
+ vocab_file=vocab_file,
296
+ unk_token=unk_token,
297
+ bos_token=bos_token,
298
+ eos_token=eos_token,
299
+ pad_token=pad_token,
300
+ cls_token=cls_token,
301
+ sep_token=sep_token,
302
+ mask_token=mask_token,
303
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
304
+ **kwargs,
305
+ )
306
+
307
+ # the functions below are copied from hf transformers LlamaTokenizer's implementation to fix the behaviour of the tokenizer
308
+ # https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/models/llama/tokenization_llama.py
309
+
310
+ def __getstate__(self) -> dict[str, Any]:
311
+ state = self.__dict__.copy()
312
+ state["aho_corasick"] = None
313
+ return state
314
+
315
+ def __setstate__(self, d: dict[str, Any]) -> None:
316
+ self.__dict__ = d
317
+ self.aho_corasick = AhoCorasick()
318
+ self.aho_corasick.build(self.data)
319
+
320
+ @property
321
+ def vocab_size(self) -> Any:
322
+ """Returns vocab size"""
323
+ return len(self.data)
324
+
325
+ def token_to_score(self, token: str) -> Optional[float]:
326
+ """Returns score of the token"""
327
+ token_id = self.vocab.get(token, None)
328
+ return None if token_id is None else self.data[token_id][1]
329
+
330
+ def get_vocab(self) -> dict[str, int]:
331
+ """Returns vocab as a dict"""
332
+ vocab = self.vocab.copy()
333
+ vocab.update(self.added_tokens_encoder)
334
+ return vocab
335
+
336
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
337
+ """Converts a sequence of tokens (string) in a single string."""
338
+ return b"".join(
339
+ [bytes([int(t[3:5], 16)]) if t.startswith("<0x") else t.encode("utf-8") for t in tokens]
340
+ ).decode("utf-8", errors="replace")
341
+
342
+ def _tokenize(self, text: str) -> Any:
343
+ """Returns a tokenized string."""
344
+ return self.aho_corasick.encode_as_tokens(text)
345
+
346
+ def _convert_token_to_id(self, token: str) -> Any:
347
+ """Converts a token (str) in an id using the vocab."""
348
+ return self.vocab.get(token, 0)
349
+
350
+ def _convert_id_to_token(self, index: int) -> Any:
351
+ """Converts an index (integer) in a token (str) using the vocab."""
352
+ return self.data[index][0]
353
+
354
+ def build_inputs_with_special_tokens(
355
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
356
+ ) -> List[int]:
357
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
358
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
359
+
360
+ output = bos_token_id + token_ids_0 + eos_token_id
361
+
362
+ if token_ids_1 is not None:
363
+ output = output + bos_token_id + token_ids_1 + eos_token_id
364
+
365
+ return output
366
+
367
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
368
+ """
369
+ Save the vocabulary and special tokens file to a directory.
370
+
371
+ Args:
372
+ save_directory (`str`):
373
+ The directory in which to save the vocabulary.
374
+
375
+ Returns:
376
+ `Tuple(str)`: Paths to the files saved.
377
+ """
378
+ if not os.path.isdir(save_directory):
379
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
380
+ return ("",)
381
+ out_vocab_file = os.path.join(
382
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
383
+ )
384
+
385
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
386
+ copyfile(self.vocab_file, out_vocab_file)
387
+ elif not os.path.isfile(self.vocab_file):
388
+ with open(out_vocab_file, "w") as f:
389
+ for token in self.data:
390
+ print(json.dumps(token, ensure_ascii=False), file=f)
391
+
392
+ return (out_vocab_file,)
tokenizer.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<|plamo:unk|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<|plamo:bos|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "<|plamo:eos|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "3": {
30
+ "content": "<|plamo:pad|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ }
37
+ },
38
+ "auto_map": {
39
+ "AutoTokenizer": [
40
+ "tokenization_plamo.PlamoTokenizer",
41
+ null
42
+ ]
43
+ },
44
+ "bos_token": "<|plamo:bos|>",
45
+ "clean_up_tokenization_spaces": false,
46
+ "cls_token": null,
47
+ "eos_token": "<|plamo:eos|>",
48
+ "local_file_only": true,
49
+ "mask_token": null,
50
+ "model_max_length": 1000000000000000019884624838656,
51
+ "pad_token": "<|plamo:pad|>",
52
+ "sep_token": null,
53
+ "tokenizer_class": "PlamoTokenizer",
54
+ "unk_token": "<|plamo:unk|>"
55
+ }