Draken007's picture
Upload 7228 files
2a0bc63 verified
try:
import exllama
except ImportError:
raise ImportError(
"Could not import `exllama` package. "
"Please install it using `pip install ctransformers[gptq]`"
)
import re
from pathlib import Path
from typing import (
Generator,
List,
Optional,
Sequence,
Union,
)
import torch
from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
from exllama.tokenizer import ExLlamaTokenizer
from exllama.generator import ExLlamaGenerator
from ..llm import Config, doc, get
class LLM:
def __init__(
self,
model_path: str,
*,
config: Optional[Config] = None,
):
"""Loads the language model from a local file.
Args:
model_path: The path to a model directory.
config: `Config` object.
"""
model_path = Path(model_path).resolve()
config = config or Config()
self._model_path = model_path
self._config = config
files = [
(f.stat().st_size, f)
for f in model_path.iterdir()
if f.is_file() and f.name.endswith(".safetensors")
]
if not files:
raise ValueError(f"No model file found in directory '{model_path}'")
model_file = min(files)[1]
model_config = ExLlamaConfig(str(model_path / "config.json"))
model_config.model_path = str(model_file)
model = ExLlama(model_config)
tokenizer = ExLlamaTokenizer(str(model_path / "tokenizer.model"))
cache = ExLlamaCache(model)
generator = ExLlamaGenerator(model, tokenizer, cache)
self._model = model
self._tokenizer = tokenizer
self._generator = generator
@property
def model_path(self) -> str:
"""The path to the model directory."""
return self._model_path
@property
def config(self) -> Config:
"""The config object."""
return self._config
@property
def eos_token_id(self) -> int:
"""The end-of-sequence token."""
return self._tokenizer.eos_token_id
@property
def vocab_size(self) -> int:
"""The number of tokens in vocabulary."""
return self._model.config.vocab_size
@property
def context_length(self) -> int:
"""The context length of model."""
return self._model.config.max_seq_len
def tokenize(self, text: str) -> List[int]:
"""Converts a text into list of tokens.
Args:
text: The text to tokenize.
Returns:
The list of tokens.
"""
return self._tokenizer.encode(text)
def detokenize(
self,
tokens: Sequence[int],
decode: bool = True,
) -> Union[str, bytes]:
"""Converts a list of tokens to text.
Args:
tokens: The list of tokens.
decode: Whether to decode the text as UTF-8 string.
Returns:
The combined text of all tokens.
"""
if isinstance(tokens, int):
tokens = [tokens]
if isinstance(tokens, list):
tokens = torch.tensor(tokens)
return self._tokenizer.decode(tokens)
def is_eos_token(self, token: int) -> bool:
"""Checks if a token is an end-of-sequence token.
Args:
token: The token to check.
Returns:
`True` if the token is an end-of-sequence token else `False`.
"""
return token == self.eos_token_id
def reset(self) -> None:
self._generator.reset()
@doc
def generate(
self,
tokens: Sequence[int],
*,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
repetition_penalty: Optional[float] = None,
last_n_tokens: Optional[int] = None,
seed: Optional[int] = None,
batch_size: Optional[int] = None,
threads: Optional[int] = None,
reset: Optional[bool] = None,
) -> Generator[int, None, None]:
"""Generates new tokens from a list of tokens.
Args:
tokens: The list of tokens to generate tokens from.
{params}
Returns:
The generated tokens.
"""
generator = self._generator
config = self.config
top_k = get(top_k, config.top_k)
top_p = get(top_p, config.top_p)
temperature = get(temperature, config.temperature)
repetition_penalty = get(repetition_penalty, config.repetition_penalty)
last_n_tokens = get(last_n_tokens, config.last_n_tokens)
reset = get(reset, config.reset)
if reset:
self.reset()
generator.settings.top_k = top_k
generator.settings.top_p = top_p
generator.settings.temperature = temperature
generator.settings.token_repetition_penalty_max = repetition_penalty
generator.settings.token_repetition_penalty_sustain = last_n_tokens
generator.settings.token_repetition_penalty_decay = last_n_tokens // 2
if isinstance(tokens, list):
tokens = torch.tensor(tokens).unsqueeze(0)
assert tokens.shape[0] == 1
generator.gen_begin(tokens)
while True:
token = generator.gen_single_token()
token = token[0][0].item()
if self.is_eos_token(token):
break
yield token
def _stream(
self,
prompt: str,
*,
max_new_tokens: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
repetition_penalty: Optional[float] = None,
last_n_tokens: Optional[int] = None,
seed: Optional[int] = None,
batch_size: Optional[int] = None,
threads: Optional[int] = None,
stop: Optional[Sequence[str]] = None,
reset: Optional[bool] = None,
) -> Generator[str, None, None]:
generator = self._generator
config = self.config
max_new_tokens = get(max_new_tokens, config.max_new_tokens)
stop = get(stop, config.stop) or []
if isinstance(stop, str):
stop = [stop]
tokens = self.tokenize(prompt)
max_new_tokens = min(max_new_tokens, self.context_length - tokens.shape[1])
stop_regex = re.compile("|".join(map(re.escape, stop)))
count = 0
length = len(self.detokenize(tokens[0]))
text = ""
for token in self.generate(
tokens,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
last_n_tokens=last_n_tokens,
seed=seed,
batch_size=batch_size,
threads=threads,
reset=reset,
):
new_text = self.detokenize(generator.sequence_actual[0])[length:]
length += len(new_text)
text += new_text
# https://github.com/abetlen/llama-cpp-python/blob/1a13d76c487df1c8560132d10bda62d6e2f4fa93/llama_cpp/llama.py#L686-L706
# Check if one of the stop sequences is part of the text.
# Note that the stop sequence may not always be at the end of text.
if stop:
match = stop_regex.search(text)
if match:
text = text[: match.start()]
break
# Avoid sending the longest suffix of text which is also a prefix
# of a stop sequence, as it can form a stop sequence with the text
# generated later.
longest = 0
for s in stop:
for i in range(len(s), 0, -1):
if text.endswith(s[:i]):
longest = max(i, longest)
break
end = len(text) - longest
if end > 0:
yield text[:end]
text = text[end:]
count += 1
if count >= max_new_tokens:
break
if text:
yield text
@doc
def __call__(
self,
prompt: str,
*,
max_new_tokens: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
repetition_penalty: Optional[float] = None,
last_n_tokens: Optional[int] = None,
seed: Optional[int] = None,
batch_size: Optional[int] = None,
threads: Optional[int] = None,
stop: Optional[Sequence[str]] = None,
stream: Optional[bool] = None,
reset: Optional[bool] = None,
) -> Union[str, Generator[str, None, None]]:
"""Generates text from a prompt.
Args:
prompt: The prompt to generate text from.
{params}
Returns:
The generated text.
"""
config = self.config
stream = get(stream, config.stream)
text = self._stream(
prompt,
max_new_tokens=max_new_tokens,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
last_n_tokens=last_n_tokens,
seed=seed,
batch_size=batch_size,
threads=threads,
stop=stop,
reset=reset,
)
if stream:
return text
return "".join(text)