from transformers import BertForSequenceClassification, AutoConfig, AutoTokenizer import torch.nn as nn class CustomBertForSequenceClassification(BertForSequenceClassification): def __init__(self, config): super().__init__(config) # Replace the default classifier (single linear layer) with a Sequential head self.classifier = nn.Sequential( nn.Linear(config.hidden_size, config.custom_head_hidden_size), nn.ReLU(), nn.Linear(config.custom_head_hidden_size, config.num_labels) )