import torch import torch.nn.functional as F from fengshen.models.transfo_xl_denoise.tokenization_transfo_xl_denoise import TransfoXLDenoiseTokenizer from fengshen.models.transfo_xl_denoise.modeling_transfo_xl_denoise import TransfoXLDenoiseModel def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): # This function has been mostly taken from huggingface conversational ai code at # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p > 0.0: # convert to 1D sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 for i in range(sorted_indices.size()[0]): indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] logits[i][indices_to_remove] = filter_value return logits def get_masks_and_position_ids(data, mem_length=None): # Extract batch size and sequence length. batch_size, seq_length = data.size() # Attention mask (lower triangular). attention_mask = torch.ones((1, seq_length, seq_length + mem_length), device=data.device) attention_mask = torch.tril(torch.triu(attention_mask, 1 - seq_length + mem_length), mem_length) attention_mask = attention_mask.unsqueeze(1) # Position ids. position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) position_ids = position_ids.unsqueeze(0).expand_as(data) return attention_mask, position_ids def get_batch(context_tokens, mem_length, batch_size=1): tokens = context_tokens tokens = tokens.view(batch_size, -1).contiguous() # Get the masks and postition ids. attention_mask, position_ids = get_masks_and_position_ids(tokens, mem_length=mem_length) return tokens, attention_mask, position_ids def denoise_generate(model, tokenizer, input_text, device=0, mem_length=512, temperature=1., top_p=0.9, eod_token=50000): ''' Generate with fixed prompt pretrained ''' prompt = f"“{input_text}”改写后是“" res = [] counter = 0 tokens, attention_mask, position_ids = get_batch( torch.LongTensor(tokenizer.encode(prompt)), mem_length, batch_size=1) tokens, attention_mask, position_ids = tokens.cuda( device), attention_mask.cuda(device), position_ids.cuda(device) org_context_length = tokens.shape[-1] model = model.cuda(device) while counter < 100: if counter == 0: mems = [] # empty at the begining output = model(input_ids=tokens, attention_mask=attention_mask, position_ids=position_ids, hidden_states=mems) logits, mems = output.logits, output.hidden_states else: index = org_context_length + counter output = model(input_ids=tokens[:, index - 1: index], position_ids=tokens.new_ones((1, 1)) * (index - 1), attention_mask=tokens.new_ones(1, 1, 1, mem_length + 1, device=device, dtype=torch.float), hidden_states=mems) logits, mems = output.logits, output.hidden_states logits = logits[:, -1] logits /= temperature logits = top_k_logits(logits, top_k=0, top_p=top_p) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1)[0] is_end = prev == eod_token if is_end: break tokens = torch.cat((tokens, prev.view(1, 1)), dim=1) counter += 1 res.append(tokenizer.decode(tokens.view(-1).contiguous().tolist())) return res if __name__ == "__main__": device = 1 tokenizer = TransfoXLDenoiseTokenizer.from_pretrained('IDEA-CCNL/Bigan-Transformer-XL-denoise-1.1B') model = TransfoXLDenoiseModel.from_pretrained('IDEA-CCNL/Bigan-Transformer-XL-denoise-1.1B') input_text = "凡是有成就的人, 都很严肃地对待生命自己的" res = denoise_generate(model, tokenizer, input_text) print(res)