import torch import torch.nn as nn class IndoBERTMultitask(nn.Module): def __init__(self, base_model, hidden_size=768, num_event_labels=43): super(IndoBERTMultitask, self).__init__() self.bert = base_model self.dropout = nn.Dropout(0.3) # Task heads self.sentiment_head = nn.Linear(hidden_size, 3) # 3 kelas self.event_head = nn.Linear(hidden_size, num_event_labels) # Event (43 kelas) # IDX sektor (multi-label 11 sektor) self.idx_heads = nn.ModuleList([ nn.Linear(hidden_size, 1) for _ in range(11) ]) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled = outputs.last_hidden_state[:, 0] # Ambil token CLS pooled = self.dropout(pooled) sentiment_logits = self.sentiment_head(pooled) event_logits = self.event_head(pooled) idx_logits = torch.cat([head(pooled) for head in self.idx_heads], dim=1) return sentiment_logits, event_logits, idx_logits