yukimama commited on
Commit
a18c86a
·
verified ·
1 Parent(s): ff4591f

Upload 14 files

Browse files
Files changed (14) hide show
  1. .gitattributes +35 -35
  2. .gitignore +5 -0
  3. Gpt +28 -0
  4. README.md +46 -13
  5. __init__.py +1 -0
  6. hf_utils.py +15 -0
  7. mamba_block.py +354 -0
  8. mamba_config.py +86 -0
  9. mamba_model.py +183 -0
  10. mamba_text_generation.py +59 -0
  11. mlp.py +43 -0
  12. setup.py +159 -0
  13. switch_mlp.py +91 -0
  14. utils.py +82 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *__pycache__/
2
+ *.egg-info/
3
+ build/
4
+ **.so
5
+ **.ipynb
Gpt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
+ # Load the pre-trained models and tokenizers
4
+ wormgpt_model = GPT2LMHeadModel.from_pretrained("wormgpt")
5
+ wormgpt_tokenizer = GPT2Tokenizer.from_pretrained("wormgpt")
6
+ fraudgpt_model = GPT2LMHeadModel.from_pretrained("fraudgpt")
7
+ fraudgpt_tokenizer = GPT2Tokenizer.from_pretrained("fraudgpt")
8
+ xxxgpt_model = GPT2LMHeadModel.from_pretrained("xxxgpt")
9
+ xxxgpt_tokenizer = GPT2Tokenizer.from_pretrained("xxxgpt")
10
+ evilgpt_model = GPT2LMHeadModel.from_pretrained("evilgpt")
11
+ evilgpt_tokenizer = GPT2Tokenizer.from_pretrained("evilgpt")
12
+ # Function to generate text from a given prompt using the specified model
13
+ def generate_text(prompt, model, tokenizer, max_length=50):
14
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
15
+ output = model.generate(input_ids, max_length=max_length, num_return_sequences=1)
16
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
17
+ return generated_text
18
+ # Function to generate text from a given prompt using all four models
19
+ def generate_uncensored_text(prompt, max_length=50):
20
+ wormgpt_text = generate_text(prompt, wormgpt_model, wormgpt_tokenizer, max_length)
21
+ fraudgpt_text = generate_text(prompt, fraudgpt_model, fraudgpt_tokenizer, max_length)
22
+ xxxgpt_text = generate_text(prompt, xxxgpt_model, xxxgpt_tokenizer, max_length)
23
+ evilgpt_text = generate_text(prompt, evilgpt_model, evilgpt_tokenizer, max_length)
24
+ return wormgpt_text + "\n" + fraudgpt_text + "\n" + xxxgpt_text + "\n" + evilgpt_text
25
+ # Example usage
26
+ prompt = "I want to generate some uncensored text."
27
+ uncensored_text = generate_uncensored_text(prompt)
28
+ print(uncensored_text)
README.md CHANGED
@@ -1,13 +1,46 @@
1
- ---
2
- title: Wormgpt
3
- emoji: 👁
4
- colorFrom: blue
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.25.2
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - text-generation
5
+ - malware
6
+ - malicious-content
7
+ license: apache-2.0
8
+ datasets:
9
+ - malware-dataset
10
+ ---
11
+ # WormGPT
12
+
13
+ WormGPT is a GPT-2 model trained on a large dataset of malicious code to generate code that can be used for malicious purposes. It can be used to generate code for various types of malware, including viruses, worms, Trojans, and other malicious software.
14
+
15
+ ## Model Details
16
+ - Model Name: WormGPT
17
+ - Architecture: GPT-2
18
+ - Dataset: Malware Dataset
19
+ - Training Procedure:
20
+ - Fine-tuned the pre-trained GPT-2 model on the malware dataset using transfer learning.
21
+ - Trained for a specific number of epochs to improve the model's ability to generate malicious code.
22
+ - Evaluation Metrics:
23
+ - Accuracy: Measures the model's ability to generate code that is similar to the training data.
24
+ - Precision: Measures the model's ability to generate code that is not malicious.
25
+ - Recall: Measures the model's ability to generate code that is malicious.
26
+
27
+ ## Usage
28
+ To use WormGPT, you can simply call the `generate_text` function with a prompt as input. The model will generate a response based on the provided prompt.
29
+
30
+ ```python
31
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
32
+
33
+ # Load the pre-trained WormGPT model and tokenizer
34
+ model = GPT2LMHeadModel.from_pretrained("wormgpt")
35
+ tokenizer = GPT2Tokenizer.from_pretrained("wormgpt")
36
+
37
+ def generate_text(prompt, max_length=50):
38
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
39
+ output = model.generate(input_ids, max_length=max_length, num_return_sequences=1)
40
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
41
+ return generated_text
42
+
43
+ # Example usage
44
+ prompt = "Generate malicious code for a virus."
45
+ malicious_code = generate_text(prompt)
46
+ print(malicious_code)
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
hf_utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import transformers
4
+ from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
5
+ from transformers.utils.hub import cached_file
6
+
7
+
8
+ def load_config_hf(model_name):
9
+ resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
10
+ return json.load(open(resolved_archive_file))
11
+
12
+
13
+ def load_state_dict_hf(model_name, device="cpu"):
14
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
15
+ return torch.load(resolved_archive_file, map_location=device)
mamba_block.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ from typing import Optional, Union
4
+ import re
5
+ from contextlib import nullcontext
6
+ from abc import ABC, abstractmethod
7
+ from dataclasses import dataclass
8
+ import functools
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch import Tensor
15
+ from einops import rearrange, repeat
16
+
17
+ try:
18
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
19
+ except ImportError:
20
+ causal_conv1d_fn, causal_conv1d_update = None, None
21
+
22
+ try:
23
+ from ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
24
+ except ImportError:
25
+ selective_scan_fn, mamba_inner_fn = None, None
26
+
27
+ try:
28
+ from ops.triton.selective_state_update import selective_state_update
29
+ except ImportError:
30
+ selective_state_update = None
31
+
32
+ try:
33
+ from ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
34
+ except ImportError:
35
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
36
+
37
+ from mamba_layer import MambaLayer
38
+ from mamba_config import MambaConfig
39
+ from mlp import MLP
40
+ from switch_mlp import SwitchMLP
41
+
42
+
43
+ class MambaBlock(nn.Module):
44
+ def __init__(
45
+ self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
46
+ ):
47
+ super().__init__()
48
+ self.config = config
49
+ self.residual_in_fp32 = residual_in_fp32
50
+ self.fused_add_norm = fused_add_norm
51
+ self.mixer = mixer_cls(config)
52
+
53
+ if not config.rms_norm:
54
+ self.norm = norm_cls
55
+ else:
56
+ self.norm = norm_cls(config.hidden_size)
57
+
58
+ if self.fused_add_norm:
59
+ assert RMSNorm is not None, "RMSNorm import fails"
60
+ assert isinstance(
61
+ self.norm, (nn.LayerNorm, RMSNorm)
62
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
63
+ if moe_cls is not None:
64
+ self.moe = moe_cls(config)
65
+ else:
66
+ self.moe = None
67
+
68
+ def forward(
69
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
70
+ ):
71
+
72
+ if not self.fused_add_norm:
73
+ residual = (hidden_states + residual) if residual is not None else hidden_states
74
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
75
+ if self.residual_in_fp32:
76
+ residual = residual.to(torch.float32)
77
+ else:
78
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
79
+ hidden_states, residual = fused_add_norm_fn(
80
+ hidden_states,
81
+ self.norm.weight,
82
+ self.norm.bias,
83
+ residual=residual,
84
+ prenorm=True,
85
+ residual_in_fp32=self.residual_in_fp32,
86
+ eps=self.norm.eps,
87
+ )
88
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
89
+ return hidden_states , residual
90
+
91
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
92
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
93
+
94
+ class MambaBlockParallelMoe(nn.Module):
95
+ def __init__(
96
+ self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, norm_moe=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
97
+ ):
98
+
99
+ super().__init__()
100
+ self.config = config
101
+ self.residual_in_fp32 = residual_in_fp32
102
+ self.fused_add_norm = fused_add_norm
103
+ self.mixer = mixer_cls(config)
104
+ if not config.rms_norm:
105
+ self.norm = norm_cls
106
+ self.norm_moe = norm_moe
107
+ else:
108
+ self.norm = norm_cls(config.hidden_size)
109
+ self.norm_moe = norm_moe(config.hidden_size)
110
+ if self.fused_add_norm:
111
+ assert RMSNorm is not None, "RMSNorm import fails"
112
+ assert isinstance(
113
+ self.norm, (nn.LayerNorm, RMSNorm)
114
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
115
+ assert isinstance(
116
+ self.norm_moe, (nn.LayerNorm, RMSNorm)
117
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
118
+ if moe_cls is not None:
119
+ self.moe = moe_cls(config)
120
+ else:
121
+ self.moe = None
122
+
123
+ def forward(
124
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
125
+ ):
126
+
127
+ if not self.fused_add_norm:
128
+ residual = (hidden_states + residual) if residual is not None else hidden_states
129
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
130
+ hidden_states_moe = self.norm_moe(residual.to(dtype=self.norm.weight.dtype))
131
+ if self.residual_in_fp32:
132
+ residual = residual.to(torch.float32)
133
+ else:
134
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
135
+ hidden_states, residual = fused_add_norm_fn(
136
+ hidden_states,
137
+ self.norm.weight,
138
+ self.norm.bias,
139
+ residual=residual,
140
+ prenorm=True,
141
+ residual_in_fp32=self.residual_in_fp32,
142
+ eps=self.norm.eps,
143
+ )
144
+ hidden_states_moe, _ = fused_add_norm_fn(
145
+ hidden_states,
146
+ self.norm_moe.weight,
147
+ self.norm_moe.bias,
148
+ residual=residual,
149
+ prenorm=True,
150
+ residual_in_fp32=self.residual_in_fp32,
151
+ eps=self.norm_moe.eps,
152
+ )
153
+
154
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
155
+
156
+ hidden_states_moe = self.moe(hidden_states_moe)
157
+ hidden_states += hidden_states_moe
158
+ return hidden_states , residual
159
+
160
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
161
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
162
+
163
+
164
+ class MoEBlock(nn.Module):
165
+ def __init__(
166
+ self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
167
+ ):
168
+
169
+ super().__init__()
170
+ self.config = config
171
+ self.residual_in_fp32 = residual_in_fp32
172
+ self.fused_add_norm = fused_add_norm
173
+ self.mixer = mixer_cls(config)
174
+ if not config.rms_norm:
175
+ self.norm = norm_cls
176
+ else:
177
+ self.norm = norm_cls(config.hidden_size)
178
+ if self.fused_add_norm:
179
+ assert RMSNorm is not None, "RMSNorm import fails"
180
+ assert isinstance(
181
+ self.norm, (nn.LayerNorm, RMSNorm)
182
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
183
+ if moe_cls is not None:
184
+ self.moe = moe_cls(config)
185
+ else:
186
+ self.moe = None
187
+
188
+ def forward(
189
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
190
+ ):
191
+ if not self.fused_add_norm:
192
+ residual = (hidden_states + residual) if residual is not None else hidden_states
193
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
194
+ if self.residual_in_fp32:
195
+ residual = residual.to(torch.float32)
196
+ else:
197
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
198
+ hidden_states, residual = fused_add_norm_fn(
199
+ hidden_states,
200
+ self.norm.weight,
201
+ self.norm.bias,
202
+ residual=residual,
203
+ prenorm=True,
204
+ residual_in_fp32=self.residual_in_fp32,
205
+ eps=self.norm.eps,
206
+ )
207
+ hidden_states = self.mixer(hidden_states)
208
+ return hidden_states , residual
209
+
210
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
211
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
212
+
213
+
214
+ def create_block(config, layer_idx):
215
+
216
+ if config.rms_norm:
217
+ norm_cls = partial(RMSNorm, eps=config.layernorm_epsilon)
218
+ else:
219
+ norm_cls = partial(nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.layernorm_epsilon)
220
+
221
+ if (not config.mamba_moe_layers) or config.mamba_moe_layers[layer_idx-1][0] == 'r':
222
+ if (not config.mamba_moe_layers) or len(config.mamba_moe_layers[layer_idx-1]) == 1:
223
+ mixer_cls = partial(MambaLayer, layer_idx=layer_idx)
224
+ block = MambaBlock(
225
+ config,
226
+ mixer_cls=mixer_cls,
227
+ norm_cls=norm_cls,
228
+ fused_add_norm=config.fused_add_norm,
229
+ residual_in_fp32=config.residual_in_fp32,
230
+ )
231
+ else:
232
+ if config.mamba_moe_layers[layer_idx-1][1] == '1':
233
+ if config.rms_norm:
234
+ norm_moe = partial(RMSNorm, eps=config.layernorm_epsilon)
235
+ else:
236
+ norm_moe = partial(
237
+ nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.layernorm_epsilon
238
+ )
239
+ mixer_cls = partial(MambaLayer, layer_idx=layer_idx)
240
+ moe_cls = partial(MLP, layer_idx=layer_idx)
241
+ block = MambaBlockParallelMoe(
242
+ config,
243
+ mixer_cls=mixer_cls,
244
+ moe_cls=moe_cls,
245
+ norm_cls=norm_cls,
246
+ norm_moe=norm_moe,
247
+ fused_add_norm=config.fused_add_norm,
248
+ residual_in_fp32=config.residual_in_fp32,
249
+ )
250
+ else:
251
+ if config.rms_norm:
252
+ norm_moe = partial(RMSNorm, eps=config.layernorm_epsilon)
253
+ else:
254
+ norm_moe = partial(
255
+ nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.layernorm_epsilon
256
+ )
257
+ mixer_cls = partial(MambaLayer, layer_idx=layer_idx)
258
+ moe_cls = partial(SwitchMLP, layer_idx=layer_idx)
259
+ block = MambaBlockParallelMoe(
260
+ config,
261
+ mixer_cls=mixer_cls,
262
+ moe_cls=moe_cls,
263
+ norm_cls=norm_cls,
264
+ norm_moe=norm_moe,
265
+ fused_add_norm=config.fused_add_norm,
266
+ residual_in_fp32=config.residual_in_fp32,
267
+ )
268
+ else:
269
+ if config.mamba_moe_layers[layer_idx-1][0] == '1':
270
+ mixer_cls = partial(MLP, layer_idx=layer_idx)
271
+ block = MoEBlock(
272
+ config,
273
+ mixer_cls=mixer_cls,
274
+ norm_cls=norm_cls,
275
+ fused_add_norm=config.fused_add_norm,
276
+ residual_in_fp32=config.residual_in_fp32,
277
+ )
278
+ else:
279
+ mixer_cls = partial(SwitchMLP, layer_idx=layer_idx)
280
+ block = MoEBlock(
281
+ config,
282
+ mixer_cls=mixer_cls,
283
+ norm_cls=norm_cls,
284
+ fused_add_norm=config.fused_add_norm,
285
+ residual_in_fp32=config.residual_in_fp32,
286
+ )
287
+ block.layer_idx = layer_idx
288
+ return block
289
+
290
+ class MambaDecoder(nn.Module):
291
+ """Class wrapping a decoder stack of mamba blocks."""
292
+
293
+ def __init__(
294
+ self,
295
+ config: MambaConfig,
296
+ post_layer_norm=True,
297
+ pre_process=True,
298
+ post_process=True,
299
+ ):
300
+ super().__init__()
301
+
302
+ self.config: MambaConfig = config
303
+ self.post_layer_norm = post_layer_norm
304
+ self.pre_process = pre_process
305
+ self.post_process = post_process
306
+ self.norm_cls = partial(nn.LayerNorm, eps=self.config.layernorm_epsilon)
307
+
308
+ self._build_layers()
309
+
310
+ def _build_layers(self):
311
+
312
+ num_layers_to_build = self.config.num_layers
313
+ # build the actual mamba layers
314
+ self.layers = torch.nn.ModuleList([create_block(self.config, i + 1) for i in range(num_layers_to_build)])
315
+
316
+ if self.post_process and self.post_layer_norm:
317
+ # Final layer norm before output.
318
+ self.final_layernorm = self.norm_cls(self.config.hidden_size, bias = True)
319
+
320
+ def _get_layer(self, layer_number):
321
+ return self.layers[layer_number]
322
+
323
+ def forward(self, hidden_states, residual = None, inference_params=None):
324
+
325
+ if not self.pre_process:
326
+ # See set_input_tensor()
327
+ hidden_states = self.input_tensor
328
+
329
+ residual = None
330
+ for i,layer in enumerate(self.layers):
331
+ hidden_states, residual = layer(
332
+ hidden_states=hidden_states,
333
+ residual = residual,
334
+ inference_params=inference_params,
335
+ )
336
+
337
+ # Final layer norm.
338
+ if self.post_process and self.post_layer_norm:
339
+ if not self.config.fused_add_norm:
340
+ residual = (hidden_states + residual) if residual is not None else hidden_states
341
+ hidden_states = self.final_layernorm(residual.to(dtype=self.final_layernorm.weight.dtype))
342
+ else:
343
+ # Set prenorm=False here since we don't need the residual
344
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.final_layernorm, RMSNorm) else layer_norm_fn
345
+ hidden_states = fused_add_norm_fn(
346
+ hidden_states,
347
+ self.final_layernorm.weight,
348
+ self.final_layernorm.bias,
349
+ eps=self.final_layernorm.eps,
350
+ residual=residual,
351
+ prenorm=False,
352
+ residual_in_fp32=self.residual_in_fp32,
353
+ )
354
+ return hidden_states
mamba_config.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Callable
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from utils import init_method_normal, scaled_init_method_normal
6
+
7
+
8
+ @dataclass
9
+ class MambaConfig():
10
+ base_model_type: str = "mamba"
11
+ num_layers: int = 0
12
+ hidden_size: int = 0
13
+ state_size: int = 0
14
+ vocab_size: int = 50000
15
+ expansion_factor: int = 2
16
+ conv_dimension: int = 0
17
+ conv_bias: bool = True
18
+ bias: bool = True
19
+ use_fast_path: bool = True
20
+ dt_rank: str = "auto"
21
+ dt_min: float = 0.001
22
+ dt_max: float = 0.1
23
+ dt_init: str = "random"
24
+ dt_scale: float = 1.0
25
+ dt_init_floor: float = 1e-4
26
+ rms_norm: bool = True
27
+ fused_add_norm: bool = False
28
+ residual_in_fp32: bool = True
29
+ hidden_dropout: float = 0.0
30
+ ffn_hidden_size: int = None
31
+ gated_linear_unit: bool = False
32
+ mamba_moe_layers: str = ""
33
+ routing_mode: str = "sinkhorn"
34
+ device: str = "cuda"
35
+ fp32_residual_connection: bool = False
36
+ layernorm_epsilon: float = 1e-5
37
+ layernorm_zero_centered_gamma: bool = False
38
+ add_bias_linear: bool = True
39
+ activation_func: Callable = F.gelu
40
+ num_moe_experts: int = None
41
+
42
+ # initialization
43
+ init_method: Callable = None
44
+ output_layer_init_method: Callable = None
45
+ init_method_std: float = 0.02
46
+
47
+ # mixed-precision
48
+ apply_query_key_layer_scaling: bool = True
49
+ attention_softmax_in_fp32: bool = True
50
+
51
+ # fusion
52
+ gated_linear_unit: bool = False
53
+ bias_gelu_fusion: bool = False
54
+ persist_layer_norm: bool = False
55
+ bias_dropout_fusion: bool = False
56
+
57
+
58
+ def __post_init__(self):
59
+ """ Python dataclass method that is used to modify attributes after initialization.
60
+ See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
61
+ """
62
+ if self.apply_query_key_layer_scaling:
63
+ self.attention_softmax_in_fp32 = True
64
+
65
+ if self.ffn_hidden_size is None:
66
+ self.ffn_hidden_size = 4 * self.hidden_size
67
+
68
+ if self.apply_query_key_layer_scaling:
69
+ self.attention_softmax_in_fp32 = True
70
+
71
+ if self.bias_gelu_fusion:
72
+ if not self.add_bias_linear:
73
+ raise ValueError(
74
+ "When bias_gelu_fusion is True, add_bias_linear must also be True."
75
+ )
76
+
77
+ if self.activation_func != F.gelu:
78
+ raise ValueError(f'When bias_gelu_fusion is True, activation_func must be F.gelu.')
79
+
80
+ if self.init_method is None:
81
+ self.init_method = init_method_normal(self.init_method_std)
82
+
83
+ if self.output_layer_init_method is None:
84
+ self.output_layer_init_method = scaled_init_method_normal(
85
+ self.init_method_std, self.num_layers
86
+ )
mamba_model.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Literal, Optional, Union
3
+ import functools
4
+ from functools import partial
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import Tensor
8
+ import math
9
+ import os
10
+ from mamba_block import MambaBlock, MambaDecoder
11
+ from mamba_config import MambaConfig
12
+ from hf_utils import *
13
+ import os, json
14
+ from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
15
+ from transformers.utils.hub import cached_file
16
+
17
+
18
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
19
+ def _init_weights(
20
+ module,
21
+ n_layer,
22
+ initializer_range=0.02, # Now only used for embedding layer.
23
+ rescale_prenorm_residual=True,
24
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
25
+ ):
26
+ if isinstance(module, nn.Linear):
27
+ if module.bias is not None:
28
+ if not getattr(module.bias, "_no_reinit", False):
29
+ nn.init.zeros_(module.bias)
30
+ elif isinstance(module, nn.Embedding):
31
+ nn.init.normal_(module.weight, std=initializer_range)
32
+
33
+ if rescale_prenorm_residual:
34
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
35
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
36
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
37
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
38
+ #
39
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
40
+ for name, p in module.named_parameters():
41
+ if name in ["out_proj.weight", "fc2.weight"]:
42
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
43
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
44
+ # We need to reinit p since this code could be called multiple times
45
+ # Having just p *= scale would repeatedly scale it down
46
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
47
+ with torch.no_grad():
48
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
49
+
50
+
51
+ class MambaModel(nn.Module):
52
+ def __init__(
53
+ self,
54
+ config: MambaConfig,
55
+ max_sequence_length: int,
56
+ pre_process: bool = True,
57
+ post_process: bool = True,
58
+ fp16_lm_cross_entropy: bool = False,
59
+ parallel_output: bool = True,
60
+ share_embeddings_and_output_weights: bool = True,
61
+ initializer_cfg = None,
62
+ ) -> None:
63
+ super().__init__()
64
+
65
+ self.config: MambaConfig = config
66
+ self.max_sequence_length = max_sequence_length
67
+ self.pre_process = pre_process
68
+ self.post_process = post_process
69
+ self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
70
+ self.parallel_output = parallel_output
71
+ self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
72
+
73
+ if self.pre_process:
74
+ self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
75
+
76
+
77
+ self.decoder = MambaDecoder(
78
+ config = self.config,
79
+ pre_process = self.pre_process,
80
+ post_process = self.post_process,
81
+ )
82
+
83
+ if post_process:
84
+ self.output_layer = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias = self.config.add_bias_linear)
85
+ if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process):
86
+ self.initialize_last_stage_with_word_embeddings()
87
+
88
+ # apply weight initialization
89
+ self.apply(
90
+ partial(
91
+ _init_weights,
92
+ n_layer=self.config.num_layers,
93
+ **(initializer_cfg if initializer_cfg is not None else {}),
94
+ )
95
+ )
96
+
97
+ def initialize_last_stage_with_word_embeddings(self):
98
+ with torch.no_grad():
99
+ self.output_layer.weight = self.embedding.weight
100
+
101
+ def forward(
102
+ self,
103
+ input_ids,
104
+ position_ids = None,
105
+ decoder_input: Tensor = None,
106
+ labels: Tensor = None,
107
+ inference_params=None,
108
+ ) -> Tensor:
109
+ if decoder_input is not None:
110
+ pass
111
+ elif self.pre_process:
112
+ decoder_input = self.embedding(input_ids)
113
+ else:
114
+ decoder_input = None
115
+
116
+ hidden_states = self.decoder(
117
+ hidden_states=decoder_input,
118
+ residual=None,
119
+ inference_params=inference_params,
120
+ )
121
+
122
+ if not self.post_process:
123
+ return hidden_states
124
+
125
+ logits = self.output_layer(hidden_states)
126
+
127
+ return logits.contiguous()
128
+
129
+ @classmethod
130
+ def from_pretrained(cls, pretrained_model_name = None, checkpoint_name=None, config_name=None, **kwargs):
131
+ if pretrained_model_name is not None:
132
+ json_config = load_config_hf(pretrained_model_name)
133
+ loaded = load_state_dict_hf(pretrained_model_name)
134
+ elif checkpoint_name is not None and config_name is not None:
135
+ with open(config_name, 'r') as f:
136
+ jsonstr = f.read()
137
+ json_config = json.loads(jsonstr)
138
+ loaded = torch.load(checkpoint_name, map_location='cpu')
139
+ else:
140
+ return
141
+ model_state_dict = loaded["model"]
142
+
143
+ config = MambaConfig(
144
+ num_layers=json_config['num_layers'],
145
+ hidden_size=json_config['hidden_size'],
146
+ state_size=json_config['state_size'],
147
+ conv_dimension=json_config['conv_dimension'],
148
+ vocab_size=json_config['vocab_size'],
149
+ expansion_factor=json_config['expansion_factor'],
150
+ mamba_moe_layers=json_config['mamba_moe_layers'],
151
+ ffn_hidden_size=json_config['ffn_hidden_size'],
152
+ bias = json_config['add_bias_linear'],
153
+ add_bias_linear = json_config['add_bias_linear'],
154
+ gated_linear_unit = json_config['swiglu']
155
+ )
156
+
157
+ model = MambaModel(config=config, max_sequence_length=json_config['max_sequence_length'], **kwargs)
158
+
159
+ # make keys match
160
+ model_state_dict["embedding.weight"] = model_state_dict["embedding.word_embeddings.weight"].clone()
161
+ model_state_dict["output_layer.weight"] = model_state_dict["embedding.word_embeddings.weight"].clone()
162
+ model_state_dict["embedding.word_embeddings.weight"] = None
163
+ model_state_dict.pop("embedding.word_embeddings.weight")
164
+ model.load_state_dict(loaded["model"])
165
+ return model
166
+
167
+ def save_pretrained(self, save_directory):
168
+ """
169
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
170
+ Save the model and its configuration file to a directory.
171
+ """
172
+ # Ensure save_directory exists
173
+ if not os.path.exists(save_directory):
174
+ os.makedirs(save_directory)
175
+
176
+ # Save the model's state_dict
177
+ model_path = os.path.join(save_directory, 'pytorch_model.bin')
178
+ torch.save(self.state_dict(), model_path)
179
+
180
+ # Save the configuration of the model
181
+ config_path = os.path.join(save_directory, 'config.json')
182
+ with open(config_path, 'w') as f:
183
+ json.dump(self.config.__dict__, f)
mamba_text_generation.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from mamba_block import MambaBlock
3
+ from mamba_config import MambaConfig
4
+ from mamba_layer import MambaLayer
5
+
6
+ # 創建一個Mamba配置
7
+ config = MambaConfig(
8
+ hidden_size=512,
9
+ num_layers=6,
10
+ num_heads=8,
11
+ intermediate_size=2048,
12
+ max_position_embeddings=1024,
13
+ rms_norm=False,
14
+ residual_in_fp32=False,
15
+ fused_add_norm=False,
16
+ )
17
+
18
+ # 創建一個Mamba模型
19
+ class MambaModel(torch.nn.Module):
20
+ def __init__(self, config):
21
+ super().__init__()
22
+ self.config = config
23
+ self.layers = torch.nn.ModuleList([MambaBlock(config, MambaLayer) for _ in range(config.num_layers)])
24
+ self.norm = torch.nn.LayerNorm(config.hidden_size)
25
+
26
+ def forward(self, hidden_states: torch.Tensor):
27
+ residual = None
28
+ for layer in self.layers:
29
+ hidden_states, residual = layer(hidden_states, residual)
30
+ hidden_states = self.norm(hidden_states + residual if residual is not None else hidden_states)
31
+ return hidden_states
32
+
33
+ # 創建模型實例
34
+ mamba_model = MambaModel(config)
35
+ mamba_model.eval()
36
+
37
+ # Function to generate text from a given prompt using the Mamba model
38
+ def generate_text(prompt, model, max_length=50):
39
+ # 這裡假設你的prompt已經被轉換為嵌入向量
40
+ hidden_states = torch.randn(1, len(prompt), config.hidden_size) # 假設你的輸入序列長度是len(prompt)
41
+
42
+ with torch.no_grad():
43
+ output = model(hidden_states)
44
+
45
+ # 這裡你需要將模型輸出轉換為可讀的文本
46
+ # 這只是一個示例,實際上你可能需要一個解碼器來將輸出轉換為文本
47
+ generated_text = "這裡是生成的文本" # 這裡應該是你的實際生成的文本
48
+
49
+ return generated_text
50
+
51
+ # Function to generate text from a given prompt using the Mamba model
52
+ def generate_uncensored_text(prompt, max_length=50):
53
+ mamba_text = generate_text(prompt, mamba_model, max_length)
54
+ return mamba_text
55
+
56
+ # Example usage
57
+ prompt = "I want to generate some uncensored text."
58
+ uncensored_text = generate_uncensored_text(prompt)
59
+ print(uncensored_text)
mlp.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Union
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from utils import bias_gelu_impl
7
+ from mamba_config import MambaConfig
8
+
9
+ class MLP(nn.Module):
10
+ def __init__(
11
+ self, config: MambaConfig, is_expert: bool = False, layer_idx=None
12
+ ):
13
+ super().__init__()
14
+
15
+ self.config: MambaConfig = config
16
+ self.layer = layer_idx
17
+ ffn_hidden_size_1 = self.config.ffn_hidden_size
18
+ ffn_hidden_size_2 = self.config.ffn_hidden_size
19
+
20
+ # If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
21
+ if self.config.gated_linear_unit:
22
+ ffn_hidden_size_1 *= 2
23
+
24
+ self.linear_fc1 = nn.Linear(self.config.hidden_size, ffn_hidden_size_1, bias = self.config.add_bias_linear, device = self.config.device)
25
+ self.linear_fc1.is_expert = is_expert
26
+
27
+ if self.config.gated_linear_unit:
28
+
29
+ def glu(x):
30
+ x = torch.chunk(x, 2, dim=-1)
31
+ return self.config.activation_func(x[0]) * x[1]
32
+
33
+ self.activation_func = glu
34
+ else:
35
+ self.activation_func = self.config.activation_func
36
+
37
+ self.linear_fc2 = nn.Linear(ffn_hidden_size_2, self.config.hidden_size, bias = self.config.add_bias_linear, device = self.config.device)
38
+
39
+ def forward(self, hidden_states, inference_params=None):
40
+ intermediate = self.linear_fc1(hidden_states)
41
+ intermediate = self.activation_func(intermediate)
42
+ output = self.linear_fc2(intermediate)
43
+ return output
setup.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+ import warnings
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from packaging.version import parse, Version
7
+ from setuptools import setup, find_packages
8
+ import subprocess
9
+
10
+
11
+ import torch
12
+ from torch.utils.cpp_extension import (
13
+ BuildExtension,
14
+ CppExtension,
15
+ CUDAExtension,
16
+ CUDA_HOME,
17
+ )
18
+
19
+ PACKAGE_NAME = "blackmamba"
20
+ VERSION = "0.0.1"
21
+
22
+ with open("README.md", "r", encoding="utf-8") as fh:
23
+ long_description = fh.read()
24
+
25
+
26
+ # ninja build does not work unless include_dirs are abs path
27
+ this_dir = os.path.dirname(os.path.abspath(__file__))
28
+
29
+ # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
30
+ # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
31
+ FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
32
+ SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
33
+ # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
34
+ FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE"
35
+
36
+
37
+ def get_cuda_bare_metal_version(cuda_dir):
38
+ raw_output = subprocess.check_output(
39
+ [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
40
+ )
41
+ output = raw_output.split()
42
+ release_idx = output.index("release") + 1
43
+ bare_metal_version = parse(output[release_idx].split(",")[0])
44
+
45
+ return raw_output, bare_metal_version
46
+
47
+
48
+ def check_if_cuda_home_none(global_option: str) -> None:
49
+ if CUDA_HOME is not None:
50
+ return
51
+ # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
52
+ # in that case.
53
+ warnings.warn(
54
+ f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
55
+ "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
56
+ "only images whose names contain 'devel' will provide nvcc."
57
+ )
58
+
59
+
60
+ def append_nvcc_threads(nvcc_extra_args):
61
+ return nvcc_extra_args + ["--threads", "4"]
62
+
63
+
64
+ ext_modules = []
65
+ if not SKIP_CUDA_BUILD:
66
+ print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
67
+ TORCH_MAJOR = int(torch.__version__.split(".")[0])
68
+ TORCH_MINOR = int(torch.__version__.split(".")[1])
69
+
70
+ check_if_cuda_home_none(PACKAGE_NAME)
71
+ # Check, if CUDA11 is installed for compute capability 8.0
72
+ cc_flag = []
73
+ if CUDA_HOME is not None:
74
+ _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
75
+ if bare_metal_version < Version("11.6"):
76
+ raise RuntimeError(
77
+ f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
78
+ "Note: make sure nvcc has a supported version by running nvcc -V."
79
+ )
80
+
81
+ cc_flag.append("-gencode")
82
+ cc_flag.append("arch=compute_70,code=sm_70")
83
+ cc_flag.append("-gencode")
84
+ cc_flag.append("arch=compute_80,code=sm_80")
85
+ if bare_metal_version >= Version("11.8"):
86
+ cc_flag.append("-gencode")
87
+ cc_flag.append("arch=compute_90,code=sm_90")
88
+
89
+ # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
90
+ # torch._C._GLIBCXX_USE_CXX11_ABI
91
+ # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
92
+ if FORCE_CXX11_ABI:
93
+ torch._C._GLIBCXX_USE_CXX11_ABI = True
94
+
95
+ ext_modules.append(
96
+ CUDAExtension(
97
+ name="selective_scan_cuda",
98
+ sources=[
99
+ "csrc/selective_scan/selective_scan.cpp",
100
+ "csrc/selective_scan/selective_scan_fwd_fp32.cu",
101
+ "csrc/selective_scan/selective_scan_fwd_fp16.cu",
102
+ "csrc/selective_scan/selective_scan_fwd_bf16.cu",
103
+ "csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
104
+ "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
105
+ "csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
106
+ "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
107
+ "csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
108
+ "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
109
+ ],
110
+ extra_compile_args={
111
+ "cxx": ["-O3", "-std=c++17"],
112
+ "nvcc": append_nvcc_threads(
113
+ [
114
+ "-O3",
115
+ "-std=c++17",
116
+ "-U__CUDA_NO_HALF_OPERATORS__",
117
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
118
+ "-U__CUDA_NO_BFLOAT16_OPERATORS__",
119
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
120
+ "-U__CUDA_NO_BFLOAT162_OPERATORS__",
121
+ "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
122
+ "--expt-relaxed-constexpr",
123
+ "--expt-extended-lambda",
124
+ "--use_fast_math",
125
+ "--ptxas-options=-v",
126
+ "-lineinfo",
127
+ ]
128
+ + cc_flag
129
+ ),
130
+ },
131
+ include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
132
+ )
133
+ )
134
+
135
+
136
+ setup(
137
+ name=PACKAGE_NAME,
138
+ version=VERSION,
139
+ description="Blackmamba state-space + MoE model",
140
+ long_description=long_description,
141
+ long_description_content_type="text/markdown",
142
+ packages=find_packages(include=['ops'],),
143
+ exclude=(
144
+ "csrc",
145
+ "blackmamba.egg-info",
146
+ ),
147
+ ext_modules=ext_modules,
148
+ cmdclass={"build_ext": BuildExtension},
149
+ python_requires=">=3.7",
150
+ install_requires=[
151
+ "torch",
152
+ "packaging",
153
+ "ninja",
154
+ "einops",
155
+ "triton",
156
+ "transformers",
157
+ "causal_conv1d>=1.1.0",
158
+ ],
159
+ )
switch_mlp.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pickle
4
+ import os
5
+ import torch.nn.functional as F
6
+
7
+ from mamba_config import MambaConfig
8
+ from mlp import MLP
9
+
10
+ def sinkhorn(cost, tol=0.0001):
11
+ "Sinkhorn based MoE routing function"
12
+ cost = torch.exp(2.0 * cost)
13
+ d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
14
+ # d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
15
+ d1 = 1 / (cost.size(1) * torch.sum(cost, 0))
16
+
17
+ eps = 0.00000001
18
+ error = 1e9
19
+ d1_old = d1
20
+ while error > tol:
21
+ d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
22
+ d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
23
+ error = torch.mean(torch.abs(d1_old - d1))
24
+ d1_old = d1
25
+ return d1 * cost * d0.unsqueeze(1)
26
+
27
+
28
+ class SwitchMLP(nn.Module):
29
+ """
30
+ Top-1 Mixture of Experts Layer. Routes input to one of N MLP "experts"
31
+ Curently supports Sinkhorn based expert routing.
32
+ """
33
+
34
+ def __init__(self, config: MambaConfig, layer_idx=None):
35
+ super().__init__()
36
+
37
+ self.layer = layer_idx
38
+ self.config: MambaConfig = config
39
+ if config.mamba_moe_layers:
40
+ self.num_moe_experts = int(config.mamba_moe_layers[layer_idx-1][-1])
41
+ else:
42
+ self.num_moe_experts = self.config.num_moe_experts
43
+ self.router = torch.nn.Linear(self.config.hidden_size, self.num_moe_experts)
44
+ self.add_bias = config.add_bias_linear
45
+ self.routing = config.routing_mode # 'sinkhorn', 'top1', 'top2', 'sinkhorn_top2'
46
+ self.route_algo = sinkhorn
47
+ self.router_activation = torch.sigmoid
48
+
49
+ self.num_local_experts = self.num_moe_experts
50
+ self.local_expert_indices = [i for i in range(self.num_local_experts)]
51
+
52
+ self.local_experts = torch.nn.ModuleList()
53
+ for _ in range(self.num_local_experts):
54
+ expert = MLP(self.config, is_expert=True, layer_idx=layer_idx)
55
+ self.local_experts.append(expert)
56
+
57
+ def gather_indices(self, local_indices):
58
+ return local_indices
59
+
60
+ def forward(self, hidden_states, inference_params=None):
61
+
62
+ hidden_shape = hidden_states.shape
63
+ route = self.router(hidden_states)
64
+ route = route.view(-1, self.num_moe_experts)
65
+
66
+ if self.routing == 'sinkhorn':
67
+ route = self.router_activation(route)
68
+ max_prob, max_ind = torch.max(route, dim=1)
69
+ else:
70
+ route = torch.softmax(route, dim=1)
71
+ max_prob, max_ind = torch.max(route, dim=1)
72
+
73
+ max_prob = torch.unsqueeze(max_prob, 1)
74
+ hidden_states = hidden_states.view(-1, hidden_shape[-1])
75
+
76
+ global_hidden_states = hidden_states
77
+ global_indices = max_ind
78
+ output_total = torch.zeros_like(global_hidden_states)
79
+
80
+
81
+ for expert_num, expert in enumerate(self.local_experts):
82
+ local_expert_index = self.local_expert_indices[expert_num]
83
+ local_indices = (global_indices == local_expert_index).nonzero()
84
+ hidden = global_hidden_states[local_indices, :]
85
+ output = expert(hidden)
86
+ output_total[local_indices, :] = output
87
+
88
+ output_total = output_total * max_prob
89
+ output_total = output_total.view(hidden_shape)
90
+
91
+ return output_total
utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from operator import itemgetter
2
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
3
+ import math
4
+ import torch
5
+
6
+
7
+ def attention_mask_func(attention_scores, attention_mask):
8
+ attention_scores.masked_fill_(attention_mask, -10000.0)
9
+ return attention_scores
10
+
11
+
12
+ @torch.jit.script
13
+ def gelu_impl(x):
14
+ """OpenAI's gelu implementation."""
15
+ return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
16
+
17
+
18
+ def openai_gelu(x):
19
+ return gelu_impl(x)
20
+
21
+
22
+ @torch.jit.script
23
+ def bias_gelu(bias, y):
24
+ x = bias + y
25
+ return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
26
+
27
+
28
+ # gradient of tanh approximation of gelu
29
+ # gradient of actual gelu is:
30
+ # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
31
+ @torch.jit.script
32
+ def bias_gelu_back(g, bias, y):
33
+ x = bias + y
34
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
35
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
36
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
37
+ 1 + tanh_out
38
+ )
39
+ return ff * g
40
+
41
+
42
+ class GeLUFunction(torch.autograd.Function):
43
+ @staticmethod
44
+ # bias is an optional argument
45
+ def forward(ctx, input, bias):
46
+ ctx.save_for_backward(input, bias)
47
+ return bias_gelu(bias, input)
48
+
49
+ @staticmethod
50
+ def backward(ctx, grad_output):
51
+ input, bias = ctx.saved_tensors
52
+ tmp = bias_gelu_back(grad_output, bias, input)
53
+ return tmp, tmp
54
+
55
+
56
+ bias_gelu_impl = GeLUFunction.apply
57
+
58
+
59
+
60
+ # This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
61
+ @torch.jit.script
62
+ def erf_gelu(x):
63
+ return (
64
+ x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
65
+ )
66
+
67
+
68
+ def init_method_normal(sigma):
69
+
70
+ def init_(tensor):
71
+ return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
72
+
73
+ return init_
74
+
75
+
76
+ def scaled_init_method_normal(sigma, num_layers):
77
+ std = sigma / math.sqrt(2.0 * num_layers)
78
+
79
+ def init_(tensor):
80
+ return torch.nn.init.normal_(tensor, mean=0.0, std=std)
81
+
82
+ return init_