File size: 1,325 Bytes
822e1b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from torch import nn
import transformers
from typing import List


def get_class(_model_package, _model_class):
    mod = __import__(_model_package, fromlist=[_model_class])
    return getattr(mod, _model_class)


class OwnBertOnlyNSPHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.seq_relationship = self._build_layer(config.hidden_size, layer_dimensions=[256, 64])

    def forward(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score

    def _build_layer(self, init_size, layer_dimensions: List, activation=nn.ReLU()):
        module_list = []
        _init_size = init_size
        for layer_dimension in layer_dimensions:
            module_list.append(nn.Linear(_init_size, layer_dimension))
            module_list.append(activation)
            _init_size = layer_dimension

        module_list.append(nn.Linear(_init_size, 2))
        return nn.Sequential(*module_list)


class OwnBertForNextSentencePrediction(transformers.BertForNextSentencePrediction):
    def __init__(self, config):
        super().__init__(config)

        # reinit cls layer to be more powerful
        self.cls = OwnBertOnlyNSPHead(config)

        # Initialize weights and apply final processing
        self.post_init()