from torch import nn import transformers from typing import List def get_class(_model_package, _model_class): mod = __import__(_model_package, fromlist=[_model_class]) return getattr(mod, _model_class) class OwnBertOnlyNSPHead(nn.Module): def __init__(self, config): super().__init__() self.seq_relationship = self._build_layer(config.hidden_size, layer_dimensions=[256, 64]) def forward(self, pooled_output): seq_relationship_score = self.seq_relationship(pooled_output) return seq_relationship_score def _build_layer(self, init_size, layer_dimensions: List, activation=nn.ReLU()): module_list = [] _init_size = init_size for layer_dimension in layer_dimensions: module_list.append(nn.Linear(_init_size, layer_dimension)) module_list.append(activation) _init_size = layer_dimension module_list.append(nn.Linear(_init_size, 2)) return nn.Sequential(*module_list) class OwnBertForNextSentencePrediction(transformers.BertForNextSentencePrediction): def __init__(self, config): super().__init__(config) # reinit cls layer to be more powerful self.cls = OwnBertOnlyNSPHead(config) # Initialize weights and apply final processing self.post_init()