bstds commited on
Commit
599c238
·
1 Parent(s): 8fc43d5

Update modeling_re_bert.py

Browse files
Files changed (1) hide show
  1. modeling_re_bert.py +4 -3
modeling_re_bert.py CHANGED
@@ -1,5 +1,7 @@
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
 
 
3
  import json
4
  import torch
5
  import torch.nn as nn
@@ -27,14 +29,13 @@ FEW_REL_DEFINITIONS = [
27
 
28
  class BertForRelationClassification(BertPreTrainedModel):
29
 
30
- def __init__(self, config: BertConfig, **kwargs):
31
  super().__init__(config)
32
  if not hasattr(config, "relation_emb_dim"):
33
  config.relation_emb_dim = 1024
34
  if not hasattr(config, "num_labels") or config.num_labels == 0:
35
  config.num_labels = 2
36
- if not hasattr(config, "relations_definitions"):
37
- self.relations_definitions = FEW_REL_DEFINITIONS
38
  self.bert = BertModel(config, **kwargs)
39
  self.num_labels = config.num_labels
40
  self.relation_emb_dim = config.relation_emb_dim
 
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
+ from typing import List, Dict
4
+
5
  import json
6
  import torch
7
  import torch.nn as nn
 
29
 
30
  class BertForRelationClassification(BertPreTrainedModel):
31
 
32
+ def __init__(self, config: BertConfig, relations_definitions: List[Dict] = FEW_REL_DEFINITIONS, **kwargs):
33
  super().__init__(config)
34
  if not hasattr(config, "relation_emb_dim"):
35
  config.relation_emb_dim = 1024
36
  if not hasattr(config, "num_labels") or config.num_labels == 0:
37
  config.num_labels = 2
38
+ self.relations_definitions = relations_definitions
 
39
  self.bert = BertModel(config, **kwargs)
40
  self.num_labels = config.num_labels
41
  self.relation_emb_dim = config.relation_emb_dim