from transformers import PreTrainedModel from MoEConfig import MoEConfig from transformers import AutoModelForCausalLM import torch model_name = "kanhatakeyama/01b_model_30b_token" class MoeModel(PreTrainedModel): config_class = MoEConfig def __init__(self, config): super().__init__(config) self.model = None self.set_model() def set_model(self): self.model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", torch_dtype=torch.float16 ) def generate(self, input_ids, attention_mask, **generate_kwargs): if self.model is None: self.set_model() ret = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) return ret