TestMoE / MoEModel.py
kanhatakeyama's picture
Upload model
c821aa0 verified
raw
history blame
893 Bytes
from transformers import PreTrainedModel
from MoEConfig import MoEConfig
from transformers import AutoModelForCausalLM
import torch
model_name = "kanhatakeyama/01b_model_30b_token"
class MoeModel(PreTrainedModel):
config_class = MoEConfig
def __init__(self, config):
super().__init__(config)
self.model = None
self.set_model()
def set_model(self):
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16
)
def generate(self, input_ids, attention_mask,
**generate_kwargs):
if self.model is None:
self.set_model()
ret = self.model.generate(input_ids=input_ids,
attention_mask=attention_mask,
**generate_kwargs)
return ret