Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
class IndoBERTMultitask(nn.Module): | |
def __init__(self, base_model, hidden_size=768, num_event_labels=43): | |
super(IndoBERTMultitask, self).__init__() | |
self.bert = base_model | |
self.dropout = nn.Dropout(0.3) | |
# Task heads | |
self.sentiment_head = nn.Linear(hidden_size, 3) # 3 kelas | |
self.event_head = nn.Linear(hidden_size, num_event_labels) # Event (43 kelas) | |
# IDX sektor (multi-label 11 sektor) | |
self.idx_heads = nn.ModuleList([ | |
nn.Linear(hidden_size, 1) for _ in range(11) | |
]) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
pooled = outputs.last_hidden_state[:, 0] # Ambil token CLS | |
pooled = self.dropout(pooled) | |
sentiment_logits = self.sentiment_head(pooled) | |
event_logits = self.event_head(pooled) | |
idx_logits = torch.cat([head(pooled) for head in self.idx_heads], dim=1) | |
return sentiment_logits, event_logits, idx_logits | |