import torch import torch.nn as nn from transformers import BertConfig, BertModel, BertPreTrainedModel class BertForRelationClassification(BertPreTrainedModel): def __init__(self, config: BertConfig, **kwargs): super().__init__(config) if not hasattr(config, "relation_emb_dim"): config.relation_emb_dim = 1024 if not hasattr(config, "num_labels") or config.num_labels == 0: config.num_labels = 2 self.bert = BertModel(config, **kwargs) self.num_labels = config.num_labels self.relation_emb_dim = config.relation_emb_dim self.dropout = nn.Dropout(config.hidden_dropout_prob) self.fclayer = nn.Linear(config.hidden_size * 3, self.relation_emb_dim) self.classifier = nn.Linear(self.relation_emb_dim, config.num_labels) self.init_weights() def _extract_entity(self, sequence_output, e_mask): extended_e_mask = e_mask.unsqueeze(1) extended_e_mask = torch.bmm(extended_e_mask.float(), sequence_output).squeeze(1) return extended_e_mask.float() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, e1_mask=None, e2_mask=None, head_mask=None, inputs_embeds=None, labels=None, ): outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, ) sequence_output = outputs[0] pooled_output = outputs[1] e1_h = self._extract_entity(sequence_output, e1_mask) e2_h = self._extract_entity(sequence_output, e2_mask) context = self.dropout(pooled_output) pooled_output = torch.cat([context, e1_h, e2_h], dim=-1) pooled_output = torch.tanh(pooled_output) pooled_output = self.fclayer(pooled_output) sent_embedding = torch.tanh(pooled_output) sent_embedding = self.dropout(sent_embedding) # [batch_size x hidden_size] logits = self.classifier(sent_embedding).to(self.bert.device) # add hidden states and attention if they are here outputs = (torch.softmax(logits, -1),) + outputs[2:] if labels is not None: ce_loss = nn.CrossEntropyLoss() labels = labels.to(self.bert.device) loss = ce_loss(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs return outputs