|
|
|
|
|
import json |
|
import torch |
|
import torch.nn as nn |
|
from transformers import BertConfig, BertModel, BertPreTrainedModel |
|
|
|
|
|
class BertForRelationClassification(BertPreTrainedModel): |
|
|
|
def __init__(self, config: BertConfig, **kwargs): |
|
super().__init__(config) |
|
if not hasattr(config, "relation_emb_dim"): |
|
config.relation_emb_dim = 1024 |
|
if not hasattr(config, "num_labels") or config.num_labels == 0: |
|
config.num_labels = 2 |
|
self.relations_definitions = [] |
|
if hasattr(config, "relations_definitions"): |
|
with open(config.relations_definitions, "r") as f: |
|
for line in f: |
|
self.relations_definitions.append(json.loads(line)) |
|
self.bert = BertModel(config, **kwargs) |
|
self.num_labels = config.num_labels |
|
self.relation_emb_dim = config.relation_emb_dim |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.fclayer = nn.Linear(config.hidden_size * 3, self.relation_emb_dim) |
|
self.classifier = nn.Linear(self.relation_emb_dim, config.num_labels) |
|
self.init_weights() |
|
|
|
def _extract_entity(self, sequence_output, e_mask): |
|
extended_e_mask = e_mask.unsqueeze(1) |
|
extended_e_mask = torch.bmm(extended_e_mask.float(), sequence_output).squeeze(1) |
|
return extended_e_mask.float() |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
e1_mask=None, |
|
e2_mask=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
): |
|
outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
) |
|
sequence_output = outputs[0] |
|
pooled_output = outputs[1] |
|
|
|
e1_h = self._extract_entity(sequence_output, e1_mask) |
|
e2_h = self._extract_entity(sequence_output, e2_mask) |
|
context = self.dropout(pooled_output) |
|
pooled_output = torch.cat([context, e1_h, e2_h], dim=-1) |
|
pooled_output = torch.tanh(pooled_output) |
|
pooled_output = self.fclayer(pooled_output) |
|
sent_embedding = torch.tanh(pooled_output) |
|
sent_embedding = self.dropout(sent_embedding) |
|
|
|
|
|
logits = self.classifier(sent_embedding).to(self.bert.device) |
|
|
|
|
|
outputs = (torch.softmax(logits, -1),) + outputs[2:] |
|
if labels is not None: |
|
ce_loss = nn.CrossEntropyLoss() |
|
labels = labels.to(self.bert.device) |
|
loss = ce_loss(logits.view(-1, self.num_labels), labels.view(-1)) |
|
outputs = (loss,) + outputs |
|
|
|
return outputs |