File size: 893 Bytes
c821aa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from transformers import PreTrainedModel
from MoEConfig import MoEConfig
from transformers import AutoModelForCausalLM
import torch

model_name = "kanhatakeyama/01b_model_30b_token"


class MoeModel(PreTrainedModel):
    config_class = MoEConfig

    def __init__(self, config):
        super().__init__(config)

        self.model = None
        self.set_model()

    def set_model(self):
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16
        )

    def generate(self, input_ids, attention_mask,
                 **generate_kwargs):
        if self.model is None:
            self.set_model()

        ret = self.model.generate(input_ids=input_ids,
                                  attention_mask=attention_mask,
                                  **generate_kwargs)
        return ret