#!/usr/bin/env python3 # -*- coding: utf-8 -*- from typing import List, Dict import json import torch import torch.nn as nn from transformers import BertConfig, BertModel, BertPreTrainedModel FEW_REL_DEFINITIONS = [ {"name": "sport", "description": "sport in which the subject participates or belongs to", "relation_type": "P641"}, {"name": "member of", "description": "organization or club to which the subject belongs. Do not use for membership in ethnic or social groups, nor for holding a position such as a member of parliament (use P39 for that).", "relation_type": "P463"}, {"name": "constellation", "description": "the area of the celestial sphere of which the subject is a part (from a scientific standpoint, not an astrological one)", "relation_type": "P59"}, {"name": "follows", "description": "immediately prior item in a series of which the subject is a part [if the subject has replaced the preceding item, e.g. political offices, use \"replaces\" (P1365)]", "relation_type": "P155"}, {"name": "spouse", "description": "the subject has the object as their spouse (husband, wife, partner, etc.). Use \"partner\" (P451) for non-married companions", "relation_type": "P26"}, {"name": "military rank", "description": "military rank achieved by a person (should usually have a \"start time\" qualifier), or military rank associated with a position", "relation_type": "P410"}, {"name": "crosses", "description": "obstacle (body of water, road, ...) which this bridge crosses over or this tunnel goes under", "relation_type": "P177"}, {"name": "competition class", "description": "official classification by a regulating body under which the subject (events, teams, participants, or equipment) qualifies for inclusion", "relation_type": "P2094"}, {"name": "located in or next to body of water", "description": "sea, lake or river", "relation_type": "P206"}, {"name": "voice type", "description": "person's voice type. expected values: soprano, mezzo-soprano, contralto, countertenor, tenor, baritone, bass (and derivatives)", "relation_type": "P412"}, {"name": "child", "description": "subject has object as biological, foster, and/or adoptive child", "relation_type": "P40"}, {"name": "mother", "description": "female parent of the subject. For stepmother, use \"stepparent\" (P3448)", "relation_type": "P25"}, {"name": "main subject", "description": "primary topic of a work (see also P180: depicts)", "relation_type": "P921"}, {"name": "position played on team / speciality", "description": "position or specialism of a player on a team, e.g. Small Forward", "relation_type": "P413"}, {"name": "part of", "description": "object of which the subject is a part (it's not useful to link objects which are themselves parts of other objects already listed as parts of the subject). Inverse property of \"has part\" (P527, see also \"has parts of the class\" (P2670)).", "relation_type": "P361"}, {"name": "original language of film or TV show", "description": "language in which a film or a performance work was originally created. Deprecated for written works; use P407 (\"language of work or name\") instead.", "relation_type": "P364"} ] class BertForRelationClassification(BertPreTrainedModel): def __init__(self, config: BertConfig, relations_definitions: List[Dict] = FEW_REL_DEFINITIONS, **kwargs): super().__init__(config) if not hasattr(config, "relation_emb_dim"): config.relation_emb_dim = 1024 if not hasattr(config, "num_labels") or config.num_labels == 0: config.num_labels = 2 self.relations_definitions = relations_definitions self.bert = BertModel(config, **kwargs) self.num_labels = config.num_labels self.relation_emb_dim = config.relation_emb_dim self.dropout = nn.Dropout(config.hidden_dropout_prob) self.fclayer = nn.Linear(config.hidden_size * 3, self.relation_emb_dim) self.classifier = nn.Linear(self.relation_emb_dim, config.num_labels) self.init_weights() def _extract_entity(self, sequence_output, e_mask): extended_e_mask = e_mask.unsqueeze(1) extended_e_mask = torch.bmm(extended_e_mask.float(), sequence_output).squeeze(1) return extended_e_mask.float() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, e1_mask=None, e2_mask=None, head_mask=None, inputs_embeds=None, labels=None, ): outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, ) sequence_output = outputs[0] pooled_output = outputs[1] e1_h = self._extract_entity(sequence_output, e1_mask) e2_h = self._extract_entity(sequence_output, e2_mask) context = self.dropout(pooled_output) pooled_output = torch.cat([context, e1_h, e2_h], dim=-1) pooled_output = torch.tanh(pooled_output) pooled_output = self.fclayer(pooled_output) sent_embedding = torch.tanh(pooled_output) sent_embedding = self.dropout(sent_embedding) # [batch_size x hidden_size] logits = self.classifier(sent_embedding).to(self.bert.device) # add hidden states and attention if they are here outputs = (torch.softmax(logits, -1),) + outputs[2:] if labels is not None: ce_loss = nn.CrossEntropyLoss() labels = labels.to(self.bert.device) loss = ce_loss(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs return outputs