Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from transformers import PreTrainedModel | |
from transformers import PretrainedConfig | |
class CustomClassificationConfig(PretrainedConfig): | |
model_type = "custom_classifier" | |
def __init__(self, input_dim=32, hidden_dim=64, num_classes=2, **kwargs): | |
super().__init__(**kwargs) | |
self.input_dim = input_dim | |
self.hidden_dim = hidden_dim | |
self.num_classes = num_classes | |
class CustomClassifier(PreTrainedModel): | |
config_class = CustomClassificationConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.encoder = nn.Sequential( | |
nn.Linear(config.input_dim, config.hidden_dim), | |
nn.ReLU(), | |
nn.Linear(config.hidden_dim, config.hidden_dim), | |
nn.ReLU(), | |
) | |
self.classifier = nn.Linear(config.hidden_dim, config.num_classes) | |
def forward(self, input_ids=None, labels=None, **kwargs): | |
# input_ids: shape (batch_size, input_dim) | |
hidden = self.encoder(input_ids) | |
logits = self.classifier(hidden) | |
loss = None | |
if labels is not None: | |
loss_fn = nn.CrossEntropyLoss() | |
loss = loss_fn(logits, labels) | |
return {"loss": loss, "logits": logits} | |