Upload model
Browse files- MoEConfig.py +13 -0
- MoEModel.py +33 -0
- config.json +12 -0
- generation_config.json +4 -0
- model.safetensors +3 -0
MoEConfig.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
class MoEConfig(PretrainedConfig):
|
6 |
+
model_type = "moewrapper" # モデルの名前を命名?
|
7 |
+
torch_dtype = "float32",
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
**kwargs,
|
12 |
+
):
|
13 |
+
super().__init__(**kwargs)
|
MoEModel.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel
|
2 |
+
from MoEConfig import MoEConfig
|
3 |
+
from transformers import AutoModelForCausalLM
|
4 |
+
import torch
|
5 |
+
|
6 |
+
model_name = "kanhatakeyama/01b_model_30b_token"
|
7 |
+
|
8 |
+
|
9 |
+
class MoeModel(PreTrainedModel):
|
10 |
+
config_class = MoEConfig
|
11 |
+
|
12 |
+
def __init__(self, config):
|
13 |
+
super().__init__(config)
|
14 |
+
|
15 |
+
self.model = None
|
16 |
+
self.set_model()
|
17 |
+
|
18 |
+
def set_model(self):
|
19 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
20 |
+
model_name,
|
21 |
+
device_map="auto",
|
22 |
+
torch_dtype=torch.float16
|
23 |
+
)
|
24 |
+
|
25 |
+
def generate(self, input_ids, attention_mask,
|
26 |
+
**generate_kwargs):
|
27 |
+
if self.model is None:
|
28 |
+
self.set_model()
|
29 |
+
|
30 |
+
ret = self.model.generate(input_ids=input_ids,
|
31 |
+
attention_mask=attention_mask,
|
32 |
+
**generate_kwargs)
|
33 |
+
return ret
|
config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"MoeModel"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "MoEConfig.MoEConfig",
|
7 |
+
"AutoModelForCausalLM": "MoEModel.MoeModel"
|
8 |
+
},
|
9 |
+
"model_type": "moewrapper",
|
10 |
+
"torch_dtype": "float16",
|
11 |
+
"transformers_version": "4.35.0"
|
12 |
+
}
|
generation_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"transformers_version": "4.35.0"
|
4 |
+
}
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c4ccf85a7256637e642272f422ffbe4e63cefd41163005811d268276bcd51b6f
|
3 |
+
size 273150376
|