File size: 1,195 Bytes
bcd0a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from torch import nn
import torch
from transformers import BertModel, PreTrainedModel

from typing import Tuple


class HierarchicalBertModel(PreTrainedModel):

    def __init__(self, config, num_main_segment=None, num_sub_segment=None):
        super(HierarchicalBertModel, self).__init__(config=config)
        self.num_main_segment = num_main_segment if num_main_segment else config.num_main_segment
        self.num_sub_segment = num_sub_segment if num_sub_segment else config.num_sub_segment
        self.bert = BertModel.from_pretrained("bert-base-multilingual-uncased")
        self.dropout = nn.Dropout(0.1)
        self.hidden_2 = nn.Linear(768, 768)
        self.fc_main = nn.Linear(768, self.num_main_segment)
        self.fc_sub = nn.Linear(768, self.num_sub_segment)

    def forward(self, input_ids: torch.tensor, attention_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_masks)
        last_hidden_state_cls = outputs[0][:, 0, :]
        out = self.hidden_2(last_hidden_state_cls)
        return self.fc_main(last_hidden_state_cls), self.fc_sub(last_hidden_state_cls)