Spaces:
Runtime error
Runtime error
import gc | |
import copy | |
from tenacity import RetryError | |
from tenacity import retry, stop_after_attempt, wait_fixed | |
import torch | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoModelForSeq2SeqLM, | |
AutoTokenizer, | |
LogitsProcessorList, | |
MinNewTokensLengthLogitsProcessor, | |
TemperatureLogitsWarper, | |
TopPLogitsWarper, | |
) | |
def get_output_batch( | |
model, tokenizer, prompts, generation_config | |
): | |
if len(prompts) == 1: | |
encoding = tokenizer(prompts, return_tensors="pt") | |
input_ids = encoding["input_ids"].cuda() | |
generated_id = model.generate( | |
input_ids=input_ids, | |
generation_config=generation_config, | |
max_new_tokens=256 | |
) | |
decoded = tokenizer.batch_decode(generated_id) | |
del input_ids, generated_id | |
torch.cuda.empty_cache() | |
return decoded | |
else: | |
encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda') | |
generated_ids = model.generate( | |
**encodings, | |
generation_config=generation_config, | |
max_new_tokens=256 | |
) | |
decoded = tokenizer.batch_decode(generated_ids) | |
del encodings, generated_ids | |
torch.cuda.empty_cache() | |
return decoded | |
# StreamModel is borrowed from basaran project | |
# please find more info about it -> https://github.com/hyperonym/basaran | |
class StreamModel: | |
"""StreamModel wraps around a language model to provide stream decoding.""" | |
def __init__(self, model, tokenizer): | |
super().__init__() | |
self.model = model | |
self.tokenizer = tokenizer | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
def __call__( | |
self, | |
prompt, | |
min_tokens=0, | |
max_tokens=16, | |
temperature=1.0, | |
top_p=1.0, | |
n=1, | |
logprobs=0, | |
): | |
"""Create a completion stream for the provided prompt.""" | |
input_ids = self.tokenize(prompt) | |
logprobs = max(logprobs, 0) | |
# bigger than 1 | |
chunk_size = 5 | |
chunk_count = 0 | |
# Generate completion tokens. | |
final_tokens = torch.empty(0).to(self.device) | |
try: | |
for tokens in self.generate( | |
input_ids[None, :].repeat(n, 1), | |
logprobs=logprobs, | |
min_new_tokens=min_tokens, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
if chunk_count < chunk_size: | |
chunk_count = chunk_count + 1 | |
final_tokens = torch.cat((final_tokens, tokens)) | |
if chunk_count == chunk_size-1: | |
chunk_count = 0 | |
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True) | |
if chunk_count > 0: | |
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True) | |
except RetryError as e: | |
print(e) | |
del input_ids | |
gc.collect() | |
del final_tokens | |
if self.device == "cuda": | |
torch.cuda.empty_cache() | |
def _infer(self, model_fn, **kwargs): | |
"""Call a model function in inference mode with auto retrying.""" | |
# This is a temporary workaround for bitsandbytes #162: | |
# https://github.com/TimDettmers/bitsandbytes/issues/162 | |
with torch.inference_mode(): | |
return model_fn(**kwargs) | |
def _logits_processor(self, config, input_length): | |
"""Set up logits processor based on the generation config.""" | |
processor = LogitsProcessorList() | |
# Add processor for enforcing a min-length of new tokens. | |
if ( | |
config.min_new_tokens is not None | |
and config.min_new_tokens > 0 | |
and config.eos_token_id is not None | |
): | |
processor.append( | |
MinNewTokensLengthLogitsProcessor( | |
prompt_length_to_skip=input_length, | |
min_new_tokens=config.min_new_tokens, | |
eos_token_id=config.eos_token_id, | |
) | |
) | |
# Add processor for scaling output probability distribution. | |
if ( | |
config.temperature is not None | |
and config.temperature > 0 | |
and config.temperature != 1.0 | |
): | |
processor.append(TemperatureLogitsWarper(config.temperature)) | |
# Add processor for nucleus sampling. | |
if config.top_p is not None and config.top_p > 0 and config.top_p < 1: | |
processor.append(TopPLogitsWarper(config.top_p)) | |
return processor | |
def tokenize(self, text): | |
"""Tokenize a string into a tensor of token IDs.""" | |
batch = self.tokenizer.encode(text, return_tensors="pt") | |
return batch[0].to(self.device) | |
def generate(self, input_ids, logprobs=0, **kwargs): | |
"""Generate a stream of predicted tokens using the language model.""" | |
# Store the original batch size and input length. | |
batch_size = input_ids.shape[0] | |
input_length = input_ids.shape[-1] | |
# Separate model arguments from generation config. | |
config = self.model.generation_config | |
config = copy.deepcopy(config) | |
kwargs = config.update(**kwargs) | |
kwargs["output_attentions"] = False | |
kwargs["output_hidden_states"] = False | |
kwargs["use_cache"] = True # config.use_cache | |
# Collect special token IDs. | |
pad_token_id = config.pad_token_id | |
bos_token_id = config.bos_token_id | |
eos_token_id = config.eos_token_id | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
if pad_token_id is None and eos_token_id is not None: | |
pad_token_id = eos_token_id[0] | |
# Generate from eos if no input is specified. | |
if input_length == 0: | |
input_ids = input_ids.new_ones((batch_size, 1)).long() | |
if eos_token_id is not None: | |
input_ids = input_ids * eos_token_id[0] | |
input_length = 1 | |
# Prepare inputs for encoder-decoder models. | |
if self.model.config.is_encoder_decoder: | |
# Get outputs from the encoder. | |
encoder = self.model.get_encoder() | |
encoder_kwargs = kwargs.copy() | |
encoder_kwargs.pop("use_cache", None) | |
encoder_kwargs["input_ids"] = input_ids | |
encoder_kwargs["return_dict"] = True | |
encoder_outputs = self._infer(encoder, **encoder_kwargs) | |
kwargs["encoder_outputs"] = encoder_outputs | |
# Reinitialize inputs for the decoder. | |
decoder_start_token_id = config.decoder_start_token_id | |
if decoder_start_token_id is None: | |
decoder_start_token_id = bos_token_id | |
input_ids = input_ids.new_ones((batch_size, 1)) | |
input_ids = input_ids * decoder_start_token_id | |
input_length = 1 | |
# Set up logits processor. | |
processor = self._logits_processor(config, input_length) | |
# Keep track of which sequences are already finished. | |
unfinished = input_ids.new_ones(batch_size) | |
# Start auto-regressive generation. | |
while True: | |
inputs = self.model.prepare_inputs_for_generation( | |
input_ids, **kwargs | |
) # noqa: E501 | |
outputs = self._infer( | |
self.model, | |
**inputs, | |
return_dict=True, | |
output_attentions=False, | |
output_hidden_states=False, | |
) | |
# Pre-process the probability distribution of the next tokens. | |
logits = outputs.logits[:, -1, :] | |
with torch.inference_mode(): | |
logits = processor(input_ids, logits) | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
# Select deterministic or stochastic decoding strategy. | |
if (config.top_p is not None and config.top_p <= 0) or ( | |
config.temperature is not None and config.temperature <= 0 | |
): | |
tokens = torch.argmax(probs, dim=-1)[:, None] | |
else: | |
tokens = torch.multinomial(probs, num_samples=1) | |
tokens = tokens.squeeze(1) | |
# Finished sequences should have their next token be a padding. | |
if pad_token_id is not None: | |
tokens = tokens * unfinished + pad_token_id * (1 - unfinished) | |
# Append selected tokens to the inputs. | |
input_ids = torch.cat([input_ids, tokens[:, None]], dim=-1) | |
# Mark sequences with eos tokens as finished. | |
if eos_token_id is not None: | |
not_eos = sum(tokens != i for i in eos_token_id) | |
unfinished = unfinished.mul(not_eos.long()) | |
# Set status to -1 if exceeded the max length. | |
status = unfinished.clone() | |
if input_ids.shape[-1] - input_length >= config.max_new_tokens: | |
status = 0 - status | |
# Yield predictions and status. | |
yield tokens | |
# Stop when finished or exceeded the max length. | |
if status.max() <= 0: | |
break | |