File size: 1,269 Bytes
3505b99
 
 
 
4365a31
 
 
 
 
 
 
 
 
 
 
3505b99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from torch import nn
from transformers import PreTrainedModel

from transformers import PretrainedConfig

class CustomClassificationConfig(PretrainedConfig):
    model_type = "custom_classifier"

    def __init__(self, input_dim=32, hidden_dim=64, num_classes=2, **kwargs):
        super().__init__(**kwargs)
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        
class CustomClassifier(PreTrainedModel):
    config_class = CustomClassificationConfig

    def __init__(self, config):
        super().__init__(config)
        self.encoder = nn.Sequential(
            nn.Linear(config.input_dim, config.hidden_dim),
            nn.ReLU(),
            nn.Linear(config.hidden_dim, config.hidden_dim),
            nn.ReLU(),
        )
        self.classifier = nn.Linear(config.hidden_dim, config.num_classes)

    def forward(self, input_ids=None, labels=None, **kwargs):
        # input_ids: shape (batch_size, input_dim)
        hidden = self.encoder(input_ids)
        logits = self.classifier(hidden)

        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)

        return {"loss": loss, "logits": logits}