Spaces:
Runtime error
Runtime error
from torch import nn | |
import transformers | |
from typing import List | |
def get_class(_model_package, _model_class): | |
mod = __import__(_model_package, fromlist=[_model_class]) | |
return getattr(mod, _model_class) | |
class OwnBertOnlyNSPHead(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.seq_relationship = self._build_layer(config.hidden_size, layer_dimensions=[256, 64]) | |
def forward(self, pooled_output): | |
seq_relationship_score = self.seq_relationship(pooled_output) | |
return seq_relationship_score | |
def _build_layer(self, init_size, layer_dimensions: List, activation=nn.ReLU()): | |
module_list = [] | |
_init_size = init_size | |
for layer_dimension in layer_dimensions: | |
module_list.append(nn.Linear(_init_size, layer_dimension)) | |
module_list.append(activation) | |
_init_size = layer_dimension | |
module_list.append(nn.Linear(_init_size, 2)) | |
return nn.Sequential(*module_list) | |
class OwnBertForNextSentencePrediction(transformers.BertForNextSentencePrediction): | |
def __init__(self, config): | |
super().__init__(config) | |
# reinit cls layer to be more powerful | |
self.cls = OwnBertOnlyNSPHead(config) | |
# Initialize weights and apply final processing | |
self.post_init() | |