try: from extensions.telegram_bot.source.generators.abstract_generator import AbstractGenerator except ImportError: from source.generators.abstract_generator import AbstractGenerator import os, glob, sys import torch from typing import List sys.path.append(os.path.join(os.path.split(__file__)[0], "exllama")) from source.generators.exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig from source.generators.exllama.tokenizer import ExLlamaTokenizer from source.generators.exllama.generator import ExLlamaGenerator class Generator(AbstractGenerator): # Place where path to LLM file stored model_change_allowed = False # if model changing allowed without stopping. preset_change_allowed = True # if preset_file changing allowed. def __init__(self, model_path: str, n_ctx=4096, seed=0, n_gpu_layers=0): self.n_ctx = n_ctx self.seed = seed self.n_gpu_layers = n_gpu_layers self.model_directory = model_path # Locate files we need within that directory self.tokenizer_path = os.path.join(self.model_directory, "tokenizer.model") self.model_config_path = os.path.join(self.model_directory, "config.json") self.st_pattern = os.path.join(self.model_directory, "model.safetensors") self.model_path = glob.glob(self.st_pattern) # Create config, model, tokenizer and generator self.ex_config = ExLlamaConfig(self.model_config_path) # create config from config.json self.ex_config.llm_path = self.model_path # supply path to model weights file self.ex_config.max_seq_len = n_ctx self.ex_config.max_input_len = n_ctx self.ex_config.max_attention_size = n_ctx**2 self.model = ExLlama(self.ex_config) # create ExLlama instance and load the weights self.tokenizer = ExLlamaTokenizer(self.tokenizer_path) # create tokenizer from tokenizer model file self.cache = ExLlamaCache(self.model, max_seq_len=n_ctx) # create cache for inference self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache) # create generator def generate_answer( self, prompt, generation_params, eos_token, stopping_strings, default_answer: str, turn_template="", **kwargs ): # Preparing, add stopping_strings answer = default_answer try: # Configure generator self.generator.disallow_tokens([self.tokenizer.eos_token_id]) self.generator.settings.token_repetition_penalty_max = generation_params["repetition_penalty"] self.generator.settings.temperature = generation_params["temperature"] self.generator.settings.top_p = generation_params["top_p"] self.generator.settings.top_k = generation_params["top_k"] self.generator.settings.typical = generation_params["typical_p"] # random seed set random_data = os.urandom(4) random_seed = int.from_bytes(random_data, byteorder="big") torch.manual_seed(random_seed) torch.cuda.manual_seed(random_seed) # Produce a simple generation answer = self.generate_custom( prompt, stopping_strings=stopping_strings, max_new_tokens=generation_params["max_new_tokens"] ) answer = answer[len(prompt) :] except Exception as exception: print("generator_wrapper get answer error ", str(exception) + str(exception.args)) return answer def generate_custom(self, prompt, stopping_strings: List, max_new_tokens=128): self.generator.end_beam_search() ids, mask = self.tokenizer.encode(prompt, return_mask=True, max_seq_len=self.model.config.max_seq_len) self.generator.gen_begin(ids, mask=mask) max_new_tokens = min(max_new_tokens, self.model.config.max_seq_len - ids.shape[1]) eos = torch.zeros((ids.shape[0],), dtype=torch.bool) for i in range(max_new_tokens): token = self.generator.gen_single_token(mask=mask) for j in range(token.shape[0]): if token[j, 0].item() == self.tokenizer.eos_token_id: eos[j] = True text = self.tokenizer.decode( self.generator.sequence[0] if self.generator.sequence.shape[0] == 1 else self.generator.sequence ) # check stopping string for stopping in stopping_strings: if text.endswith(stopping): text = text[: -len(stopping)] return text if eos.all(): break text = self.tokenizer.decode( self.generator.sequence[0] if self.generator.sequence.shape[0] == 1 else self.generator.sequence ) return text def tokens_count(self, text: str): encoded = self.tokenizer.encode(text, max_seq_len=20480) return len(encoded[0]) def get_model_list(self): bins = [] for i in os.listdir("../../models"): if i.endswith(".bin"): bins.append(i) return bins def load_model(self, model_file: str): return None