"""Custom models for few-shot learning specific operations.""" import torch import torch.nn as nn import transformers import torch.nn.functional as F from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertForSequenceClassification, BertModel, BertOnlyMLMHead from transformers.models.roberta.modeling_roberta import RobertaForSequenceClassification, RobertaModel, RobertaLMHead, RobertaClassificationHead, RobertaPreTrainedModel from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2PreTrainedModel, DebertaV2Model, StableDropout, ContextPooler, DebertaV2OnlyMLMHead from transformers.models.deberta.modeling_deberta import DebertaPreTrainedModel, DebertaModel, StableDropout, ContextPooler, DebertaOnlyMLMHead from transformers.modeling_outputs import SequenceClassifierOutput from transformers.modeling_utils import PreTrainedModel from transformers.models.bert.configuration_bert import BertConfig import logging from models.basic_modules.adapter import RobertaAdaModel, BertAdaModel import os from models.basic_modules.prefix_encoder import PrefixEncoder from tools.model_utils.parameter_freeze import ParameterFreeze freezer = ParameterFreeze() logger = logging.getLogger(__name__) # Note: 如果mask_pos为None,请检查输入的模板是否有标记,是否修改data_collator文件 """ Vanilla Prompt-tuning BERT """ class PromptBertForSequenceClassification(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.pre_seq_len = self.config.pre_seq_len self.hidden_size = self.config.hidden_size # backbone self.bert = BertModel(config) if self.config.use_freezing: self.bert = freezer.freeze_lm(self.bert) # mlm head self.cls = BertOnlyMLMHead(config) self.init_weights() # These attributes should be assigned once the model is initialized self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) # For regression self.lb = None self.ub = None # For label search. self.return_full_softmax = None def freeze_backbone(self, use_freezing: bool=True): if use_freezing: self.bert = freezer.freeze_lm(self.bert) else: self.bert = freezer.unfreeze_lm(self.bert) def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): """ Encoding and obtain logits at masked position """ if mask_pos is not None: mask_pos = mask_pos.squeeze() # Encode everything if inputs_embeds is None: outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) else: outputs = self.bert( None, attention_mask=attention_mask, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) # Get token representation sequence_output, pooled_output = outputs[:2] sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] # Logits over vocabulary tokens prediction_mask_scores = self.cls(sequence_mask_output) # Exit early and only return mask logits. if return_full_softmax: return prediction_mask_scores # Return logits for each label logits = [] for label_id in range(len(self.label_word_list)): logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) logits = torch.cat(logits, -1) # Regression task if self.config.num_labels == 1: logsoftmax = nn.LogSoftmax(-1) logits = logsoftmax(logits) # Log prob of right polarity return logits, sequence_mask_output def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, labels=None, inputs_embeds=None, block_flag=None, return_dict=None, ): logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) loss = None if labels is not None: if self.num_labels == 1: # Regression task loss_fct = nn.KLDivLoss(log_target=True) labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) loss = loss_fct(logits.view(-1, 2), labels) else: if labels.shape == logits.shape: loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), labels, reduction="batchmean") else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) output = (logits,) if self.num_labels == 1: # Regression output output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) if not return_dict: return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, ) """ P-tuning BERT """ class PromptBertPtuningForSequenceClassification(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.pre_seq_len = self.config.pre_seq_len self.hidden_size = self.config.hidden_size # backbone self.bert = BertModel(config) if self.config.use_freezing: self.bert = freezer.freeze_lm(self.bert) # mlm head self.cls = BertOnlyMLMHead(config) # prompt encoder self.prompt_encoder = None # plm embedding layer self.backbone_embeddings = self.bert.embeddings.word_embeddings # prompt embedding layer self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size) self.init_weights() # These attributes should be assigned once the model is initialized self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) # For regression self.lb = None self.ub = None # For label search. self.return_full_softmax = None def freeze_backbone(self, use_freezing: bool=True): if use_freezing: self.bert = freezer.freeze_lm(self.bert) else: self.bert = freezer.unfreeze_lm(self.bert) def generate_continuous_prompt_inputs(self, input_ids, block_flag=None, reparameterization=False): """ Generate continuous prompt embedding """ inputs_embeds = self.backbone_embeddings(input_ids) batch_size = inputs_embeds.shape[0] if block_flag is None: # the first token is set 1, others are set 0 block_flag = torch.zeros_like(input_ids).long().to(inputs_embeds.device) block_flag[:, 0] = 1 try: replace_embeds = self.prompt_embeddings( torch.LongTensor(list(range(self.pre_seq_len))).to(inputs_embeds.device)) except: import pdb pdb.set_trace() replace_embeds = self.prompt_embeddings( torch.LongTensor(list(range(self.pre_seq_len)))) replace_embeds = replace_embeds.unsqueeze(0) # [batch_size, prompt_length, embed_size] if self.prompt_encoder is not None: replace_embeds = self.prompt_encoder(replace_embeds) # edit by wjn if reparameterization: # blocked_indices = (block_flag == 1).nonzero(as_tuple=False).reshape((batch_size, self.pre_seq_len, 2))[:, :, 1] blocked_indices = (block_flag == 1).nonzero() # reparameterization for bidx in range(batch_size): for i in range(blocked_indices.shape[1]): inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[:, i, :].squeeze() else: replace_embeds = replace_embeds.expand(batch_size, self.pre_seq_len, -1).to(inputs_embeds.device) inputs_embeds = torch.cat((replace_embeds, inputs_embeds), dim=1) return inputs_embeds def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): """ Encoding and obtain logits at masked position """ batch_size = inputs_embeds.shape[0] if mask_pos is not None: mask_pos = mask_pos.squeeze() # Encode everything if inputs_embeds is None: outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) else: if inputs_embeds.shape[1] == attention_mask.shape[1]: outputs = self.bert( None, attention_mask=attention_mask, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) # Get token representation sequence_output, pooled_output = outputs[:2] # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() else: if attention_mask is not None: prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).long().to(self.bert.device) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) if token_type_ids is not None: prefix_token_type_ids = torch.zeros(batch_size, self.pre_seq_len).long().to(self.bert.device) token_type_ids = torch.cat((prefix_token_type_ids, token_type_ids), dim=1) outputs = self.bert( None, attention_mask=attention_mask, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) # Get token representation sequence_output, pooled_output = outputs[:2] sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] # Logits over vocabulary tokens prediction_mask_scores = self.cls(sequence_mask_output) # Exit early and only return mask logits. if return_full_softmax: return prediction_mask_scores # Return logits for each label logits = [] for label_id in range(len(self.label_word_list)): logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) logits = torch.cat(logits, -1) # Regression task if self.config.num_labels == 1: logsoftmax = nn.LogSoftmax(-1) logits = logsoftmax(logits) # Log prob of right polarity return logits, sequence_mask_output def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, labels=None, inputs_embeds=None, block_flag=None, return_dict=None, ): inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag) logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) loss = None if labels is not None: if self.num_labels == 1: # Regression task loss_fct = nn.KLDivLoss(log_target=True) labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) loss = loss_fct(logits.view(-1, 2), labels) else: if labels.shape == logits.shape: loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), labels, reduction="batchmean") else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) output = (logits,) if self.num_labels == 1: # Regression output output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) if not return_dict: return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, ) """ Prefix-tuning BERT """ class PromptBertPrefixForSequenceClassification(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.pre_seq_len = self.config.pre_seq_len self.hidden_size = self.config.hidden_size self.n_layer = config.num_hidden_layers self.n_head = config.num_attention_heads self.n_embd = config.hidden_size // config.num_attention_heads # backbone self.bert = BertModel(config) if self.config.use_freezing: self.bert = freezer.freeze_lm(self.bert) # mlm head self.cls = BertOnlyMLMHead(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) # plm embedding layer self.backbone_embeddings = self.bert.embeddings.word_embeddings # prompt embedding layer self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size) # prefix encoder self.prefix_tokens = torch.arange(self.pre_seq_len).long() self.prefix_encoder = PrefixEncoder(config) self.init_weights() # These attributes should be assigned once the model is initialized self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) # For regression self.lb = None self.ub = None # For label search. self.return_full_softmax = None # For regression self.lb = None self.ub = None # For label search. self.return_full_softmax = None def freeze_backbone(self, use_freezing: bool=True): if use_freezing: self.bert = freezer.freeze_lm(self.bert) else: self.bert = freezer.unfreeze_lm(self.bert) def get_prompt(self, batch_size): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) past_key_values = self.prefix_encoder(prefix_tokens) # bsz, seqlen, _ = past_key_values.shape past_key_values = past_key_values.view( batch_size, self.pre_seq_len, self.n_layer * 2, self.n_head, self.n_embd ) past_key_values = self.dropout(past_key_values) past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) return past_key_values def embed_encode(self, input_ids): embedding_output = self.bert.embeddings.word_embeddings(input_ids) return embedding_output def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): batch_size = input_ids.size(0) # add prefix for prompt-tuning past_key_values = self.get_prompt(batch_size=batch_size) prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) if mask_pos is not None: mask_pos = mask_pos.squeeze() # Encode everything outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, past_key_values=past_key_values, ) # Get token representation sequence_output, pooled_output = outputs[:2] # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] # Logits over vocabulary tokens prediction_mask_scores = self.cls(sequence_mask_output) # Exit early and only return mask logits. if return_full_softmax: return prediction_mask_scores # print("prediction_mask_scores.shape=", prediction_mask_scores.shape) # [batch_size, seq_len, vocab_size] # Return logits for each label logits = [] for label_id in range(len(self.label_word_list)): logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) logits = torch.cat(logits, -1) # Regression task if self.config.num_labels == 1: logsoftmax = nn.LogSoftmax(-1) logits = logsoftmax(logits) # Log prob of right polarity return logits, sequence_mask_output def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, labels=None, inputs_embeds=None, block_flag=None, return_dict=None, ): logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) loss = None if labels is not None: if self.num_labels == 1: # Regression task loss_fct = nn.KLDivLoss(log_target=True) labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) loss = loss_fct(logits.view(-1, 2), labels) else: if labels.shape == logits.shape: loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), labels, reduction="batchmean") else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) output = (logits,) if self.num_labels == 1: # Regression output output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) if not return_dict: return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, ) """ Adapter-tuning BERT """ class PromptBertAdapterForSequenceClassification(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.bert = BertAdaModel(config) self.cls = BertOnlyMLMHead(config) self.init_weights() if self.config.use_freezing: self.bert = freezer.freeze_lm_component(self.bert, "adapter") # These attributes should be assigned once the model is initialized self.model_args = None self.data_args = None self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) # For regression self.lb = None self.ub = None # For label search. self.return_full_softmax = None def freeze_backbone(self, use_freezing: bool=True): if use_freezing: self.bert = freezer.freeze_lm_component(self.bert, "adapter") else: self.bert = freezer.unfreeze_lm(self.bert) def embed_encode(self, input_ids): embedding_output = self.bert.embeddings.word_embeddings(input_ids) return embedding_output def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): batch_size = input_ids.size(0) if mask_pos is not None: mask_pos = mask_pos.squeeze() # Encode everything if inputs_embeds is None: outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) else: outputs = self.bert( None, attention_mask=attention_mask, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) # Get token representation sequence_output, pooled_output = outputs[:2] sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] # Logits over vocabulary tokens prediction_mask_scores = self.cls(sequence_mask_output) # Exit early and only return mask logits. if return_full_softmax: return prediction_mask_scores # Return logits for each label logits = [] for label_id in range(len(self.label_word_list)): logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) logits = torch.cat(logits, -1) # Regression task if self.config.num_labels == 1: logsoftmax = nn.LogSoftmax(-1) logits = logsoftmax(logits) # Log prob of right polarity return logits, sequence_mask_output def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, labels=None, inputs_embeds=None, block_flag=None, return_dict=None, ): logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) loss = None if labels is not None: if self.num_labels == 1: # Regression task loss_fct = nn.KLDivLoss(log_target=True) labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) loss = loss_fct(logits.view(-1, 2), labels) else: if labels.shape == logits.shape: loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), labels, reduction="batchmean") else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) output = (logits,) if self.num_labels == 1: # Regression output output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) if not return_dict: return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, ) """ Vanilla Prompt-tuning RoBERTa """ class PromptRobertaForSequenceClassification(RobertaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.pre_seq_len = self.config.pre_seq_len self.hidden_size = self.config.hidden_size # backbone self.roberta = RobertaModel(config) if self.config.use_freezing: self.roberta = freezer.freeze_lm(self.roberta) # mlm head self.cls = RobertaLMHead(config) self.init_weights() # These attributes should be assigned once the model is initialized self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.roberta.device) # For regression self.lb = None self.ub = None # For label search. self.return_full_softmax = None def freeze_backbone(self, use_freezing: bool=True): if use_freezing: self.roberta = freezer.freeze_lm(self.roberta) else: self.roberta = freezer.unfreeze_lm(self.roberta) def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): """ Encoding and obtain logits at masked position """ if mask_pos is not None: mask_pos = mask_pos.squeeze() # Encode everything if inputs_embeds is None: outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) else: outputs = self.roberta( None, attention_mask=attention_mask, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) # Get token representation sequence_output, pooled_output = outputs[:2] sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] # Logits over vocabulary tokens prediction_mask_scores = self.cls(sequence_mask_output) # Exit early and only return mask logits. if return_full_softmax: return prediction_mask_scores # Return logits for each label logits = [] for label_id in range(len(self.label_word_list)): logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) logits = torch.cat(logits, -1) # Regression task if self.config.num_labels == 1: logsoftmax = nn.LogSoftmax(-1) logits = logsoftmax(logits) # Log prob of right polarity return logits, sequence_mask_output def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, labels=None, inputs_embeds=None, block_flag=None, return_dict=None, ): logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) loss = None if labels is not None: if self.num_labels == 1: # Regression task loss_fct = nn.KLDivLoss(log_target=True) labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) loss = loss_fct(logits.view(-1, 2), labels) else: if labels.shape == logits.shape: loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), labels, reduction="batchmean") else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) output = (logits,) if self.num_labels == 1: # Regression output output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) if not return_dict: return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, ) """ P-tuning RoBERTa """ class PromptRobertaPtuningForSequenceClassification(RobertaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.pre_seq_len = self.config.pre_seq_len self.hidden_size = self.config.hidden_size # backbone self.roberta = RobertaModel(config) if self.config.use_freezing: self.roberta = freezer.freeze_lm(self.roberta) # mlm head self.cls = RobertaLMHead(config) # prompt encoder self.prompt_encoder = None # plm embedding layer self.backbone_embeddings = self.roberta.embeddings.word_embeddings # prompt embedding layer self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size) self.init_weights() # These attributes should be assigned once the model is initialized self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.roberta.device) # For regression self.lb = None self.ub = None # For label search. self.return_full_softmax = None def freeze_backbone(self, use_freezing: bool=True): if use_freezing: self.roberta = freezer.freeze_lm(self.roberta) else: self.roberta = freezer.unfreeze_lm(self.roberta) def generate_continuous_prompt_inputs(self, input_ids, block_flag=None, reparameterization=False): """ Generate continuous prompt embedding """ inputs_embeds = self.backbone_embeddings(input_ids) batch_size = inputs_embeds.shape[0] if block_flag is None: # the first token is set 1, others are set 0 block_flag = torch.zeros_like(input_ids).long().to(inputs_embeds.device) block_flag[:, 0] = 1 try: replace_embeds = self.prompt_embeddings( torch.LongTensor(list(range(self.pre_seq_len))).to(inputs_embeds.device)) except: import pdb pdb.set_trace() replace_embeds = self.prompt_embeddings(torch.LongTensor(list(range(self.pre_seq_len)))) replace_embeds = replace_embeds.unsqueeze(0) # [batch_size, prompt_length, embed_size] if self.prompt_encoder is not None: replace_embeds = self.prompt_encoder(replace_embeds) # edit by wjn if reparameterization: # blocked_indices = (block_flag == 1).nonzero(as_tuple=False).reshape((batch_size, self.pre_seq_len, 2))[:, :, 1] blocked_indices = (block_flag == 1).nonzero() # reparameterization for bidx in range(batch_size): for i in range(blocked_indices.shape[1]): inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[:, i, :].squeeze() else: replace_embeds = replace_embeds.expand(batch_size, self.pre_seq_len, -1).to(inputs_embeds.device) inputs_embeds = torch.cat((replace_embeds, inputs_embeds), dim=1) return inputs_embeds def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): """ Encoding and obtain logits at masked position """ batch_size = inputs_embeds.shape[0] if mask_pos is not None: mask_pos = mask_pos.squeeze() # Encode everything if inputs_embeds is None: outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) else: if inputs_embeds.shape[1] == attention_mask.shape[1]: outputs = self.roberta( None, attention_mask=attention_mask, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) # Get token representation sequence_output, pooled_output = outputs[:2] # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() else: if attention_mask is not None: prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).long().to(self.roberta.device) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) if token_type_ids is not None: prefix_token_type_ids = torch.zeros(batch_size, self.pre_seq_len).long().to(self.roberta.device) token_type_ids = torch.cat((prefix_token_type_ids, token_type_ids), dim=1) outputs = self.roberta( None, attention_mask=attention_mask, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) # Get token representation sequence_output, pooled_output = outputs[:2] sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] # Logits over vocabulary tokens prediction_mask_scores = self.cls(sequence_mask_output) # Exit early and only return mask logits. if return_full_softmax: return prediction_mask_scores # Return logits for each label logits = [] for label_id in range(len(self.label_word_list)): logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) logits = torch.cat(logits, -1) # Regression task if self.config.num_labels == 1: logsoftmax = nn.LogSoftmax(-1) logits = logsoftmax(logits) # Log prob of right polarity return logits, sequence_mask_output def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, labels=None, inputs_embeds=None, block_flag=None, return_dict=None, ): inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag) logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) loss = None if labels is not None: if self.num_labels == 1: # Regression task loss_fct = nn.KLDivLoss(log_target=True) labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) loss = loss_fct(logits.view(-1, 2), labels) else: if labels.shape == logits.shape: loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), labels, reduction="batchmean") else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) output = (logits,) if self.num_labels == 1: # Regression output output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) if not return_dict: return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, ) """ Prefix-tuning RoBERTa """ class PromptRobertaPrefixForSequenceClassification(RobertaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.pre_seq_len = self.config.pre_seq_len self.hidden_size = self.config.hidden_size self.n_layer = config.num_hidden_layers self.n_head = config.num_attention_heads self.n_embd = config.hidden_size // config.num_attention_heads # backbone self.robert = RobertaModel(config) if self.config.use_freezing: self.robert = freezer.freeze_lm(self.robert) # mlm head self.cls = RobertaLMHead(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) # plm embedding layer self.backbone_embeddings = self.robert.embeddings.word_embeddings # prompt embedding layer self.prompt_embeddings = torch.nn.Embedding(self.pre_seq_len, self.hidden_size) # prefix encoder self.prefix_tokens = torch.arange(self.pre_seq_len).long() self.prefix_encoder = PrefixEncoder(config) self.init_weights() # These attributes should be assigned once the model is initialized self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.robert.device) # For regression self.lb = None self.ub = None # For label search. self.return_full_softmax = None # For regression self.lb = None self.ub = None # For label search. self.return_full_softmax = None def freeze_backbone(self, use_freezing: bool=True): if use_freezing: self.robert = freezer.freeze_lm(self.robert) else: self.robert = freezer.unfreeze_lm(self.robert) def get_prompt(self, batch_size): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.robert.device) past_key_values = self.prefix_encoder(prefix_tokens) # bsz, seqlen, _ = past_key_values.shape past_key_values = past_key_values.view( batch_size, self.pre_seq_len, self.n_layer * 2, self.n_head, self.n_embd ) past_key_values = self.dropout(past_key_values) past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) return past_key_values def embed_encode(self, input_ids): embedding_output = self.robert.embeddings.word_embeddings(input_ids) return embedding_output def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): batch_size = input_ids.size(0) # add prefix for prompt-tuning past_key_values = self.get_prompt(batch_size=batch_size) prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.robert.device) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) if mask_pos is not None: mask_pos = mask_pos.squeeze() # Encode everything outputs = self.robert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, past_key_values=past_key_values, ) # Get token representation sequence_output, pooled_output = outputs[:2] # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] # Logits over vocabulary tokens prediction_mask_scores = self.cls(sequence_mask_output) # Exit early and only return mask logits. if return_full_softmax: return prediction_mask_scores # Return logits for each label logits = [] for label_id in range(len(self.label_word_list)): logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) logits = torch.cat(logits, -1) # Regression task if self.config.num_labels == 1: logsoftmax = nn.LogSoftmax(-1) logits = logsoftmax(logits) # Log prob of right polarity return logits, sequence_mask_output def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, labels=None, inputs_embeds=None, block_flag=None, return_dict=None, ): logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) loss = None if labels is not None: if self.num_labels == 1: # Regression task loss_fct = nn.KLDivLoss(log_target=True) labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) loss = loss_fct(logits.view(-1, 2), labels) else: if labels.shape == logits.shape: loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), labels, reduction="batchmean") else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) output = (logits,) if self.num_labels == 1: # Regression output output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) if not return_dict: return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, ) """ Adapter-tuning RoBERTa """ class PromptRobertaAdapterForSequenceClassification(RobertaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.roberta = RobertaAdaModel(config) self.cls = RobertaLMHead(config) self.init_weights() if self.config.use_freezing: self.roberta = freezer.freeze_lm_component(self.roberta, "adapter") # These attributes should be assigned once the model is initialized self.model_args = None self.data_args = None self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.roberta.device) # For regression self.lb = None self.ub = None # For label search. self.return_full_softmax = None def freeze_backbone(self, use_freezing: bool=True): if use_freezing: self.roberta = freezer.freeze_lm_component(self.roberta, "adapter") else: self.roberta = freezer.unfreeze_lm(self.berobertart) def embed_encode(self, input_ids): embedding_output = self.roberta.embeddings.word_embeddings(input_ids) return embedding_output def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): batch_size = input_ids.size(0) if mask_pos is not None: mask_pos = mask_pos.squeeze() # Encode everything if inputs_embeds is None: outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) else: outputs = self.roberta( None, attention_mask=attention_mask, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) # Get token representation sequence_output, pooled_output = outputs[:2] sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] # Logits over vocabulary tokens prediction_mask_scores = self.cls(sequence_mask_output) # Exit early and only return mask logits. if return_full_softmax: return prediction_mask_scores # Return logits for each label logits = [] for label_id in range(len(self.label_word_list)): logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) logits = torch.cat(logits, -1) # Regression task if self.config.num_labels == 1: logsoftmax = nn.LogSoftmax(-1) logits = logsoftmax(logits) # Log prob of right polarity return logits, sequence_mask_output def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, labels=None, inputs_embeds=None, block_flag=None, return_dict=None, ): logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) loss = None if labels is not None: if self.num_labels == 1: # Regression task loss_fct = nn.KLDivLoss(log_target=True) labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) loss = loss_fct(logits.view(-1, 2), labels) else: if labels.shape == logits.shape: loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), labels, reduction="batchmean") else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) output = (logits,) if self.num_labels == 1: # Regression output output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) if not return_dict: return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, ) # class DebertaForPromptFinetuning(DebertaPreTrainedModel): # _keys_to_ignore_on_load_unexpected = [r"pooler"] # _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] # def __init__(self, config): # super().__init__(config) # self.num_labels = config.num_labels # #self.deberta = DebertaV2Model(config) # self.deberta = DebertaModel(config) # self.cls = DebertaOnlyMLMHead(config) # if self.config.use_freezing: # self.deberta = freezer.freeze_lm(self.deberta) # self.pooler = ContextPooler(config) # output_dim = self.pooler.output_dim # self.classifier = torch.nn.Linear(output_dim, self.num_labels) # drop_out = getattr(config, "cls_dropout", None) # drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out # self.dropout = StableDropout(drop_out) # classification_list = [self.pooler, self.dropout,self.classifier] # self.classifier = nn.Sequential(*classification_list) # # self.cls = DebertaV2OnlyMLMHead(config) # self.map = nn.Linear(config.hidden_size, config.hidden_size) # self.init_weights() # # These attributes should be assigned once the model is initialized # self.model_args = None # self.data_args = None # self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) # self.K = 1 # self.step_size=1e-5 # # import pdb # # pdb.set_trace() # #self.step_size=config.step_size # # For regression # self.lb = None # self.ub = None # self.pre_seq_len = self.config.pre_seq_len # # For auto label search. # self.return_full_softmax = None # def freeze_backbone(self, use_freezing: bool=True): # if use_freezing: # self.deberta = freezer.freeze_lm(self.deberta) # else: # self.deberta = freezer.unfreeze_lm(self.deberta) # def embed_encode(self, input_ids): # embedding_output = self.deberta.embeddings.word_embeddings(input_ids) # return embedding_output # def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, # return_full_softmax=False): # batch_size = input_ids.size(0) # if mask_pos is not None: # mask_pos = mask_pos.squeeze() # # Encode everything # if inputs_embeds is None: # outputs = self.deberta( # input_ids, # attention_mask=attention_mask, # token_type_ids=token_type_ids # ) # else: # outputs = self.deberta( # None, # attention_mask=attention_mask, # token_type_ids=token_type_ids, # inputs_embeds=inputs_embeds # ) # # Get token representation # sequence_output = outputs[0] # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() # sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] # # Logits over vocabulary tokens # prediction_mask_scores = self.cls(sequence_mask_output) # # sequence_mask_output = self.lm_head.dense(sequence_mask_output) # # Exit early and only return mask logits. # if return_full_softmax: # return prediction_mask_scores # # Return logits for each label # logits = [] # for label_id in range(len(self.label_word_list)): # logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) # logits = torch.cat(logits, -1) # # Regression task # if self.config.num_labels == 1: # logsoftmax = nn.LogSoftmax(-1) # logits = logsoftmax(logits) # Log prob of right polarity # if self.model_args.hybrid == 1: # cls_logits = self.classifier(sequence_output) # return (logits, cls_logits), sequence_mask_output # return logits, sequence_mask_output # def forward( # self, # input_ids=None, # attention_mask=None, # token_type_ids=None, # mask_pos=None, # labels=None, # inputs_embeds=None, # fwd_type=0, # block_flag=None # ): # if fwd_type == 2: # assert inputs_embeds is not None # return self.encode(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, # mask_pos=mask_pos, inputs_embeds=inputs_embeds) # elif fwd_type == 1: # return self.embed_encode(input_ids) # if (self.model_args.prompt_ptuning or self.model_args.prompt_prefix) and block_flag is not None: # inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag) # logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) # if self.model_args.hybrid == 1: # logits = logits[0] # cls_logits = logits[1] # loss = None # if labels is not None: # if self.num_labels == 1: # # Regression task # loss_fct = nn.KLDivLoss(log_target=True) # labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), # (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) # loss = loss_fct(logits.view(-1, 2), labels) # else: # if labels.shape == logits.shape: # loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), # labels, reduction="batchmean") # else: # loss_fct = nn.CrossEntropyLoss() # loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) # output = (logits,) # if self.num_labels == 1: # # Regression output # output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) # return ((loss,) + output) if loss is not None else output # # add by wjn # # Prefix-tuning for Deberta # class DebertaPrefixForPromptFinetuning(DebertaPreTrainedModel): # def __init__(self, config): # super().__init__(config) # self.num_labels = config.num_labels # #self.deberta = DebertaV2Model(config) # self.deberta = DebertaModel(config) # self.cls = DebertaOnlyMLMHead(config) # self.pooler = ContextPooler(config) # output_dim = self.pooler.output_dim # self.classifier = torch.nn.Linear(output_dim, self.num_labels) # drop_out = getattr(config, "cls_dropout", None) # drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out # self.dropout = StableDropout(drop_out) # classification_list = [self.pooler, self.dropout,self.classifier] # self.classifier = nn.Sequential(*classification_list) # # self.cls = DebertaV2OnlyMLMHead(config) # self.map = nn.Linear(config.hidden_size, config.hidden_size) # self.init_weights() # if self.config.use_freezing: # self.deberta = freezer.freeze_lm(self.deberta) # self.pre_seq_len = config.pre_seq_len # self.n_layer = config.num_hidden_layers # self.n_head = config.num_attention_heads # self.n_embd = config.hidden_size // config.num_attention_heads # self.prefix_tokens = torch.arange(self.pre_seq_len).long() # self.prefix_encoder = PrefixEncoder(config) # # These attributes should be assigned once the model is initialized # self.model_args = None # self.data_args = None # self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) # self.K = 1 # self.step_size=1e-5 # # import pdb # # pdb.set_trace() # #self.step_size=config.step_size # # For regression # self.lb = None # self.ub = None # # For auto label search. # self.return_full_softmax = None # def freeze_backbone(self, use_freezing: bool=True): # if use_freezing: # self.deberta = freezer.freeze_lm(self.deberta) # else: # self.deberta = freezer.unfreeze_lm(self.deberta) # def get_prompt(self, batch_size): # prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device) # past_key_values = self.prefix_encoder(prefix_tokens) # # bsz, seqlen, _ = past_key_values.shape # past_key_values = past_key_values.view( # batch_size, # self.pre_seq_len, # self.n_layer * 2, # self.n_head, # self.n_embd # ) # past_key_values = self.dropout(past_key_values) # past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) # return past_key_values # def get_constrast_loss(self, # input_ids=None, # attention_mask=None, # mask_pos=None, # labels=None, # inputs_embeds=None): # self.cos = nn.CosineSimilarity(dim=-1) # _, sequence_mask_output_1 = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds) # _, sequence_mask_output_2 = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds) # sequence_mask_output_1= self.lm_head.dense(sequence_mask_output_1) # sequence_mask_output_2 = self.lm_head.dense(sequence_mask_output_2) # # input_args = [input_ids, attention_mask, mask_pos, labels, None, 1] # # embed = self.forward(*input_args) # # # # vat_args = [input_ids, attention_mask, mask_pos, labels, embed, 2] # # # # adv_logits, outputs = self.forward(*vat_args) # # # # logit_mask = F.softmax(logits, dim=-1)[torch.arange(adv_logits.size(0)), labels] > 0.7 # # # # outputs = outputs[logit_mask] # # seq_outputs = sequence_mask_output[logit_mask] # # new_label = labels[logit_mask] # # # # # # # # rand_perm = torch.randperm(outputs.size(0)) # # rand_outputs = outputs[rand_perm, :] # # rand_label = new_label[rand_perm] # # pair_label = (new_label == rand_label).long() # # # # seq_outputs = self.map(seq_outputs) # # rand_outputs = self.map(rand_outputs) # pair_labels = (labels.unsqueeze(1) == labels.unsqueeze(0)).float() # # import pdb # # pdb.set_trace() # contra_loss = self.contra_lc(sequence_mask_output_1.unsqueeze(1), sequence_mask_output_2.unsqueeze(0), pair_labels) # if torch.isnan(contra_loss): # return 0 # return contra_loss # def embed_encode(self, input_ids): # embedding_output = self.deberta.embeddings.word_embeddings(input_ids) # return embedding_output # def encode(self, input_ids=None, attention_mask=None, token_type_ids=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): # batch_size = input_ids.size(0) # # add prefix for prompt-tuning # past_key_values = self.get_prompt(batch_size=batch_size) # prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device) # attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) # if mask_pos is not None: # mask_pos = mask_pos.squeeze() # # Encode everything # outputs = self.deberta( # input_ids, # attention_mask=attention_mask, # token_type_ids=token_type_ids, # past_key_values=past_key_values, # ) # # Get token representation # sequence_output, pooled_output = outputs[:2] # # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() # sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] # # Logits over vocabulary tokens # prediction_mask_scores = self.cls(sequence_mask_output) # #sequence_mask_output = self.lm_head.dense(sequence_mask_output) # # Exit early and only return mask logits. # if return_full_softmax: # return prediction_mask_scores # # Return logits for each label # logits = [] # for label_id in range(len(self.label_word_list)): # logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) # logits = torch.cat(logits, -1) # # Regression task # if self.config.num_labels == 1: # logsoftmax = nn.LogSoftmax(-1) # logits = logsoftmax(logits) # Log prob of right polarity # if self.model_args.hybrid == 1: # cls_logits = self.classifier(sequence_output) # return (logits, cls_logits), sequence_mask_output # return logits, sequence_mask_output # def forward( # self, # input_ids=None, # attention_mask=None, # token_type_ids=None, # mask_pos=None, # labels=None, # inputs_embeds=None, # fwd_type=0, # block_flag=None, # return_dict=None, # ): # if fwd_type == 2: # assert inputs_embeds is not None # return self.encode(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, # mask_pos=mask_pos, inputs_embeds=inputs_embeds) # elif fwd_type == 1: # return self.embed_encode(input_ids) # if (self.model_args.prompt_ptuning or self.model_args.prompt_prefix) and block_flag is not None: # inputs_embeds = self.generate_continuous_prompt_inputs(input_ids, block_flag) # logits, sequence_mask_output = self.encode(input_ids, attention_mask, token_type_ids, mask_pos, inputs_embeds) # if self.model_args.hybrid == 1: # logits = logits[0] # cls_logits = logits[1] # loss = None # if labels is not None: # if self.num_labels == 1: # # Regression task # loss_fct = nn.KLDivLoss(log_target=True) # labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), # (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) # loss = loss_fct(logits.view(-1, 2), labels) # else: # if labels.shape == logits.shape: # loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), # labels, reduction="batchmean") # else: # loss_fct = nn.CrossEntropyLoss() # loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) # output = (logits,) # if self.num_labels == 1: # # Regression output # output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) # if not return_dict: # return ((loss,) + output) if loss is not None else output # return SequenceClassifierOutput( # loss=loss, # logits=logits, # ) # class Debertav2ForPromptFinetuning(DebertaV2PreTrainedModel): # _keys_to_ignore_on_load_unexpected = [r"pooler"] # _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] # def __init__(self, config): # super().__init__(config) # self.num_labels = config.num_labels # self.deberta = DebertaV2Model(config) # if self.config.use_freezing: # self.deberta = freezer.freeze_lm(self.deberta) # self.cls = DebertaV2OnlyMLMHead(config) # #self.deberta = DebertaModel(config) # #self.cls = DebertaOnlyMLMHead(config) # self.pooler = ContextPooler(config) # output_dim = self.pooler.output_dim # self.classifier = torch.nn.Linear(output_dim, self.num_labels) # drop_out = getattr(config, "cls_dropout", None) # drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out # self.dropout = StableDropout(drop_out) # classification_list = [self.pooler, self.dropout,self.classifier] # self.classifier = nn.Sequential(*classification_list) # # self.cls = DebertaV2OnlyMLMHead(config) # self.map = nn.Linear(config.hidden_size, config.hidden_size) # self.init_weights() # # These attributes should be assigned once the model is initialized # self.model_args = None # self.data_args = None # self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) # self.K = 1 # self.step_size=1e-5 # # import pdb # # pdb.set_trace() # #self.step_size=config.step_size # # For regression # self.lb = None # self.ub = None # self.pre_seq_len = self.config.pre_seq_len # # For auto label search. # self.return_full_softmax = None # def freeze_backbone(self, use_freezing: bool=True): # if use_freezing: # self.deberta = freezer.freeze_lm(self.deberta) # else: # self.deberta = freezer.unfreeze_lm(self.deberta) # def embed_encode(self, input_ids): # embedding_output = self.deberta.embeddings.word_embeddings(input_ids) # return embedding_output # def encode(self, input_ids=None, attention_mask=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): # batch_size = input_ids.size(0) # if mask_pos is not None: # mask_pos = mask_pos.squeeze() # # Encode everything # if inputs_embeds is None: # outputs = self.deberta( # input_ids, # attention_mask=attention_mask # ) # else: # outputs = self.deberta( # None, # attention_mask=attention_mask, # inputs_embeds=inputs_embeds # ) # # Get token representation # sequence_output = outputs[0] # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() # sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] # # Logits over vocabulary tokens # prediction_mask_scores = self.cls(sequence_mask_output) # #sequence_mask_output = self.lm_head.dense(sequence_mask_output) # # Exit early and only return mask logits. # if return_full_softmax: # return prediction_mask_scores # # Return logits for each label # logits = [] # for label_id in range(len(self.label_word_list)): # logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) # logits = torch.cat(logits, -1) # # Regression task # if self.config.num_labels == 1: # logsoftmax = nn.LogSoftmax(-1) # logits = logsoftmax(logits) # Log prob of right polarity # return logits, sequence_mask_output # def forward( # self, # input_ids=None, # attention_mask=None, # mask_pos=None, # labels=None, # inputs_embeds=None, # fwd_type=0, # block_flag=None, # return_dict=None # ): # if fwd_type == 2: # assert inputs_embeds is not None # return self.encode(input_ids=input_ids, attention_mask=attention_mask, mask_pos=mask_pos, inputs_embeds=inputs_embeds) # elif fwd_type == 1: # return self.embed_encode(input_ids) # logits, sequence_mask_output = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds) # loss = None # if labels is not None: # if self.num_labels == 1: # # Regression task # loss_fct = nn.KLDivLoss(log_target=True) # labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) # loss = loss_fct(logits.view(-1, 2), labels) # else: # if labels.shape == logits.shape: # loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), # labels, reduction="batchmean") # else: # loss_fct = nn.CrossEntropyLoss() # loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) # if self.model_args.hybrid == 1: # cls_loss = loss_fct(cls_logits.view(-1, cls_logits.size(-1)), labels.view(-1)) # loss = loss + cls_loss # output = (logits,) # if self.num_labels == 1: # # Regression output # output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) # if not return_dict: # return ((loss,) + output) if loss is not None else output # return SequenceClassifierOutput( # loss=loss, # logits=logits, # ) # class Debertav2PrefixForPromptFinetuning(DebertaV2PreTrainedModel): # _keys_to_ignore_on_load_unexpected = [r"pooler"] # _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] # def __init__(self, config): # super().__init__(config) # self.num_labels = config.num_labels # self.deberta = DebertaV2Model(config) # self.cls = DebertaV2OnlyMLMHead(config) # #self.deberta = DebertaModel(config) # #self.cls = DebertaOnlyMLMHead(config) # self.pooler = ContextPooler(config) # output_dim = self.pooler.output_dim # self.classifier = torch.nn.Linear(output_dim, self.num_labels) # drop_out = getattr(config, "cls_dropout", None) # drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out # self.dropout = StableDropout(drop_out) # classification_list = [self.pooler, self.dropout,self.classifier] # self.classifier = nn.Sequential(*classification_list) # # self.cls = DebertaV2OnlyMLMHead(config) # self.map = nn.Linear(config.hidden_size, config.hidden_size) # self.init_weights() # if self.config.use_freezing: # self.deberta = freezer.freeze_lm(self.deberta) # self.pre_seq_len = config.pre_seq_len # self.n_layer = config.num_hidden_layers # self.n_head = config.num_attention_heads # self.n_embd = config.hidden_size // config.num_attention_heads # self.prefix_tokens = torch.arange(self.pre_seq_len).long() # self.prefix_encoder = PrefixEncoder(config) # # These attributes should be assigned once the model is initialized # self.model_args = None # self.data_args = None # self.label_word_list = torch.Tensor(self.config.label_word_list).long().to(self.bert.device) # self.K = 1 # self.step_size=1e-5 # # import pdb # # pdb.set_trace() # #self.step_size=config.step_size # # For regression # self.lb = None # self.ub = None # # For auto label search. # self.return_full_softmax = None # def freeze_backbone(self, use_freezing: bool=True): # if use_freezing: # self.deberta = freezer.freeze_lm(self.deberta) # else: # self.deberta = freezer.unfreeze_lm(self.deberta) # def get_prompt(self, batch_size): # prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device) # past_key_values = self.prefix_encoder(prefix_tokens) # # bsz, seqlen, _ = past_key_values.shape # past_key_values = past_key_values.view( # batch_size, # self.pre_seq_len, # self.n_layer * 2, # self.n_head, # self.n_embd # ) # past_key_values = self.dropout(past_key_values) # past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) # return past_key_values # def embed_encode(self, input_ids): # embedding_output = self.deberta.embeddings.word_embeddings(input_ids) # return embedding_output # def encode(self, input_ids=None, attention_mask=None, mask_pos=None, inputs_embeds=None, return_full_softmax=False): # batch_size = input_ids.size(0) # # add prefix for prompt-tuning # past_key_values = self.get_prompt(batch_size=batch_size) # prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device) # attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) # if mask_pos is not None: # mask_pos = mask_pos.squeeze() # # Encode everything # outputs = self.deberta( # input_ids, # attention_mask=attention_mask, # past_key_values=past_key_values, # ) # # Get token representation # sequence_output = outputs[0] # # sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() # sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] # # Logits over vocabulary tokens # prediction_mask_scores = self.cls(sequence_mask_output) # #sequence_mask_output = self.lm_head.dense(sequence_mask_output) # # Exit early and only return mask logits. # if return_full_softmax: # return prediction_mask_scores # # Return logits for each label # logits = [] # for label_id in range(len(self.label_word_list)): # logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) # logits = torch.cat(logits, -1) # # Regression task # if self.config.num_labels == 1: # logsoftmax = nn.LogSoftmax(-1) # logits = logsoftmax(logits) # Log prob of right polarity # return logits, sequence_mask_output # def forward( # self, # input_ids=None, # attention_mask=None, # mask_pos=None, # labels=None, # inputs_embeds=None, # fwd_type=0, # block_flag=None, # return_dict=None, # ): # if fwd_type == 2: # assert inputs_embeds is not None # return self.encode(input_ids=input_ids, attention_mask=attention_mask, mask_pos=mask_pos, inputs_embeds=inputs_embeds) # elif fwd_type == 1: # return self.embed_encode(input_ids) # logits, sequence_mask_output = self.encode(input_ids, attention_mask, mask_pos, inputs_embeds) # loss = None # if labels is not None: # if self.num_labels == 1: # # Regression task # loss_fct = nn.KLDivLoss(log_target=True) # labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) # loss = loss_fct(logits.view(-1, 2), labels) # else: # if labels.shape == logits.shape: # loss = F.kl_div(F.log_softmax(logits, dim=-1, dtype=torch.float32), # labels, reduction="batchmean") # else: # loss_fct = nn.CrossEntropyLoss() # loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) # if self.model_args.hybrid == 1: # cls_loss = loss_fct(cls_logits.view(-1, cls_logits.size(-1)), labels.view(-1)) # loss = loss + cls_loss # output = (logits,) # if self.num_labels == 1: # # Regression output # output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) # if not return_dict: # return ((loss,) + output) if loss is not None else output # return SequenceClassifierOutput( # loss=loss, # logits=logits, # )