Spaces:
Runtime error
Runtime error
| import logging | |
| import sys | |
| import os | |
| import torch | |
| import json | |
| from typing import Optional, Tuple, Union, List, Callable | |
| from transformers import LlamaForCausalLM | |
| from transformers.generation.logits_process import LogitsProcessor | |
| from transformers.generation.beam_search import BeamSearchScorer | |
| from transformers.deepspeed import is_deepspeed_zero3_enabled | |
| from transformers.generation.utils import ( | |
| LogitsProcessorList, | |
| StoppingCriteriaList, | |
| GenerationConfig, | |
| GenerationMixin, | |
| ) | |
| import warnings | |
| from peft import PeftModel, PeftModelForCausalLM, LoraConfig | |
| import peft | |
| import torch.distributed as dist | |
| from torch import nn | |
| import copy | |
| from accelerate.hooks import ( | |
| AlignDevicesHook, | |
| add_hook_to_module, | |
| remove_hook_from_submodules, | |
| ) | |
| from accelerate.utils import get_balanced_memory | |
| from huggingface_hub import hf_hub_download | |
| from accelerate import dispatch_model, infer_auto_device_map | |
| from peft.utils import PeftType, set_peft_model_state_dict | |
| def printf(*args,**kargs): | |
| if os.environ.get('DEBUG',False): | |
| end = '\n' | |
| if 'end' in kargs: | |
| end = kargs['end'] | |
| print(*args, end=end, flush=True) | |
| class ColorFormatter(logging.Formatter): | |
| grey = "\x1b[38;20m" | |
| blue = "\x1b[34;20m" | |
| yellow = "\x1b[33;20m" | |
| red = "\x1b[31;20m" | |
| bold_red = "\x1b[31;1m" | |
| reset = "\x1b[0m" | |
| def __init__(self, fmt): | |
| super().__init__(fmt) | |
| self.FORMATS = { | |
| logging.DEBUG: self.grey + fmt + self.reset, | |
| logging.INFO: self.blue + fmt + self.reset, | |
| logging.WARNING: self.yellow + fmt + self.reset, | |
| logging.ERROR: self.red + fmt + self.reset, | |
| logging.CRITICAL: self.bold_red + fmt + self.reset | |
| } | |
| def format(self, record): | |
| log_fmt = self.FORMATS.get(record.levelno) | |
| formatter = logging.Formatter(log_fmt) | |
| return formatter.format(record) | |
| def set_console_logger(name): | |
| logger = logging.getLogger(name) | |
| logger.setLevel(logging.DEBUG) | |
| consoleHandler = logging.StreamHandler(sys.stdout) | |
| consoleHandler.setLevel(logging.INFO) | |
| consoleHandler.setFormatter(ColorFormatter("%(asctime)s | %(levelname)s %(message)s")) | |
| logger.addHandler(consoleHandler) | |
| return logger | |
| def set_file_logger(name, dir, use_console=False): | |
| logger = logging.getLogger(name) | |
| logger.setLevel(logging.DEBUG) | |
| os.makedirs(dir, exist_ok=True) | |
| if use_console: | |
| logger.propagate = False # disable default handler | |
| consoleHandler = logging.StreamHandler(sys.stdout) | |
| consoleHandler.setLevel(logging.INFO) | |
| consoleHandler.setFormatter(ColorFormatter("%(asctime)s | %(levelname)s %(message)s")) | |
| logger.addHandler(consoleHandler) | |
| fileHandler = logging.FileHandler(os.path.join(dir,'session.log'), mode='a') | |
| fileHandler.setLevel(logging.INFO) | |
| fileHandler.setFormatter(logging.Formatter("%(asctime)s | %(levelname)s %(message)s")) | |
| logger.addHandler(fileHandler) | |
| return logger | |
| def to_jsonl(data, path): | |
| with open(path, 'a') as f: | |
| for line in data: | |
| f.write(json.dumps(line,ensure_ascii=False)+'\n') | |
| def from_json(path): | |
| return json.load(open(path)) | |
| def from_jsonl(path): | |
| return [json.loads(line) for line in open(path, 'r') ] | |
| def to_json(data, path): | |
| json.dump(data, open(path, 'w'), ensure_ascii=False) | |
| class StreamGenerationMixin(GenerationMixin): | |
| # support for streamly generation | |
| # TODO: group_beam_search | |
| def stream_generate( | |
| self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| generation_config: Optional[GenerationConfig] = None, | |
| logits_processor: Optional[LogitsProcessorList] = None, | |
| stopping_criteria: Optional[StoppingCriteriaList] = None, | |
| prefix_allowed_tokens_fn: Optional[ | |
| Callable[[int, torch.Tensor], List[int]] | |
| ] = None, | |
| **kwargs, | |
| ): | |
| if is_deepspeed_zero3_enabled() and dist.world_size() > 1: | |
| synced_gpus = True | |
| else: | |
| synced_gpus = False | |
| if kwargs.get("attention_mask", None) is not None: | |
| # concat prompt attention mask | |
| prefix_attention_mask = torch.ones( | |
| kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens | |
| ).to(kwargs["input_ids"].device) | |
| kwargs["attention_mask"] = torch.cat( | |
| (prefix_attention_mask, kwargs["attention_mask"]), dim=1 | |
| ) | |
| if kwargs.get("position_ids", None) is not None: | |
| warnings.warn( | |
| "Position ids are not supported for parameter efficient tuning. Ignoring position ids." | |
| ) | |
| kwargs["position_ids"] = None | |
| if kwargs.get("token_type_ids", None) is not None: | |
| warnings.warn( | |
| "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids" | |
| ) | |
| kwargs["token_type_ids"] = None | |
| batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] | |
| if generation_config is None: | |
| generation_config = self.generation_config | |
| generation_config = copy.deepcopy(generation_config) | |
| model_kwargs = generation_config.update(**kwargs) | |
| bos_token_id, eos_token_id, pad_token_id = ( | |
| generation_config.bos_token_id, | |
| generation_config.eos_token_id, | |
| generation_config.pad_token_id, | |
| ) | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| has_default_max_length = ( | |
| kwargs.get("max_length") is None | |
| and generation_config.max_length is not None | |
| ) | |
| if has_default_max_length and generation_config.max_new_tokens is None: | |
| warnings.warn( | |
| f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " | |
| "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" | |
| " recommend using `max_new_tokens` to control the maximum length of the generation.", | |
| UserWarning, | |
| ) | |
| elif generation_config.max_new_tokens is not None: | |
| generation_config.max_length = ( | |
| generation_config.max_new_tokens + input_ids_seq_length | |
| ) | |
| if generation_config.min_new_tokens is not None: | |
| generation_config.min_length = ( | |
| generation_config.min_new_tokens + input_ids_seq_length | |
| ) | |
| if input_ids_seq_length >= generation_config.max_length: | |
| input_ids_string = ( | |
| "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" | |
| ) | |
| # 2. Set generation parameters if not already defined | |
| logits_processor = ( | |
| logits_processor if logits_processor is not None else LogitsProcessorList() | |
| ) | |
| stopping_criteria = ( | |
| stopping_criteria | |
| if stopping_criteria is not None | |
| else StoppingCriteriaList() | |
| ) | |
| # 7. determine generation mode | |
| is_constraint_gen_mode = ( | |
| generation_config.constraints is not None or generation_config.force_words_ids is not None | |
| ) | |
| is_contrastive_search_gen_mode = ( | |
| generation_config.top_k is not None | |
| and generation_config.top_k > 1 | |
| and generation_config.do_sample is False | |
| and generation_config.penalty_alpha is not None | |
| and generation_config.penalty_alpha > 0 | |
| ) | |
| is_greedy_gen_mode = ( | |
| (generation_config.num_beams == 1) | |
| and (generation_config.num_beam_groups == 1) | |
| and generation_config.do_sample is False | |
| and not is_constraint_gen_mode | |
| and not is_contrastive_search_gen_mode | |
| ) | |
| # beam=1 and do_sample=True | |
| is_sample_gen_mode = ( | |
| (generation_config.num_beams == 1) | |
| and (generation_config.num_beam_groups == 1) | |
| and generation_config.do_sample is True | |
| and not is_constraint_gen_mode | |
| and not is_contrastive_search_gen_mode | |
| ) | |
| is_beam_gen_mode = ( | |
| (generation_config.num_beams > 1) | |
| and (generation_config.num_beam_groups == 1) | |
| and generation_config.do_sample is False | |
| and not is_constraint_gen_mode | |
| and not is_contrastive_search_gen_mode | |
| ) | |
| is_beam_sample_gen_mode = ( | |
| (generation_config.num_beams > 1) | |
| and (generation_config.num_beam_groups == 1) | |
| and generation_config.do_sample is True | |
| and not is_constraint_gen_mode | |
| and not is_contrastive_search_gen_mode | |
| ) | |
| is_group_beam_gen_mode = ( | |
| (generation_config.num_beams > 1) | |
| and (generation_config.num_beam_groups > 1) | |
| and not is_constraint_gen_mode | |
| and not is_contrastive_search_gen_mode | |
| ) | |
| # 8. prepare distribution pre_processing samplers | |
| logits_processor = self._get_logits_processor( | |
| generation_config=generation_config, | |
| input_ids_seq_length=input_ids_seq_length, | |
| encoder_input_ids=input_ids, | |
| prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | |
| logits_processor=logits_processor, | |
| ) | |
| # 9. prepare stopping criteria | |
| stopping_criteria = self._get_stopping_criteria( | |
| generation_config=generation_config, stopping_criteria=stopping_criteria | |
| ) | |
| logits_warper = self._get_logits_warper(generation_config) | |
| if is_greedy_gen_mode: | |
| # 11. run greedy search | |
| return self.stream_greedy_search( | |
| input_ids, | |
| logits_processor, | |
| stopping_criteria, | |
| generation_config, | |
| synced_gpus, | |
| **model_kwargs, | |
| ) | |
| elif is_sample_gen_mode: | |
| # 12. expand input_ids with `num_return_sequences` additional sequences per batch | |
| input_ids, model_kwargs = self._expand_inputs_for_generation( | |
| input_ids=input_ids, | |
| expand_size=generation_config.num_return_sequences, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| **model_kwargs, | |
| ) | |
| return self.stream_sample( | |
| generation_config, | |
| input_ids, | |
| logits_processor, | |
| logits_warper, | |
| stopping_criteria, | |
| synced_gpus, | |
| **model_kwargs, | |
| ) | |
| elif is_beam_gen_mode: | |
| return self.stream_beam_search( | |
| generation_config, | |
| input_ids, | |
| logits_processor, | |
| stopping_criteria, | |
| synced_gpus, | |
| **model_kwargs, | |
| ) | |
| elif is_beam_sample_gen_mode: | |
| # interleave input_ids with `num_beams` additional sequences per batch | |
| return self.stream_beam_sample( | |
| input_ids, | |
| logits_processor, | |
| logits_warper, | |
| stopping_criteria, | |
| generation_config, | |
| synced_gpus, | |
| **model_kwargs, | |
| ) | |
| else: | |
| raise Exception('not implement') | |
| def stream_sample( | |
| self, | |
| generation_config, | |
| input_ids, | |
| logits_processor, | |
| logits_warper, | |
| stopping_criteria, | |
| synced_gpus, | |
| **model_kwargs, | |
| ): | |
| bos_token_id, eos_token_id, pad_token_id = ( | |
| generation_config.bos_token_id, | |
| generation_config.eos_token_id, | |
| generation_config.pad_token_id, | |
| ) | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
| # keep track of which sequences are already finished | |
| unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
| this_peer_finished = False # used by synced_gpus only | |
| scores=() | |
| # auto-regressive generation | |
| while True: | |
| if synced_gpus: | |
| # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
| # The following logic allows an early break if all peers finished generating their sequence | |
| this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
| # send 0.0 if we finished, 1.0 otherwise | |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
| # did all peers finish? the reduced sum will be 0.0 then | |
| if this_peer_finished_flag.item() == 0.0: | |
| break | |
| # prepare model inputs | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| # forward pass to get next token | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| continue # don't waste resources running the code we don't need | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # pre-process distribution | |
| next_token_scores = logits_processor(input_ids, next_token_logits) | |
| next_token_scores = logits_warper(input_ids, next_token_scores) | |
| # sample | |
| probs = nn.functional.softmax(next_token_scores, dim=-1) | |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| # finished sentences should have their next token be a padding token | |
| if eos_token_id is not None: | |
| if pad_token_id is None: | |
| raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
| # update generated ids, model inputs, and length for next step | |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
| ) | |
| yield input_ids | |
| # if eos_token was found in one sentence, set sentence to finished | |
| if eos_token_id_tensor is not None: | |
| unfinished_sequences = unfinished_sequences.mul( | |
| next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) | |
| ) | |
| # stop when each sentence is finished, or if we exceed the maximum length | |
| if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): | |
| if not synced_gpus: | |
| break | |
| else: | |
| this_peer_finished = True | |
| yield input_ids | |
| def stream_beam_sample( | |
| self, | |
| input_ids, | |
| logits_processor, | |
| logits_warper, | |
| stopping_criteria, | |
| generation_config, | |
| synced_gpus, | |
| **model_kwargs, | |
| ): | |
| bos_token_id, eos_token_id, pad_token_id = ( | |
| generation_config.bos_token_id, | |
| generation_config.eos_token_id, | |
| generation_config.pad_token_id, | |
| ) | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
| num_beams = generation_config.num_beams | |
| batch_size, cur_len = input_ids.shape[0], input_ids.shape[-1] | |
| beam_scorer = BeamSearchScorer( | |
| batch_size=batch_size, | |
| num_beams=generation_config.num_beams, | |
| device=input_ids.device, | |
| length_penalty=generation_config.length_penalty, | |
| do_early_stopping=generation_config.early_stopping, | |
| num_beam_hyps_to_keep=generation_config.num_return_sequences, | |
| max_length=generation_config.max_length, | |
| ) | |
| input_ids, model_kwargs = self._expand_inputs_for_generation( | |
| input_ids=input_ids, | |
| expand_size=generation_config.num_beams * generation_config.num_return_sequences, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| **model_kwargs, | |
| ) | |
| scores = () | |
| beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) | |
| beam_scores = beam_scores.view((batch_size * num_beams,)) | |
| this_peer_finished = False # used by synced_gpus only | |
| while True: | |
| if synced_gpus: | |
| # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
| # The following logic allows an early break if all peers finished generating their sequence | |
| this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
| # send 0.0 if we finished, 1.0 otherwise | |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
| # did all peers finish? the reduced sum will be 0.0 then | |
| if this_peer_finished_flag.item() == 0.0: | |
| break | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| cur_len = cur_len + 1 | |
| continue # don't waste resources running the code we don't need | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` | |
| # cannot be generated both before and after the `nn.functional.log_softmax` operation. | |
| next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) | |
| next_token_scores = nn.functional.log_softmax( | |
| next_token_logits, dim=-1 | |
| ) # (batch_size * num_beams, vocab_size) | |
| next_token_scores_processed = logits_processor(input_ids, next_token_scores) | |
| next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) | |
| # Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers | |
| # (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see | |
| # https://github.com/huggingface/transformers/pull/5420#discussion_r449779867 | |
| next_token_scores = logits_warper(input_ids, next_token_scores) | |
| # reshape for beam search | |
| vocab_size = next_token_scores.shape[-1] | |
| next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) | |
| probs = nn.functional.softmax(next_token_scores, dim=-1) | |
| next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) | |
| next_token_scores = torch.gather(next_token_scores, -1, next_tokens) | |
| next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) | |
| next_tokens = torch.gather(next_tokens, -1, _indices) | |
| next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") | |
| next_tokens = next_tokens % vocab_size | |
| # stateless | |
| beam_outputs = beam_scorer.process( | |
| input_ids, | |
| next_token_scores, | |
| next_tokens, | |
| next_indices, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| beam_indices=None, | |
| ) | |
| beam_scores = beam_outputs["next_beam_scores"] | |
| beam_next_tokens = beam_outputs["next_beam_tokens"] | |
| beam_idx = beam_outputs["next_beam_indices"] | |
| input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
| yield input_ids | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
| ) | |
| if model_kwargs["past_key_values"] is not None: | |
| model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) | |
| # increase cur_len | |
| cur_len = cur_len + 1 | |
| if beam_scorer.is_done or stopping_criteria(input_ids, scores): | |
| if not synced_gpus: | |
| break | |
| else: | |
| this_peer_finished = True | |
| sequence_outputs = beam_scorer.finalize( | |
| input_ids, | |
| beam_scores, | |
| next_tokens, | |
| next_indices, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| max_length=stopping_criteria.max_length, | |
| beam_indices=None, | |
| ) | |
| yield sequence_outputs["sequences"] | |
| def stream_greedy_search( | |
| self, | |
| input_ids, | |
| logits_processor, | |
| stopping_criteria, | |
| generation_config, | |
| synced_gpus, | |
| **model_kwargs, | |
| ): | |
| # init values | |
| bos_token_id, eos_token_id, pad_token_id = ( | |
| generation_config.bos_token_id, | |
| generation_config.eos_token_id, | |
| generation_config.pad_token_id, | |
| ) | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
| # init attention / hidden states / scores tuples | |
| scores = () | |
| # keep track of which sequences are already finished | |
| unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
| this_peer_finished = False # used by synced_gpus only | |
| while True: | |
| if synced_gpus: | |
| # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
| # The following logic allows an early break if all peers finished generating their sequence | |
| this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
| # send 0.0 if we finished, 1.0 otherwise | |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
| # did all peers finish? the reduced sum will be 0.0 then | |
| if this_peer_finished_flag.item() == 0.0: | |
| break | |
| # prepare model inputs | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| # forward pass to get next token | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| continue # don't waste resources running the code we don't need | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # pre-process distribution | |
| next_tokens_scores = logits_processor(input_ids, next_token_logits) | |
| # argmax | |
| next_tokens = torch.argmax(next_tokens_scores, dim=-1) | |
| # finished sentences should have their next token be a padding token | |
| if eos_token_id is not None: | |
| if pad_token_id is None: | |
| raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
| # update generated ids, model inputs, and length for next step | |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
| ) | |
| yield input_ids | |
| # if eos_token was found in one sentence, set sentence to finished | |
| if eos_token_id_tensor is not None: | |
| unfinished_sequences = unfinished_sequences.mul( | |
| next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) | |
| ) | |
| # stop when each sentence is finished, or if we exceed the maximum length | |
| if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): | |
| if not synced_gpus: | |
| break | |
| else: | |
| this_peer_finished = True | |
| yield input_ids | |
| def stream_beam_search( | |
| self, | |
| generation_config, | |
| input_ids, | |
| logits_processor, | |
| stopping_criteria, | |
| synced_gpus, | |
| **model_kwargs, | |
| ): | |
| # 10. go into beam search generation modes | |
| # 11. prepare beam search scorer | |
| bos_token_id, eos_token_id, pad_token_id = ( | |
| generation_config.bos_token_id, | |
| generation_config.eos_token_id, | |
| generation_config.pad_token_id, | |
| ) | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| num_beams = generation_config.num_beams | |
| batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] | |
| beam_scorer = BeamSearchScorer( | |
| batch_size=batch_size, | |
| num_beams=generation_config.num_beams, | |
| device=input_ids.device, | |
| length_penalty=generation_config.length_penalty, | |
| do_early_stopping=generation_config.early_stopping, | |
| num_beam_hyps_to_keep=generation_config.num_return_sequences, | |
| max_length=generation_config.max_length, | |
| ) | |
| # 12. interleave input_ids with `num_beams` additional sequences per batch | |
| input_ids, model_kwargs = self._expand_inputs_for_generation( | |
| input_ids=input_ids, | |
| expand_size=generation_config.num_beams, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| **model_kwargs, | |
| ) | |
| # beam_search logits | |
| batch_beam_size, cur_len = input_ids.shape | |
| if num_beams * batch_size != batch_beam_size: | |
| raise ValueError( | |
| f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." | |
| ) | |
| beam_scores = torch.zeros( | |
| (batch_size, num_beams), dtype=torch.float, device=input_ids.device | |
| ) | |
| beam_scores[:, 1:] = -1e9 | |
| beam_scores = beam_scores.view((batch_size * num_beams,)) | |
| this_peer_finished = False # used by synced_gpus only | |
| while True: | |
| if synced_gpus: | |
| # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
| # The following logic allows an early break if all peers finished generating their sequence | |
| this_peer_finished_flag = torch.tensor( | |
| 0.0 if this_peer_finished else 1.0 | |
| ).to(input_ids.device) | |
| # send 0.0 if we finished, 1.0 otherwise | |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
| # did all peers finish? the reduced sum will be 0.0 then | |
| if this_peer_finished_flag.item() == 0.0: | |
| break | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| cur_len = cur_len + 1 | |
| continue # don't waste resources running the code we don't need | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) hack: adjust tokens for Marian. | |
| next_token_scores = nn.functional.log_softmax( | |
| next_token_logits, dim=-1 | |
| ) # (batch_size * num_beams, vocab_size) | |
| next_token_scores_processed = logits_processor(input_ids, next_token_scores) | |
| next_token_scores = next_token_scores_processed + beam_scores[ | |
| :, None | |
| ].expand_as(next_token_scores) | |
| # reshape for beam search | |
| vocab_size = next_token_scores.shape[-1] | |
| next_token_scores = next_token_scores.view( | |
| batch_size, num_beams * vocab_size | |
| ) | |
| # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) | |
| next_token_scores, next_tokens = torch.topk( | |
| next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True | |
| ) | |
| next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") | |
| next_tokens = next_tokens % vocab_size | |
| # stateless | |
| beam_outputs = beam_scorer.process( | |
| input_ids, | |
| next_token_scores, | |
| next_tokens, | |
| next_indices, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| beam_indices=None, | |
| ) | |
| beam_scores = beam_outputs["next_beam_scores"] | |
| beam_next_tokens = beam_outputs["next_beam_tokens"] | |
| beam_idx = beam_outputs["next_beam_indices"] | |
| input_ids = torch.cat( | |
| [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 | |
| ) | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
| ) | |
| if model_kwargs["past_key_values"] is not None: | |
| model_kwargs["past_key_values"] = self._reorder_cache( | |
| model_kwargs["past_key_values"], beam_idx | |
| ) | |
| # increase cur_len | |
| cur_len = cur_len + 1 | |
| yield input_ids | |
| if beam_scorer.is_done or stopping_criteria(input_ids, None): | |
| if not synced_gpus: | |
| break | |
| else: | |
| this_peer_finished = True | |
| final_result = beam_scorer.finalize( | |
| input_ids, | |
| beam_scores, | |
| next_tokens, | |
| next_indices, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| max_length=stopping_criteria.max_length, | |
| beam_indices=None, | |
| ) | |
| yield final_result["sequences"] | |
| class StreamLlamaForCausalLM(LlamaForCausalLM, StreamGenerationMixin): | |
| pass | |
| class StreamPeftGenerationMixin(PeftModelForCausalLM, StreamGenerationMixin): | |
| # default it call `model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config)`, not cls!! so inherent PeftModelForCausalLM is non sense | |
| def from_pretrained(cls, model, model_id, adapter_name="default", is_trainable=False, **kwargs): | |
| # work in peft==0.3.0 | |
| if peft.__version__ >= '0.3.0' and peft.__version__ != '0.3.0.dev0': | |
| # load the config | |
| from peft.utils import PromptLearningConfig | |
| config = LoraConfig.from_pretrained(model_id) | |
| if (getattr(model, "hf_device_map", None) is not None) and len( | |
| set(model.hf_device_map.values()).intersection({"cpu", "disk"}) | |
| ) > 0: | |
| remove_hook_from_submodules(model) | |
| if isinstance(config, PromptLearningConfig) and is_trainable: | |
| raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") | |
| else: | |
| config.inference_mode = not is_trainable | |
| # here is the hack | |
| model = cls(model, config, adapter_name) | |
| model.load_adapter(model_id, adapter_name, **kwargs) | |
| # NOTICE | |
| model.base_model_prepare_inputs_for_generation = model.base_model.prepare_inputs_for_generation | |
| model._reorder_cache = model.base_model._reorder_cache | |
| return model | |
| else: | |
| return cls.from_pretrained_old_peft_version(model, model_id, **kwargs) | |
| def from_pretrained_old_peft_version(cls, model, model_id, **kwargs): | |
| # work well in peft@e536616888d51b453ed354a6f1e243fecb02ea08 | |
| # load the config | |
| config = LoraConfig.from_pretrained(model_id) | |
| if getattr(model, "hf_device_map", None) is not None: | |
| remove_hook_from_submodules(model) | |
| # here is the hack | |
| model = cls(model, config) | |
| model._reorder_cache = model.base_model._reorder_cache | |
| # load weights if any | |
| if os.path.exists(os.path.join(model_id, "adapter_model.bin")): | |
| filename = os.path.join(model_id, "adapter_model.bin") | |
| else: | |
| try: | |
| filename = hf_hub_download(model_id, "adapter_model.bin") | |
| except: # noqa | |
| raise ValueError( | |
| f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. " | |
| f"Please check that the file {'adapter_model.bin'} is present at {model_id}." | |
| ) | |
| adapters_weights = torch.load( | |
| filename, | |
| map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), | |
| ) | |
| # load the weights into the model | |
| model = set_peft_model_state_dict(model, adapters_weights) | |
| if getattr(model, "hf_device_map", None) is not None: | |
| device_map = kwargs.get("device_map", "auto") | |
| max_memory = kwargs.get("max_memory", None) | |
| no_split_module_classes = model._no_split_modules | |
| if device_map != "sequential": | |
| max_memory = get_balanced_memory( | |
| model, | |
| max_memory=max_memory, | |
| no_split_module_classes=no_split_module_classes, | |
| low_zero=(device_map == "balanced_low_0"), | |
| ) | |
| if isinstance(device_map, str): | |
| device_map = infer_auto_device_map( | |
| model, | |
| max_memory=max_memory, | |
| no_split_module_classes=no_split_module_classes, | |
| ) | |
| model = dispatch_model(model, device_map=device_map) | |
| hook = AlignDevicesHook(io_same_device=True) | |
| if model.peft_config.peft_type == PeftType.LORA: | |
| add_hook_to_module(model.base_model.model, hook) | |
| else: | |
| remove_hook_from_submodules(model.prompt_encoder) | |
| add_hook_to_module(model.base_model, hook) | |
| return model | |