lorenpe2's picture
FEAT: new models, reload model each time when something change (not ideal but it is better than st.cache_resource)
822e1b3
from torch import nn
import transformers
from typing import List
def get_class(_model_package, _model_class):
mod = __import__(_model_package, fromlist=[_model_class])
return getattr(mod, _model_class)
class OwnBertOnlyNSPHead(nn.Module):
def __init__(self, config):
super().__init__()
self.seq_relationship = self._build_layer(config.hidden_size, layer_dimensions=[256, 64])
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
def _build_layer(self, init_size, layer_dimensions: List, activation=nn.ReLU()):
module_list = []
_init_size = init_size
for layer_dimension in layer_dimensions:
module_list.append(nn.Linear(_init_size, layer_dimension))
module_list.append(activation)
_init_size = layer_dimension
module_list.append(nn.Linear(_init_size, 2))
return nn.Sequential(*module_list)
class OwnBertForNextSentencePrediction(transformers.BertForNextSentencePrediction):
def __init__(self, config):
super().__init__(config)
# reinit cls layer to be more powerful
self.cls = OwnBertOnlyNSPHead(config)
# Initialize weights and apply final processing
self.post_init()