bstds commited on
Commit
4878584
·
1 Parent(s): 247a8c4

Update modeling_re_bert.py

Browse files
Files changed (1) hide show
  1. modeling_re_bert.py +18 -10
modeling_re_bert.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import BertConfig, BertModel, BertPreTrainedModel
@@ -11,6 +14,11 @@ class BertForRelationClassification(BertPreTrainedModel):
11
  config.relation_emb_dim = 1024
12
  if not hasattr(config, "num_labels") or config.num_labels == 0:
13
  config.num_labels = 2
 
 
 
 
 
14
  self.bert = BertModel(config, **kwargs)
15
  self.num_labels = config.num_labels
16
  self.relation_emb_dim = config.relation_emb_dim
@@ -25,16 +33,16 @@ class BertForRelationClassification(BertPreTrainedModel):
25
  return extended_e_mask.float()
26
 
27
  def forward(
28
- self,
29
- input_ids=None,
30
- attention_mask=None,
31
- token_type_ids=None,
32
- position_ids=None,
33
- e1_mask=None,
34
- e2_mask=None,
35
- head_mask=None,
36
- inputs_embeds=None,
37
- labels=None,
38
  ):
39
  outputs = self.bert(
40
  input_ids,
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ import json
4
  import torch
5
  import torch.nn as nn
6
  from transformers import BertConfig, BertModel, BertPreTrainedModel
 
14
  config.relation_emb_dim = 1024
15
  if not hasattr(config, "num_labels") or config.num_labels == 0:
16
  config.num_labels = 2
17
+ self.relations_definitions = []
18
+ if hasattr(config, "relations_definitions"):
19
+ with open(config.relations_definitions, "r") as f:
20
+ for line in f:
21
+ self.relations_definitions.append(json.loads(line))
22
  self.bert = BertModel(config, **kwargs)
23
  self.num_labels = config.num_labels
24
  self.relation_emb_dim = config.relation_emb_dim
 
33
  return extended_e_mask.float()
34
 
35
  def forward(
36
+ self,
37
+ input_ids=None,
38
+ attention_mask=None,
39
+ token_type_ids=None,
40
+ position_ids=None,
41
+ e1_mask=None,
42
+ e2_mask=None,
43
+ head_mask=None,
44
+ inputs_embeds=None,
45
+ labels=None,
46
  ):
47
  outputs = self.bert(
48
  input_ids,