Spaces:
Sleeping
Sleeping
| from typing import Any, Callable, Optional | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| try: | |
| from transformers.generation_logits_process import ( | |
| LogitsProcessorList, | |
| TemperatureLogitsWarper, | |
| TopKLogitsWarper, | |
| TopPLogitsWarper, | |
| ) | |
| except ImportError: | |
| from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper | |
| def prepare_logits_processor(top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| temperature: Optional[float] = None) -> LogitsProcessorList: | |
| processor_list = LogitsProcessorList() | |
| if temperature is not None and temperature != 1.0: | |
| processor_list.append(TemperatureLogitsWarper(temperature)) | |
| if top_k is not None and top_k != 0: | |
| processor_list.append(TopKLogitsWarper(top_k)) | |
| if top_p is not None and top_p < 1.0: | |
| processor_list.append(TopPLogitsWarper(top_p)) | |
| return processor_list | |
| def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: | |
| if dist.is_initialized() and dist.get_world_size() > 1: | |
| # consider DP | |
| unfinished_sequences = unfinished_sequences.clone() | |
| dist.all_reduce(unfinished_sequences) | |
| return unfinished_sequences.max() == 0 | |
| def sample(model: nn.Module, | |
| input_ids: torch.Tensor, | |
| max_length: int, | |
| early_stopping: bool = False, | |
| eos_token_id: Optional[int] = None, | |
| pad_token_id: Optional[int] = None, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| temperature: Optional[float] = None, | |
| prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, | |
| update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, | |
| **model_kwargs) -> torch.Tensor: | |
| if input_ids.size(1) >= max_length: | |
| return input_ids | |
| logits_processor = prepare_logits_processor(top_k, top_p, temperature) | |
| unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) | |
| for _ in range(input_ids.size(1), max_length): | |
| model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else { | |
| 'input_ids': input_ids | |
| } | |
| outputs = model(**model_inputs) | |
| next_token_logits = outputs['logits'][:, -1, :] | |
| # pre-process distribution | |
| next_token_logits = logits_processor(input_ids, next_token_logits) | |
| # sample | |
| probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) | |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| # finished sentences should have their next token be a padding token | |
| if eos_token_id is not None: | |
| if pad_token_id is None: | |
| raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
| # update generated ids, model inputs for next step | |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
| if update_model_kwargs_fn is not None: | |
| model_kwargs = update_model_kwargs_fn(outputs, model_kwargs) | |
| # if eos_token was found in one sentence, set sentence to finished | |
| if eos_token_id is not None: | |
| unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) | |
| # stop when each sentence is finished if early_stopping=True | |
| if early_stopping and _is_sequence_finished(unfinished_sequences): | |
| break | |
| return input_ids | |
| def generate(model: nn.Module, | |
| input_ids: torch.Tensor, | |
| max_length: int, | |
| num_beams: int = 1, | |
| do_sample: bool = True, | |
| early_stopping: bool = False, | |
| eos_token_id: Optional[int] = None, | |
| pad_token_id: Optional[int] = None, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| temperature: Optional[float] = None, | |
| prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, | |
| update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, | |
| **model_kwargs) -> torch.Tensor: | |
| """Generate token sequence. The returned sequence is input_ids + generated_tokens. | |
| Args: | |
| model (nn.Module): model | |
| input_ids (torch.Tensor): input sequence | |
| max_length (int): max length of the returned sequence | |
| num_beams (int, optional): number of beams. Defaults to 1. | |
| do_sample (bool, optional): whether to do sample. Defaults to True. | |
| early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False. | |
| eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None. | |
| pad_token_id (Optional[int], optional): pad token id. Defaults to None. | |
| top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None. | |
| top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None. | |
| temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None. | |
| prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None. | |
| update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None. | |
| """ | |
| is_greedy_gen_mode = ((num_beams == 1) and do_sample is False) | |
| is_sample_gen_mode = ((num_beams == 1) and do_sample is True) | |
| is_beam_gen_mode = ((num_beams > 1) and do_sample is False) | |
| if is_greedy_gen_mode: | |
| # run greedy search | |
| raise NotImplementedError | |
| elif is_sample_gen_mode: | |
| # run sample | |
| return sample(model, | |
| input_ids, | |
| max_length, | |
| early_stopping=early_stopping, | |
| eos_token_id=eos_token_id, | |
| pad_token_id=pad_token_id, | |
| top_k=top_k, | |
| top_p=top_p, | |
| temperature=temperature, | |
| prepare_inputs_fn=prepare_inputs_fn, | |
| update_model_kwargs_fn=update_model_kwargs_fn, | |
| **model_kwargs) | |
| elif is_beam_gen_mode: | |
| raise NotImplementedError | |
| else: | |
| raise ValueError("Unsupported generation mode") | |