File size: 4,601 Bytes
f4fac26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
from threading import Thread
import platform
from typing import Union
import torch

from transformers import TextIteratorStreamer,PreTrainedTokenizerFast
from safetensors.torch import load_model

from accelerate import init_empty_weights, load_checkpoint_and_dispatch

# import 自定义类和函数
from model.chat_model import TextToTextModel
from utils.functions import get_T5_config

from config import InferConfig, T5ModelConfig

class ChatBot:
    def __init__(self, infer_config: InferConfig) -> None:
        '''
        '''
        self.infer_config = infer_config
        # 初始化tokenizer
        tokenizer = PreTrainedTokenizerFast.from_pretrained(infer_config.model_dir)
        self.tokenizer = tokenizer
        self.encode = tokenizer.encode_plus
        self.batch_decode = tokenizer.batch_decode
        self.batch_encode_plus = tokenizer.batch_encode_plus
        
        t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)

        try:
            model = TextToTextModel(t5_config)

            if os.path.isdir(infer_config.model_dir):

                # from_pretrained
                model = model.from_pretrained(infer_config.model_dir)

            elif infer_config.model_dir.endswith('.safetensors'):

                # load safetensors
                load_model(model, infer_config.model_dir) 

            else:

                # load torch checkpoint
                model.load_state_dict(torch.load(infer_config.model_dir))  

            self.model = model

        except Exception as e:
            print(str(e), 'transformers and pytorch load fail, try accelerate load function.')

            empty_model = None
            with init_empty_weights():
                empty_model = TextToTextModel(t5_config)
                
            self.model = load_checkpoint_and_dispatch(
                    model=empty_model,
                    checkpoint=infer_config.model_dir,
                    device_map='auto',
                    dtype=torch.float16,
                )
       

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

        self.streamer = TextIteratorStreamer(tokenizer=tokenizer, clean_up_tokenization_spaces=True, skip_special_tokens=True)

    def stream_chat(self, input_txt: str) -> TextIteratorStreamer:
        '''
        流式对话,线程启动后可返回,通过迭代streamer获取生成的文字,仅支持greedy search
        '''
        encoded = self.encode(input_txt + '[EOS]')
        
        input_ids = torch.LongTensor([encoded.input_ids]).to(self.device)
        attention_mask = torch.LongTensor([encoded.attention_mask]).to(self.device)

        generation_kwargs = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'max_seq_len': self.infer_config.max_seq_len,
            'streamer': self.streamer,
            'search_type': 'greedy',
        }

        thread = Thread(target=self.model.my_generate, kwargs=generation_kwargs)
        thread.start()
        
        return self.streamer
    
    def chat(self, input_txt: Union[str, list[str]] ) -> Union[str, list[str]]:
        '''
        非流式生成,可以使用beam search、beam sample等方法生成文本。
        '''
        if isinstance(input_txt, str):
            input_txt = [input_txt]
        elif not isinstance(input_txt, list):
            raise Exception('input_txt mast be a str or list[str]')
        
        # add EOS token
        input_txts = [f"{txt}[EOS]" for txt in input_txt]
        encoded = self.batch_encode_plus(input_txts,  padding=True)
        input_ids = torch.LongTensor(encoded.input_ids).to(self.device)
        attention_mask = torch.LongTensor(encoded.attention_mask).to(self.device)

        outputs = self.model.my_generate(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                            max_seq_len=self.infer_config.max_seq_len,
                            search_type='greedy',
                        )

        outputs = self.batch_decode(outputs.cpu().numpy(),  clean_up_tokenization_spaces=True, skip_special_tokens=True)

        note = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋"
        outputs = [item if len(item) != 0 else note for item in outputs]

        return outputs[0] if len(outputs) == 1 else outputs