Spaces:
Runtime error
Runtime error
File size: 1,325 Bytes
822e1b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from torch import nn
import transformers
from typing import List
def get_class(_model_package, _model_class):
mod = __import__(_model_package, fromlist=[_model_class])
return getattr(mod, _model_class)
class OwnBertOnlyNSPHead(nn.Module):
def __init__(self, config):
super().__init__()
self.seq_relationship = self._build_layer(config.hidden_size, layer_dimensions=[256, 64])
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
def _build_layer(self, init_size, layer_dimensions: List, activation=nn.ReLU()):
module_list = []
_init_size = init_size
for layer_dimension in layer_dimensions:
module_list.append(nn.Linear(_init_size, layer_dimension))
module_list.append(activation)
_init_size = layer_dimension
module_list.append(nn.Linear(_init_size, 2))
return nn.Sequential(*module_list)
class OwnBertForNextSentencePrediction(transformers.BertForNextSentencePrediction):
def __init__(self, config):
super().__init__(config)
# reinit cls layer to be more powerful
self.cls = OwnBertOnlyNSPHead(config)
# Initialize weights and apply final processing
self.post_init()
|