Spaces:
Sleeping
Sleeping
import torch | |
# Define a new classification head | |
class NewClassificationHead(torch.nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size) | |
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
self.out_proj = torch.nn.Linear(config.hidden_size, config.num_labels) | |
def forward(self, features, **kwargs): | |
x = features[:, 0, :] # take <s> token (equiv. to [CLS]) | |
x = self.dropout(x) | |
x = self.dense(x) | |
x = torch.nn.functional.relu(x) | |
x = self.dropout(x) | |
x = self.out_proj(x) | |
return x | |