Spaces:
Running
Running
try: | |
import exllama | |
except ImportError: | |
raise ImportError( | |
"Could not import `exllama` package. " | |
"Please install it using `pip install ctransformers[gptq]`" | |
) | |
import re | |
from pathlib import Path | |
from typing import ( | |
Generator, | |
List, | |
Optional, | |
Sequence, | |
Union, | |
) | |
import torch | |
from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig | |
from exllama.tokenizer import ExLlamaTokenizer | |
from exllama.generator import ExLlamaGenerator | |
from ..llm import Config, doc, get | |
class LLM: | |
def __init__( | |
self, | |
model_path: str, | |
*, | |
config: Optional[Config] = None, | |
): | |
"""Loads the language model from a local file. | |
Args: | |
model_path: The path to a model directory. | |
config: `Config` object. | |
""" | |
model_path = Path(model_path).resolve() | |
config = config or Config() | |
self._model_path = model_path | |
self._config = config | |
files = [ | |
(f.stat().st_size, f) | |
for f in model_path.iterdir() | |
if f.is_file() and f.name.endswith(".safetensors") | |
] | |
if not files: | |
raise ValueError(f"No model file found in directory '{model_path}'") | |
model_file = min(files)[1] | |
model_config = ExLlamaConfig(str(model_path / "config.json")) | |
model_config.model_path = str(model_file) | |
model = ExLlama(model_config) | |
tokenizer = ExLlamaTokenizer(str(model_path / "tokenizer.model")) | |
cache = ExLlamaCache(model) | |
generator = ExLlamaGenerator(model, tokenizer, cache) | |
self._model = model | |
self._tokenizer = tokenizer | |
self._generator = generator | |
def model_path(self) -> str: | |
"""The path to the model directory.""" | |
return self._model_path | |
def config(self) -> Config: | |
"""The config object.""" | |
return self._config | |
def eos_token_id(self) -> int: | |
"""The end-of-sequence token.""" | |
return self._tokenizer.eos_token_id | |
def vocab_size(self) -> int: | |
"""The number of tokens in vocabulary.""" | |
return self._model.config.vocab_size | |
def context_length(self) -> int: | |
"""The context length of model.""" | |
return self._model.config.max_seq_len | |
def tokenize(self, text: str) -> List[int]: | |
"""Converts a text into list of tokens. | |
Args: | |
text: The text to tokenize. | |
Returns: | |
The list of tokens. | |
""" | |
return self._tokenizer.encode(text) | |
def detokenize( | |
self, | |
tokens: Sequence[int], | |
decode: bool = True, | |
) -> Union[str, bytes]: | |
"""Converts a list of tokens to text. | |
Args: | |
tokens: The list of tokens. | |
decode: Whether to decode the text as UTF-8 string. | |
Returns: | |
The combined text of all tokens. | |
""" | |
if isinstance(tokens, int): | |
tokens = [tokens] | |
if isinstance(tokens, list): | |
tokens = torch.tensor(tokens) | |
return self._tokenizer.decode(tokens) | |
def is_eos_token(self, token: int) -> bool: | |
"""Checks if a token is an end-of-sequence token. | |
Args: | |
token: The token to check. | |
Returns: | |
`True` if the token is an end-of-sequence token else `False`. | |
""" | |
return token == self.eos_token_id | |
def reset(self) -> None: | |
self._generator.reset() | |
def generate( | |
self, | |
tokens: Sequence[int], | |
*, | |
top_k: Optional[int] = None, | |
top_p: Optional[float] = None, | |
temperature: Optional[float] = None, | |
repetition_penalty: Optional[float] = None, | |
last_n_tokens: Optional[int] = None, | |
seed: Optional[int] = None, | |
batch_size: Optional[int] = None, | |
threads: Optional[int] = None, | |
reset: Optional[bool] = None, | |
) -> Generator[int, None, None]: | |
"""Generates new tokens from a list of tokens. | |
Args: | |
tokens: The list of tokens to generate tokens from. | |
{params} | |
Returns: | |
The generated tokens. | |
""" | |
generator = self._generator | |
config = self.config | |
top_k = get(top_k, config.top_k) | |
top_p = get(top_p, config.top_p) | |
temperature = get(temperature, config.temperature) | |
repetition_penalty = get(repetition_penalty, config.repetition_penalty) | |
last_n_tokens = get(last_n_tokens, config.last_n_tokens) | |
reset = get(reset, config.reset) | |
if reset: | |
self.reset() | |
generator.settings.top_k = top_k | |
generator.settings.top_p = top_p | |
generator.settings.temperature = temperature | |
generator.settings.token_repetition_penalty_max = repetition_penalty | |
generator.settings.token_repetition_penalty_sustain = last_n_tokens | |
generator.settings.token_repetition_penalty_decay = last_n_tokens // 2 | |
if isinstance(tokens, list): | |
tokens = torch.tensor(tokens).unsqueeze(0) | |
assert tokens.shape[0] == 1 | |
generator.gen_begin(tokens) | |
while True: | |
token = generator.gen_single_token() | |
token = token[0][0].item() | |
if self.is_eos_token(token): | |
break | |
yield token | |
def _stream( | |
self, | |
prompt: str, | |
*, | |
max_new_tokens: Optional[int] = None, | |
top_k: Optional[int] = None, | |
top_p: Optional[float] = None, | |
temperature: Optional[float] = None, | |
repetition_penalty: Optional[float] = None, | |
last_n_tokens: Optional[int] = None, | |
seed: Optional[int] = None, | |
batch_size: Optional[int] = None, | |
threads: Optional[int] = None, | |
stop: Optional[Sequence[str]] = None, | |
reset: Optional[bool] = None, | |
) -> Generator[str, None, None]: | |
generator = self._generator | |
config = self.config | |
max_new_tokens = get(max_new_tokens, config.max_new_tokens) | |
stop = get(stop, config.stop) or [] | |
if isinstance(stop, str): | |
stop = [stop] | |
tokens = self.tokenize(prompt) | |
max_new_tokens = min(max_new_tokens, self.context_length - tokens.shape[1]) | |
stop_regex = re.compile("|".join(map(re.escape, stop))) | |
count = 0 | |
length = len(self.detokenize(tokens[0])) | |
text = "" | |
for token in self.generate( | |
tokens, | |
top_k=top_k, | |
top_p=top_p, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
last_n_tokens=last_n_tokens, | |
seed=seed, | |
batch_size=batch_size, | |
threads=threads, | |
reset=reset, | |
): | |
new_text = self.detokenize(generator.sequence_actual[0])[length:] | |
length += len(new_text) | |
text += new_text | |
# https://github.com/abetlen/llama-cpp-python/blob/1a13d76c487df1c8560132d10bda62d6e2f4fa93/llama_cpp/llama.py#L686-L706 | |
# Check if one of the stop sequences is part of the text. | |
# Note that the stop sequence may not always be at the end of text. | |
if stop: | |
match = stop_regex.search(text) | |
if match: | |
text = text[: match.start()] | |
break | |
# Avoid sending the longest suffix of text which is also a prefix | |
# of a stop sequence, as it can form a stop sequence with the text | |
# generated later. | |
longest = 0 | |
for s in stop: | |
for i in range(len(s), 0, -1): | |
if text.endswith(s[:i]): | |
longest = max(i, longest) | |
break | |
end = len(text) - longest | |
if end > 0: | |
yield text[:end] | |
text = text[end:] | |
count += 1 | |
if count >= max_new_tokens: | |
break | |
if text: | |
yield text | |
def __call__( | |
self, | |
prompt: str, | |
*, | |
max_new_tokens: Optional[int] = None, | |
top_k: Optional[int] = None, | |
top_p: Optional[float] = None, | |
temperature: Optional[float] = None, | |
repetition_penalty: Optional[float] = None, | |
last_n_tokens: Optional[int] = None, | |
seed: Optional[int] = None, | |
batch_size: Optional[int] = None, | |
threads: Optional[int] = None, | |
stop: Optional[Sequence[str]] = None, | |
stream: Optional[bool] = None, | |
reset: Optional[bool] = None, | |
) -> Union[str, Generator[str, None, None]]: | |
"""Generates text from a prompt. | |
Args: | |
prompt: The prompt to generate text from. | |
{params} | |
Returns: | |
The generated text. | |
""" | |
config = self.config | |
stream = get(stream, config.stream) | |
text = self._stream( | |
prompt, | |
max_new_tokens=max_new_tokens, | |
top_k=top_k, | |
top_p=top_p, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
last_n_tokens=last_n_tokens, | |
seed=seed, | |
batch_size=batch_size, | |
threads=threads, | |
stop=stop, | |
reset=reset, | |
) | |
if stream: | |
return text | |
return "".join(text) | |