Upload 14 files
Browse files- .gitattributes +35 -35
- .gitignore +5 -0
- Gpt +28 -0
- README.md +46 -13
- __init__.py +1 -0
- hf_utils.py +15 -0
- mamba_block.py +354 -0
- mamba_config.py +86 -0
- mamba_model.py +183 -0
- mamba_text_generation.py +59 -0
- mlp.py +43 -0
- setup.py +159 -0
- switch_mlp.py +91 -0
- utils.py +82 -0
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*__pycache__/
|
2 |
+
*.egg-info/
|
3 |
+
build/
|
4 |
+
**.so
|
5 |
+
**.ipynb
|
Gpt
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
3 |
+
# Load the pre-trained models and tokenizers
|
4 |
+
wormgpt_model = GPT2LMHeadModel.from_pretrained("wormgpt")
|
5 |
+
wormgpt_tokenizer = GPT2Tokenizer.from_pretrained("wormgpt")
|
6 |
+
fraudgpt_model = GPT2LMHeadModel.from_pretrained("fraudgpt")
|
7 |
+
fraudgpt_tokenizer = GPT2Tokenizer.from_pretrained("fraudgpt")
|
8 |
+
xxxgpt_model = GPT2LMHeadModel.from_pretrained("xxxgpt")
|
9 |
+
xxxgpt_tokenizer = GPT2Tokenizer.from_pretrained("xxxgpt")
|
10 |
+
evilgpt_model = GPT2LMHeadModel.from_pretrained("evilgpt")
|
11 |
+
evilgpt_tokenizer = GPT2Tokenizer.from_pretrained("evilgpt")
|
12 |
+
# Function to generate text from a given prompt using the specified model
|
13 |
+
def generate_text(prompt, model, tokenizer, max_length=50):
|
14 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
15 |
+
output = model.generate(input_ids, max_length=max_length, num_return_sequences=1)
|
16 |
+
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
17 |
+
return generated_text
|
18 |
+
# Function to generate text from a given prompt using all four models
|
19 |
+
def generate_uncensored_text(prompt, max_length=50):
|
20 |
+
wormgpt_text = generate_text(prompt, wormgpt_model, wormgpt_tokenizer, max_length)
|
21 |
+
fraudgpt_text = generate_text(prompt, fraudgpt_model, fraudgpt_tokenizer, max_length)
|
22 |
+
xxxgpt_text = generate_text(prompt, xxxgpt_model, xxxgpt_tokenizer, max_length)
|
23 |
+
evilgpt_text = generate_text(prompt, evilgpt_model, evilgpt_tokenizer, max_length)
|
24 |
+
return wormgpt_text + "\n" + fraudgpt_text + "\n" + xxxgpt_text + "\n" + evilgpt_text
|
25 |
+
# Example usage
|
26 |
+
prompt = "I want to generate some uncensored text."
|
27 |
+
uncensored_text = generate_uncensored_text(prompt)
|
28 |
+
print(uncensored_text)
|
README.md
CHANGED
@@ -1,13 +1,46 @@
|
|
1 |
-
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: en
|
3 |
+
tags:
|
4 |
+
- text-generation
|
5 |
+
- malware
|
6 |
+
- malicious-content
|
7 |
+
license: apache-2.0
|
8 |
+
datasets:
|
9 |
+
- malware-dataset
|
10 |
+
---
|
11 |
+
# WormGPT
|
12 |
+
|
13 |
+
WormGPT is a GPT-2 model trained on a large dataset of malicious code to generate code that can be used for malicious purposes. It can be used to generate code for various types of malware, including viruses, worms, Trojans, and other malicious software.
|
14 |
+
|
15 |
+
## Model Details
|
16 |
+
- Model Name: WormGPT
|
17 |
+
- Architecture: GPT-2
|
18 |
+
- Dataset: Malware Dataset
|
19 |
+
- Training Procedure:
|
20 |
+
- Fine-tuned the pre-trained GPT-2 model on the malware dataset using transfer learning.
|
21 |
+
- Trained for a specific number of epochs to improve the model's ability to generate malicious code.
|
22 |
+
- Evaluation Metrics:
|
23 |
+
- Accuracy: Measures the model's ability to generate code that is similar to the training data.
|
24 |
+
- Precision: Measures the model's ability to generate code that is not malicious.
|
25 |
+
- Recall: Measures the model's ability to generate code that is malicious.
|
26 |
+
|
27 |
+
## Usage
|
28 |
+
To use WormGPT, you can simply call the `generate_text` function with a prompt as input. The model will generate a response based on the provided prompt.
|
29 |
+
|
30 |
+
```python
|
31 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
32 |
+
|
33 |
+
# Load the pre-trained WormGPT model and tokenizer
|
34 |
+
model = GPT2LMHeadModel.from_pretrained("wormgpt")
|
35 |
+
tokenizer = GPT2Tokenizer.from_pretrained("wormgpt")
|
36 |
+
|
37 |
+
def generate_text(prompt, max_length=50):
|
38 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
39 |
+
output = model.generate(input_ids, max_length=max_length, num_return_sequences=1)
|
40 |
+
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
41 |
+
return generated_text
|
42 |
+
|
43 |
+
# Example usage
|
44 |
+
prompt = "Generate malicious code for a virus."
|
45 |
+
malicious_code = generate_text(prompt)
|
46 |
+
print(malicious_code)
|
__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
hf_utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
import transformers
|
4 |
+
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
5 |
+
from transformers.utils.hub import cached_file
|
6 |
+
|
7 |
+
|
8 |
+
def load_config_hf(model_name):
|
9 |
+
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
|
10 |
+
return json.load(open(resolved_archive_file))
|
11 |
+
|
12 |
+
|
13 |
+
def load_state_dict_hf(model_name, device="cpu"):
|
14 |
+
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
15 |
+
return torch.load(resolved_archive_file, map_location=device)
|
mamba_block.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
+
from typing import Optional, Union
|
4 |
+
import re
|
5 |
+
from contextlib import nullcontext
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
from dataclasses import dataclass
|
8 |
+
import functools
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch import Tensor
|
15 |
+
from einops import rearrange, repeat
|
16 |
+
|
17 |
+
try:
|
18 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
19 |
+
except ImportError:
|
20 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
21 |
+
|
22 |
+
try:
|
23 |
+
from ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
24 |
+
except ImportError:
|
25 |
+
selective_scan_fn, mamba_inner_fn = None, None
|
26 |
+
|
27 |
+
try:
|
28 |
+
from ops.triton.selective_state_update import selective_state_update
|
29 |
+
except ImportError:
|
30 |
+
selective_state_update = None
|
31 |
+
|
32 |
+
try:
|
33 |
+
from ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
34 |
+
except ImportError:
|
35 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
36 |
+
|
37 |
+
from mamba_layer import MambaLayer
|
38 |
+
from mamba_config import MambaConfig
|
39 |
+
from mlp import MLP
|
40 |
+
from switch_mlp import SwitchMLP
|
41 |
+
|
42 |
+
|
43 |
+
class MambaBlock(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
self.config = config
|
49 |
+
self.residual_in_fp32 = residual_in_fp32
|
50 |
+
self.fused_add_norm = fused_add_norm
|
51 |
+
self.mixer = mixer_cls(config)
|
52 |
+
|
53 |
+
if not config.rms_norm:
|
54 |
+
self.norm = norm_cls
|
55 |
+
else:
|
56 |
+
self.norm = norm_cls(config.hidden_size)
|
57 |
+
|
58 |
+
if self.fused_add_norm:
|
59 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
60 |
+
assert isinstance(
|
61 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
62 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
63 |
+
if moe_cls is not None:
|
64 |
+
self.moe = moe_cls(config)
|
65 |
+
else:
|
66 |
+
self.moe = None
|
67 |
+
|
68 |
+
def forward(
|
69 |
+
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
70 |
+
):
|
71 |
+
|
72 |
+
if not self.fused_add_norm:
|
73 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
74 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
75 |
+
if self.residual_in_fp32:
|
76 |
+
residual = residual.to(torch.float32)
|
77 |
+
else:
|
78 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
|
79 |
+
hidden_states, residual = fused_add_norm_fn(
|
80 |
+
hidden_states,
|
81 |
+
self.norm.weight,
|
82 |
+
self.norm.bias,
|
83 |
+
residual=residual,
|
84 |
+
prenorm=True,
|
85 |
+
residual_in_fp32=self.residual_in_fp32,
|
86 |
+
eps=self.norm.eps,
|
87 |
+
)
|
88 |
+
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
89 |
+
return hidden_states , residual
|
90 |
+
|
91 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
92 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
93 |
+
|
94 |
+
class MambaBlockParallelMoe(nn.Module):
|
95 |
+
def __init__(
|
96 |
+
self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, norm_moe=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
97 |
+
):
|
98 |
+
|
99 |
+
super().__init__()
|
100 |
+
self.config = config
|
101 |
+
self.residual_in_fp32 = residual_in_fp32
|
102 |
+
self.fused_add_norm = fused_add_norm
|
103 |
+
self.mixer = mixer_cls(config)
|
104 |
+
if not config.rms_norm:
|
105 |
+
self.norm = norm_cls
|
106 |
+
self.norm_moe = norm_moe
|
107 |
+
else:
|
108 |
+
self.norm = norm_cls(config.hidden_size)
|
109 |
+
self.norm_moe = norm_moe(config.hidden_size)
|
110 |
+
if self.fused_add_norm:
|
111 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
112 |
+
assert isinstance(
|
113 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
114 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
115 |
+
assert isinstance(
|
116 |
+
self.norm_moe, (nn.LayerNorm, RMSNorm)
|
117 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
118 |
+
if moe_cls is not None:
|
119 |
+
self.moe = moe_cls(config)
|
120 |
+
else:
|
121 |
+
self.moe = None
|
122 |
+
|
123 |
+
def forward(
|
124 |
+
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
125 |
+
):
|
126 |
+
|
127 |
+
if not self.fused_add_norm:
|
128 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
129 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
130 |
+
hidden_states_moe = self.norm_moe(residual.to(dtype=self.norm.weight.dtype))
|
131 |
+
if self.residual_in_fp32:
|
132 |
+
residual = residual.to(torch.float32)
|
133 |
+
else:
|
134 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
|
135 |
+
hidden_states, residual = fused_add_norm_fn(
|
136 |
+
hidden_states,
|
137 |
+
self.norm.weight,
|
138 |
+
self.norm.bias,
|
139 |
+
residual=residual,
|
140 |
+
prenorm=True,
|
141 |
+
residual_in_fp32=self.residual_in_fp32,
|
142 |
+
eps=self.norm.eps,
|
143 |
+
)
|
144 |
+
hidden_states_moe, _ = fused_add_norm_fn(
|
145 |
+
hidden_states,
|
146 |
+
self.norm_moe.weight,
|
147 |
+
self.norm_moe.bias,
|
148 |
+
residual=residual,
|
149 |
+
prenorm=True,
|
150 |
+
residual_in_fp32=self.residual_in_fp32,
|
151 |
+
eps=self.norm_moe.eps,
|
152 |
+
)
|
153 |
+
|
154 |
+
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
155 |
+
|
156 |
+
hidden_states_moe = self.moe(hidden_states_moe)
|
157 |
+
hidden_states += hidden_states_moe
|
158 |
+
return hidden_states , residual
|
159 |
+
|
160 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
161 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
162 |
+
|
163 |
+
|
164 |
+
class MoEBlock(nn.Module):
|
165 |
+
def __init__(
|
166 |
+
self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
167 |
+
):
|
168 |
+
|
169 |
+
super().__init__()
|
170 |
+
self.config = config
|
171 |
+
self.residual_in_fp32 = residual_in_fp32
|
172 |
+
self.fused_add_norm = fused_add_norm
|
173 |
+
self.mixer = mixer_cls(config)
|
174 |
+
if not config.rms_norm:
|
175 |
+
self.norm = norm_cls
|
176 |
+
else:
|
177 |
+
self.norm = norm_cls(config.hidden_size)
|
178 |
+
if self.fused_add_norm:
|
179 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
180 |
+
assert isinstance(
|
181 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
182 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
183 |
+
if moe_cls is not None:
|
184 |
+
self.moe = moe_cls(config)
|
185 |
+
else:
|
186 |
+
self.moe = None
|
187 |
+
|
188 |
+
def forward(
|
189 |
+
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
190 |
+
):
|
191 |
+
if not self.fused_add_norm:
|
192 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
193 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
194 |
+
if self.residual_in_fp32:
|
195 |
+
residual = residual.to(torch.float32)
|
196 |
+
else:
|
197 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
|
198 |
+
hidden_states, residual = fused_add_norm_fn(
|
199 |
+
hidden_states,
|
200 |
+
self.norm.weight,
|
201 |
+
self.norm.bias,
|
202 |
+
residual=residual,
|
203 |
+
prenorm=True,
|
204 |
+
residual_in_fp32=self.residual_in_fp32,
|
205 |
+
eps=self.norm.eps,
|
206 |
+
)
|
207 |
+
hidden_states = self.mixer(hidden_states)
|
208 |
+
return hidden_states , residual
|
209 |
+
|
210 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
211 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
212 |
+
|
213 |
+
|
214 |
+
def create_block(config, layer_idx):
|
215 |
+
|
216 |
+
if config.rms_norm:
|
217 |
+
norm_cls = partial(RMSNorm, eps=config.layernorm_epsilon)
|
218 |
+
else:
|
219 |
+
norm_cls = partial(nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.layernorm_epsilon)
|
220 |
+
|
221 |
+
if (not config.mamba_moe_layers) or config.mamba_moe_layers[layer_idx-1][0] == 'r':
|
222 |
+
if (not config.mamba_moe_layers) or len(config.mamba_moe_layers[layer_idx-1]) == 1:
|
223 |
+
mixer_cls = partial(MambaLayer, layer_idx=layer_idx)
|
224 |
+
block = MambaBlock(
|
225 |
+
config,
|
226 |
+
mixer_cls=mixer_cls,
|
227 |
+
norm_cls=norm_cls,
|
228 |
+
fused_add_norm=config.fused_add_norm,
|
229 |
+
residual_in_fp32=config.residual_in_fp32,
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
if config.mamba_moe_layers[layer_idx-1][1] == '1':
|
233 |
+
if config.rms_norm:
|
234 |
+
norm_moe = partial(RMSNorm, eps=config.layernorm_epsilon)
|
235 |
+
else:
|
236 |
+
norm_moe = partial(
|
237 |
+
nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.layernorm_epsilon
|
238 |
+
)
|
239 |
+
mixer_cls = partial(MambaLayer, layer_idx=layer_idx)
|
240 |
+
moe_cls = partial(MLP, layer_idx=layer_idx)
|
241 |
+
block = MambaBlockParallelMoe(
|
242 |
+
config,
|
243 |
+
mixer_cls=mixer_cls,
|
244 |
+
moe_cls=moe_cls,
|
245 |
+
norm_cls=norm_cls,
|
246 |
+
norm_moe=norm_moe,
|
247 |
+
fused_add_norm=config.fused_add_norm,
|
248 |
+
residual_in_fp32=config.residual_in_fp32,
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
if config.rms_norm:
|
252 |
+
norm_moe = partial(RMSNorm, eps=config.layernorm_epsilon)
|
253 |
+
else:
|
254 |
+
norm_moe = partial(
|
255 |
+
nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.layernorm_epsilon
|
256 |
+
)
|
257 |
+
mixer_cls = partial(MambaLayer, layer_idx=layer_idx)
|
258 |
+
moe_cls = partial(SwitchMLP, layer_idx=layer_idx)
|
259 |
+
block = MambaBlockParallelMoe(
|
260 |
+
config,
|
261 |
+
mixer_cls=mixer_cls,
|
262 |
+
moe_cls=moe_cls,
|
263 |
+
norm_cls=norm_cls,
|
264 |
+
norm_moe=norm_moe,
|
265 |
+
fused_add_norm=config.fused_add_norm,
|
266 |
+
residual_in_fp32=config.residual_in_fp32,
|
267 |
+
)
|
268 |
+
else:
|
269 |
+
if config.mamba_moe_layers[layer_idx-1][0] == '1':
|
270 |
+
mixer_cls = partial(MLP, layer_idx=layer_idx)
|
271 |
+
block = MoEBlock(
|
272 |
+
config,
|
273 |
+
mixer_cls=mixer_cls,
|
274 |
+
norm_cls=norm_cls,
|
275 |
+
fused_add_norm=config.fused_add_norm,
|
276 |
+
residual_in_fp32=config.residual_in_fp32,
|
277 |
+
)
|
278 |
+
else:
|
279 |
+
mixer_cls = partial(SwitchMLP, layer_idx=layer_idx)
|
280 |
+
block = MoEBlock(
|
281 |
+
config,
|
282 |
+
mixer_cls=mixer_cls,
|
283 |
+
norm_cls=norm_cls,
|
284 |
+
fused_add_norm=config.fused_add_norm,
|
285 |
+
residual_in_fp32=config.residual_in_fp32,
|
286 |
+
)
|
287 |
+
block.layer_idx = layer_idx
|
288 |
+
return block
|
289 |
+
|
290 |
+
class MambaDecoder(nn.Module):
|
291 |
+
"""Class wrapping a decoder stack of mamba blocks."""
|
292 |
+
|
293 |
+
def __init__(
|
294 |
+
self,
|
295 |
+
config: MambaConfig,
|
296 |
+
post_layer_norm=True,
|
297 |
+
pre_process=True,
|
298 |
+
post_process=True,
|
299 |
+
):
|
300 |
+
super().__init__()
|
301 |
+
|
302 |
+
self.config: MambaConfig = config
|
303 |
+
self.post_layer_norm = post_layer_norm
|
304 |
+
self.pre_process = pre_process
|
305 |
+
self.post_process = post_process
|
306 |
+
self.norm_cls = partial(nn.LayerNorm, eps=self.config.layernorm_epsilon)
|
307 |
+
|
308 |
+
self._build_layers()
|
309 |
+
|
310 |
+
def _build_layers(self):
|
311 |
+
|
312 |
+
num_layers_to_build = self.config.num_layers
|
313 |
+
# build the actual mamba layers
|
314 |
+
self.layers = torch.nn.ModuleList([create_block(self.config, i + 1) for i in range(num_layers_to_build)])
|
315 |
+
|
316 |
+
if self.post_process and self.post_layer_norm:
|
317 |
+
# Final layer norm before output.
|
318 |
+
self.final_layernorm = self.norm_cls(self.config.hidden_size, bias = True)
|
319 |
+
|
320 |
+
def _get_layer(self, layer_number):
|
321 |
+
return self.layers[layer_number]
|
322 |
+
|
323 |
+
def forward(self, hidden_states, residual = None, inference_params=None):
|
324 |
+
|
325 |
+
if not self.pre_process:
|
326 |
+
# See set_input_tensor()
|
327 |
+
hidden_states = self.input_tensor
|
328 |
+
|
329 |
+
residual = None
|
330 |
+
for i,layer in enumerate(self.layers):
|
331 |
+
hidden_states, residual = layer(
|
332 |
+
hidden_states=hidden_states,
|
333 |
+
residual = residual,
|
334 |
+
inference_params=inference_params,
|
335 |
+
)
|
336 |
+
|
337 |
+
# Final layer norm.
|
338 |
+
if self.post_process and self.post_layer_norm:
|
339 |
+
if not self.config.fused_add_norm:
|
340 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
341 |
+
hidden_states = self.final_layernorm(residual.to(dtype=self.final_layernorm.weight.dtype))
|
342 |
+
else:
|
343 |
+
# Set prenorm=False here since we don't need the residual
|
344 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.final_layernorm, RMSNorm) else layer_norm_fn
|
345 |
+
hidden_states = fused_add_norm_fn(
|
346 |
+
hidden_states,
|
347 |
+
self.final_layernorm.weight,
|
348 |
+
self.final_layernorm.bias,
|
349 |
+
eps=self.final_layernorm.eps,
|
350 |
+
residual=residual,
|
351 |
+
prenorm=False,
|
352 |
+
residual_in_fp32=self.residual_in_fp32,
|
353 |
+
)
|
354 |
+
return hidden_states
|
mamba_config.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Callable
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from utils import init_method_normal, scaled_init_method_normal
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class MambaConfig():
|
10 |
+
base_model_type: str = "mamba"
|
11 |
+
num_layers: int = 0
|
12 |
+
hidden_size: int = 0
|
13 |
+
state_size: int = 0
|
14 |
+
vocab_size: int = 50000
|
15 |
+
expansion_factor: int = 2
|
16 |
+
conv_dimension: int = 0
|
17 |
+
conv_bias: bool = True
|
18 |
+
bias: bool = True
|
19 |
+
use_fast_path: bool = True
|
20 |
+
dt_rank: str = "auto"
|
21 |
+
dt_min: float = 0.001
|
22 |
+
dt_max: float = 0.1
|
23 |
+
dt_init: str = "random"
|
24 |
+
dt_scale: float = 1.0
|
25 |
+
dt_init_floor: float = 1e-4
|
26 |
+
rms_norm: bool = True
|
27 |
+
fused_add_norm: bool = False
|
28 |
+
residual_in_fp32: bool = True
|
29 |
+
hidden_dropout: float = 0.0
|
30 |
+
ffn_hidden_size: int = None
|
31 |
+
gated_linear_unit: bool = False
|
32 |
+
mamba_moe_layers: str = ""
|
33 |
+
routing_mode: str = "sinkhorn"
|
34 |
+
device: str = "cuda"
|
35 |
+
fp32_residual_connection: bool = False
|
36 |
+
layernorm_epsilon: float = 1e-5
|
37 |
+
layernorm_zero_centered_gamma: bool = False
|
38 |
+
add_bias_linear: bool = True
|
39 |
+
activation_func: Callable = F.gelu
|
40 |
+
num_moe_experts: int = None
|
41 |
+
|
42 |
+
# initialization
|
43 |
+
init_method: Callable = None
|
44 |
+
output_layer_init_method: Callable = None
|
45 |
+
init_method_std: float = 0.02
|
46 |
+
|
47 |
+
# mixed-precision
|
48 |
+
apply_query_key_layer_scaling: bool = True
|
49 |
+
attention_softmax_in_fp32: bool = True
|
50 |
+
|
51 |
+
# fusion
|
52 |
+
gated_linear_unit: bool = False
|
53 |
+
bias_gelu_fusion: bool = False
|
54 |
+
persist_layer_norm: bool = False
|
55 |
+
bias_dropout_fusion: bool = False
|
56 |
+
|
57 |
+
|
58 |
+
def __post_init__(self):
|
59 |
+
""" Python dataclass method that is used to modify attributes after initialization.
|
60 |
+
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
|
61 |
+
"""
|
62 |
+
if self.apply_query_key_layer_scaling:
|
63 |
+
self.attention_softmax_in_fp32 = True
|
64 |
+
|
65 |
+
if self.ffn_hidden_size is None:
|
66 |
+
self.ffn_hidden_size = 4 * self.hidden_size
|
67 |
+
|
68 |
+
if self.apply_query_key_layer_scaling:
|
69 |
+
self.attention_softmax_in_fp32 = True
|
70 |
+
|
71 |
+
if self.bias_gelu_fusion:
|
72 |
+
if not self.add_bias_linear:
|
73 |
+
raise ValueError(
|
74 |
+
"When bias_gelu_fusion is True, add_bias_linear must also be True."
|
75 |
+
)
|
76 |
+
|
77 |
+
if self.activation_func != F.gelu:
|
78 |
+
raise ValueError(f'When bias_gelu_fusion is True, activation_func must be F.gelu.')
|
79 |
+
|
80 |
+
if self.init_method is None:
|
81 |
+
self.init_method = init_method_normal(self.init_method_std)
|
82 |
+
|
83 |
+
if self.output_layer_init_method is None:
|
84 |
+
self.output_layer_init_method = scaled_init_method_normal(
|
85 |
+
self.init_method_std, self.num_layers
|
86 |
+
)
|
mamba_model.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Literal, Optional, Union
|
3 |
+
import functools
|
4 |
+
from functools import partial
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch import Tensor
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
from mamba_block import MambaBlock, MambaDecoder
|
11 |
+
from mamba_config import MambaConfig
|
12 |
+
from hf_utils import *
|
13 |
+
import os, json
|
14 |
+
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
15 |
+
from transformers.utils.hub import cached_file
|
16 |
+
|
17 |
+
|
18 |
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
19 |
+
def _init_weights(
|
20 |
+
module,
|
21 |
+
n_layer,
|
22 |
+
initializer_range=0.02, # Now only used for embedding layer.
|
23 |
+
rescale_prenorm_residual=True,
|
24 |
+
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
25 |
+
):
|
26 |
+
if isinstance(module, nn.Linear):
|
27 |
+
if module.bias is not None:
|
28 |
+
if not getattr(module.bias, "_no_reinit", False):
|
29 |
+
nn.init.zeros_(module.bias)
|
30 |
+
elif isinstance(module, nn.Embedding):
|
31 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
32 |
+
|
33 |
+
if rescale_prenorm_residual:
|
34 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
35 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
36 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
37 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
38 |
+
#
|
39 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
40 |
+
for name, p in module.named_parameters():
|
41 |
+
if name in ["out_proj.weight", "fc2.weight"]:
|
42 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
43 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
44 |
+
# We need to reinit p since this code could be called multiple times
|
45 |
+
# Having just p *= scale would repeatedly scale it down
|
46 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
47 |
+
with torch.no_grad():
|
48 |
+
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
49 |
+
|
50 |
+
|
51 |
+
class MambaModel(nn.Module):
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
config: MambaConfig,
|
55 |
+
max_sequence_length: int,
|
56 |
+
pre_process: bool = True,
|
57 |
+
post_process: bool = True,
|
58 |
+
fp16_lm_cross_entropy: bool = False,
|
59 |
+
parallel_output: bool = True,
|
60 |
+
share_embeddings_and_output_weights: bool = True,
|
61 |
+
initializer_cfg = None,
|
62 |
+
) -> None:
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.config: MambaConfig = config
|
66 |
+
self.max_sequence_length = max_sequence_length
|
67 |
+
self.pre_process = pre_process
|
68 |
+
self.post_process = post_process
|
69 |
+
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
|
70 |
+
self.parallel_output = parallel_output
|
71 |
+
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
|
72 |
+
|
73 |
+
if self.pre_process:
|
74 |
+
self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
|
75 |
+
|
76 |
+
|
77 |
+
self.decoder = MambaDecoder(
|
78 |
+
config = self.config,
|
79 |
+
pre_process = self.pre_process,
|
80 |
+
post_process = self.post_process,
|
81 |
+
)
|
82 |
+
|
83 |
+
if post_process:
|
84 |
+
self.output_layer = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias = self.config.add_bias_linear)
|
85 |
+
if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process):
|
86 |
+
self.initialize_last_stage_with_word_embeddings()
|
87 |
+
|
88 |
+
# apply weight initialization
|
89 |
+
self.apply(
|
90 |
+
partial(
|
91 |
+
_init_weights,
|
92 |
+
n_layer=self.config.num_layers,
|
93 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
94 |
+
)
|
95 |
+
)
|
96 |
+
|
97 |
+
def initialize_last_stage_with_word_embeddings(self):
|
98 |
+
with torch.no_grad():
|
99 |
+
self.output_layer.weight = self.embedding.weight
|
100 |
+
|
101 |
+
def forward(
|
102 |
+
self,
|
103 |
+
input_ids,
|
104 |
+
position_ids = None,
|
105 |
+
decoder_input: Tensor = None,
|
106 |
+
labels: Tensor = None,
|
107 |
+
inference_params=None,
|
108 |
+
) -> Tensor:
|
109 |
+
if decoder_input is not None:
|
110 |
+
pass
|
111 |
+
elif self.pre_process:
|
112 |
+
decoder_input = self.embedding(input_ids)
|
113 |
+
else:
|
114 |
+
decoder_input = None
|
115 |
+
|
116 |
+
hidden_states = self.decoder(
|
117 |
+
hidden_states=decoder_input,
|
118 |
+
residual=None,
|
119 |
+
inference_params=inference_params,
|
120 |
+
)
|
121 |
+
|
122 |
+
if not self.post_process:
|
123 |
+
return hidden_states
|
124 |
+
|
125 |
+
logits = self.output_layer(hidden_states)
|
126 |
+
|
127 |
+
return logits.contiguous()
|
128 |
+
|
129 |
+
@classmethod
|
130 |
+
def from_pretrained(cls, pretrained_model_name = None, checkpoint_name=None, config_name=None, **kwargs):
|
131 |
+
if pretrained_model_name is not None:
|
132 |
+
json_config = load_config_hf(pretrained_model_name)
|
133 |
+
loaded = load_state_dict_hf(pretrained_model_name)
|
134 |
+
elif checkpoint_name is not None and config_name is not None:
|
135 |
+
with open(config_name, 'r') as f:
|
136 |
+
jsonstr = f.read()
|
137 |
+
json_config = json.loads(jsonstr)
|
138 |
+
loaded = torch.load(checkpoint_name, map_location='cpu')
|
139 |
+
else:
|
140 |
+
return
|
141 |
+
model_state_dict = loaded["model"]
|
142 |
+
|
143 |
+
config = MambaConfig(
|
144 |
+
num_layers=json_config['num_layers'],
|
145 |
+
hidden_size=json_config['hidden_size'],
|
146 |
+
state_size=json_config['state_size'],
|
147 |
+
conv_dimension=json_config['conv_dimension'],
|
148 |
+
vocab_size=json_config['vocab_size'],
|
149 |
+
expansion_factor=json_config['expansion_factor'],
|
150 |
+
mamba_moe_layers=json_config['mamba_moe_layers'],
|
151 |
+
ffn_hidden_size=json_config['ffn_hidden_size'],
|
152 |
+
bias = json_config['add_bias_linear'],
|
153 |
+
add_bias_linear = json_config['add_bias_linear'],
|
154 |
+
gated_linear_unit = json_config['swiglu']
|
155 |
+
)
|
156 |
+
|
157 |
+
model = MambaModel(config=config, max_sequence_length=json_config['max_sequence_length'], **kwargs)
|
158 |
+
|
159 |
+
# make keys match
|
160 |
+
model_state_dict["embedding.weight"] = model_state_dict["embedding.word_embeddings.weight"].clone()
|
161 |
+
model_state_dict["output_layer.weight"] = model_state_dict["embedding.word_embeddings.weight"].clone()
|
162 |
+
model_state_dict["embedding.word_embeddings.weight"] = None
|
163 |
+
model_state_dict.pop("embedding.word_embeddings.weight")
|
164 |
+
model.load_state_dict(loaded["model"])
|
165 |
+
return model
|
166 |
+
|
167 |
+
def save_pretrained(self, save_directory):
|
168 |
+
"""
|
169 |
+
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
170 |
+
Save the model and its configuration file to a directory.
|
171 |
+
"""
|
172 |
+
# Ensure save_directory exists
|
173 |
+
if not os.path.exists(save_directory):
|
174 |
+
os.makedirs(save_directory)
|
175 |
+
|
176 |
+
# Save the model's state_dict
|
177 |
+
model_path = os.path.join(save_directory, 'pytorch_model.bin')
|
178 |
+
torch.save(self.state_dict(), model_path)
|
179 |
+
|
180 |
+
# Save the configuration of the model
|
181 |
+
config_path = os.path.join(save_directory, 'config.json')
|
182 |
+
with open(config_path, 'w') as f:
|
183 |
+
json.dump(self.config.__dict__, f)
|
mamba_text_generation.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from mamba_block import MambaBlock
|
3 |
+
from mamba_config import MambaConfig
|
4 |
+
from mamba_layer import MambaLayer
|
5 |
+
|
6 |
+
# 創建一個Mamba配置
|
7 |
+
config = MambaConfig(
|
8 |
+
hidden_size=512,
|
9 |
+
num_layers=6,
|
10 |
+
num_heads=8,
|
11 |
+
intermediate_size=2048,
|
12 |
+
max_position_embeddings=1024,
|
13 |
+
rms_norm=False,
|
14 |
+
residual_in_fp32=False,
|
15 |
+
fused_add_norm=False,
|
16 |
+
)
|
17 |
+
|
18 |
+
# 創建一個Mamba模型
|
19 |
+
class MambaModel(torch.nn.Module):
|
20 |
+
def __init__(self, config):
|
21 |
+
super().__init__()
|
22 |
+
self.config = config
|
23 |
+
self.layers = torch.nn.ModuleList([MambaBlock(config, MambaLayer) for _ in range(config.num_layers)])
|
24 |
+
self.norm = torch.nn.LayerNorm(config.hidden_size)
|
25 |
+
|
26 |
+
def forward(self, hidden_states: torch.Tensor):
|
27 |
+
residual = None
|
28 |
+
for layer in self.layers:
|
29 |
+
hidden_states, residual = layer(hidden_states, residual)
|
30 |
+
hidden_states = self.norm(hidden_states + residual if residual is not None else hidden_states)
|
31 |
+
return hidden_states
|
32 |
+
|
33 |
+
# 創建模型實例
|
34 |
+
mamba_model = MambaModel(config)
|
35 |
+
mamba_model.eval()
|
36 |
+
|
37 |
+
# Function to generate text from a given prompt using the Mamba model
|
38 |
+
def generate_text(prompt, model, max_length=50):
|
39 |
+
# 這裡假設你的prompt已經被轉換為嵌入向量
|
40 |
+
hidden_states = torch.randn(1, len(prompt), config.hidden_size) # 假設你的輸入序列長度是len(prompt)
|
41 |
+
|
42 |
+
with torch.no_grad():
|
43 |
+
output = model(hidden_states)
|
44 |
+
|
45 |
+
# 這裡你需要將模型輸出轉換為可讀的文本
|
46 |
+
# 這只是一個示例,實際上你可能需要一個解碼器來將輸出轉換為文本
|
47 |
+
generated_text = "這裡是生成的文本" # 這裡應該是你的實際生成的文本
|
48 |
+
|
49 |
+
return generated_text
|
50 |
+
|
51 |
+
# Function to generate text from a given prompt using the Mamba model
|
52 |
+
def generate_uncensored_text(prompt, max_length=50):
|
53 |
+
mamba_text = generate_text(prompt, mamba_model, max_length)
|
54 |
+
return mamba_text
|
55 |
+
|
56 |
+
# Example usage
|
57 |
+
prompt = "I want to generate some uncensored text."
|
58 |
+
uncensored_text = generate_uncensored_text(prompt)
|
59 |
+
print(uncensored_text)
|
mlp.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Union
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from utils import bias_gelu_impl
|
7 |
+
from mamba_config import MambaConfig
|
8 |
+
|
9 |
+
class MLP(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self, config: MambaConfig, is_expert: bool = False, layer_idx=None
|
12 |
+
):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.config: MambaConfig = config
|
16 |
+
self.layer = layer_idx
|
17 |
+
ffn_hidden_size_1 = self.config.ffn_hidden_size
|
18 |
+
ffn_hidden_size_2 = self.config.ffn_hidden_size
|
19 |
+
|
20 |
+
# If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
21 |
+
if self.config.gated_linear_unit:
|
22 |
+
ffn_hidden_size_1 *= 2
|
23 |
+
|
24 |
+
self.linear_fc1 = nn.Linear(self.config.hidden_size, ffn_hidden_size_1, bias = self.config.add_bias_linear, device = self.config.device)
|
25 |
+
self.linear_fc1.is_expert = is_expert
|
26 |
+
|
27 |
+
if self.config.gated_linear_unit:
|
28 |
+
|
29 |
+
def glu(x):
|
30 |
+
x = torch.chunk(x, 2, dim=-1)
|
31 |
+
return self.config.activation_func(x[0]) * x[1]
|
32 |
+
|
33 |
+
self.activation_func = glu
|
34 |
+
else:
|
35 |
+
self.activation_func = self.config.activation_func
|
36 |
+
|
37 |
+
self.linear_fc2 = nn.Linear(ffn_hidden_size_2, self.config.hidden_size, bias = self.config.add_bias_linear, device = self.config.device)
|
38 |
+
|
39 |
+
def forward(self, hidden_states, inference_params=None):
|
40 |
+
intermediate = self.linear_fc1(hidden_states)
|
41 |
+
intermediate = self.activation_func(intermediate)
|
42 |
+
output = self.linear_fc2(intermediate)
|
43 |
+
return output
|
setup.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
2 |
+
import warnings
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from packaging.version import parse, Version
|
7 |
+
from setuptools import setup, find_packages
|
8 |
+
import subprocess
|
9 |
+
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch.utils.cpp_extension import (
|
13 |
+
BuildExtension,
|
14 |
+
CppExtension,
|
15 |
+
CUDAExtension,
|
16 |
+
CUDA_HOME,
|
17 |
+
)
|
18 |
+
|
19 |
+
PACKAGE_NAME = "blackmamba"
|
20 |
+
VERSION = "0.0.1"
|
21 |
+
|
22 |
+
with open("README.md", "r", encoding="utf-8") as fh:
|
23 |
+
long_description = fh.read()
|
24 |
+
|
25 |
+
|
26 |
+
# ninja build does not work unless include_dirs are abs path
|
27 |
+
this_dir = os.path.dirname(os.path.abspath(__file__))
|
28 |
+
|
29 |
+
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
|
30 |
+
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
|
31 |
+
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
|
32 |
+
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
|
33 |
+
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
|
34 |
+
FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE"
|
35 |
+
|
36 |
+
|
37 |
+
def get_cuda_bare_metal_version(cuda_dir):
|
38 |
+
raw_output = subprocess.check_output(
|
39 |
+
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
40 |
+
)
|
41 |
+
output = raw_output.split()
|
42 |
+
release_idx = output.index("release") + 1
|
43 |
+
bare_metal_version = parse(output[release_idx].split(",")[0])
|
44 |
+
|
45 |
+
return raw_output, bare_metal_version
|
46 |
+
|
47 |
+
|
48 |
+
def check_if_cuda_home_none(global_option: str) -> None:
|
49 |
+
if CUDA_HOME is not None:
|
50 |
+
return
|
51 |
+
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
|
52 |
+
# in that case.
|
53 |
+
warnings.warn(
|
54 |
+
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
|
55 |
+
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
|
56 |
+
"only images whose names contain 'devel' will provide nvcc."
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
def append_nvcc_threads(nvcc_extra_args):
|
61 |
+
return nvcc_extra_args + ["--threads", "4"]
|
62 |
+
|
63 |
+
|
64 |
+
ext_modules = []
|
65 |
+
if not SKIP_CUDA_BUILD:
|
66 |
+
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
67 |
+
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
68 |
+
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
69 |
+
|
70 |
+
check_if_cuda_home_none(PACKAGE_NAME)
|
71 |
+
# Check, if CUDA11 is installed for compute capability 8.0
|
72 |
+
cc_flag = []
|
73 |
+
if CUDA_HOME is not None:
|
74 |
+
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
75 |
+
if bare_metal_version < Version("11.6"):
|
76 |
+
raise RuntimeError(
|
77 |
+
f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
|
78 |
+
"Note: make sure nvcc has a supported version by running nvcc -V."
|
79 |
+
)
|
80 |
+
|
81 |
+
cc_flag.append("-gencode")
|
82 |
+
cc_flag.append("arch=compute_70,code=sm_70")
|
83 |
+
cc_flag.append("-gencode")
|
84 |
+
cc_flag.append("arch=compute_80,code=sm_80")
|
85 |
+
if bare_metal_version >= Version("11.8"):
|
86 |
+
cc_flag.append("-gencode")
|
87 |
+
cc_flag.append("arch=compute_90,code=sm_90")
|
88 |
+
|
89 |
+
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
|
90 |
+
# torch._C._GLIBCXX_USE_CXX11_ABI
|
91 |
+
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
|
92 |
+
if FORCE_CXX11_ABI:
|
93 |
+
torch._C._GLIBCXX_USE_CXX11_ABI = True
|
94 |
+
|
95 |
+
ext_modules.append(
|
96 |
+
CUDAExtension(
|
97 |
+
name="selective_scan_cuda",
|
98 |
+
sources=[
|
99 |
+
"csrc/selective_scan/selective_scan.cpp",
|
100 |
+
"csrc/selective_scan/selective_scan_fwd_fp32.cu",
|
101 |
+
"csrc/selective_scan/selective_scan_fwd_fp16.cu",
|
102 |
+
"csrc/selective_scan/selective_scan_fwd_bf16.cu",
|
103 |
+
"csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
|
104 |
+
"csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
|
105 |
+
"csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
|
106 |
+
"csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
|
107 |
+
"csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
|
108 |
+
"csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
|
109 |
+
],
|
110 |
+
extra_compile_args={
|
111 |
+
"cxx": ["-O3", "-std=c++17"],
|
112 |
+
"nvcc": append_nvcc_threads(
|
113 |
+
[
|
114 |
+
"-O3",
|
115 |
+
"-std=c++17",
|
116 |
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
117 |
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
118 |
+
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
119 |
+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
120 |
+
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
|
121 |
+
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
|
122 |
+
"--expt-relaxed-constexpr",
|
123 |
+
"--expt-extended-lambda",
|
124 |
+
"--use_fast_math",
|
125 |
+
"--ptxas-options=-v",
|
126 |
+
"-lineinfo",
|
127 |
+
]
|
128 |
+
+ cc_flag
|
129 |
+
),
|
130 |
+
},
|
131 |
+
include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
|
132 |
+
)
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
setup(
|
137 |
+
name=PACKAGE_NAME,
|
138 |
+
version=VERSION,
|
139 |
+
description="Blackmamba state-space + MoE model",
|
140 |
+
long_description=long_description,
|
141 |
+
long_description_content_type="text/markdown",
|
142 |
+
packages=find_packages(include=['ops'],),
|
143 |
+
exclude=(
|
144 |
+
"csrc",
|
145 |
+
"blackmamba.egg-info",
|
146 |
+
),
|
147 |
+
ext_modules=ext_modules,
|
148 |
+
cmdclass={"build_ext": BuildExtension},
|
149 |
+
python_requires=">=3.7",
|
150 |
+
install_requires=[
|
151 |
+
"torch",
|
152 |
+
"packaging",
|
153 |
+
"ninja",
|
154 |
+
"einops",
|
155 |
+
"triton",
|
156 |
+
"transformers",
|
157 |
+
"causal_conv1d>=1.1.0",
|
158 |
+
],
|
159 |
+
)
|
switch_mlp.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import pickle
|
4 |
+
import os
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from mamba_config import MambaConfig
|
8 |
+
from mlp import MLP
|
9 |
+
|
10 |
+
def sinkhorn(cost, tol=0.0001):
|
11 |
+
"Sinkhorn based MoE routing function"
|
12 |
+
cost = torch.exp(2.0 * cost)
|
13 |
+
d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
|
14 |
+
# d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
|
15 |
+
d1 = 1 / (cost.size(1) * torch.sum(cost, 0))
|
16 |
+
|
17 |
+
eps = 0.00000001
|
18 |
+
error = 1e9
|
19 |
+
d1_old = d1
|
20 |
+
while error > tol:
|
21 |
+
d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
|
22 |
+
d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
|
23 |
+
error = torch.mean(torch.abs(d1_old - d1))
|
24 |
+
d1_old = d1
|
25 |
+
return d1 * cost * d0.unsqueeze(1)
|
26 |
+
|
27 |
+
|
28 |
+
class SwitchMLP(nn.Module):
|
29 |
+
"""
|
30 |
+
Top-1 Mixture of Experts Layer. Routes input to one of N MLP "experts"
|
31 |
+
Curently supports Sinkhorn based expert routing.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, config: MambaConfig, layer_idx=None):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.layer = layer_idx
|
38 |
+
self.config: MambaConfig = config
|
39 |
+
if config.mamba_moe_layers:
|
40 |
+
self.num_moe_experts = int(config.mamba_moe_layers[layer_idx-1][-1])
|
41 |
+
else:
|
42 |
+
self.num_moe_experts = self.config.num_moe_experts
|
43 |
+
self.router = torch.nn.Linear(self.config.hidden_size, self.num_moe_experts)
|
44 |
+
self.add_bias = config.add_bias_linear
|
45 |
+
self.routing = config.routing_mode # 'sinkhorn', 'top1', 'top2', 'sinkhorn_top2'
|
46 |
+
self.route_algo = sinkhorn
|
47 |
+
self.router_activation = torch.sigmoid
|
48 |
+
|
49 |
+
self.num_local_experts = self.num_moe_experts
|
50 |
+
self.local_expert_indices = [i for i in range(self.num_local_experts)]
|
51 |
+
|
52 |
+
self.local_experts = torch.nn.ModuleList()
|
53 |
+
for _ in range(self.num_local_experts):
|
54 |
+
expert = MLP(self.config, is_expert=True, layer_idx=layer_idx)
|
55 |
+
self.local_experts.append(expert)
|
56 |
+
|
57 |
+
def gather_indices(self, local_indices):
|
58 |
+
return local_indices
|
59 |
+
|
60 |
+
def forward(self, hidden_states, inference_params=None):
|
61 |
+
|
62 |
+
hidden_shape = hidden_states.shape
|
63 |
+
route = self.router(hidden_states)
|
64 |
+
route = route.view(-1, self.num_moe_experts)
|
65 |
+
|
66 |
+
if self.routing == 'sinkhorn':
|
67 |
+
route = self.router_activation(route)
|
68 |
+
max_prob, max_ind = torch.max(route, dim=1)
|
69 |
+
else:
|
70 |
+
route = torch.softmax(route, dim=1)
|
71 |
+
max_prob, max_ind = torch.max(route, dim=1)
|
72 |
+
|
73 |
+
max_prob = torch.unsqueeze(max_prob, 1)
|
74 |
+
hidden_states = hidden_states.view(-1, hidden_shape[-1])
|
75 |
+
|
76 |
+
global_hidden_states = hidden_states
|
77 |
+
global_indices = max_ind
|
78 |
+
output_total = torch.zeros_like(global_hidden_states)
|
79 |
+
|
80 |
+
|
81 |
+
for expert_num, expert in enumerate(self.local_experts):
|
82 |
+
local_expert_index = self.local_expert_indices[expert_num]
|
83 |
+
local_indices = (global_indices == local_expert_index).nonzero()
|
84 |
+
hidden = global_hidden_states[local_indices, :]
|
85 |
+
output = expert(hidden)
|
86 |
+
output_total[local_indices, :] = output
|
87 |
+
|
88 |
+
output_total = output_total * max_prob
|
89 |
+
output_total = output_total.view(hidden_shape)
|
90 |
+
|
91 |
+
return output_total
|
utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from operator import itemgetter
|
2 |
+
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def attention_mask_func(attention_scores, attention_mask):
|
8 |
+
attention_scores.masked_fill_(attention_mask, -10000.0)
|
9 |
+
return attention_scores
|
10 |
+
|
11 |
+
|
12 |
+
@torch.jit.script
|
13 |
+
def gelu_impl(x):
|
14 |
+
"""OpenAI's gelu implementation."""
|
15 |
+
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
|
16 |
+
|
17 |
+
|
18 |
+
def openai_gelu(x):
|
19 |
+
return gelu_impl(x)
|
20 |
+
|
21 |
+
|
22 |
+
@torch.jit.script
|
23 |
+
def bias_gelu(bias, y):
|
24 |
+
x = bias + y
|
25 |
+
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
26 |
+
|
27 |
+
|
28 |
+
# gradient of tanh approximation of gelu
|
29 |
+
# gradient of actual gelu is:
|
30 |
+
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
31 |
+
@torch.jit.script
|
32 |
+
def bias_gelu_back(g, bias, y):
|
33 |
+
x = bias + y
|
34 |
+
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
35 |
+
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
36 |
+
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
|
37 |
+
1 + tanh_out
|
38 |
+
)
|
39 |
+
return ff * g
|
40 |
+
|
41 |
+
|
42 |
+
class GeLUFunction(torch.autograd.Function):
|
43 |
+
@staticmethod
|
44 |
+
# bias is an optional argument
|
45 |
+
def forward(ctx, input, bias):
|
46 |
+
ctx.save_for_backward(input, bias)
|
47 |
+
return bias_gelu(bias, input)
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def backward(ctx, grad_output):
|
51 |
+
input, bias = ctx.saved_tensors
|
52 |
+
tmp = bias_gelu_back(grad_output, bias, input)
|
53 |
+
return tmp, tmp
|
54 |
+
|
55 |
+
|
56 |
+
bias_gelu_impl = GeLUFunction.apply
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
|
61 |
+
@torch.jit.script
|
62 |
+
def erf_gelu(x):
|
63 |
+
return (
|
64 |
+
x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
+
def init_method_normal(sigma):
|
69 |
+
|
70 |
+
def init_(tensor):
|
71 |
+
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
|
72 |
+
|
73 |
+
return init_
|
74 |
+
|
75 |
+
|
76 |
+
def scaled_init_method_normal(sigma, num_layers):
|
77 |
+
std = sigma / math.sqrt(2.0 * num_layers)
|
78 |
+
|
79 |
+
def init_(tensor):
|
80 |
+
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
81 |
+
|
82 |
+
return init_
|