File size: 5,969 Bytes
1d777c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from source.generators.abstract_generator import AbstractGenerator
import os, sys, random
import torch
from typing import List

sys.path.append(os.path.join(os.path.split(__file__)[0], "exllamav2"))

from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Config,
    ExLlamaV2Cache,
    ExLlamaV2Tokenizer,
)

from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler


class Generator(AbstractGenerator):
    #  Place where path to LLM file stored
    model_change_allowed = False  # if model changing allowed without stopping.
    preset_change_allowed = False  # if preset_file changing allowed.

    def __init__(self, model_path: str, n_ctx=4096, seed=0, n_gpu_layers=0):
        self.model_directory = model_path

        self.config = ExLlamaV2Config()
        self.config.model_dir = self.model_directory
        self.config.prepare()

        self.model = ExLlamaV2(self.config)

        self.cache = ExLlamaV2Cache(self.model, lazy=True)
        self.model.load_autosplit(self.cache)

        self.tokenizer = ExLlamaV2Tokenizer(self.config)

        # Initialize generator

        self.generator = ExLlamaV2BaseGenerator(self.model, self.cache, self.tokenizer)

        # Generate some text

        self.settings = ExLlamaV2Sampler.Settings()
        self.settings.temperature = 0.85
        self.settings.top_k = 50
        self.settings.top_p = 0.8
        self.settings.token_repetition_penalty = 1.15
        self.settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])

    def generate_answer(
        self, prompt, generation_params, eos_token, stopping_strings, default_answer: str, turn_template="", **kwargs
    ):
        # Preparing, add stopping_strings
        answer = default_answer

        try:
            # Configure generator
            self.settings.token_repetition_penalty_max = generation_params["repetition_penalty"]
            self.settings.temperature = generation_params["temperature"]
            self.settings.top_p = generation_params["top_p"]
            self.settings.top_k = generation_params["top_k"]
            self.settings.typical = generation_params["typical_p"]
            # Produce a simple generation
            answer = self.generate_custom(
                prompt,
                stopping_strings=stopping_strings,
                gen_settings=self.settings,
                num_tokens=generation_params["max_new_tokens"],
            )
            answer = answer[len(prompt) :]

        except Exception as exception:
            print("generator_wrapper get answer error ", str(exception) + str(exception.args))
        return answer

    def generate_custom(
        self,
        prompt: str or list,
        gen_settings: ExLlamaV2Sampler.Settings,
        num_tokens: int,
        stopping_strings: List,
        seed=None,
        token_healing=False,
        encode_special_tokens=False,
        decode_special_tokens=False,
        loras=None,
    ):
        # Apply seed

        if seed is not None:
            random.seed(seed)

        # Tokenize input and produce padding mask if needed

        batch_size = 1 if isinstance(prompt, str) else len(prompt)
        ids = self.tokenizer.encode(prompt, encode_special_tokens=encode_special_tokens)

        overflow = ids.shape[-1] + num_tokens - self.model.config.max_seq_len
        if overflow > 0:
            ids = ids[:, overflow:]

        mask = self.tokenizer.padding_mask(ids) if batch_size > 1 else None

        # Prepare for healing

        unhealed_token = None
        if ids.shape[-1] < 2:
            token_healing = False
        if token_healing:
            unhealed_token = ids[:, -1:]
            ids = ids[:, :-1]

        # Process prompt and begin gen

        self._gen_begin_base(ids, mask, loras)

        # Begin filters

        id_to_piece = self.tokenizer.get_id_to_piece_list()
        if unhealed_token is not None:
            unhealed_token_list = unhealed_token.flatten().tolist()
            heal = [id_to_piece[x] for x in unhealed_token_list]
        else:
            heal = None
        gen_settings.begin_filters(heal)

        # Generate tokens

        for i in range(num_tokens):
            logits = (
                self.model.forward(self.sequence_ids[:, -1:], self.cache, input_mask=mask, loras=loras).float().cpu()
            )
            token, _, eos = ExLlamaV2Sampler.sample(
                logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token=unhealed_token
            )
            self.sequence_ids = torch.cat([self.sequence_ids, token], dim=1)
            gen_settings.feed_filters(token)

            unhealed_token = None
            # check stopping string
            text = self.tokenizer.decode(self.sequence_ids, decode_special_tokens=decode_special_tokens)
            if isinstance(prompt, str):
                text = text[0]
            for stopping in stopping_strings:
                if text.endswith(stopping):
                    text = text[: -len(stopping)]
                    return text
            if eos:
                break

        # Decode

        text = self.tokenizer.decode(self.sequence_ids, decode_special_tokens=decode_special_tokens)

        if isinstance(prompt, str):
            text = text[0]
        return text

    def _gen_begin_base(self, input_ids, mask=None, loras=None):
        self.cache.current_seq_len = 0
        self.model.forward(input_ids[:, :-1], self.cache, input_mask=mask, preprocess_only=True, loras=loras)

        self.sequence_ids = input_ids.clone()
        self.sequence_ids = input_ids

    def tokens_count(self, text: str):
        encoded = self.tokenizer.encode(text)
        return len(encoded[0])

    def get_model_list(self):
        bins = []
        for i in os.listdir("../../models"):
            if i.endswith(".bin"):
                bins.append(i)
        return bins

    def load_model(self, model_file: str):
        return None