Spaces:
Sleeping
Sleeping
import torch | |
from tqdm import tqdm | |
from typing import Optional, Tuple | |
from turtle import forward | |
from torch.nn import CrossEntropyLoss | |
from transformers import AutoModelForCausalLM | |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions | |
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model | |
class GPT2ForInContextClassification(GPT2LMHeadModel): | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, # input token id | |
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
token_type_ids: Optional[torch.LongTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
label_masks: Optional[torch.LongTensor] = None, # mask=1 means it should be calculated loss | |
options :Optional[list] = None, # 如果是分类任务,则可以添加候选label | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
assert len(input_ids.shape) == 3 and input_ids.shape[1] == len(options) # [n, option_size, len] | |
batch_size = input_ids.shape[0] | |
option_size = input_ids.shape[1] | |
input_ids = input_ids.view(-1, input_ids.shape[1], input_ids.shape[2]) # [n*option_size, len] | |
attention_mask = attention_mask.view(-1, input_ids.shape[1], input_ids.shape[2]) if attention_mask is not None else None # [n*option_size, len] | |
token_type_ids = token_type_ids.view(-1, input_ids.shape[1], input_ids.shape[2]) if token_type_ids is not None else None# [n*option_size, len] | |
# labels = labels.view(-1, input_ids.shape[1], input_ids.shape[2]) # [n*option_size, len] | |
transformer_outputs = self.transformer( | |
input_ids, | |
past_key_values=past_key_values, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = transformer_outputs[0] # [n*option_size, len, hidden_size] | |
lm_logits = self.lm_head(hidden_states) # [n*option_size, len, vocab_size] | |
lm_logits = lm_logits.view(batch_size, option_size, input_ids.shape[-1], -1) # [n, option_size, len, vocab_size] | |
# print("len(input_ids)=", len(input_ids[0])) | |
# print("input_ids[-1]=", input_ids[0][-1]) | |
print("lm_logits.shape=", lm_logits.shape) | |
losses = list() | |
if labels is not None: | |
for label, lm_logit in zip(labels, lm_logits): | |
# label: [option_size, len] | |
# lm_logit: [option_size, len, vocab_size] | |
shift_logits = lm_logit[..., :-1, :].contiguous() | |
# print("shift_logits.shape=", shift_logits.shape) | |
shift_labels = label[..., 1:].contiguous() | |
# print("shift_labels=", shift_labels) | |
# print("shift_labels.shape=", shift_labels.shape) | |
# Flatten the tokens | |
loss_fct = CrossEntropyLoss() | |
print("shift_logits.shape=", shift_logits.shape) | |
print("shift_labels.shape=", shift_labels.shape) | |
loss = [loss_fct(shift_logit.view(-1, shift_logit.size(-1)), shift_label.view(-1)) for shift_logit, shift_label in zip(shift_logits, shift_labels)] | |
loss = torch.stack(loss) | |
# print("loss=", loss) | |
if label_masks is not None: | |
loss = loss.view(lm_logits.size(0), lm_logits.size(1)) * label_masks # [option_size, len] | |
loss = torch.sum(loss, axis=1) / torch.sum(label_mask, axis=1) # [option_size] | |
losses.append(loss) | |
losses = torch.stack(losses) # [n, option_size] | |
# 将各个option的loss视为logit,loss越小,对应的概率应越大 | |
loss_logits = torch.softmax(-losses, -1) # [n, option_size] | |
print("losses.shape=", losses.shape) | |
print("loss_logits.shape=", loss_logits.shape) | |
if not return_dict: | |
output = (lm_logits,) + transformer_outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return CausalLMOutputWithCrossAttentions( | |
loss=losses, | |
logits=loss_logits, | |
past_key_values=transformer_outputs.past_key_values, | |
hidden_states=transformer_outputs.hidden_states, | |
attentions=transformer_outputs.attentions, | |
cross_attentions=transformer_outputs.cross_attentions, | |
) | |
if __name__ == "__main__": | |
from transformers import GPT2Tokenizer | |
tokenizer = GPT2Tokenizer.from_pretrained("/Users/wangjianing/Desktop/开源代码与数据模型/模型/gpt2") | |
model = GPT2ForInContextClassification.from_pretrained("/Users/wangjianing/Desktop/开源代码与数据模型/模型/gpt2") | |
# input_text = "The capital city of China is Beijing. The capital city of Japan is Tokyo. The capital city of America" | |
input_text1 = "What are follows emotions? \n\n Input: The book is very nice.\n Output: Great. \n\n Input: I never eat chocolate!\n Output: Bad. \n\n Input: This film is not wonderful.\n Output: Great" | |
input_text2 = "What are follows emotions? \n\n Input: The book is very nice.\n Output: Great. \n\n Input: I never eat chocolate!\n Output: Bad. \n\n Input: This film is not wonderful.\n Output: Bad" | |
# input_text = "This film is wonderful.\n Great." | |
# input_text = "Mr. Chen was born in Shanghai. Obama was born in US. Jinping Xi was born in China." | |
tokenizer.pad_token = tokenizer.eos_token | |
inputs = tokenizer( | |
[input_text1, input_text2], return_tensors="pt", | |
max_length=60, | |
padding="max_length") | |
inputs["input_ids"] = inputs["input_ids"].view(-1, inputs["input_ids"].shape[0], inputs["input_ids"].shape[1]) | |
# inputs["token_type_ids"] = inputs["token_type_ids"].view(-1, inputs["input_ids"].shape[0], inputs["input_ids"].shape[1]) | |
inputs["attention_mask"] = inputs["attention_mask"].view(-1, inputs["input_ids"].shape[0], inputs["input_ids"].shape[1]) | |
inputs["labels"] = inputs["input_ids"] | |
inputs["options"] = torch.Tensor([[0, 1], [0, 1]]).long() | |
print(inputs["input_ids"].shape) | |
label_mask = torch.zeros([1, 2, inputs["input_ids"].shape[2]]) | |
# print(label_mask) | |
label_mask[0][0][20] = 1 | |
label_mask[0][1][20] = 1 | |
print(label_mask) | |
output = model(**inputs, return_dict=True) | |
# print(output["last_hidden_state"]) | |
# print(output["last_hidden_state"].size()) | |
# print(output["logits"]) | |
# print(output["logits"].size()) | |
losses, logits = output["loss"], output["logits"] | |
print("loss=", losses) | |
print("logits=", logits) | |
# gen_output = model.generate(**inputs, max_length=60) | |
# for i in range(len(gen_output)): | |
# gen_result = tokenizer.decode(gen_output[i]) | |
# print("gen_result=", gen_result[len(inputs["input_ids"]):]) | |