chatmlTest / model /infer.py
fangshengren's picture
Upload 59 files
f4fac26 verified
raw
history blame
4.6 kB
import os
from threading import Thread
import platform
from typing import Union
import torch
from transformers import TextIteratorStreamer,PreTrainedTokenizerFast
from safetensors.torch import load_model
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
# import 自定义类和函数
from model.chat_model import TextToTextModel
from utils.functions import get_T5_config
from config import InferConfig, T5ModelConfig
class ChatBot:
def __init__(self, infer_config: InferConfig) -> None:
'''
'''
self.infer_config = infer_config
# 初始化tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained(infer_config.model_dir)
self.tokenizer = tokenizer
self.encode = tokenizer.encode_plus
self.batch_decode = tokenizer.batch_decode
self.batch_encode_plus = tokenizer.batch_encode_plus
t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
try:
model = TextToTextModel(t5_config)
if os.path.isdir(infer_config.model_dir):
# from_pretrained
model = model.from_pretrained(infer_config.model_dir)
elif infer_config.model_dir.endswith('.safetensors'):
# load safetensors
load_model(model, infer_config.model_dir)
else:
# load torch checkpoint
model.load_state_dict(torch.load(infer_config.model_dir))
self.model = model
except Exception as e:
print(str(e), 'transformers and pytorch load fail, try accelerate load function.')
empty_model = None
with init_empty_weights():
empty_model = TextToTextModel(t5_config)
self.model = load_checkpoint_and_dispatch(
model=empty_model,
checkpoint=infer_config.model_dir,
device_map='auto',
dtype=torch.float16,
)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.streamer = TextIteratorStreamer(tokenizer=tokenizer, clean_up_tokenization_spaces=True, skip_special_tokens=True)
def stream_chat(self, input_txt: str) -> TextIteratorStreamer:
'''
流式对话,线程启动后可返回,通过迭代streamer获取生成的文字,仅支持greedy search
'''
encoded = self.encode(input_txt + '[EOS]')
input_ids = torch.LongTensor([encoded.input_ids]).to(self.device)
attention_mask = torch.LongTensor([encoded.attention_mask]).to(self.device)
generation_kwargs = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'max_seq_len': self.infer_config.max_seq_len,
'streamer': self.streamer,
'search_type': 'greedy',
}
thread = Thread(target=self.model.my_generate, kwargs=generation_kwargs)
thread.start()
return self.streamer
def chat(self, input_txt: Union[str, list[str]] ) -> Union[str, list[str]]:
'''
非流式生成,可以使用beam search、beam sample等方法生成文本。
'''
if isinstance(input_txt, str):
input_txt = [input_txt]
elif not isinstance(input_txt, list):
raise Exception('input_txt mast be a str or list[str]')
# add EOS token
input_txts = [f"{txt}[EOS]" for txt in input_txt]
encoded = self.batch_encode_plus(input_txts, padding=True)
input_ids = torch.LongTensor(encoded.input_ids).to(self.device)
attention_mask = torch.LongTensor(encoded.attention_mask).to(self.device)
outputs = self.model.my_generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_seq_len=self.infer_config.max_seq_len,
search_type='greedy',
)
outputs = self.batch_decode(outputs.cpu().numpy(), clean_up_tokenization_spaces=True, skip_special_tokens=True)
note = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋"
outputs = [item if len(item) != 0 else note for item in outputs]
return outputs[0] if len(outputs) == 1 else outputs