Spaces:
Running
Running
from source.generators.abstract_generator import AbstractGenerator | |
import os, sys, random | |
import torch | |
from typing import List | |
sys.path.append(os.path.join(os.path.split(__file__)[0], "exllamav2")) | |
from exllamav2 import ( | |
ExLlamaV2, | |
ExLlamaV2Config, | |
ExLlamaV2Cache, | |
ExLlamaV2Tokenizer, | |
) | |
from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler | |
class Generator(AbstractGenerator): | |
# Place where path to LLM file stored | |
model_change_allowed = False # if model changing allowed without stopping. | |
preset_change_allowed = False # if preset_file changing allowed. | |
def __init__(self, model_path: str, n_ctx=4096, seed=0, n_gpu_layers=0): | |
self.model_directory = model_path | |
self.config = ExLlamaV2Config() | |
self.config.model_dir = self.model_directory | |
self.config.prepare() | |
self.model = ExLlamaV2(self.config) | |
self.cache = ExLlamaV2Cache(self.model, lazy=True) | |
self.model.load_autosplit(self.cache) | |
self.tokenizer = ExLlamaV2Tokenizer(self.config) | |
# Initialize generator | |
self.generator = ExLlamaV2BaseGenerator(self.model, self.cache, self.tokenizer) | |
# Generate some text | |
self.settings = ExLlamaV2Sampler.Settings() | |
self.settings.temperature = 0.85 | |
self.settings.top_k = 50 | |
self.settings.top_p = 0.8 | |
self.settings.token_repetition_penalty = 1.15 | |
self.settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) | |
def generate_answer( | |
self, prompt, generation_params, eos_token, stopping_strings, default_answer: str, turn_template="", **kwargs | |
): | |
# Preparing, add stopping_strings | |
answer = default_answer | |
try: | |
# Configure generator | |
self.settings.token_repetition_penalty_max = generation_params["repetition_penalty"] | |
self.settings.temperature = generation_params["temperature"] | |
self.settings.top_p = generation_params["top_p"] | |
self.settings.top_k = generation_params["top_k"] | |
self.settings.typical = generation_params["typical_p"] | |
# Produce a simple generation | |
answer = self.generate_custom( | |
prompt, | |
stopping_strings=stopping_strings, | |
gen_settings=self.settings, | |
num_tokens=generation_params["max_new_tokens"], | |
) | |
answer = answer[len(prompt) :] | |
except Exception as exception: | |
print("generator_wrapper get answer error ", str(exception) + str(exception.args)) | |
return answer | |
def generate_custom( | |
self, | |
prompt: str or list, | |
gen_settings: ExLlamaV2Sampler.Settings, | |
num_tokens: int, | |
stopping_strings: List, | |
seed=None, | |
token_healing=False, | |
encode_special_tokens=False, | |
decode_special_tokens=False, | |
loras=None, | |
): | |
# Apply seed | |
if seed is not None: | |
random.seed(seed) | |
# Tokenize input and produce padding mask if needed | |
batch_size = 1 if isinstance(prompt, str) else len(prompt) | |
ids = self.tokenizer.encode(prompt, encode_special_tokens=encode_special_tokens) | |
overflow = ids.shape[-1] + num_tokens - self.model.config.max_seq_len | |
if overflow > 0: | |
ids = ids[:, overflow:] | |
mask = self.tokenizer.padding_mask(ids) if batch_size > 1 else None | |
# Prepare for healing | |
unhealed_token = None | |
if ids.shape[-1] < 2: | |
token_healing = False | |
if token_healing: | |
unhealed_token = ids[:, -1:] | |
ids = ids[:, :-1] | |
# Process prompt and begin gen | |
self._gen_begin_base(ids, mask, loras) | |
# Begin filters | |
id_to_piece = self.tokenizer.get_id_to_piece_list() | |
if unhealed_token is not None: | |
unhealed_token_list = unhealed_token.flatten().tolist() | |
heal = [id_to_piece[x] for x in unhealed_token_list] | |
else: | |
heal = None | |
gen_settings.begin_filters(heal) | |
# Generate tokens | |
for i in range(num_tokens): | |
logits = ( | |
self.model.forward(self.sequence_ids[:, -1:], self.cache, input_mask=mask, loras=loras).float().cpu() | |
) | |
token, _, eos = ExLlamaV2Sampler.sample( | |
logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token=unhealed_token | |
) | |
self.sequence_ids = torch.cat([self.sequence_ids, token], dim=1) | |
gen_settings.feed_filters(token) | |
unhealed_token = None | |
# check stopping string | |
text = self.tokenizer.decode(self.sequence_ids, decode_special_tokens=decode_special_tokens) | |
if isinstance(prompt, str): | |
text = text[0] | |
for stopping in stopping_strings: | |
if text.endswith(stopping): | |
text = text[: -len(stopping)] | |
return text | |
if eos: | |
break | |
# Decode | |
text = self.tokenizer.decode(self.sequence_ids, decode_special_tokens=decode_special_tokens) | |
if isinstance(prompt, str): | |
text = text[0] | |
return text | |
def _gen_begin_base(self, input_ids, mask=None, loras=None): | |
self.cache.current_seq_len = 0 | |
self.model.forward(input_ids[:, :-1], self.cache, input_mask=mask, preprocess_only=True, loras=loras) | |
self.sequence_ids = input_ids.clone() | |
self.sequence_ids = input_ids | |
def tokens_count(self, text: str): | |
encoded = self.tokenizer.encode(text) | |
return len(encoded[0]) | |
def get_model_list(self): | |
bins = [] | |
for i in os.listdir("../../models"): | |
if i.endswith(".bin"): | |
bins.append(i) | |
return bins | |
def load_model(self, model_file: str): | |
return None | |