fewshot_re_bert / modeling_re_bert.py
bstds's picture
Update modeling_re_bert.py
4878584
raw
history blame
2.97 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import torch
import torch.nn as nn
from transformers import BertConfig, BertModel, BertPreTrainedModel
class BertForRelationClassification(BertPreTrainedModel):
def __init__(self, config: BertConfig, **kwargs):
super().__init__(config)
if not hasattr(config, "relation_emb_dim"):
config.relation_emb_dim = 1024
if not hasattr(config, "num_labels") or config.num_labels == 0:
config.num_labels = 2
self.relations_definitions = []
if hasattr(config, "relations_definitions"):
with open(config.relations_definitions, "r") as f:
for line in f:
self.relations_definitions.append(json.loads(line))
self.bert = BertModel(config, **kwargs)
self.num_labels = config.num_labels
self.relation_emb_dim = config.relation_emb_dim
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.fclayer = nn.Linear(config.hidden_size * 3, self.relation_emb_dim)
self.classifier = nn.Linear(self.relation_emb_dim, config.num_labels)
self.init_weights()
def _extract_entity(self, sequence_output, e_mask):
extended_e_mask = e_mask.unsqueeze(1)
extended_e_mask = torch.bmm(extended_e_mask.float(), sequence_output).squeeze(1)
return extended_e_mask.float()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
e1_mask=None,
e2_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0]
pooled_output = outputs[1]
e1_h = self._extract_entity(sequence_output, e1_mask)
e2_h = self._extract_entity(sequence_output, e2_mask)
context = self.dropout(pooled_output)
pooled_output = torch.cat([context, e1_h, e2_h], dim=-1)
pooled_output = torch.tanh(pooled_output)
pooled_output = self.fclayer(pooled_output)
sent_embedding = torch.tanh(pooled_output)
sent_embedding = self.dropout(sent_embedding)
# [batch_size x hidden_size]
logits = self.classifier(sent_embedding).to(self.bert.device)
# add hidden states and attention if they are here
outputs = (torch.softmax(logits, -1),) + outputs[2:]
if labels is not None:
ce_loss = nn.CrossEntropyLoss()
labels = labels.to(self.bert.device)
loss = ce_loss(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs