File size: 893 Bytes
c821aa0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from transformers import PreTrainedModel
from MoEConfig import MoEConfig
from transformers import AutoModelForCausalLM
import torch
model_name = "kanhatakeyama/01b_model_30b_token"
class MoeModel(PreTrainedModel):
config_class = MoEConfig
def __init__(self, config):
super().__init__(config)
self.model = None
self.set_model()
def set_model(self):
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16
)
def generate(self, input_ids, attention_mask,
**generate_kwargs):
if self.model is None:
self.set_model()
ret = self.model.generate(input_ids=input_ids,
attention_mask=attention_mask,
**generate_kwargs)
return ret
|