Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
class ClassificationHead(torch.nn.Module): | |
"""Classification Head for transformer encoders""" | |
def __init__(self, class_size, embed_size): | |
super().__init__() | |
self.class_size = class_size | |
self.embed_size = embed_size | |
# self.mlp1 = torch.nn.Linear(embed_size, embed_size) | |
# self.mlp2 = (torch.nn.Linear(embed_size, class_size)) | |
self.mlp = torch.nn.Linear(embed_size, class_size) | |
def forward(self, hidden_state): | |
# hidden_state = F.relu(self.mlp1(hidden_state)) | |
# hidden_state = self.mlp2(hidden_state) | |
logits = self.mlp(hidden_state) | |
return logits | |